Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 64 additions & 56 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,14 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
return
}

var value interface{}
var id int64
var keyFind bool
var aliasTable string
var suffix string
if isInsert {
var newTable *sqlparser.TableName
for _, inserExpression := range inserExpressions {
var value interface{}
var id int64
var keyFind bool
columnNames := insertNames
insertValues := inserExpression.Exprs
value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...)
Expand All @@ -352,25 +353,21 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
newTable = &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}

fillID := true
if isInsert {
for _, name := range insertNames {
if name.Name == "id" {
fillID = false
break
}
}
if fillID {
tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1))
if err != nil {
return ftQuery, stQuery, tableName, err
}
id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
columnNames = append(insertNames, &sqlparser.Ident{Name: "id"})
insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)})
for _, name := range insertNames {
if name.Name == "id" {
fillID = false
break
}
}

if fillID {
tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1))
if err != nil {
return ftQuery, stQuery, tableName, err
}
id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
columnNames = append(insertNames, &sqlparser.Ident{Name: "id"})
insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)})

insertStmt.ColumnNames = columnNames
inserExpression.Exprs = insertValues
}
Expand All @@ -381,10 +378,8 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
stQuery = insertStmt.String()

} else {
var value interface{}
var id int64
var keyFind bool
value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...)
value, id, keyFind, aliasTable, err = s.nonInsertValue(tableName, r.ShardingKey, condition, args...)

if err != nil {
return
}
Expand All @@ -398,6 +393,10 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,

switch stmt := expr.(type) {
case *sqlparser.SelectStatement:
if aliasTable != "" {
newTable.Alias = &sqlparser.Ident{Name: aliasTable}
}

ftQuery = stmt.String()
stmt.FromItems = newTable
stmt.OrderBy = replaceOrderByTableName(stmt.OrderBy, tableName, newTable.Name.Name)
Expand Down Expand Up @@ -460,51 +459,60 @@ func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sql
return
}

func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) {
func (s *Sharding) nonInsertValue(tableName, key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, alias string, err error) {
var handleExprIdent = func(n *sqlparser.BinaryExpr, x *sqlparser.Ident) (err error) {
if x.Name == key && n.Op == sqlparser.EQ {
keyFind = true
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
value = args[expr.Pos]
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
default:
return sqlparser.ErrNotImplemented
}
return nil
} else if x.Name == "id" && n.Op == sqlparser.EQ {
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
v := args[expr.Pos]
var ok bool
if id, ok = v.(int64); !ok {
return fmt.Errorf("ID should be int64 type")
}
case *sqlparser.NumberLit:
id, err = strconv.ParseInt(expr.Value, 10, 64)
if err != nil {
return err
}
default:
return ErrInvalidID
}
return nil
}
return nil
}

err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error {
if n, ok := node.(*sqlparser.BinaryExpr); ok {
if x, ok := n.X.(*sqlparser.Ident); ok {
if x.Name == key && n.Op == sqlparser.EQ {
keyFind = true
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
value = args[expr.Pos]
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
default:
return sqlparser.ErrNotImplemented
}
return nil
} else if x.Name == "id" && n.Op == sqlparser.EQ {
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
v := args[expr.Pos]
var ok bool
if id, ok = v.(int64); !ok {
return fmt.Errorf("ID should be int64 type")
}
case *sqlparser.NumberLit:
id, err = strconv.ParseInt(expr.Value, 10, 64)
if err != nil {
return err
}
default:
return ErrInvalidID
}
return nil
}
return handleExprIdent(n, x)
} else if ref, ok := n.X.(*sqlparser.QualifiedRef); ok {
alias = ref.Table.Name
return handleExprIdent(n, ref.Column)
}
}

return nil
}), condition)
if err != nil {
return
}

if !keyFind && id == 0 {
return nil, 0, keyFind, ErrMissingShardingKey
return nil, 0, keyFind, "", ErrMissingShardingKey
}

return
Expand Down
48 changes: 42 additions & 6 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ import (
"gorm.io/plugin/dbresolver"
)

type User struct {
ID int64
Name string
Orders []Order
}

type Order struct {
ID int64 `gorm:"primarykey"`
UserID int64
Expand Down Expand Up @@ -117,7 +123,7 @@ func init() {
fmt.Println("Clean only tables ...")
dropTables()
fmt.Println("AutoMigrate tables ...")
err := db.AutoMigrate(&Order{}, &Category{})
err := db.AutoMigrate(&Order{}, &Category{}, &User{})
if err != nil {
panic(err)
}
Expand All @@ -144,7 +150,7 @@ func init() {
}

func dropTables() {
tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories", "users"}
for _, table := range tables {
db.Exec("DROP TABLE IF EXISTS " + table)
dbRead.Exec("DROP TABLE IF EXISTS " + table)
Expand All @@ -154,7 +160,7 @@ func dropTables() {
}

func TestMigrate(t *testing.T) {
targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories", "users"}
sort.Strings(targetTables)

// origin tables
Expand All @@ -163,18 +169,18 @@ func TestMigrate(t *testing.T) {
assert.Equal(t, tables, targetTables)

// drop table
db.Migrator().DropTable(Order{}, &Category{})
db.Migrator().DropTable(Order{}, &Category{}, &User{})
tables, _ = db.Migrator().GetTables()
assert.Equal(t, len(tables), 0)

// auto migrate
db.AutoMigrate(&Order{}, &Category{})
db.AutoMigrate(&Order{}, &Category{}, &User{})
tables, _ = db.Migrator().GetTables()
sort.Strings(tables)
assert.Equal(t, tables, targetTables)

// auto migrate again
err := db.AutoMigrate(&Order{}, &Category{})
err := db.AutoMigrate(&Order{}, &Category{}, &User{})
assert.Equal(t, err, nil)
}

Expand Down Expand Up @@ -397,6 +403,36 @@ func TestReadWriteSplitting(t *testing.T) {
assert.Equal(t, "iPhone", order.Product)
}

func TestSelectAlias(t *testing.T) {
sql := toDialect(`SELECT * FROM "orders" WHERE orders.user_id = 101`)
tx := db.Raw(sql).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM orders_1 AS orders WHERE orders.user_id = 101`, tx)

sql1 := toDialect(`SELECT * FROM "orders" AS "o" WHERE o.user_id = 101`)
tx1 := db.Raw(sql1).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM orders_1 AS o WHERE o.user_id = 101`, tx1)
}

func TestAssociation(t *testing.T) {
user := User{
Name: "association_user",
Orders: []Order{
{Product: "association_product_1"},
},
}

var err error
err = db.Create(&user).Error
assert.Equal(t, err, nil)

var user1 User
err = db.Preload("Orders").Find(&user1).Error
assert.Equal(t, err, nil)
assert.Equal(t, user1.Name, user.Name)
assert.Equal(t, len(user1.Orders), len(user.Orders))
assert.Equal(t, user1.Orders[0].Product, user.Orders[0].Product)
}

func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) {
t.Helper()
assert.Equal(t, toDialect(expected), middleware.LastQuery())
Expand Down