From 91152a93afcfa0e55d851c96a89d9c961508d523 Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Thu, 5 May 2022 15:53:04 +0800 Subject: [PATCH 1/3] fix: sharding alias where clause --- sharding.go | 81 ++++++++++++++++++++++++++++-------------------- sharding_test.go | 38 +++++++++++++++++++---- 2 files changed, 80 insertions(+), 39 deletions(-) diff --git a/sharding.go b/sharding.go index 6d2aa5d..9bce440 100644 --- a/sharding.go +++ b/sharding.go @@ -322,13 +322,14 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, var value interface{} var id int64 var keyFind bool + var isQualifiedRefExpr bool if isInsert { value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...) if err != nil { return } } else { - value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...) + value, id, keyFind, isQualifiedRefExpr, err = s.nonInsertValue(tableName, r.ShardingKey, condition, args...) if err != nil { return } @@ -381,6 +382,9 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, stQuery = stmt.String() case *sqlparser.SelectStatement: ftQuery = stmt.String() + if isQualifiedRefExpr { + newTable.Alias = &sqlparser.Ident{Name: tableName} + } stmt.FromItems = newTable stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) stQuery = stmt.String() @@ -425,43 +429,54 @@ func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sql return } -func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) { +func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, isQualifiedRefExpr bool, err error) { + var handleExprIdent = func(n *sqlparser.BinaryExpr, x *sqlparser.Ident) (err error) { + if x.Name == key && n.Op == sqlparser.EQ { + keyFind = true + switch expr := n.Y.(type) { + case *sqlparser.BindExpr: + value = args[expr.Pos] + case *sqlparser.StringLit: + value = expr.Value + case *sqlparser.NumberLit: + value = expr.Value + default: + return sqlparser.ErrNotImplemented + } + return nil + } else if x.Name == "id" && n.Op == sqlparser.EQ { + switch expr := n.Y.(type) { + case *sqlparser.BindExpr: + v := args[expr.Pos] + var ok bool + if id, ok = v.(int64); !ok { + return fmt.Errorf("ID should be int64 type") + } + case *sqlparser.NumberLit: + id, err = strconv.ParseInt(expr.Value, 10, 64) + if err != nil { + return err + } + default: + return ErrInvalidID + } + return nil + } + return nil + } + err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { if n, ok := node.(*sqlparser.BinaryExpr); ok { if x, ok := n.X.(*sqlparser.Ident); ok { - if x.Name == key && n.Op == sqlparser.EQ { - keyFind = true - switch expr := n.Y.(type) { - case *sqlparser.BindExpr: - value = args[expr.Pos] - case *sqlparser.StringLit: - value = expr.Value - case *sqlparser.NumberLit: - value = expr.Value - default: - return sqlparser.ErrNotImplemented - } - return nil - } else if x.Name == "id" && n.Op == sqlparser.EQ { - switch expr := n.Y.(type) { - case *sqlparser.BindExpr: - v := args[expr.Pos] - var ok bool - if id, ok = v.(int64); !ok { - return fmt.Errorf("ID should be int64 type") - } - case *sqlparser.NumberLit: - id, err = strconv.ParseInt(expr.Value, 10, 64) - if err != nil { - return err - } - default: - return ErrInvalidID - } - return nil + return handleExprIdent(n, x) + } else if ref, ok := n.X.(*sqlparser.QualifiedRef); ok { + if ref.Table.Name == tableName { + isQualifiedRefExpr = true + return handleExprIdent(n, ref.Column) } } } + return nil }), condition) if err != nil { @@ -469,7 +484,7 @@ func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ... } if !keyFind && id == 0 { - return nil, 0, keyFind, ErrMissingShardingKey + return nil, 0, keyFind, false, ErrMissingShardingKey } return diff --git a/sharding_test.go b/sharding_test.go index 9eb7465..6c6dd41 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -17,6 +17,12 @@ import ( "gorm.io/plugin/dbresolver" ) +type User struct { + ID int64 + Name string + Orders []Order +} + type Order struct { ID int64 `gorm:"primarykey"` UserID int64 @@ -116,7 +122,7 @@ func init() { fmt.Println("Clean only tables ...") dropTables() fmt.Println("AutoMigrate tables ...") - err := db.AutoMigrate(&Order{}, &Category{}) + err := db.AutoMigrate(&Order{}, &Category{}, &User{}) if err != nil { panic(err) } @@ -143,7 +149,7 @@ func init() { } func dropTables() { - tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} + tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories", "users"} for _, table := range tables { db.Exec("DROP TABLE IF EXISTS " + table) dbRead.Exec("DROP TABLE IF EXISTS " + table) @@ -153,7 +159,7 @@ func dropTables() { } func TestMigrate(t *testing.T) { - targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} + targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories", "users"} sort.Strings(targetTables) // origin tables @@ -162,18 +168,18 @@ func TestMigrate(t *testing.T) { assert.Equal(t, tables, targetTables) // drop table - db.Migrator().DropTable(Order{}, &Category{}) + db.Migrator().DropTable(Order{}, &Category{}, &User{}) tables, _ = db.Migrator().GetTables() assert.Equal(t, len(tables), 0) // auto migrate - db.AutoMigrate(&Order{}, &Category{}) + db.AutoMigrate(&Order{}, &Category{}, &User{}) tables, _ = db.Migrator().GetTables() sort.Strings(tables) assert.Equal(t, tables, targetTables) // auto migrate again - err := db.AutoMigrate(&Order{}, &Category{}) + err := db.AutoMigrate(&Order{}, &Category{}, &User{}) assert.Equal(t, err, nil) } @@ -381,6 +387,26 @@ func TestReadWriteSplitting(t *testing.T) { assert.Equal(t, "iPhone", order.Product) } +func TestAssociation(t *testing.T) { + user := User{ + Name: "association_user", + Orders: []Order{ + {Product: "association_product_1"}, + }, + } + + var err error + err = db.Create(&user).Error + assert.Equal(t, err, nil) + + var user1 User + err = db.Preload("Orders").Find(&user1).Error + assert.Equal(t, err, nil) + assert.Equal(t, user1.Name, user.Name) + assert.Equal(t, len(user1.Orders), len(user.Orders)) + assert.Equal(t, user1.Orders[0].Product, user.Orders[0].Product) +} + func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) { t.Helper() assert.Equal(t, toDialect(expected), middleware.LastQuery()) From 3043694224622abb1408364350eb5ab300f94793 Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Thu, 5 May 2022 16:14:29 +0800 Subject: [PATCH 2/3] fix: use user alias table name --- sharding.go | 18 ++++++++---------- sharding_test.go | 10 ++++++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sharding.go b/sharding.go index 9bce440..5a1a4d1 100644 --- a/sharding.go +++ b/sharding.go @@ -322,14 +322,14 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, var value interface{} var id int64 var keyFind bool - var isQualifiedRefExpr bool + var aliasTable string if isInsert { value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...) if err != nil { return } } else { - value, id, keyFind, isQualifiedRefExpr, err = s.nonInsertValue(tableName, r.ShardingKey, condition, args...) + value, id, keyFind, aliasTable, err = s.nonInsertValue(tableName, r.ShardingKey, condition, args...) if err != nil { return } @@ -382,8 +382,8 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, stQuery = stmt.String() case *sqlparser.SelectStatement: ftQuery = stmt.String() - if isQualifiedRefExpr { - newTable.Alias = &sqlparser.Ident{Name: tableName} + if aliasTable != "" { + newTable.Alias = &sqlparser.Ident{Name: aliasTable} } stmt.FromItems = newTable stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) @@ -429,7 +429,7 @@ func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sql return } -func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, isQualifiedRefExpr bool, err error) { +func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, alias string, err error) { var handleExprIdent = func(n *sqlparser.BinaryExpr, x *sqlparser.Ident) (err error) { if x.Name == key && n.Op == sqlparser.EQ { keyFind = true @@ -470,10 +470,8 @@ func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Exp if x, ok := n.X.(*sqlparser.Ident); ok { return handleExprIdent(n, x) } else if ref, ok := n.X.(*sqlparser.QualifiedRef); ok { - if ref.Table.Name == tableName { - isQualifiedRefExpr = true - return handleExprIdent(n, ref.Column) - } + alias = ref.Table.Name + return handleExprIdent(n, ref.Column) } } @@ -484,7 +482,7 @@ func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Exp } if !keyFind && id == 0 { - return nil, 0, keyFind, false, ErrMissingShardingKey + return nil, 0, keyFind, "", ErrMissingShardingKey } return diff --git a/sharding_test.go b/sharding_test.go index 6c6dd41..0ce3e02 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -387,6 +387,16 @@ func TestReadWriteSplitting(t *testing.T) { assert.Equal(t, "iPhone", order.Product) } +func TestSelectAlias(t *testing.T) { + sql := toDialect(`SELECT * FROM "orders" WHERE orders.user_id = 101`) + tx := db.Raw(sql).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM orders_1 AS orders WHERE orders.user_id = 101`, tx) + + sql1 := toDialect(`SELECT * FROM "orders" AS "o" WHERE o.user_id = 101`) + tx1 := db.Raw(sql1).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM orders_1 AS o WHERE o.user_id = 101`, tx1) +} + func TestAssociation(t *testing.T) { user := User{ Name: "association_user", From b79201739cb8cdbaf0a9be0f00d2fe24f45e6d61 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Feb 2023 16:57:02 +0800 Subject: [PATCH 3/3] Update sharding.go --- sharding.go | 81 +++++++++++------------------------------------------ 1 file changed, 16 insertions(+), 65 deletions(-) diff --git a/sharding.go b/sharding.go index 6e2510d..ba3cec8 100644 --- a/sharding.go +++ b/sharding.go @@ -353,25 +353,21 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}} fillID := true - if isInsert { - for _, name := range insertNames { - if name.Name == "id" { - fillID = false - break - } - } - if fillID { - tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1)) - if err != nil { - return ftQuery, stQuery, tableName, err - } - id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) - columnNames = append(insertNames, &sqlparser.Ident{Name: "id"}) - insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) + for _, name := range insertNames { + if name.Name == "id" { + fillID = false + break } } - if fillID { + tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1)) + if err != nil { + return ftQuery, stQuery, tableName, err + } + id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) + columnNames = append(insertNames, &sqlparser.Ident{Name: "id"}) + insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) + insertStmt.ColumnNames = columnNames inserExpression.Exprs = insertValues } @@ -397,6 +393,10 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, switch stmt := expr.(type) { case *sqlparser.SelectStatement: + if aliasTable != "" { + newTable.Alias = &sqlparser.Ident{Name: aliasTable} + } + ftQuery = stmt.String() stmt.FromItems = newTable stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) @@ -428,55 +428,6 @@ func getSuffix(value interface{}, id int64, keyFind bool, r Config) (suffix stri } suffix = r.ShardingAlgorithmByPrimaryKey(id) } - - newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}} - - fillID := true - if isInsert { - for _, name := range insertNames { - if name.Name == "id" { - fillID = false - break - } - } - if fillID { - tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1)) - if err != nil { - return ftQuery, stQuery, tableName, err - } - id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) - insertNames = append(insertNames, &sqlparser.Ident{Name: "id"}) - insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) - } - } - - switch stmt := expr.(type) { - case *sqlparser.InsertStatement: - if fillID { - stmt.ColumnNames = insertNames - stmt.Expressions[0].Exprs = insertValues - } - ftQuery = stmt.String() - stmt.TableName = newTable - stQuery = stmt.String() - case *sqlparser.SelectStatement: - ftQuery = stmt.String() - if aliasTable != "" { - newTable.Alias = &sqlparser.Ident{Name: aliasTable} - } - stmt.FromItems = newTable - stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name) - stQuery = stmt.String() - case *sqlparser.UpdateStatement: - ftQuery = stmt.String() - stmt.TableName = newTable - stQuery = stmt.String() - case *sqlparser.DeleteStatement: - ftQuery = stmt.String() - stmt.TableName = newTable - stQuery = stmt.String() - } - return }