diff --git a/sharding.go b/sharding.go index dd173be..ba3cec8 100644 --- a/sharding.go +++ b/sharding.go @@ -322,13 +322,14 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, return } + var value interface{} + var id int64 + var keyFind bool + var aliasTable string var suffix string if isInsert { var newTable *sqlparser.TableName for _, inserExpression := range inserExpressions { - var value interface{} - var id int64 - var keyFind bool columnNames := insertNames insertValues := inserExpression.Exprs value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...) @@ -352,25 +353,21 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}} fillID := true - if isInsert { - for _, name := range insertNames { - if name.Name == "id" { - fillID = false - break - } - } - if fillID { - tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1)) - if err != nil { - return ftQuery, stQuery, tableName, err - } - id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) - columnNames = append(insertNames, &sqlparser.Ident{Name: "id"}) - insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) + for _, name := range insertNames { + if name.Name == "id" { + fillID = false + break } } - if fillID { + tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1)) + if err != nil { + return ftQuery, stQuery, tableName, err + } + id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) + columnNames = append(insertNames, &sqlparser.Ident{Name: "id"}) + insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) + insertStmt.ColumnNames = columnNames inserExpression.Exprs = insertValues } @@ -381,10 +378,8 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, stQuery = insertStmt.String() } else { - var value interface{} - var id int64 - var keyFind bool - value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...) + value, id, keyFind, aliasTable, err = s.nonInsertValue(tableName, r.ShardingKey, condition, args...) + if err != nil { return } @@ -398,6 +393,10 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, switch stmt := expr.(type) { case *sqlparser.SelectStatement: + if aliasTable != "" { + newTable.Alias = &sqlparser.Ident{Name: aliasTable} + } + ftQuery = stmt.String() stmt.FromItems = newTable stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) @@ -460,43 +459,52 @@ func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sql return } -func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) { +func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, alias string, err error) { + var handleExprIdent = func(n *sqlparser.BinaryExpr, x *sqlparser.Ident) (err error) { + if x.Name == key && n.Op == sqlparser.EQ { + keyFind = true + switch expr := n.Y.(type) { + case *sqlparser.BindExpr: + value = args[expr.Pos] + case *sqlparser.StringLit: + value = expr.Value + case *sqlparser.NumberLit: + value = expr.Value + default: + return sqlparser.ErrNotImplemented + } + return nil + } else if x.Name == "id" && n.Op == sqlparser.EQ { + switch expr := n.Y.(type) { + case *sqlparser.BindExpr: + v := args[expr.Pos] + var ok bool + if id, ok = v.(int64); !ok { + return fmt.Errorf("ID should be int64 type") + } + case *sqlparser.NumberLit: + id, err = strconv.ParseInt(expr.Value, 10, 64) + if err != nil { + return err + } + default: + return ErrInvalidID + } + return nil + } + return nil + } + err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { if n, ok := node.(*sqlparser.BinaryExpr); ok { if x, ok := n.X.(*sqlparser.Ident); ok { - if x.Name == key && n.Op == sqlparser.EQ { - keyFind = true - switch expr := n.Y.(type) { - case *sqlparser.BindExpr: - value = args[expr.Pos] - case *sqlparser.StringLit: - value = expr.Value - case *sqlparser.NumberLit: - value = expr.Value - default: - return sqlparser.ErrNotImplemented - } - return nil - } else if x.Name == "id" && n.Op == sqlparser.EQ { - switch expr := n.Y.(type) { - case *sqlparser.BindExpr: - v := args[expr.Pos] - var ok bool - if id, ok = v.(int64); !ok { - return fmt.Errorf("ID should be int64 type") - } - case *sqlparser.NumberLit: - id, err = strconv.ParseInt(expr.Value, 10, 64) - if err != nil { - return err - } - default: - return ErrInvalidID - } - return nil - } + return handleExprIdent(n, x) + } else if ref, ok := n.X.(*sqlparser.QualifiedRef); ok { + alias = ref.Table.Name + return handleExprIdent(n, ref.Column) } } + return nil }), condition) if err != nil { @@ -504,7 +512,7 @@ func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ... } if !keyFind && id == 0 { - return nil, 0, keyFind, ErrMissingShardingKey + return nil, 0, keyFind, "", ErrMissingShardingKey } return diff --git a/sharding_test.go b/sharding_test.go index 5821533..4e25923 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -18,6 +18,12 @@ import ( "gorm.io/plugin/dbresolver" ) +type User struct { + ID int64 + Name string + Orders []Order +} + type Order struct { ID int64 `gorm:"primarykey"` UserID int64 @@ -117,7 +123,7 @@ func init() { fmt.Println("Clean only tables ...") dropTables() fmt.Println("AutoMigrate tables ...") - err := db.AutoMigrate(&Order{}, &Category{}) + err := db.AutoMigrate(&Order{}, &Category{}, &User{}) if err != nil { panic(err) } @@ -144,7 +150,7 @@ func init() { } func dropTables() { - tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} + tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories", "users"} for _, table := range tables { db.Exec("DROP TABLE IF EXISTS " + table) dbRead.Exec("DROP TABLE IF EXISTS " + table) @@ -154,7 +160,7 @@ func dropTables() { } func TestMigrate(t *testing.T) { - targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} + targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories", "users"} sort.Strings(targetTables) // origin tables @@ -163,18 +169,18 @@ func TestMigrate(t *testing.T) { assert.Equal(t, tables, targetTables) // drop table - db.Migrator().DropTable(Order{}, &Category{}) + db.Migrator().DropTable(Order{}, &Category{}, &User{}) tables, _ = db.Migrator().GetTables() assert.Equal(t, len(tables), 0) // auto migrate - db.AutoMigrate(&Order{}, &Category{}) + db.AutoMigrate(&Order{}, &Category{}, &User{}) tables, _ = db.Migrator().GetTables() sort.Strings(tables) assert.Equal(t, tables, targetTables) // auto migrate again - err := db.AutoMigrate(&Order{}, &Category{}) + err := db.AutoMigrate(&Order{}, &Category{}, &User{}) assert.Equal(t, err, nil) } @@ -397,6 +403,36 @@ func TestReadWriteSplitting(t *testing.T) { assert.Equal(t, "iPhone", order.Product) } +func TestSelectAlias(t *testing.T) { + sql := toDialect(`SELECT * FROM "orders" WHERE orders.user_id = 101`) + tx := db.Raw(sql).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM orders_1 AS orders WHERE orders.user_id = 101`, tx) + + sql1 := toDialect(`SELECT * FROM "orders" AS "o" WHERE o.user_id = 101`) + tx1 := db.Raw(sql1).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM orders_1 AS o WHERE o.user_id = 101`, tx1) +} + +func TestAssociation(t *testing.T) { + user := User{ + Name: "association_user", + Orders: []Order{ + {Product: "association_product_1"}, + }, + } + + var err error + err = db.Create(&user).Error + assert.Equal(t, err, nil) + + var user1 User + err = db.Preload("Orders").Find(&user1).Error + assert.Equal(t, err, nil) + assert.Equal(t, user1.Name, user.Name) + assert.Equal(t, len(user1.Orders), len(user.Orders)) + assert.Equal(t, user1.Orders[0].Product, user.Orders[0].Product) +} + func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) { t.Helper() assert.Equal(t, toDialect(expected), middleware.LastQuery())