From c9077e55730a100c808dc8158d1ca0ca772ad12f Mon Sep 17 00:00:00 2001 From: Cory Redmond Date: Tue, 23 Mar 2021 23:22:22 +0000 Subject: [PATCH] Allow recursion using `,recurse` --- sqlstruct.go | 129 +++++++++++++++++++++++++--------------------- sqlstruct_test.go | 15 ++++++ 2 files changed, 86 insertions(+), 58 deletions(-) diff --git a/sqlstruct.go b/sqlstruct.go index bd9756b..a5a1f71 100644 --- a/sqlstruct.go +++ b/sqlstruct.go @@ -9,75 +9,87 @@ the Go standard library's database/sql package. The package matches struct field names to SQL query column names. A field can also specify a matching column with "sql" tag, if it's different from field name. Unexported fields or fields marked with `sql:"-"` are ignored, just like -with "encoding/json" package. +with "encoding/json" package. Fields marked with `sql:",recurse"` are treated as +embedded structs and are recursively scanned. For example: - type T struct { - F1 string - F2 string `sql:"field2"` - F3 string `sql:"-"` - } + type T1 struct { + F4 string `sql:"field4"` + } - rows, err := db.Query(fmt.Sprintf("SELECT %s FROM tablename", sqlstruct.Columns(T{}))) - ... + type T2 struct { + F5 string `sql:"field5"` + } + + type T struct { + F1 string + F2 string `sql:"field2"` + F3 string `sql:"-"` + fieldT1 T1 `sql:",recurse"` + T2 + } - for rows.Next() { - var t T - err = sqlstruct.Scan(&t, rows) - ... - } + rows, err := db.Query(fmt.Sprintf("SELECT %s FROM tablename", sqlstruct.Columns(T{}))) + ... - err = rows.Err() // get any errors encountered during iteration + for rows.Next() { + var t T + err = sqlstruct.Scan(&t, rows) + ... + } + + err = rows.Err() // get any errors encountered during iteration Aliased tables in a SQL statement may be scanned into a specific structure identified by the same alias, using the ColumnsAliased and ScanAliased functions: - type User struct { - Id int `sql:"id"` - Username string `sql:"username"` - Email string `sql:"address"` - Name string `sql:"name"` - HomeAddress *Address `sql:"-"` - } - - type Address struct { - Id int `sql:"id"` - City string `sql:"city"` - Street string `sql:"address"` - } - - ... - - var user User - var address Address - sql := ` + type User struct { + Id int `sql:"id"` + Username string `sql:"username"` + Email string `sql:"address"` + Name string `sql:"name"` + HomeAddress *Address `sql:"-"` + } + + type Address struct { + Id int `sql:"id"` + City string `sql:"city"` + Street string `sql:"address"` + } + + ... + + var user User + var address Address + sql := ` + SELECT %s, %s FROM users AS u INNER JOIN address AS a ON a.id = u.address_id WHERE u.username = ? ` - sql = fmt.Sprintf(sql, sqlstruct.ColumnsAliased(*user, "u"), sqlstruct.ColumnsAliased(*address, "a")) - rows, err := db.Query(sql, "gedi") - if err != nil { - log.Fatal(err) - } - defer rows.Close() - if rows.Next() { - err = sqlstruct.ScanAliased(&user, rows, "u") - if err != nil { - log.Fatal(err) - } - err = sqlstruct.ScanAliased(&address, rows, "a") - if err != nil { - log.Fatal(err) - } - user.HomeAddress = address - } - fmt.Printf("%+v", *user) - // output: "{Id:1 Username:gedi Email:gediminas.morkevicius@gmail.com Name:Gedas HomeAddress:0xc21001f570}" - fmt.Printf("%+v", *user.HomeAddress) - // output: "{Id:2 City:Vilnius Street:Plento 34}" + sql = fmt.Sprintf(sql, sqlstruct.ColumnsAliased(*user, "u"), sqlstruct.ColumnsAliased(*address, "a")) + rows, err := db.Query(sql, "gedi") + if err != nil { + log.Fatal(err) + } + defer rows.Close() + if rows.Next() { + err = sqlstruct.ScanAliased(&user, rows, "u") + if err != nil { + log.Fatal(err) + } + err = sqlstruct.ScanAliased(&address, rows, "a") + if err != nil { + log.Fatal(err) + } + user.HomeAddress = address + } + fmt.Printf("%+v", *user) + // output: "{Id:1 Username:gedi Email:gediminas.morkevicius@gmail.com Name:Gedas HomeAddress:0xc21001f570}" + fmt.Printf("%+v", *user.HomeAddress) + // output: "{Id:2 City:Vilnius Street:Plento 34}" */ package sqlstruct @@ -97,7 +109,7 @@ import ( // The default mapper converts field names to lower case. If instead you would prefer // field names converted to snake case, simply assign sqlstruct.ToSnakeCase to the variable: // -// sqlstruct.NameMapper = sqlstruct.ToSnakeCase +// sqlstruct.NameMapper = sqlstruct.ToSnakeCase // // Alternatively for a custom mapping, any func(string) string can be used instead. var NameMapper func(string) string = strings.ToLower @@ -145,8 +157,8 @@ func getFieldInfo(typ reflect.Type) fieldInfo { continue } - // Handle embedded structs - if f.Anonymous && f.Type.Kind() == reflect.Struct { + // Handle embedded and recurse tagged structs + if (f.Anonymous || strings.EqualFold(tag, ",recurse")) && f.Type.Kind() == reflect.Struct { for k, v := range getFieldInfo(f.Type) { finfo[k] = append([]int{i}, v...) } @@ -198,7 +210,8 @@ func Columns(s interface{}) string { // given alias. // // For each field in the given struct it will generate a statement like: -// alias.field AS alias_field +// +// alias.field AS alias_field // // It is intended to be used in conjunction with the ScanAliased function. func ColumnsAliased(s interface{}, alias string) string { diff --git a/sqlstruct_test.go b/sqlstruct_test.go index be6e940..4847fff 100644 --- a/sqlstruct_test.go +++ b/sqlstruct_test.go @@ -25,6 +25,11 @@ type testType2 struct { FieldSec string `sql:"field_sec"` } +type testType3 struct { + FieldA string `sql:"field_a"` + EmbeddedType EmbeddedType `sql:",recurse"` +} + // testRows is a mock version of sql.Rows which can only scan strings type testRows struct { columns []string @@ -67,6 +72,16 @@ func TestColumns(t *testing.T) { } } +func TestColumnDeep(t *testing.T) { + var v testType3 + e := "field_a, field_e" + c := Columns(v) + + if c != e { + t.Errorf("expected %q got %q", e, c) + } +} + func TestColumnsAliased(t *testing.T) { var t1 testType var t2 testType2