diff --git a/oracle/common.go b/oracle/common.go index 70db86a..aec9d74 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -79,7 +79,7 @@ func getOracleArrayType(field *schema.Field, values []any) string { case schema.Bytes: return "TABLE OF BLOB" default: - return "TABLE OF VARCHAR2(4000)" // Safe default + return "TABLE OF " + strings.ToUpper(string(field.DataType)) } } @@ -110,11 +110,47 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field { return nil } +// Extra data types to determine the destination type for OUT parameters +// when using a serializer +const ( + Timestamp schema.DataType = "timestamp" + TimestampWithTimeZone schema.DataType = "timestamp with time zone" +) + // Create typed destination for OUT parameters func createTypedDestination(f *schema.Field) interface{} { if f == nil { - var s string - return &s + return new(string) + } + + // If the field has a serializer, the field type may not be directly related to the column type in the database. + // In this case, determine the destination type using the field's data type, which is the column type in the + // database. + if f.Serializer != nil { + dt := strings.ToLower(string(f.DataType)) + switch schema.DataType(dt) { + case schema.Bool: + return new(bool) + case schema.Uint: + return new(uint64) + case schema.Int: + return new(int64) + case schema.Float: + return new(float64) + case schema.String: + return new(string) + case Timestamp: + fallthrough + case TimestampWithTimeZone: + fallthrough + case schema.Time: + return new(time.Time) + case schema.Bytes: + return new([]byte) + default: + // Fallback + return new(string) + } } ft := f.FieldType @@ -163,8 +199,7 @@ func createTypedDestination(f *schema.Field) interface{} { } // Fallback - var s string - return &s + return new(string) } // Convert values for Oracle-specific types @@ -218,6 +253,13 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ return nil } + // Deserialize data into objects when a serializer is used + if field.Serializer != nil { + serializerField := field.NewValuePool.Get().(sql.Scanner) + serializerField.Scan(value) + return serializerField + } + targetType := field.FieldType var converted any diff --git a/oracle/update.go b/oracle/update.go index 735eb9d..386abb7 100644 --- a/oracle/update.go +++ b/oracle/update.go @@ -168,7 +168,6 @@ func checkMissingWhereConditions(db *gorm.DB) { } // Has non-soft-delete equality condition, this is valid hasMeaningfulConditions = true - break case clause.IN: // Has IN condition with values, this is valid if len(e.Values) > 0 { @@ -187,11 +186,9 @@ func checkMissingWhereConditions(db *gorm.DB) { } // Has non-soft-delete expression condition, consider it valid hasMeaningfulConditions = true - break case clause.AndConditions, clause.OrConditions: // Complex conditions are likely valid (but we could be more thorough here) hasMeaningfulConditions = true - break case clause.Where: // Handle nested WHERE clauses - recursively check their expressions if len(e.Exprs) > 0 { @@ -208,7 +205,6 @@ func checkMissingWhereConditions(db *gorm.DB) { default: // Unknown condition types - assume they're meaningful for safety hasMeaningfulConditions = true - break } // If we found meaningful conditions, we can stop checking diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 110485a..c7f3742 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -60,8 +60,8 @@ type SerializerStruct struct { Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` - CreatedTime int64 `gorm:"serializer:unixtime;type:timestamp"` // store time in db, use int as field type - UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamp"` // store time in db, use int as field type + CreatedTime int64 `gorm:"serializer:unixtime;type:timestamp with time zone"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamp with time zone"` // store time in db, use int as field type CustomSerializerString string `gorm:"serializer:custom"` EncryptedString EncryptedString } @@ -122,7 +122,9 @@ func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst r } func TestSerializer(t *testing.T) { - schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + if _, ok := schema.GetSerializer("custom"); !ok { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + } DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) @@ -168,8 +170,82 @@ func TestSerializer(t *testing.T) { } } +// Issue 48: https://github.com/oracle-samples/gorm-oracle/issues/48 +func TestSerializerBulkInsert(t *testing.T) { + if _, ok := schema.GetSerializer("custom"); !ok { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + } + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt := createdAt.Unix() + + data := []SerializerStruct{ + { + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + Roles3: &Roles{}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + UpdatedTime: &updatedAt, + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Kenmawr", + IsIntern: false, + }, + CustomSerializerString: "world", + }, + { + Name: []byte("john"), + Roles: []string{"l1", "l2"}, + Roles3: &Roles{}, + Contracts: map[string]interface{}{"name": "john", "age": 20}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + UpdatedTime: &updatedAt, + JobInfo: Job{ + Title: "manager", + Number: 7710, + Location: "Redwood City", + IsIntern: false, + }, + CustomSerializerString: "foo", + }, + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result []SerializerStruct + if err := DB.Find(&result).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + tests.AssertEqual(t, result, data) + + // Update all the "roles" columns to "n1" + if err := DB.Model(&SerializerStruct{}).Where("\"roles\" IS NOT NULL").Update("roles", []string{"n1"}).Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + var count int64 + if err := DB.Model(&SerializerStruct{}).Where("\"roles\" = ?", "n1").Count(&count).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + tests.AssertEqual(t, count, 2) +} + func TestSerializerZeroValue(t *testing.T) { - schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + if _, ok := schema.GetSerializer("custom"); !ok { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + } DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) @@ -200,7 +276,9 @@ func TestSerializerZeroValue(t *testing.T) { } func TestSerializerAssignFirstOrCreate(t *testing.T) { - schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + if _, ok := schema.GetSerializer("custom"); !ok { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + } DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)