From d41fb3acdcfefc80e1ca24e3a4f1d0e3c39ba252 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 11 Sep 2021 16:22:35 +0800 Subject: [PATCH 01/83] Refactor dummy driver QuoteTo method --- utils/tests/dummy_dialecter.go | 48 +++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index b8452ef9a..84fdd2b6e 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -31,9 +31,51 @@ func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v in } func (DummyDialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteByte('`') - writer.WriteString(str) - writer.WriteByte('`') + var ( + underQuoted, selfQuoted bool + continuousBacktick int8 + shiftDelimiter int8 + ) + + for _, v := range []byte(str) { + switch v { + case '`': + continuousBacktick++ + if continuousBacktick == 2 { + writer.WriteString("``") + continuousBacktick = 0 + } + case '.': + if continuousBacktick > 0 || !selfQuoted { + shiftDelimiter = 0 + underQuoted = false + continuousBacktick = 0 + writer.WriteString("`") + } + writer.WriteByte(v) + continue + default: + if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { + writer.WriteByte('`') + underQuoted = true + if selfQuoted = continuousBacktick > 0; selfQuoted { + continuousBacktick -= 1 + } + } + + for ; continuousBacktick > 0; continuousBacktick -= 1 { + writer.WriteString("``") + } + + writer.WriteByte(v) + } + shiftDelimiter++ + } + + if continuousBacktick > 0 && !selfQuoted { + writer.WriteString("``") + } + writer.WriteString("`") } func (DummyDialector) Explain(sql string, vars ...interface{}) string { From 61b018cb942900fad2bf179818d4e2c0497435e9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Sep 2021 11:17:54 +0800 Subject: [PATCH 02/83] Fix count with selected * --- finisher_api.go | 2 +- tests/count_test.go | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 34e1596bb..741a94561 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -399,7 +399,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if tx.Statement.Distinct { expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} - } else { + } else if dbName != "*" { expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} } } diff --git a/tests/count_test.go b/tests/count_test.go index dd25f8b65..de06d0eb7 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -112,7 +112,7 @@ func TestCount(t *testing.T) { if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { - t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + t.Fatalf("Count should work, but got err %v", err) } expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} @@ -123,9 +123,15 @@ func TestCount(t *testing.T) { AssertEqual(t, users, expects) var count9 int64 - if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { + if err := DB.Scopes(func(tx *gorm.DB) *gorm.DB { return tx.Table("users") }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { - t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + t.Fatalf("Count should work, but got err %v", err) } + + var count10 int64 + if err := DB.Model(&User{}).Select("*").Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count10).Error; err != nil || count10 != 3 { + t.Fatalf("Count should be 3, but got count: %v err %v", count10, err) + } + } From 12bbde89e683d85181b0344ff71f44d3148bf9cd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Sep 2021 14:04:19 +0800 Subject: [PATCH 03/83] Fix Scan with interface --- finisher_api.go | 7 ++++++- scan.go | 20 ++++++++++++-------- schema/schema.go | 6 +++++- tests/scan_test.go | 37 +++++++++++++++++++++++++++++++++++-- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 741a94561..d273093f1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -506,7 +506,12 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { - tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + elem := tx.Statement.ReflectValue.Elem() + if !elem.IsValid() { + elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) + tx.Statement.ReflectValue.Set(elem) + } + tx.Statement.ReflectValue = elem } Scan(rows, tx, true) return tx.Error diff --git a/scan.go b/scan.go index 2beecd453..20bdde9e5 100644 --- a/scan.go +++ b/scan.go @@ -97,11 +97,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } default: Schema := db.Statement.Schema + reflectValue := db.Statement.ReflectValue + if reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } - switch db.Statement.ReflectValue.Kind() { + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - reflectValueType = db.Statement.ReflectValue.Type().Elem() + reflectValueType = reflectValue.Type().Elem() isPtr = reflectValueType.Kind() == reflect.Ptr fields = make([]*schema.Field, len(columns)) joinFields [][2]*schema.Field @@ -111,7 +115,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { reflectValueType = reflectValueType.Elem() } - db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) if Schema != nil { if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { @@ -186,13 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem)) } else { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) + db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem())) } } case reflect.Struct, reflect.Ptr: - if db.Statement.ReflectValue.Type() != Schema.ModelType { + if reflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } @@ -220,11 +224,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for idx, column := range columns { if field := Schema.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.ReflectValue, values[idx]) + field.Set(reflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + relValue := rel.Field.ReflectValueOf(reflectValue) value := reflect.ValueOf(values[idx]).Elem() if relValue.Kind() == reflect.Ptr && relValue.IsNil() { diff --git a/schema/schema.go b/schema/schema.go index faba2e21c..c425070b3 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -77,7 +77,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - modelType := reflect.ValueOf(dest).Type() + modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() + } + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } diff --git a/tests/scan_test.go b/tests/scan_test.go index 67d5f385a..aacad8272 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -29,8 +29,9 @@ func TestScan(t *testing.T) { } var resPointer *result - DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer) - if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } @@ -70,6 +71,38 @@ func TestScan(t *testing.T) { if uint(id) != user2.ID { t.Errorf("Failed to scan to customized data type") } + + var resInt interface{} + resInt = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3) + } + + var resInt2 interface{} + resInt2 = &User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3) + } + + var resInt3 interface{} + resInt3 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3) + } + + var resInt4 interface{} + resInt4 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { + t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) + } } func TestScanRows(t *testing.T) { From da16a8aac6c3620532f5ad6d1fedf20fca2c1cf6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Sep 2021 15:29:49 +0800 Subject: [PATCH 04/83] Update updated_at when upserting with Create OnConflict --- callbacks/create.go | 21 ++++++++++++-- schema/field.go | 15 ++++++---- tests/upsert_test.go | 66 +++++++++++++++++++++++++++++--------------- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 8a3c593cc..a2944319c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) { // ConvertToCreateValues convert to create values func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + curTime := stmt.DB.NowFunc() + switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) @@ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) _, updateTrackTime = stmt.Get("gorm:update_track_time") - curTime = stmt.DB.NowFunc() isZero bool ) stmt.Settings.Delete("gorm:update_track_time") @@ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if field.AutoUpdateTime > 0 { + assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} + switch field.AutoUpdateTime { + case schema.UnixNanosecond: + assignment.Value = curTime.UnixNano() + case schema.UnixMillisecond: + assignment.Value = curTime.UnixNano() / 1e6 + case schema.UnixSecond: + assignment.Value = curTime.Unix() + } + + onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) + } else { + columns = append(columns, column.Name) + } } } } } - onConflict.DoUpdates = clause.AssignmentColumns(columns) + onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) // use primary fields as default OnConflict columns if len(onConflict.Columns) == 0 { diff --git a/schema/field.go b/schema/field.go index ce0e3c130..f3189c7a3 100644 --- a/schema/field.go +++ b/schema/field.go @@ -21,9 +21,10 @@ type TimeType int64 var TimeReflectType = reflect.TypeOf(time.Time{}) const ( - UnixSecond TimeType = 1 - UnixMillisecond TimeType = 2 - UnixNanosecond TimeType = 3 + UnixTime TimeType = 1 + UnixSecond TimeType = 2 + UnixMillisecond TimeType = 3 + UnixNanosecond TimeType = 4 ) const ( @@ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { - if strings.ToUpper(v) == "NANO" { + if field.DataType == Time { + field.AutoCreateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoCreateTime = UnixMillisecond @@ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { - if strings.ToUpper(v) == "NANO" { + if field.DataType == Time { + field.AutoUpdateTime = UnixTime + } else if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoUpdateTime = UnixMillisecond diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 867110d88..0e247caa1 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -66,6 +66,26 @@ func TestUpsert(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } + + var user = *GetUser("upsert_on_conflict", Config{}) + user.Age = 20 + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error %v", err) + } + + var user2 User + DB.First(&user2, user.ID) + user2.Age = 30 + time.Sleep(time.Second) + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil { + t.Fatalf("failed to onconflict create user, got error %v", err) + } else { + var user3 User + DB.First(&user3, user.ID) + if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() { + t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt) + } + } } func TestUpsertSlice(t *testing.T) { @@ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) { } } - // lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} - // if err := DB.Save(&lang).Error; err != nil { - // t.Errorf("Failed to create, got error %v", err) - // } - - // var result Language - // if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { - // t.Errorf("Failed to query lang, got error %v", err) - // } else { - // AssertEqual(t, result, lang) - // } - - // lang.Name += "_new" - // if err := DB.Save(&lang).Error; err != nil { - // t.Errorf("Failed to create, got error %v", err) - // } - - // var result2 Language - // if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { - // t.Errorf("Failed to query lang, got error %v", err) - // } else { - // AssertEqual(t, result2, lang) - // } + lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + + lang.Name += "_new" + if err := DB.Save(&lang).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + var result2 Language + if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result2, lang) + } } func TestFindOrInitialize(t *testing.T) { From ab355336cbedde681f852318c9cb9b78ef633ea1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Sep 2021 18:35:14 +0800 Subject: [PATCH 05/83] Fix scan with interface --- scan.go | 6 ++++-- tests/scan_test.go | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/scan.go b/scan.go index 20bdde9e5..4570380d8 100644 --- a/scan.go +++ b/scan.go @@ -190,11 +190,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem)) + reflectValue = reflect.Append(reflectValue, elem) } else { - db.Statement.ReflectValue.Set(reflect.Append(reflectValue, elem.Elem())) + reflectValue = reflect.Append(reflectValue, elem.Elem()) } } + + db.Statement.ReflectValue.Set(reflectValue) case reflect.Struct, reflect.Ptr: if reflectValue.Type() != Schema.ModelType { Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) diff --git a/tests/scan_test.go b/tests/scan_test.go index aacad8272..59fc6de5d 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -103,6 +103,14 @@ func TestScan(t *testing.T) { } else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) } + + var resInt5 interface{} + resInt5 = []User{} + if err := DB.Table("users").Select("id, name, age").Where("id IN ?", []uint{user1.ID, user2.ID, user3.ID}).Find(&resInt5).Error; err != nil { + t.Fatalf("Failed to query with pointer of value, got error %v", err) + } else if rus := resInt5.([]User); len(rus) != 3 { + t.Fatalf("Scan into struct should work, got %+v, len %v", resInt5, len(rus)) + } } func TestScanRows(t *testing.T) { From d67120a1551629a8da0199c9f96a379c13221a38 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Sep 2021 21:25:29 +0800 Subject: [PATCH 06/83] Bump gorm.io/driver/sqlite from 1.1.4 to 1.1.5 in /tests (#4701) Bumps [gorm.io/driver/sqlite](https://github.com/go-gorm/sqlite) from 1.1.4 to 1.1.5. - [Release notes](https://github.com/go-gorm/sqlite/releases) - [Commits](https://github.com/go-gorm/sqlite/compare/v1.1.4...v1.1.5) --- updated-dependencies: - dependency-name: gorm.io/driver/sqlite dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index d7ab65ad7..77e88ca99 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,9 +8,9 @@ require ( github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.0 - gorm.io/driver/sqlite v1.1.4 + gorm.io/driver/sqlite v1.1.5 gorm.io/driver/sqlserver v1.0.9 - gorm.io/gorm v1.21.14 + gorm.io/gorm v1.21.15 ) replace gorm.io/gorm => ../ From 199c8529b6c4e447ddbab9ae3edad137d954d36f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Sep 2021 21:33:38 +0800 Subject: [PATCH 07/83] Bump gorm.io/driver/postgres from 1.1.0 to 1.1.1 in /tests (#4699) Bumps [gorm.io/driver/postgres](https://github.com/go-gorm/postgres) from 1.1.0 to 1.1.1. - [Release notes](https://github.com/go-gorm/postgres/releases) - [Commits](https://github.com/go-gorm/postgres/compare/v1.1.0...v1.1.1) --- updated-dependencies: - dependency-name: gorm.io/driver/postgres dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 77e88ca99..c4e27024e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 - gorm.io/driver/postgres v1.1.0 + gorm.io/driver/postgres v1.1.1 gorm.io/driver/sqlite v1.1.5 gorm.io/driver/sqlserver v1.0.9 gorm.io/gorm v1.21.15 From 5202529ea147916a5b6e331c5d39f60859df2360 Mon Sep 17 00:00:00 2001 From: Jim Date: Mon, 20 Sep 2021 09:40:48 -0400 Subject: [PATCH 08/83] fix (clause/expression): Allow sql stmt terminator (#4693) Allow the sql stmt terminator ";" at the end of a named parameter. Example: select * from table_name where name == @name; --- clause/expression.go | 2 +- clause/expression_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/clause/expression.go b/clause/expression.go index f7b93f4c3..e914b7b30 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -121,7 +121,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) diff --git a/clause/expression_test.go b/clause/expression_test.go index 050748654..eadd96ea7 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -89,6 +89,11 @@ func TestNamedExpr(t *testing.T) { SQL: "create table ? (? ?, ? ?)", Vars: []interface{}{}, Result: "create table ? (? ?, ? ?)", + }, { + SQL: "name1 = @name AND name2 = @name;", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?;", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }} for idx, result := range results { From 6864a241504bc251b249c9bd3b85c803b0df90ce Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Mon, 27 Sep 2021 22:11:29 +0800 Subject: [PATCH 09/83] fix:remove the tableName judgment in pluck (#4731) --- finisher_api.go | 2 -- tests/distinct_test.go | 6 ++++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index d273093f1..e98efc92a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -483,8 +483,6 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { column = f.DBName } } - } else if tx.Statement.Table == "" { - tx.AddError(ErrModelValueRequired) } if len(tx.Statement.Selects) != 1 { diff --git a/tests/distinct_test.go b/tests/distinct_test.go index 29a320ff7..f97738a77 100644 --- a/tests/distinct_test.go +++ b/tests/distinct_test.go @@ -31,6 +31,12 @@ func TestDistinct(t *testing.T) { AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) + var names2 []string + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("users") + }).Where("name like ?", "distinct%").Order("name").Pluck("name", &names2) + AssertEqual(t, names2, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + var results []User if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { t.Errorf("failed to query users, got error: %v", err) From 002bf78ea787f1df8ef3dd084e4854a9da8fedce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Sep 2021 21:43:12 +0800 Subject: [PATCH 10/83] Fix Join condition with DB, close #4719 --- chainable_api.go | 2 +- tests/joins_test.go | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 01ab25977..23e601102 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -179,8 +179,8 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { if db, ok := args[0].(*DB); ok { if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where}) + return } - return } } diff --git a/tests/joins_test.go b/tests/joins_test.go index e560f38ad..25fa20b43 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -102,6 +102,12 @@ func TestJoinConds(t *testing.T) { if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) } + + iv := DB.Table(`table_invoices`).Select(`seller, SUM(total) as total, SUM(paid) as paid, SUM(balance) as balance`).Group(`seller`) + stmt = dryDB.Table(`table_employees`).Select(`id, name, iv.total, iv.paid, iv.balance`).Joins(`LEFT JOIN (?) AS iv ON iv.seller = table_employees.id`, iv).Scan(&user).Statement + if !regexp.MustCompile("SELECT id, name, iv.total, iv.paid, iv.balance FROM .table_employees. LEFT JOIN \\(SELECT seller, SUM\\(total\\) as total, SUM\\(paid\\) as paid, SUM\\(balance\\) as balance FROM .table_invoices. GROUP BY .seller.\\) AS iv ON iv.seller = table_employees.id").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } } func TestJoinOn(t *testing.T) { From c4a2e891daee9fa5ba4305b3594d2e155a17a082 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Sep 2021 22:37:15 +0800 Subject: [PATCH 11/83] Fix Join condition with DB --- chainable_api.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 23e601102..173479d30 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -175,10 +175,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if len(args) > 0 { + if len(args) == 1 { if db, ok := args[0].(*DB); ok { if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where}) + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) return } } From 851fea0221ff6ab53e3b9ce2d127c2126bd9a6f0 Mon Sep 17 00:00:00 2001 From: River Date: Wed, 29 Sep 2021 14:02:35 +0800 Subject: [PATCH 12/83] fix: QuoteTo not fully support raw mode (#4735) * fix: QuoteTo not fully support raw mode * fix: table alias without AS * test: clause.Column/Table quote test * fix: revert table alias quote --- clause/expression_test.go | 28 ++++++++++++++++++++++++++++ statement.go | 30 +++++++++++++++++------------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/clause/expression_test.go b/clause/expression_test.go index eadd96ea7..4826db381 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -94,6 +94,34 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?;", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, + Result: "`table`.`col`", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}}, + Result: "table.col", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}}, + Result: "table.id", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}}, + Result: "`table`.`col` AS `alias`", + }, { + SQL: "?", + Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}}, + Result: "table.col AS alias", + }, { + SQL: "?", + Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}}, + Result: "`table` `alias`", + }, { + SQL: "?", + Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}}, + Result: "table alias", }} for idx, result := range results { diff --git a/statement.go b/statement.go index 383634434..347f88ff0 100644 --- a/statement.go +++ b/statement.go @@ -75,30 +75,36 @@ func (stmt *Statement) WriteQuoted(value interface{}) { // QuoteTo write quoted value to writer func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { + write := func(raw bool, str string) { + if raw { + writer.WriteString(str) + } else { + stmt.DB.Dialector.QuoteTo(writer, str) + } + } + switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) } else { - stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + write(v.Raw, stmt.Table) } - } else if v.Raw { - writer.WriteString(v.Name) } else { - stmt.DB.Dialector.QuoteTo(writer, v.Name) + write(v.Raw, v.Name) } if v.Alias != "" { writer.WriteByte(' ') - stmt.DB.Dialector.QuoteTo(writer, v.Alias) + write(v.Raw, v.Alias) } case clause.Column: if v.Table != "" { if v.Table == clause.CurrentTable { - stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + write(v.Raw, stmt.Table) } else { - stmt.DB.Dialector.QuoteTo(writer, v.Table) + write(v.Raw, v.Table) } writer.WriteByte('.') } @@ -107,19 +113,17 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) } else if stmt.Schema.PrioritizedPrimaryField != nil { - stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { - stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) + write(v.Raw, stmt.Schema.DBNames[0]) } - } else if v.Raw { - writer.WriteString(v.Name) } else { - stmt.DB.Dialector.QuoteTo(writer, v.Name) + write(v.Raw, v.Name) } if v.Alias != "" { writer.WriteString(" AS ") - stmt.DB.Dialector.QuoteTo(writer, v.Alias) + write(v.Raw, v.Alias) } case []clause.Column: writer.WriteByte('(') From 0b6bd3393484da7cf3b2befd4f620f6e6e5d1b9d Mon Sep 17 00:00:00 2001 From: s-takehana Date: Fri, 8 Oct 2021 11:51:53 +0900 Subject: [PATCH 13/83] Update `tests.yml` (#4741) --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d5ee1e88f..700af759d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.17', '1.16', '1.15'] + go: ['1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -39,7 +39,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.17', '1.16', '1.15'] + go: ['1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -82,8 +82,8 @@ jobs: postgres: strategy: matrix: - dbversion: ['postgres:latest', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.17', '1.16', '1.15'] + dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] + go: ['1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -125,7 +125,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.17', '1.16', '1.15'] + go: ['1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} From 57d927d04673a850910934aa3672cfd18749939b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 8 Oct 2021 10:54:50 +0800 Subject: [PATCH 14/83] Bump gorm.io/driver/postgres from 1.1.1 to 1.1.2 in /tests (#4740) Bumps [gorm.io/driver/postgres](https://github.com/go-gorm/postgres) from 1.1.1 to 1.1.2. - [Release notes](https://github.com/go-gorm/postgres/releases) - [Commits](https://github.com/go-gorm/postgres/compare/v1.1.1...v1.1.2) --- updated-dependencies: - dependency-name: gorm.io/driver/postgres dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index c4e27024e..5484d6ad7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 - gorm.io/driver/postgres v1.1.1 + gorm.io/driver/postgres v1.1.2 gorm.io/driver/sqlite v1.1.5 gorm.io/driver/sqlserver v1.0.9 gorm.io/gorm v1.21.15 From 5d91ddac8c01aeff48e2402efeb11fcb697b37a0 Mon Sep 17 00:00:00 2001 From: Paras Waykole Date: Fri, 8 Oct 2021 08:29:55 +0530 Subject: [PATCH 15/83] fixed belongs_to & has_one reversed if field same (proper fix) (#4694) * fixed belongs_to & has_one reversed if field same * hasmany same foreign key bug fixed and test added * belongsToSameForeignKey fixed and reverted old fix --- schema/relationship.go | 12 ++++----- schema/relationship_test.go | 54 +++++++++++++++++++++++++++++++++---- utils/utils.go | 12 --------- 3 files changed, 54 insertions(+), 24 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 84556baee..5699ec5f1 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -7,7 +7,6 @@ import ( "github.com/jinzhu/inflection" "gorm.io/gorm/clause" - "gorm.io/gorm/utils" ) // RelationshipType relationship type @@ -78,6 +77,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { schema.buildPolymorphicRelation(relation, field, polymorphic) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) + } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { + schema.guessRelation(relation, field, guessBelongs) } else { switch field.IndirectFieldType.Kind() { case reflect.Struct: @@ -405,14 +406,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - ff := foreignSchema.LookUpField(foreignKey) - pf := primarySchema.LookUpField(foreignKey) - isKeySame := utils.ExistsIn(foreignKey, &relation.primaryKeys) - if ff == nil || (pf != nil && ff != nil && schema == primarySchema && primarySchema != foreignSchema && !isKeySame && field.IndirectFieldType.Kind() == reflect.Struct) { + if f := foreignSchema.LookUpField(foreignKey); f != nil { + foreignFields = append(foreignFields, f) + } else { reguessOrErr() return - } else { - foreignFields = append(foreignFields, ff) } } } else { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index d0ffc28a4..cb616fc07 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -144,6 +144,25 @@ func TestHasOneOverrideReferences(t *testing.T) { }) } +func TestHasOneOverrideReferences2(t *testing.T) { + + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + ProfileID uint `gorm:"column:profile_id"` + Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}}, + }) +} + func TestHasOneWithOnlyReferences(t *testing.T) { type Profile struct { gorm.Model @@ -483,22 +502,47 @@ func TestSameForeignKey(t *testing.T) { ) } -func TestBelongsToWithSameForeignKey(t *testing.T) { +func TestBelongsToSameForeignKey(t *testing.T) { + + type User struct { + gorm.Model + Name string + UUID string + } + + type UserAux struct { + gorm.Model + Aux string + UUID string + User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"` + } + + checkStructRelation(t, &UserAux{}, + Relation{ + Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", false}, + }, + }, + ) +} + +func TestHasOneWithSameForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string - ProfileRefer int + ProfileRefer int // not used in relationship } type User struct { gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileRefer"` + Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"` ProfileRefer int } checkStructRelation(t, &User{}, Relation{ - Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", - References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}}, }) } diff --git a/utils/utils.go b/utils/utils.go index 1110c7a7a..9c238ac55 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -114,15 +114,3 @@ func ToString(value interface{}) string { } return "" } - -func ExistsIn(a string, list *[]string) bool { - if list == nil { - return false - } - for _, b := range *list { - if b == a { - return true - } - } - return false -} From c13f3011f9d1076103e1cbb7cef89fd7b7620e1f Mon Sep 17 00:00:00 2001 From: heige Date: Fri, 8 Oct 2021 11:05:50 +0800 Subject: [PATCH 16/83] feat: adjust SetupJoinTable func if..else code (#4680) --- gorm.go | 54 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/gorm.go b/gorm.go index 7f7bad262..71cd01e82 100644 --- a/gorm.go +++ b/gorm.go @@ -387,43 +387,45 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac modelSchema, joinSchema *schema.Schema ) - if err := stmt.Parse(model); err == nil { - modelSchema = stmt.Schema - } else { + err := stmt.Parse(model) + if err != nil { return err } + modelSchema = stmt.Schema - if err := stmt.Parse(joinTable); err == nil { - joinSchema = stmt.Schema - } else { + err = stmt.Parse(joinTable) + if err != nil { return err } + joinSchema = stmt.Schema - if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { - for _, ref := range relation.References { - if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { - f.DataType = ref.ForeignKey.DataType - f.GORMDataType = ref.ForeignKey.GORMDataType - if f.Size == 0 { - f.Size = ref.ForeignKey.Size - } - ref.ForeignKey = f - } else { - return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) - } + relation, ok := modelSchema.Relationships.Relations[field] + isRelation := ok && relation.JoinTable != nil + if !isRelation { + return fmt.Errorf("failed to found relation: %s", field) + } + + for _, ref := range relation.References { + f := joinSchema.LookUpField(ref.ForeignKey.DBName) + if f == nil { + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) } - for name, rel := range relation.JoinTable.Relationships.Relations { - if _, ok := joinSchema.Relationships.Relations[name]; !ok { - rel.Schema = joinSchema - joinSchema.Relationships.Relations[name] = rel - } + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size } + ref.ForeignKey = f + } - relation.JoinTable = joinSchema - } else { - return fmt.Errorf("failed to found relation: %s", field) + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } } + relation.JoinTable = joinSchema return nil } From e3fc49a694520c722fb301ba149102803eb86912 Mon Sep 17 00:00:00 2001 From: heige Date: Fri, 8 Oct 2021 11:16:58 +0800 Subject: [PATCH 17/83] feat: ajust PreparedStmtDB unlock location and BuildCondition if logic (#4681) --- prepare_stmt.go | 19 +++++++++++-------- statement.go | 12 +++++++++--- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 5faea9950..88bec4e95 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -32,14 +32,14 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) Close() { db.Mux.Lock() + defer db.Mux.Unlock() + for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) go stmt.Close() } } - - db.Mux.Unlock() } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -51,9 +51,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.RUnlock() db.Mux.Lock() + defer db.Mux.Unlock() + // double check if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - db.Mux.Unlock() return stmt, nil } else if ok { go stmt.Close() @@ -64,7 +65,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.PreparedSQL = append(db.PreparedSQL, query) } - defer db.Mux.Unlock() return db.Stmts[query], err } @@ -83,9 +83,9 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. result, err = stmt.ExecContext(ctx, args...) if err != nil { db.Mux.Lock() + defer db.Mux.Unlock() go stmt.Close() delete(db.Stmts, query) - db.Mux.Unlock() } } return result, err @@ -97,9 +97,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . rows, err = stmt.QueryContext(ctx, args...) if err != nil { db.Mux.Lock() + defer db.Mux.Unlock() + go stmt.Close() delete(db.Stmts, query) - db.Mux.Unlock() } } return rows, err @@ -138,9 +139,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.Mux.Unlock() } } return result, err @@ -152,9 +154,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.Mux.Unlock() } } return rows, err diff --git a/statement.go b/statement.go index 347f88ff0..3b76f653a 100644 --- a/statement.go +++ b/statement.go @@ -271,13 +271,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if _, err := strconv.Atoi(s); err != nil { if s == "" && len(args) == 0 { return nil - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + } + + if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} - } else if len(args) > 0 && strings.Contains(s, "@") { + } + + if len(args) > 0 && strings.Contains(s, "@") { // looks like a named query return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} - } else if len(args) == 1 { + } + + if len(args) == 1 { return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } } From b46e2afc4a5fca825c959545b92eef9cd8c83d53 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Fri, 8 Oct 2021 13:47:01 +0800 Subject: [PATCH 18/83] fix : update miss where's condition when primary key use "<-:create" tag (#4738) * fix:update miss where condition * fix:rename test case --- callbacks/update.go | 4 ++-- tests/upsert_test.go | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 7d5ea4a4f..a0a2c579f 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -235,7 +235,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { - if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable { + if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { value, isZero := field.ValueOf(updatingValue) @@ -252,7 +252,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { isZero = false } - if ok || !isZero { + if (ok || !isZero) && field.Updatable { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) assignValue(field, value) } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 0e247caa1..a7b53ab7c 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -309,3 +309,22 @@ func TestFindOrCreate(t *testing.T) { t.Errorf("belongs to association should be saved") } } + +func TestUpdateWithMissWhere(t *testing.T) { + type User struct { + ID uint `gorm:"column:id;<-:create"` + Name string `gorm:"column:name"` + } + user := User{ID: 1, Name: "king"} + tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) + + if err := tx.Error; err != nil { + t.Fatalf("failed to update user,missing where condtion,err=%+v", err) + + } + + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", tx.Statement.SQL.String()) + } + +} From d4c838c1cefcd16d94b9c629b3a841cc24e28328 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 8 Oct 2021 17:31:58 +0800 Subject: [PATCH 19/83] Upgrade sqlite driver --- tests/go.mod | 2 +- tests/migrate_test.go | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 5484d6ad7..6df53d7f9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.2 - gorm.io/driver/sqlite v1.1.5 + gorm.io/driver/sqlite v1.1.6 gorm.io/driver/sqlserver v1.0.9 gorm.io/gorm v1.21.15 ) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 599ca8503..ba2714785 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -357,10 +357,6 @@ func TestMigrateColumns(t *testing.T) { } func TestMigrateConstraint(t *testing.T) { - if DB.Dialector.Name() == "sqlite" { - t.Skip() - } - names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Team", "fk_users_team", "Languages", "fk_users_languages"} for _, name := range names { From 6312d86c54db2da8b9874163564a86637d5c869c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 8 Oct 2021 17:51:27 +0800 Subject: [PATCH 20/83] Support specify select/omit columns with table --- statement.go | 7 +++++++ statement_test.go | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/statement.go b/statement.go index 3b76f653a..bea4f7f07 100644 --- a/statement.go +++ b/statement.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "regexp" "sort" "strconv" "strings" @@ -627,6 +628,8 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } +var nameMatcher = regexp.MustCompile(`\.[\W]?(.+?)[\W]?$`) + // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} @@ -647,6 +650,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { + results[matches[1]] = true } else { results[column] = true } @@ -662,6 +667,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false + } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { + results[matches[1]] = false } else { results[omit] = false } diff --git a/statement_test.go b/statement_test.go index 03ad81dc6..3f099d611 100644 --- a/statement_test.go +++ b/statement_test.go @@ -34,3 +34,16 @@ func TestWhereCloneCorruption(t *testing.T) { }) } } + +func TestNameMatcher(t *testing.T) { + for k, v := range map[string]string{ + "table.name": "name", + "`table`.`name`": "name", + "'table'.'name'": "name", + "'table'.name": "name", + } { + if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + } + } +} From bfda75d0991f15200af1768bd9fe32040c219a29 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 9 Oct 2021 10:42:41 +0800 Subject: [PATCH 21/83] Support specify select/omit columns with table --- statement.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/statement.go b/statement.go index bea4f7f07..c631031ed 100644 --- a/statement.go +++ b/statement.go @@ -628,7 +628,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`\.[\W]?(.+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { From 418c60c83cf8472d883bb9ab8b9821444e7c8f0a Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Sat, 9 Oct 2021 16:55:45 +0800 Subject: [PATCH 22/83] fixed: clauseSelect.Columns missed when use Join And execute multiple query. (#4757) --- callbacks/query.go | 13 ++++++------- tests/joins_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 1cfd618cf..0eee2a439 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -95,7 +95,12 @@ func BuildQuerySQL(db *gorm.DB) { } // inline joins - if len(db.Statement.Joins) != 0 { + joins := []clause.Join{} + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + + if len(db.Statement.Joins) != 0 || len(joins) != 0 { if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { @@ -103,12 +108,6 @@ func BuildQuerySQL(db *gorm.DB) { } } - joins := []clause.Join{} - - if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - joins = fromClause.Joins - } - for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ diff --git a/tests/joins_test.go b/tests/joins_test.go index 25fa20b43..ca8477dc9 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -157,3 +157,30 @@ func TestJoinsWithSelect(t *testing.T) { t.Errorf("Should find all two pets with Join select, got %+v", results) } } + +func TestJoinCount(t *testing.T) { + companyA := Company{Name: "A"} + companyB := Company{Name: "B"} + DB.Create(&companyA) + DB.Create(&companyB) + + user := User{Name: "kingGo", CompanyID: &companyB.ID} + DB.Create(&user) + + query := DB.Model(&User{}).Joins("Company") + //Bug happens when .Count is called on a query. + //Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + var total int64 + query.Count(&total) + + var result User + + // Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id + if err := query.First(&result, user.ID).Error; err != nil { + t.Fatalf("Failed, got error: %v", err) + } + + if result.ID != user.ID { + t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) + } +} From ec58e3319feef549f3f0b01235e3254559b5828c Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Tue, 12 Oct 2021 21:19:08 +0800 Subject: [PATCH 23/83] fixed:panic when create value from nil struct pointer. (#4771) * fixed:create nil pointer * fixed:panic when create value from nil struct pointer. --- schema/schema.go | 7 ++++++- tests/create_test.go | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/schema/schema.go b/schema/schema.go index c425070b3..60a434faa 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -77,7 +77,12 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - modelType := reflect.Indirect(reflect.ValueOf(dest)).Type() + value := reflect.ValueOf(dest) + if value.Kind() == reflect.Ptr && value.IsNil() { + value = reflect.New(value.Type().Elem()) + } + modelType := reflect.Indirect(value).Type() + if modelType.Kind() == reflect.Interface { modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() } diff --git a/tests/create_test.go b/tests/create_test.go index bd968ea8f..060f78af2 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -517,3 +517,12 @@ func TestCreateFromSubQuery(t *testing.T) { t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) } } + +func TestCreateNilPointer(t *testing.T) { + var user *User + + err := DB.Create(user).Error + if err == nil || err != gorm.ErrInvalidValue { + t.Fatalf("it is not ErrInvalidValue") + } +} From 696092e2875d222304cf2bf00b8d1361f0c128d2 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Wed, 13 Oct 2021 14:41:33 +0800 Subject: [PATCH 24/83] update tests' go.mod and tests_all.sh (#4774) --- tests/go.mod | 4 ++-- tests/tests_all.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 6df53d7f9..e18dc1dc4 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,8 +9,8 @@ require ( gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.2 gorm.io/driver/sqlite v1.1.6 - gorm.io/driver/sqlserver v1.0.9 - gorm.io/gorm v1.21.15 + gorm.io/driver/sqlserver v1.1.0 + gorm.io/gorm v1.21.16 ) replace gorm.io/gorm => ../ diff --git a/tests/tests_all.sh b/tests/tests_all.sh index f5657df18..79e0b5b71 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -9,7 +9,7 @@ fi if [ -d tests ] then cd tests - go get -u ./... + go get -u -t ./... go mod download go mod tidy cd .. From 19cf645dbd3e83b1d797911d900f0e248fc554bd Mon Sep 17 00:00:00 2001 From: Jim Date: Sun, 12 Sep 2021 06:42:48 -0400 Subject: [PATCH 25/83] feat: Convert SQL nulls to zero values (ConvertNullToZeroValues) Makes it the default behavior to convert SQL null values to zero values for model fields which are not pointers. --- callbacks/create.go | 32 +++++++++++++-- tests/gorm_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 3 deletions(-) create mode 100644 tests/gorm_test.go diff --git a/callbacks/create.go b/callbacks/create.go index a2944319c..ebfc84263 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -159,6 +159,7 @@ func CreateWithReturning(db *gorm.DB) { break } + resetFields := map[int]*schema.Field{} for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) @@ -172,22 +173,47 @@ func CreateWithReturning(db *gorm.DB) { goto BEGIN } - values[idx] = fieldValue.Addr().Interface() + if field.FieldType.Kind() == reflect.Ptr { + values[idx] = fieldValue.Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) + reflectValue.Elem().Set(fieldValue.Addr()) + values[idx] = reflectValue.Interface() + resetFields[idx] = field + } } db.RowsAffected++ if err := rows.Scan(values...); err != nil { db.AddError(err) } + + for idx, field := range resetFields { + if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { + field.ReflectValueOf(reflectValue).Set(v) + } + } } case reflect.Struct: + resetFields := map[int]*schema.Field{} for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + if field.FieldType.Kind() == reflect.Ptr { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } else { + reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) + reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr()) + values[idx] = reflectValue.Interface() + resetFields[idx] = field + } } - if rows.Next() { db.RowsAffected++ db.AddError(rows.Scan(values...)) + for idx, field := range resetFields { + if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { + field.ReflectValueOf(db.Statement.ReflectValue).Set(v) + } + } } } } else { diff --git a/tests/gorm_test.go b/tests/gorm_test.go new file mode 100644 index 000000000..39741439f --- /dev/null +++ b/tests/gorm_test.go @@ -0,0 +1,98 @@ +package tests_test + +import ( + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "testing" +) + +func TestReturningWithNullToZeroValues(t *testing.T) { + dialect := DB.Dialector.Name() + switch dialect { + case "mysql", "sqlserver": + // these dialects do not support the "returning" clause + return + default: + // This user struct will leverage the existing users table, but override + // the Name field to default to null. + type user struct { + gorm.Model + Name string `gorm:"default:null"` + } + u1 := user{} + c := DB.Callback().Create().Get("gorm:create") + t.Cleanup(func() { + DB.Callback().Create().Replace("gorm:create", c) + }) + DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true})) + + if results := DB.Create(&u1); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } + + got := user{} + results := DB.First(&got, "id = ?", u1.ID) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("first expects: %v, got %v", u1, got) + } + + results = DB.Select("id, name").Find(&got) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if got.ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1, got) + } + + u1.Name = "jinzhu" + if results := DB.Save(&u1); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + u1 = user{} // important to reinitialize this before creating it again + u2 := user{} + db := DB.Session(&gorm.Session{CreateBatchSize: 10}) + + if results := db.Create([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on create: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } else if u1.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u1.ID) + } else if u2.ID == 0 { + t.Fatalf("ID expects : not equal 0, got %v", u2.ID) + } + + var gotUsers []user + results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) + if results.Error != nil { + t.Fatalf("errors happened on first: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected) + } else if gotUsers[0].ID != u1.ID { + t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID) + } else if gotUsers[1].ID != u2.ID { + t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID) + } + + u1.Name = "Jinzhu" + u2.Name = "Zhang" + if results := DB.Save([]*user{&u1, &u2}); results.Error != nil { + t.Fatalf("errors happened on update: %v", results.Error) + } else if results.RowsAffected != 2 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + } +} From b27095e8a1994f48f9099242d191acd43542e458 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 13 Oct 2021 21:01:32 +0800 Subject: [PATCH 26/83] Refactor Convert SQL null values to zero values for model fields which are not pointers #4710 --- callbacks/create.go | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index ebfc84263..c889caf66 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -149,8 +149,11 @@ func CreateWithReturning(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - c := db.Statement.Clauses["ON CONFLICT"] - onConflict, _ := c.Expression.(clause.OnConflict) + var ( + c = db.Statement.Clauses["ON CONFLICT"] + onConflict, _ = c.Expression.(clause.OnConflict) + resetFieldValues = map[int]reflect.Value{} + ) for rows.Next() { BEGIN: @@ -159,7 +162,6 @@ func CreateWithReturning(db *gorm.DB) { break } - resetFields := map[int]*schema.Field{} for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) @@ -179,7 +181,7 @@ func CreateWithReturning(db *gorm.DB) { reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) reflectValue.Elem().Set(fieldValue.Addr()) values[idx] = reflectValue.Interface() - resetFields[idx] = field + resetFieldValues[idx] = fieldValue } } @@ -188,30 +190,31 @@ func CreateWithReturning(db *gorm.DB) { db.AddError(err) } - for idx, field := range resetFields { - if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { - field.ReflectValueOf(reflectValue).Set(v) + for idx, fv := range resetFieldValues { + if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { + fv.Set(v.Elem()) } } } case reflect.Struct: - resetFields := map[int]*schema.Field{} + resetFieldValues := map[int]reflect.Value{} for idx, field := range fields { if field.FieldType.Kind() == reflect.Ptr { values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } else { reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) - reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr()) + fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) + reflectValue.Elem().Set(fieldValue.Addr()) values[idx] = reflectValue.Interface() - resetFields[idx] = field + resetFieldValues[idx] = fieldValue } } if rows.Next() { db.RowsAffected++ db.AddError(rows.Scan(values...)) - for idx, field := range resetFields { - if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { - field.ReflectValueOf(db.Statement.ReflectValue).Set(v) + for idx, fv := range resetFieldValues { + if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { + fv.Set(v.Elem()) } } } From a3bd9c3ea2d3af82ab615d4bdebb17008b525e43 Mon Sep 17 00:00:00 2001 From: Wendell Sun Date: Wed, 13 Oct 2021 01:59:28 +0800 Subject: [PATCH 27/83] fix: automigrate error caused by indexes while using dynamic table name --- schema/schema.go | 24 +++++++++++++++++++----- statement.go | 2 +- tests/migrate_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 60a434faa..c8d79ddc1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -73,6 +73,15 @@ type Tabler interface { // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + return parse(dest, cacheStore, namer, "") +} + +// ParseWithSchemaTable get data type from dialector with extra schema table +func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { + return parse(dest, cacheStore, namer, schemaTable) +} + +func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -107,6 +116,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) modelValue := reflect.New(modelType) tableName := namer.TableName(modelType.Name()) + if schemaTable != "" { + tableName = schemaTable + } if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } @@ -235,11 +247,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - // Wait for the initialization of other goroutines to complete - <-s.initialized - return s, s.err + if schemaTable == "" { + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } } defer func() { diff --git a/statement.go b/statement.go index c631031ed..bbe001063 100644 --- a/statement.go +++ b/statement.go @@ -456,7 +456,7 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] diff --git a/tests/migrate_test.go b/tests/migrate_test.go index ba2714785..06eb96b33 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -381,3 +381,33 @@ func TestMigrateConstraint(t *testing.T) { } } } + +type MigrateUser struct { + gorm.Model + Name string `gorm:"index"` +} + +// https://github.com/go-gorm/gorm/issues/4752 +func TestMigrateIndexesWithDynamicTableName(t *testing.T) { + tableNameSuffixes := []string{"01", "02", "03"} + for _, v := range tableNameSuffixes { + tableName := "migrate_user_" + v + m := DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table(tableName) + }).Migrator() + + if err := m.AutoMigrate(&MigrateUser{}); err != nil { + t.Fatalf("Failed to create table for %#v", tableName) + } + + if !m.HasTable(tableName) { + t.Fatalf("Failed to create table for %#v", tableName) + } + if !m.HasIndex(&MigrateUser{}, "Name") { + t.Fatalf("Should find index for %s's name after AutoMigrate", tableName) + } + if !m.HasIndex(&MigrateUser{}, "DeletedAt") { + t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName) + } + } +} From d3211908a030169184801800ba74a3a3d93ea6ea Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 25 Oct 2021 11:26:44 +0800 Subject: [PATCH 28/83] Refactor ParseWithSchemaTable method and improve test. (#4789) * Refactor ParseWithSchemaTable method and improve test. * Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test. * Rename `schemaTable` to `specialTableName` for clearly argument. --- migrator/migrator.go | 2 +- schema/schema.go | 44 ++++++++++++++++++++++++------------------- statement.go | 6 +++++- tests/migrate_test.go | 33 ++++++++++++++++++++------------ 4 files changed, 52 insertions(+), 33 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 48db151e0..30586a8cf 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error if table, ok := value.(string); ok { stmt.Table = table - } else if err := stmt.Parse(value); err != nil { + } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { return err } diff --git a/schema/schema.go b/schema/schema.go index c8d79ddc1..ce7cf3b13 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -73,15 +73,11 @@ type Tabler interface { // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { - return parse(dest, cacheStore, namer, "") + return ParseWithSpecialTableName(dest, cacheStore, namer, "") } -// ParseWithSchemaTable get data type from dialector with extra schema table -func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { - return parse(dest, cacheStore, namer, schemaTable) -} - -func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { +// ParseWithSpecialTableName get data type from dialector with extra schema table +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -107,7 +103,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - if v, ok := cacheStore.Load(modelType); ok { + // Cache the Schema for performance, + // Use the modelType or modelType + schemaTable (if it present) as cache key. + var schemaCacheKey interface{} + if specialTableName != "" { + schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) + } else { + schemaCacheKey = modelType + } + + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized @@ -116,15 +122,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri modelValue := reflect.New(modelType) tableName := namer.TableName(modelType.Name()) - if schemaTable != "" { - tableName = schemaTable - } if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } + if specialTableName != "" && specialTableName != tableName { + tableName = specialTableName + } schema := &Schema{ Name: modelType.Name(), @@ -140,7 +146,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - if v, loaded := cacheStore.Load(modelType); loaded { + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized @@ -247,13 +254,12 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri } } - if schemaTable == "" { - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - // Wait for the initialization of other goroutines to complete - <-s.initialized - return s, s.err - } + // Cache the schema + if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err } defer func() { diff --git a/statement.go b/statement.go index bbe001063..85432e48f 100644 --- a/statement.go +++ b/statement.go @@ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { + return stmt.ParseWithSpecialTableName(value, "") +} + +func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 06eb96b33..0354e84e1 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -382,32 +382,41 @@ func TestMigrateConstraint(t *testing.T) { } } -type MigrateUser struct { +type DynamicUser struct { gorm.Model - Name string `gorm:"index"` + Name string + CompanyID string `gorm:"index"` } +// To test auto migrate crate indexes for dynamic table name // https://github.com/go-gorm/gorm/issues/4752 func TestMigrateIndexesWithDynamicTableName(t *testing.T) { - tableNameSuffixes := []string{"01", "02", "03"} - for _, v := range tableNameSuffixes { - tableName := "migrate_user_" + v + // Create primary table + if err := DB.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) + } + + // Create sub tables + for _, v := range []string{"01", "02", "03"} { + tableName := "dynamic_users_" + v m := DB.Scopes(func(db *gorm.DB) *gorm.DB { return db.Table(tableName) }).Migrator() - if err := m.AutoMigrate(&MigrateUser{}); err != nil { - t.Fatalf("Failed to create table for %#v", tableName) + if err := m.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) } if !m.HasTable(tableName) { - t.Fatalf("Failed to create table for %#v", tableName) + t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName) } - if !m.HasIndex(&MigrateUser{}, "Name") { - t.Fatalf("Should find index for %s's name after AutoMigrate", tableName) + + if !m.HasIndex(&DynamicUser{}, "CompanyID") { + t.Fatalf("Should have index on %s", "CompanyI.") } - if !m.HasIndex(&MigrateUser{}, "DeletedAt") { - t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName) + + if !m.HasIndex(&DynamicUser{}, "DeletedAt") { + t.Fatalf("Should have index on deleted_at.") } } } From af3fbdc2fcfface01ce2a0795ee0fac3997ddc8e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Oct 2021 22:36:37 +0800 Subject: [PATCH 29/83] Improve returning support --- callbacks/callbacks.go | 28 ++-- callbacks/create.go | 233 ++++++++++------------------------ callbacks/query.go | 2 +- callbacks/update.go | 68 ++++++---- clause/on_conflict.go | 2 +- finisher_api.go | 2 +- scan.go | 282 +++++++++++++++++++++++------------------ tests/go.mod | 6 +- tests/gorm_test.go | 9 +- tests/update_test.go | 8 +- 10 files changed, 300 insertions(+), 340 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index d85c19280..bc18d8544 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -13,7 +13,6 @@ var ( type Config struct { LastInsertIDReversed bool - WithReturning bool CreateClauses []string QueryClauses []string UpdateClauses []string @@ -25,6 +24,19 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { return !db.SkipDefaultTransaction } + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + createCallback := db.Callback().Create() createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) @@ -33,18 +45,12 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - if len(config.CreateClauses) == 0 { - config.CreateClauses = createClauses - } createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) - if len(config.QueryClauses) == 0 { - config.QueryClauses = queryClauses - } queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() @@ -54,9 +60,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - if len(config.DeleteClauses) == 0 { - config.DeleteClauses = deleteClauses - } deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() @@ -64,13 +67,10 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) - updateCallback.Register("gorm:update", Update) + updateCallback.Register("gorm:update", Update(config)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - if len(config.UpdateClauses) == 0 { - config.UpdateClauses = updateClauses - } updateCallback.Clauses = config.UpdateClauses rowCallback := db.Callback().Row() diff --git a/callbacks/create.go b/callbacks/create.go index c889caf66..fe4cd797c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -31,204 +31,111 @@ func BeforeCreate(db *gorm.DB) { } func Create(config *Config) func(db *gorm.DB) { - if config.WithReturning { - return CreateWithReturning + withReturning := false + for _, clause := range config.CreateClauses { + if clause == "RETURNING" { + withReturning = true + } } return func(db *gorm.DB) { if db.Error != nil { return } + onReturning := false - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{}) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - - db.Statement.Build(db.Statement.BuildClauses...) - } - - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err != nil { - db.AddError(err) - return + if db.Statement.Schema != nil { + if !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } } - db.RowsAffected, _ = result.RowsAffected() - - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } - - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { + onReturning = true + if _, ok := db.Statement.Clauses["RETURNING"]; !ok { + fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) + for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { + fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) } - } else { - db.AddError(err) + db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } - } -} - -func CreateWithReturning(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build(db.Statement.BuildClauses...) } - if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { - db.Statement.WriteString(" RETURNING ") - - var ( - fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) - values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) - ) - - for idx, field := range sch.FieldsWithDefaultDBValue { - if idx > 0 { - db.Statement.WriteByte(',') + if !db.DryRun && db.Error == nil { + if onReturning { + doNothing := false + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + onConflict, _ := c.Expression.(clause.OnConflict) + doNothing = onConflict.DoNothing } + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + if doNothing { + gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) + } else { + gorm.Scan(rows, db, gorm.ScanUpdate) + } + rows.Close() + } + } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - fields[idx] = field - db.Statement.WriteQuoted(field.DBName) - } - - if !db.DryRun && db.Error == nil { - db.RowsAffected = 0 - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - defer rows.Close() - - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - c = db.Statement.Clauses["ON CONFLICT"] - onConflict, _ = c.Expression.(clause.OnConflict) - resetFieldValues = map[int]reflect.Value{} - ) - - for rows.Next() { - BEGIN: - reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) - if reflect.Indirect(reflectValue).Kind() != reflect.Struct { - break - } - - for idx, field := range fields { - fieldValue := field.ReflectValueOf(reflectValue) - - if onConflict.DoNothing && !fieldValue.IsZero() { - db.RowsAffected++ + if err != nil { + db.AddError(err) + return + } - if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { - return + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break } - goto BEGIN - } - - if field.FieldType.Kind() == reflect.Ptr { - values[idx] = fieldValue.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) - reflectValue.Elem().Set(fieldValue.Addr()) - values[idx] = reflectValue.Interface() - resetFieldValues[idx] = fieldValue + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } } - } - - db.RowsAffected++ - if err := rows.Scan(values...); err != nil { - db.AddError(err) - } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } - for idx, fv := range resetFieldValues { - if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { - fv.Set(v.Elem()) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } } } - } - case reflect.Struct: - resetFieldValues := map[int]reflect.Value{} - for idx, field := range fields { - if field.FieldType.Kind() == reflect.Ptr { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) - fieldValue := field.ReflectValueOf(db.Statement.ReflectValue) - reflectValue.Elem().Set(fieldValue.Addr()) - values[idx] = reflectValue.Interface() - resetFieldValues[idx] = fieldValue - } - } - if rows.Next() { - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - for idx, fv := range resetFieldValues { - if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() { - fv.Set(v.Elem()) - } + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } } + } else { + db.AddError(err) } - } else { - db.AddError(err) } } - } else if !db.DryRun && db.Error == nil { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } } } } diff --git a/callbacks/query.go b/callbacks/query.go index 0eee2a439..0cfb0b3f1 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -22,7 +22,7 @@ func Query(db *gorm.DB) { } defer rows.Close() - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, 0) } } } diff --git a/callbacks/update.go b/callbacks/update.go index a0a2c579f..90dc6a89a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -50,40 +50,56 @@ func BeforeUpdate(db *gorm.DB) { } } -func Update(db *gorm.DB) { - if db.Error != nil { - return - } - - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) +func Update(config *Config) func(db *gorm.DB) { + withReturning := false + for _, clause := range config.UpdateClauses { + if clause == "RETURNING" { + withReturning = true } } - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { + return func(db *gorm.DB) { + if db.Error != nil { return } - db.Statement.Build(db.Statement.BuildClauses...) - } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build(db.Statement.BuildClauses...) + } - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + if !db.DryRun && db.Error == nil { + if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, gorm.ScanUpdate) + rows.Close() + } + } else { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } } } } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 64ee7f530..309c5fcd2 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -26,7 +26,7 @@ func (onConflict OnConflict) Build(builder Builder) { } builder.WriteString(`) `) } - + if len(onConflict.TargetWhere.Exprs) > 0 { builder.WriteString(" WHERE ") onConflict.TargetWhere.Build(builder) diff --git a/finisher_api.go b/finisher_api.go index e98efc92a..48eb94c5f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -511,7 +511,7 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { } tx.Statement.ReflectValue = elem } - Scan(rows, tx, true) + Scan(rows, tx, ScanInitialized) return tx.Error } diff --git a/scan.go b/scan.go index 4570380d8..37f5112d9 100644 --- a/scan.go +++ b/scan.go @@ -49,13 +49,93 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func Scan(rows *sql.Rows, db *DB, initialized bool) { - columns, _ := rows.Columns() - values := make([]interface{}, len(columns)) +func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, column := range columns { + if sch == nil { + values[idx] = reflectValue.Interface() + } else if field := sch.LookUpField(column); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + continue + } + } + values[idx] = &sql.RawBytes{} + } else if len(columns) == 1 { + sch = nil + values[idx] = reflectValue.Interface() + } else { + values[idx] = &sql.RawBytes{} + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + field.Set(reflectValue, values[idx]) + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(reflectValue) + value := reflect.ValueOf(values[idx]).Elem() + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } + } + } + } + } +} + +type ScanMode uint8 + +const ( + ScanInitialized ScanMode = 1 << 0 + ScanUpdate = 1 << 1 + ScanOnConflictDoNothing = 1 << 2 +) + +func Scan(rows *sql.Rows, db *DB, mode ScanMode) { + var ( + columns, _ = rows.Columns() + values = make([]interface{}, len(columns)) + initialized = mode&ScanInitialized != 0 + update = mode&ScanUpdate != 0 + onConflictDonothing = mode&ScanOnConflictDoNothing != 0 + ) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: + if update && db.Statement.Schema != nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + fields := make([]*schema.Field, len(columns)) + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } + } + + if initialized || rows.Next() { + db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) + } + } + } + if initialized || rows.Next() { columnTypes, _ := rows.ColumnTypes() prepareValues(values, db, columnTypes, columns) @@ -71,7 +151,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } scanIntoMap(mapValue, values, columns) } - case *[]map[string]interface{}: + case *[]map[string]interface{}, []map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) @@ -82,7 +162,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) - *dest = append(*dest, mapValue) + if values, ok := dest.([]map[string]interface{}); ok { + values = append(values, mapValue) + } else if values, ok := dest.(*[]map[string]interface{}); ok { + *values = append(*values, mapValue) + } } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, @@ -96,155 +180,109 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(dest)) } default: - Schema := db.Statement.Schema - reflectValue := db.Statement.ReflectValue + var ( + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue + ) + if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } - switch reflectValue.Kind() { - case reflect.Slice, reflect.Array: - var ( - reflectValueType = reflectValue.Type().Elem() - isPtr = reflectValueType.Kind() == reflect.Ptr - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - ) - - if isPtr { - reflectValueType = reflectValueType.Elem() + reflectValueType := reflectValue.Type() + switch reflectValueType.Kind() { + case reflect.Array, reflect.Slice: + reflectValueType = reflectValueType.Elem() + } + isPtr := reflectValueType.Kind() == reflect.Ptr + if isPtr { + reflectValueType = reflectValueType.Elem() + } + + if sch != nil { + if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { + sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field - if Schema != nil { - if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) - } - - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} } } - // pluck values into slice of data - isPluck := false - if len(fields) == 1 { - if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner + if len(columns) == 1 { + // isPluck + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct - Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time - isPluck = true + sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + sch = nil } } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var elem reflect.Value + + if !update { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } for initialized || rows.Next() { + BEGIN: initialized = false - db.RowsAffected++ - elem := reflect.New(reflectValueType) - if isPluck { - db.AddError(rows.Scan(elem.Interface())) - } else { - for idx, field := range fields { - if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } + if update { + if int(db.RowsAffected) >= reflectValue.Len() { + return } - - db.AddError(rows.Scan(values...)) - - for idx, field := range fields { - if len(joinFields) != 0 && joinFields[idx][0] != nil { - value := reflect.ValueOf(values[idx]).Elem() - relValue := joinFields[idx][0].ReflectValueOf(elem) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) + elem = reflectValue.Index(int(db.RowsAffected)) + if onConflictDonothing { + for _, field := range fields { + if _, ok := field.ValueOf(elem); !ok { + db.RowsAffected++ + goto BEGIN } - - field.Set(relValue, values[idx]) - } else if field != nil { - field.Set(elem, values[idx]) } } - } - - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + elem = reflect.New(reflectValueType) } - } - db.Statement.ReflectValue.Set(reflectValue) - case reflect.Struct, reflect.Ptr: - if reflectValue.Type() != Schema.ModelType { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) - } + db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) - if initialized || rows.Next() { - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - continue - } - } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - values[idx] = dest + if !update { + if isPtr { + reflectValue = reflect.Append(reflectValue, elem) } else { - values[idx] = &sql.RawBytes{} + reflectValue = reflect.Append(reflectValue, elem.Elem()) } } + } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) - - for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - field.Set(reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(reflectValue) - value := reflect.ValueOf(values[idx]).Elem() - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(relValue, values[idx]) - } - } - } - } + if !update { + db.Statement.ReflectValue.Set(reflectValue) + } + case reflect.Struct, reflect.Ptr: + if initialized || rows.Next() { + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) } default: db.AddError(rows.Scan(dest)) diff --git a/tests/go.mod b/tests/go.mod index e18dc1dc4..96db05593 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,9 +7,9 @@ require ( github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 - gorm.io/driver/postgres v1.1.2 - gorm.io/driver/sqlite v1.1.6 - gorm.io/driver/sqlserver v1.1.0 + gorm.io/driver/postgres v1.2.0 + gorm.io/driver/sqlite v1.2.0 + gorm.io/driver/sqlserver v1.1.1 gorm.io/gorm v1.21.16 ) diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 39741439f..9827465cc 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -1,9 +1,9 @@ package tests_test import ( - "gorm.io/gorm" - "gorm.io/gorm/callbacks" "testing" + + "gorm.io/gorm" ) func TestReturningWithNullToZeroValues(t *testing.T) { @@ -20,11 +20,6 @@ func TestReturningWithNullToZeroValues(t *testing.T) { Name string `gorm:"default:null"` } u1 := user{} - c := DB.Callback().Create().Get("gorm:create") - t.Cleanup(func() { - DB.Callback().Create().Replace("gorm:create", c) - }) - DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true})) if results := DB.Create(&u1); results.Error != nil { t.Fatalf("errors happened on create: %v", results.Error) diff --git a/tests/update_test.go b/tests/update_test.go index 631d0d6d2..0dd9465af 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -9,6 +9,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -166,13 +167,16 @@ func TestUpdates(t *testing.T) { } // update with gorm exprs - if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User DB.First(&user4, user3.ID) - user3.Age += 100 + // sqlite, postgres support returning + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + user3.Age += 100 + } AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") } From 835d7bde59a24ac769a1c5ded206b58f7cedfba3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 Oct 2021 07:24:38 +0800 Subject: [PATCH 30/83] Add returning support to delete --- callbacks/callbacks.go | 2 +- callbacks/create.go | 27 +++++++++------------------ callbacks/delete.go | 25 ++++++++++++++++++------- callbacks/helper.go | 13 +++++++++++++ callbacks/update.go | 16 +++++----------- clause/returning.go | 14 +++++++++----- scan.go | 2 +- tests/go.mod | 4 ++-- tests/update_test.go | 2 +- utils/utils.go | 9 +++++++++ 10 files changed, 68 insertions(+), 46 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index bc18d8544..d681aef36 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -57,7 +57,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) - deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:delete", Delete(config)) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Clauses = config.DeleteClauses diff --git a/callbacks/create.go b/callbacks/create.go index fe4cd797c..656273fb1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func BeforeCreate(db *gorm.DB) { @@ -31,18 +32,12 @@ func BeforeCreate(db *gorm.DB) { } func Create(config *Config) func(db *gorm.DB) { - withReturning := false - for _, clause := range config.CreateClauses { - if clause == "RETURNING" { - withReturning = true - } - } + supportReturning := utils.Contains(config.CreateClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { return } - onReturning := false if db.Statement.Schema != nil { if !db.Statement.Unscoped { @@ -51,8 +46,7 @@ func Create(config *Config) func(db *gorm.DB) { } } - if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { - onReturning = true + if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { @@ -72,18 +66,15 @@ func Create(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - if onReturning { - doNothing := false + + if ok, mode := hasReturning(db, supportReturning); ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - onConflict, _ := c.Expression.(clause.OnConflict) - doNothing = onConflict.DoNothing + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing + } } if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - if doNothing { - gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing) - } else { - gorm.Scan(rows, db, gorm.ScanUpdate) - } + gorm.Scan(rows, db, mode) rows.Close() } } else { diff --git a/callbacks/delete.go b/callbacks/delete.go index 91659c511..a1fd0a573 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func BeforeDelete(db *gorm.DB) { @@ -104,8 +105,14 @@ func DeleteBeforeAssociations(db *gorm.DB) { } } -func Delete(db *gorm.DB) { - if db.Error == nil { +func Delete(config *Config) func(db *gorm.DB) { + supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") + + return func(db *gorm.DB) { + if db.Error != nil { + return + } + if db.Statement.Schema != nil && !db.Statement.Unscoped { for _, c := range db.Statement.Schema.DeleteClauses { db.Statement.AddClause(c) @@ -144,12 +151,16 @@ func Delete(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - - if err == nil { - db.RowsAffected, _ = result.RowsAffected() + if ok, mode := hasReturning(db, supportReturning); ok { + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() + } } else { - db.AddError(err) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.AddError(err) == nil { + db.RowsAffected, _ = result.RowsAffected() + } } } } diff --git a/callbacks/helper.go b/callbacks/helper.go index d83d20cef..1d96ab26b 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -93,3 +93,16 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st } return } + +func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { + if supportReturning { + if c, ok := tx.Statement.Clauses["RETURNING"]; ok { + returning, _ := c.Expression.(clause.Returning) + if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { + return true, 0 + } + return true, gorm.ScanUpdate + } + } + return false, 0 +} diff --git a/callbacks/update.go b/callbacks/update.go index 90dc6a89a..991581ddf 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func SetupUpdateReflectValue(db *gorm.DB) { @@ -51,12 +52,7 @@ func BeforeUpdate(db *gorm.DB) { } func Update(config *Config) func(db *gorm.DB) { - withReturning := false - for _, clause := range config.UpdateClauses { - if clause == "RETURNING" { - withReturning = true - } - } + supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { @@ -86,18 +82,16 @@ func Update(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok { + if ok, mode := hasReturning(db, supportReturning); ok { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, gorm.ScanUpdate) + gorm.Scan(rows, db, mode) rows.Close() } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { + if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) } } } diff --git a/clause/returning.go b/clause/returning.go index 04bc96dab..d94b7a4ca 100644 --- a/clause/returning.go +++ b/clause/returning.go @@ -11,12 +11,16 @@ func (returning Returning) Name() string { // Build build where clause func (returning Returning) Build(builder Builder) { - for idx, column := range returning.Columns { - if idx > 0 { - builder.WriteByte(',') - } + if len(returning.Columns) > 0 { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(column) + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') } } diff --git a/scan.go b/scan.go index 37f5112d9..70fcda4a7 100644 --- a/scan.go +++ b/scan.go @@ -241,7 +241,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { case reflect.Slice, reflect.Array: var elem reflect.Value - if !update { + if !update && reflectValue.Len() != 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) } diff --git a/tests/go.mod b/tests/go.mod index 96db05593..6d9e68c15 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,8 +9,8 @@ require ( gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.2.0 gorm.io/driver/sqlite v1.2.0 - gorm.io/driver/sqlserver v1.1.1 - gorm.io/gorm v1.21.16 + gorm.io/driver/sqlserver v1.1.2 + gorm.io/gorm v1.22.0 ) replace gorm.io/gorm => ../ diff --git a/tests/update_test.go b/tests/update_test.go index 0dd9465af..f58656edd 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -167,7 +167,7 @@ func TestUpdates(t *testing.T) { } // update with gorm exprs - if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User diff --git a/utils/utils.go b/utils/utils.go index 9c238ac55..f00f92ba3 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -72,6 +72,15 @@ func ToStringKey(values ...interface{}) string { return strings.Join(results, "_") } +func Contains(elems []string, elem string) bool { + for _, e := range elems { + if elem == e { + return true + } + } + return false +} + func AssertEqual(src, dst interface{}) bool { if !reflect.DeepEqual(src, dst) { if valuer, ok := src.(driver.Valuer); ok { From e953880d19ff600c658456c4cd7734ab746f4681 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 Oct 2021 08:03:23 +0800 Subject: [PATCH 31/83] Add returning tests --- callbacks/update.go | 30 +++++++++++++++----------- scan.go | 16 -------------- soft_delete.go | 2 +- tests/delete_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++ tests/go.mod | 4 ++-- tests/update_test.go | 39 ++++++++++++++++++++++++++++----- 6 files changed, 106 insertions(+), 36 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 991581ddf..1603a5172 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) { if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + dest := db.Statement.Dest + db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() gorm.Scan(rows, db, mode) + db.Statement.Dest = dest rows.Close() } } else { @@ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - var primaryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { - var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + if size := stmt.ReflectValue.Len(); size > 0 { + var primaryKeyExprs []clause.Expression + for i := 0; i < stmt.ReflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + } } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { diff --git a/scan.go b/scan.go index 70fcda4a7..360ed8b9b 100644 --- a/scan.go +++ b/scan.go @@ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: - if update && db.Statement.Schema != nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - fields := make([]*schema.Field, len(columns)) - for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } - } - - if initialized || rows.Next() { - db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) - } - } - } - if initialized || rows.Next() { columnTypes, _ := rows.ColumnTypes() prepareValues(values, db, columnTypes, columns) diff --git a/soft_delete.go b/soft_delete.go index af02f8fd4..11c4fafc0 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } stmt.AddClauseIfNotExists(clause.Update{}) - stmt.Build("UPDATE", "SET", "WHERE") + stmt.Build(stmt.DB.Callback().Update().Clauses...) } } diff --git a/tests/delete_test.go b/tests/delete_test.go index f62cc6061..049b2ac46 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) { } } } + +// only sqlite, postgres support returning +func TestSoftDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("delete-returning-1", Config{}), + GetUser("delete-returning-2", Config{}), + GetUser("delete-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} + +func TestDeleteReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + companies := []Company{ + {Name: "delete-returning-1"}, + {Name: "delete-returning-2"}, + {Name: "delete-returning-3"}, + } + DB.Create(&companies) + + var results []Company + DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) + if len(results) != 2 { + t.Errorf("failed to return delete data, got %v", results) + } + + var count int64 + DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) + if count != 1 { + t.Errorf("failed to delete data, current count %v", count) + } +} diff --git a/tests/go.mod b/tests/go.mod index 6d9e68c15..ab3ef8981 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,8 +7,8 @@ require ( github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.2 - gorm.io/driver/postgres v1.2.0 - gorm.io/driver/sqlite v1.2.0 + gorm.io/driver/postgres v1.2.1 + gorm.io/driver/sqlite v1.2.2 gorm.io/driver/sqlserver v1.1.2 gorm.io/gorm v1.22.0 ) diff --git a/tests/update_test.go b/tests/update_test.go index f58656edd..14ed98207 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -167,16 +167,13 @@ func TestUpdates(t *testing.T) { } // update with gorm exprs - if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User DB.First(&user4, user3.ID) - // sqlite, postgres support returning - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { - user3.Age += 100 - } + user3.Age += 100 AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") } @@ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) { t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) } } + +// only sqlite, postgres support returning +func TestUpdateReturning(t *testing.T) { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + return + } + + users := []*User{ + GetUser("update-returning-1", Config{}), + GetUser("update-returning-2", Config{}), + GetUser("update-returning-3", Config{}), + } + DB.Create(&users) + + var results []User + DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) + if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { + t.Errorf("failed to return updated data, got %v", results) + } + + if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + + if results[1].Age-results[0].Age != 100 { + t.Errorf("failed to return updated age column") + } +} From 9f533950a2864277d4210a355531abc49da0246b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 28 Oct 2021 17:12:31 +0800 Subject: [PATCH 32/83] Add dest value if current size equal zero --- scan.go | 3 ++- tests/go.mod | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scan.go b/scan.go index 360ed8b9b..119049c68 100644 --- a/scan.go +++ b/scan.go @@ -225,7 +225,8 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { case reflect.Slice, reflect.Array: var elem reflect.Value - if !update && reflectValue.Len() != 0 { + if !update || reflectValue.Len() == 0 { + update = false db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) } diff --git a/tests/go.mod b/tests/go.mod index ab3ef8981..52781a8b9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 - gorm.io/driver/mysql v1.1.2 + gorm.io/driver/mysql v1.1.3 gorm.io/driver/postgres v1.2.1 gorm.io/driver/sqlite v1.2.2 gorm.io/driver/sqlserver v1.1.2 From 9635d25150b35581bf75d5312daf2a6835af261b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Nov 2021 12:00:36 +0800 Subject: [PATCH 33/83] Fix query with uninitialized map --- scan.go | 3 +++ tests/go.mod | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 119049c68..2d0c8fc6b 100644 --- a/scan.go +++ b/scan.go @@ -130,6 +130,9 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { mapValue, ok := dest.(map[string]interface{}) if !ok { if v, ok := dest.(*map[string]interface{}); ok { + if *v == nil { + *v = map[string]interface{}{} + } mapValue = *v } } diff --git a/tests/go.mod b/tests/go.mod index 52781a8b9..8ced0b2f7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.10.3 gorm.io/driver/mysql v1.1.3 gorm.io/driver/postgres v1.2.1 - gorm.io/driver/sqlite v1.2.2 + gorm.io/driver/sqlite v1.2.3 gorm.io/driver/sqlserver v1.1.2 gorm.io/gorm v1.22.0 ) From 8de266b4a7391145e962918abb3a9705c13fd2c8 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 1 Nov 2021 17:08:54 +0800 Subject: [PATCH 34/83] Add ToSQL support to generate SQL string. (#4787) * Add db.ToSQL method for generate SQL string. * Improve sql builder test for all dialects. Improve assertEqualSQL test helper for ignore quotes in SQL. --- gorm.go | 15 +++++ tests/sql_builder_test.go | 135 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/gorm.go b/gorm.go index 71cd01e82..fc70f6845 100644 --- a/gorm.go +++ b/gorm.go @@ -441,3 +441,18 @@ func (db *DB) Use(plugin Plugin) error { db.Plugins[name] = plugin return nil } + +// ToSQL for generate SQL string. +// +// db.ToSQL(func(tx *gorm.DB) *gorm.DB { +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// }) +func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { + tx := queryFn(db.Session(&Session{DryRun: true})) + stmt := tx.Statement + + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) +} diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 081b96c92..2f9fd8dad 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -8,6 +8,8 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" + + "time" ) func TestRow(t *testing.T) { @@ -287,3 +289,136 @@ func TestFromWithJoins(t *testing.T) { t.Errorf("The first join condition is over written instead of combining") } } + +func TestToSQL(t *testing.T) { + // By default DB.DryRun should false + if DB.DryRun { + t.Fatal("Failed expect DB.DryRun to be false") + } + + if DB.Dialector.Name() == "sqlserver" { + t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") + } + + date, _ := time.Parse("2006-01-02", "2021-10-18") + + // find + sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Limit(10).Order("age desc").Find(&[]User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) + + // after model chagned + if DB.Statement.DryRun || DB.DryRun { + t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") + } + + if DB.Statement.SQL.String() != "" { + t.Fatal("Failed expect DB.Statement.SQL to be empty") + } + + // first + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) + + // last and unscoped + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Unscoped().Where(&User{Name: "bar", Age: 12}).Limit(10).Offset(5).Order("name ASC").Last(&User{}) + }) + assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'bar' AND "users"."age" = 12 ORDER BY name ASC,"users"."id" DESC LIMIT 1 OFFSET 5`, sql) + + // create + user := &User{Name: "foo", Age: 20} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Create(user) + }) + assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) + + // save + user = &User{Name: "foo", Age: 20} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Save(user) + }) + assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) + + // updates + user = &User{Name: "bar", Age: 22} + user.CreatedAt = date + user.UpdatedAt = date + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Updates(user) + }) + assertEqualSQL(t, `UPDATE "users" SET "created_at"='2021-10-18 00:00:00',"updated_at"='2021-10-18 19:50:09.438',"name"='bar',"age"=22 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // update + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).Update("name", "Foo bar") + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar',"updated_at"='2021-10-18 19:50:09.438' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // UpdateColumn + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).UpdateColumn("name", "Foo bar") + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&User{}).Where("id = ?", 100).UpdateColumns(User{Name: "Foo", Age: 100}) + }) + assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) + + // after model chagned + if DB.Statement.DryRun || DB.DryRun { + t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") + } +} + +// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals. +func assertEqualSQL(t *testing.T, expected string, actually string) { + t.Helper() + + // replace SQL quote, convert into postgresql like "" + expected = replaceQuoteInSQL(expected) + actually = replaceQuoteInSQL(actually) + + // ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. + var updatedAtRe = regexp.MustCompile(`(?i)"updated_at"=".+?"`) + actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) + expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) + + // ignore RETURNING "id" (only in PostgreSQL) + var returningRe = regexp.MustCompile(`(?i)RETURNING "id"`) + actually = returningRe.ReplaceAllString(actually, ``) + expected = returningRe.ReplaceAllString(expected, ``) + + actually = strings.TrimSpace(actually) + expected = strings.TrimSpace(expected) + + if actually != expected { + t.Fatalf("\nexpected: %s\nactually: %s", expected, actually) + } +} + +func replaceQuoteInSQL(sql string) string { + // convert single quote into double quote + sql = strings.Replace(sql, `'`, `"`, -1) + + // convert dialect speical quote into double quote + switch DB.Dialector.Name() { + case "postgres": + sql = strings.Replace(sql, `"`, `"`, -1) + case "mysql", "sqlite": + sql = strings.Replace(sql, "`", `"`, -1) + case "sqlserver": + sql = strings.Replace(sql, `'`, `"`, -1) + } + + return sql +} From 7b927900e9924ce83dba63a7aadf3866fe216044 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Nov 2021 17:09:08 +0800 Subject: [PATCH 35/83] Bump gorm.io/driver/sqlserver from 1.1.2 to 1.2.0 in /tests (#4820) Bumps [gorm.io/driver/sqlserver](https://github.com/go-gorm/sqlserver) from 1.1.2 to 1.2.0. - [Release notes](https://github.com/go-gorm/sqlserver/releases) - [Commits](https://github.com/go-gorm/sqlserver/compare/v1.1.2...v1.2.0) --- updated-dependencies: - dependency-name: gorm.io/driver/sqlserver dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 8ced0b2f7..b4c5d79df 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,8 +9,8 @@ require ( gorm.io/driver/mysql v1.1.3 gorm.io/driver/postgres v1.2.1 gorm.io/driver/sqlite v1.2.3 - gorm.io/driver/sqlserver v1.1.2 - gorm.io/gorm v1.22.0 + gorm.io/driver/sqlserver v1.2.0 + gorm.io/gorm v1.22.2 ) replace gorm.io/gorm => ../ From c170af11e909098311b0c2f188b7917803e714e9 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Wed, 3 Nov 2021 13:39:52 +0800 Subject: [PATCH 36/83] fix connections leak (#4826) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix connections leak * fix connections leak * fix connections leak * fix connections leak Co-authored-by: 李龙 --- callbacks/transaction.go | 2 +- finisher_api.go | 60 ++++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 8ba2ba3b0..f116d19f0 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -5,7 +5,7 @@ import ( ) func BeginTransaction(db *gorm.DB) { - if !db.Config.SkipDefaultTransaction { + if !db.Config.SkipDefaultTransaction && db.Error == nil { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool db.InstanceSet("gorm:started_transaction", true) diff --git a/finisher_api.go b/finisher_api.go index 48eb94c5f..efdbd563f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -285,44 +285,44 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - - if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { - if c, ok := tx.Statement.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok { - tx.assignInterfacesToValue(where.Exprs) + if tx = queryTx.Find(dest, conds...); tx.Error == nil { + if tx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } } - } - // initialize with attrs, conds - if len(tx.Statement.attrs) > 0 { - tx.assignInterfacesToValue(tx.Statement.attrs...) - } + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } - // initialize with attrs, conds - if len(tx.Statement.assigns) > 0 { - tx.assignInterfacesToValue(tx.Statement.assigns...) - } + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } - return tx.Create(dest) - } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) - assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - assigns[column] = eq.Value - case clause.Column: - assigns[column.Name] = eq.Value - default: + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + default: + } } } - } - return tx.Model(dest).Updates(assigns) + return tx.Model(dest).Updates(assigns) + } } - - return db + return tx } // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields From 4c8810a8484df2ed450e41913c886b54367a3969 Mon Sep 17 00:00:00 2001 From: heige Date: Thu, 4 Nov 2021 13:45:44 +0800 Subject: [PATCH 37/83] Refactor if logic (#4683) * adjust code for preload * adjust code for Create --- callbacks/create.go | 119 +++++++++++++++++++---------------- callbacks/delete.go | 145 ++++++++++++++++++++++--------------------- callbacks/preload.go | 39 ++++++------ 3 files changed, 163 insertions(+), 140 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 656273fb1..36e165a01 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -65,67 +65,82 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build(db.Statement.BuildClauses...) } - if !db.DryRun && db.Error == nil { + isDryRun := !db.DryRun && db.Error == nil + if !isDryRun { + return + } - if ok, mode := hasReturning(db, supportReturning); ok { - if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { - if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { - mode |= gorm.ScanOnConflictDoNothing - } - } - if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, mode) - rows.Close() + ok, mode := hasReturning(db, supportReturning) + if ok { + if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { + mode |= gorm.ScanOnConflictDoNothing } - } else { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } - if err != nil { - db.AddError(err) - return - } + rows, err := db.Statement.ConnPool.QueryContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() + } - db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + return + } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + result, err := db.Statement.ConnPool.ExecContext( + db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., + ) + if err != nil { + db.AddError(err) + return + } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } + db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && + db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } - } else { - db.AddError(err) } } + case reflect.Struct: + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index a1fd0a573..087375051 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -26,82 +26,87 @@ func BeforeDelete(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + if !restricted { + return + } + + for column, v := range selectColumns { + if !v { + continue + } + + rel, ok := db.Statement.Schema.Relationships.Relations[column] + if !ok { + continue + } - if restricted { - for column, v := range selectColumns { - if v { - if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { - switch rel.Type { - case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) - modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) - withoutConditions := false - if db.Statement.Unscoped { - tx = tx.Unscoped() - } - - if len(db.Statement.Selects) > 0 { - selects := make([]string, 0, len(db.Statement.Selects)) - for _, s := range db.Statement.Selects { - if s == clause.Associations { - selects = append(selects, s) - } else if strings.HasPrefix(s, column+".") { - selects = append(selects, strings.TrimPrefix(s, column+".")) - } - } - - if len(selects) > 0 { - tx = tx.Select(selects) - } - } - - for _, cond := range queryConds { - if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { - withoutConditions = true - break - } - } - - if !withoutConditions { - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } - } - case schema.Many2Many: - var ( - queryConds = make([]clause.Expression, 0, len(rel.References)) - foreignFields = make([]*schema.Field, 0, len(rel.References)) - relForeignKeys = make([]string, 0, len(rel.References)) - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() - table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) - ) - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - foreignFields = append(foreignFields, ref.PrimaryKey) - relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) - } else if ref.PrimaryValue != "" { - queryConds = append(queryConds, clause.Eq{ - Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - }) - } - } - - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) - column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) - queryConds = append(queryConds, clause.IN{Column: column, Values: values}) - - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return - } + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } + + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { + selects = append(selects, strings.TrimPrefix(s, columnPrefix)) } } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return } } } + } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 9882590c3..c887c6c0a 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -145,27 +145,30 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues[idx], _ = field.ValueOf(elem) } - if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { - for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) - if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { - reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) - } + datas, ok := identityMap[utils.ToStringKey(fieldValues...)] + if !ok { + db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", + elem.Interface())) + continue + } + + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } - reflectFieldValue = reflect.Indirect(reflectFieldValue) - switch reflectFieldValue.Kind() { - case reflect.Struct: - rel.Field.Set(data, reflectResults.Index(i).Interface()) - case reflect.Slice, reflect.Array: - if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) - } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) - } + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, elem.Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } - } else { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) } } } From d9d5c4dce0dcf322202f3336f5951c844475cc51 Mon Sep 17 00:00:00 2001 From: Mayank Govilla <31316460+mgovilla@users.noreply.github.com> Date: Sun, 7 Nov 2021 20:47:29 -0500 Subject: [PATCH 38/83] Fix self-referential belongs to constraint (#4801) * create tests for self-ref has one migration * add relation equality check to avoid skipping self-referential schemas * remove drop table error check --- schema/relationship.go | 2 +- schema/relationship_test.go | 15 +++++++++++++++ tests/migrate_test.go | 19 +++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/schema/relationship.go b/schema/relationship.go index 5699ec5f1..c5d3dcad9 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -519,7 +519,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { if rel.Type == BelongsTo { for _, r := range rel.FieldSchema.Relationships.Relations { - if r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { + if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { matched := true for idx, ref := range r.References { if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && diff --git a/schema/relationship_test.go b/schema/relationship_test.go index cb616fc07..afa103b3d 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -93,6 +93,21 @@ func TestBelongsToWithOnlyReferences2(t *testing.T) { }) } +func TestSelfReferentialBelongsTo(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatorID *int32 + Creator *User + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatorID", "User", "", false}}, + }) + +} + func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0354e84e1..f0467c5b8 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -54,6 +54,25 @@ func TestMigrate(t *testing.T) { } } +func TestAutoMigrateSelfReferential(t *testing.T) { + type MigratePerson struct { + ID uint + Name string + ManagerID *uint + Manager *MigratePerson + } + + DB.Migrator().DropTable(&MigratePerson{}) + + if err := DB.AutoMigrate(&MigratePerson{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } + + if !DB.Migrator().HasConstraint("migrate_people", "fk_migrate_people_manager") { + t.Fatalf("Failed to find has one constraint between people and managers") + } +} + func TestSmartMigrateColumn(t *testing.T) { fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] From b23c3b290e98d005cdc13e574d4a7e36045693dd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Nov 2021 18:49:49 +0800 Subject: [PATCH 39/83] Don't query with primary key when using Save --- callbacks.go | 8 +++++--- finisher_api.go | 2 +- logger/logger.go | 1 - statement.go | 4 ++++ 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/callbacks.go b/callbacks.go index 7ab38926b..f344649ea 100644 --- a/callbacks.go +++ b/callbacks.go @@ -130,9 +130,11 @@ func (p *processor) Execute(db *DB) *DB { f(db) } - db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected - }, db.Error) + if stmt.SQL.Len() > 0 { + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + }, db.Error) + } if !stmt.DB.DryRun { stmt.SQL.Reset() diff --git a/finisher_api.go b/finisher_api.go index efdbd563f..920ea739c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -101,7 +101,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } diff --git a/logger/logger.go b/logger/logger.go index 69d41113f..0c4ca4a01 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -140,7 +140,6 @@ func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { // Trace print sql message func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { - if l.LogLevel <= Silent { return } diff --git a/statement.go b/statement.go index 85432e48f..1bd6c2b24 100644 --- a/statement.go +++ b/statement.go @@ -665,6 +665,10 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( for _, omit := range stmt.Omits { if stmt.Schema == nil { results[omit] = false + } else if omit == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = false + } } else if omit == clause.Associations { for _, rel := range stmt.Schema.Relationships.Relations { results[rel.Name] = false From ca7accdbf6b1ea1145c9342e661827b001c44f7a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Nov 2021 19:40:40 +0800 Subject: [PATCH 40/83] Fix preload all associations with inline conditions, close #4836 --- callbacks/query.go | 2 +- tests/preload_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index 0cfb0b3f1..6ca3a1fb0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -221,7 +221,7 @@ func Preload(db *gorm.DB) { for _, name := range preloadNames { if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { - preload(db, rel, db.Statement.Preloads[name], preloadMap[name]) + preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/tests/preload_test.go b/tests/preload_test.go index 8f49955e7..a3e672003 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -147,6 +147,19 @@ func TestPreloadWithConds(t *testing.T) { for i, u := range users3 { CheckUser(t, u, users[i]) } + + var user4 User + DB.Delete(&users3[0].Account) + + if err := DB.Preload(clause.Associations).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID != 0 { + t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) + } + + if err := DB.Preload(clause.Associations, func(tx *gorm.DB) *gorm.DB { + return tx.Unscoped() + }).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID == 0 { + t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) + } } func TestNestedPreloadWithConds(t *testing.T) { From 5daa413f418d8b745d5e7178b07405b0a215f5f2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 8 Nov 2021 20:20:55 +0800 Subject: [PATCH 41/83] Stabilize schema.FieldsWithDefaultDBValue's order, close #4643 --- schema/schema.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/schema.go b/schema/schema.go index ce7cf3b13..eca113e96 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -222,7 +222,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } - for _, field := range schema.FieldsByDBName { + for _, field := range schema.Fields { if field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } From 33bc56cbb5916173c670d28fb7fcf6a2bbd0b185 Mon Sep 17 00:00:00 2001 From: riverchu Date: Tue, 9 Nov 2021 19:55:47 +0800 Subject: [PATCH 42/83] feat(update): update when has SET clause --- callbacks/update.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/update.go b/callbacks/update.go index 1603a5172..8efc3983a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -70,7 +70,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) if set := ConvertToAssignments(db.Statement); len(set) != 0 { db.Statement.AddClause(set) - } else { + } else if _, ok := db.Statement.Clauses["SET"]; !ok { return } db.Statement.Build(db.Statement.BuildClauses...) From 5e64ac7de9765319da7a588a13bc06d67f7416c9 Mon Sep 17 00:00:00 2001 From: "dino.ma" Date: Sat, 13 Nov 2021 14:03:33 +0800 Subject: [PATCH 43/83] feat(migrator,migrator/migrator.go,tests/migrate_test.go) : Get multiple data tables for migrator. (#4841) * feat(migrator,migrator/migrator.go,tests/migrate_test.go) : Get multiple data tables for migrator. * feat(migrator.go and migrator/migrator.go) : remove Table Struct replace with []string * fix(migrator) : Return all data tables * Update migrator.go * fix(migrator/migrator.go):remove var sql * feat(migrate_test.go/go.mod):update sqlserver,sqlite,postgres,pq version and add getTables test * fix(migrate_test.go):change GetTables Method Test,use intersection Co-authored-by: dino.ma --- migrator.go | 1 + migrator/migrator.go | 4 ++++ tests/go.mod | 8 ++++---- tests/migrate_test.go | 18 +++++++++++++++++- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/migrator.go b/migrator.go index 7dddcabf2..2a8b42548 100644 --- a/migrator.go +++ b/migrator.go @@ -54,6 +54,7 @@ type Migrator interface { DropTable(dst ...interface{}) error HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error + GetTables() (tableList []string, err error) // Columns AddColumn(dst interface{}, field string) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 30586a8cf..95a708deb 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -155,6 +155,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } +func (m Migrator) GetTables() (tableList []string, err error) { + return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error +} + func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) diff --git a/tests/go.mod b/tests/go.mod index b4c5d79df..e321d3d8b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,11 +5,11 @@ go 1.14 require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.2 - github.com/lib/pq v1.10.3 + github.com/lib/pq v1.10.4 gorm.io/driver/mysql v1.1.3 - gorm.io/driver/postgres v1.2.1 - gorm.io/driver/sqlite v1.2.3 - gorm.io/driver/sqlserver v1.2.0 + gorm.io/driver/postgres v1.2.2 + gorm.io/driver/sqlite v1.2.4 + gorm.io/driver/sqlserver v1.2.1 gorm.io/gorm v1.22.2 ) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index f0467c5b8..789a5e451 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -14,7 +14,6 @@ func TestMigrate(t *testing.T) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) - DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") if err := DB.Migrator().DropTable(allModels...); err != nil { @@ -25,6 +24,23 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to auto migrate, but got error %v", err) } + if tables, err := DB.Migrator().GetTables(); err != nil { + t.Fatalf("Failed to get database all tables, but got error %v", err) + } else { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + hasTable := false + for _, t2 := range tables { + if t2 == t1 { + hasTable = true + break + } + } + if !hasTable { + t.Fatalf("Failed to get table %v when GetTables", t1) + } + } + } + for _, m := range allModels { if !DB.Migrator().HasTable(m) { t.Fatalf("Failed to create table for %#v---", m) From 11d5c346aeab2902d801691ed4bf926c41de7c7c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Nov 2021 11:39:42 +0800 Subject: [PATCH 44/83] Bump github.com/jinzhu/now from 1.1.2 to 1.1.3 (#4865) Bumps [github.com/jinzhu/now](https://github.com/jinzhu/now) from 1.1.2 to 1.1.3. - [Release notes](https://github.com/jinzhu/now/releases) - [Commits](https://github.com/jinzhu/now/compare/v1.1.2...v1.1.3) --- updated-dependencies: - dependency-name: github.com/jinzhu/now dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index d95d3f100..75662c80e 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.2 + github.com/jinzhu/now v1.1.3 ) diff --git a/go.sum b/go.sum index c66a6b576..c17a1ceb5 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= -github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.3 h1:PlHq1bSCSZL9K0wUhbm2pGLoTWs2GwVhsP6emvGV/ZI= +github.com/jinzhu/now v1.1.3/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= From 0f8e86159765ac6b048ce259667eed2defbc43e9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Nov 2021 11:40:03 +0800 Subject: [PATCH 45/83] Bump github.com/jinzhu/now from 1.1.2 to 1.1.3 in /tests (#4866) Bumps [github.com/jinzhu/now](https://github.com/jinzhu/now) from 1.1.2 to 1.1.3. - [Release notes](https://github.com/jinzhu/now/releases) - [Commits](https://github.com/jinzhu/now/compare/v1.1.2...v1.1.3) --- updated-dependencies: - dependency-name: github.com/jinzhu/now dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index e321d3d8b..43c580f62 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,7 +4,7 @@ go 1.14 require ( github.com/google/uuid v1.3.0 - github.com/jinzhu/now v1.1.2 + github.com/jinzhu/now v1.1.3 github.com/lib/pq v1.10.4 gorm.io/driver/mysql v1.1.3 gorm.io/driver/postgres v1.2.2 From cff7845e584662528c2c1bff5292b18a68f2fb0a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Nov 2021 11:40:18 +0800 Subject: [PATCH 46/83] Bump gorm.io/driver/mysql from 1.1.3 to 1.2.0 in /tests (#4856) Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.1.3 to 1.2.0. - [Release notes](https://github.com/go-gorm/mysql/releases) - [Commits](https://github.com/go-gorm/mysql/compare/v1.1.3...v1.2.0) --- updated-dependencies: - dependency-name: gorm.io/driver/mysql dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 43c580f62..6502c1790 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.3 github.com/lib/pq v1.10.4 - gorm.io/driver/mysql v1.1.3 + gorm.io/driver/mysql v1.2.0 gorm.io/driver/postgres v1.2.2 gorm.io/driver/sqlite v1.2.4 gorm.io/driver/sqlserver v1.2.1 - gorm.io/gorm v1.22.2 + gorm.io/gorm v1.22.3 ) replace gorm.io/gorm => ../ From b8f33a42a469f5a4ab64bb8937ef7c8e5524af7e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Nov 2021 17:11:52 +0800 Subject: [PATCH 47/83] Add unused argument (#4871) * Append unused argument to gorm statement --- .github/workflows/reviewdog.yml | 4 +++- clause/expression.go | 6 ++++++ statement.go | 5 +++++ tests/go.mod | 4 +++- tests/postgres_test.go | 4 ++++ 5 files changed, 21 insertions(+), 2 deletions(-) diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index d55a46999..abfd57f3e 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,6 +6,8 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v1 + uses: actions/checkout@v2 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 + with: + golangci_lint_flags: '-E cyclop,unconvert,misspell,unparam,ineffassign,gocritic,prealloc,exportloopref,gosec' diff --git a/clause/expression.go b/clause/expression.go index e914b7b30..d04983068 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -67,6 +67,12 @@ func (expr Expr) Build(builder Builder) { builder.WriteByte(v) } } + + if idx < len(expr.Vars) { + for _, v := range expr.Vars[idx:] { + builder.AddVar(builder, sql.NamedArg{Value: v}) + } + } } // NamedExpr raw expression for named expr diff --git a/statement.go b/statement.go index 1bd6c2b24..453e485e6 100644 --- a/statement.go +++ b/statement.go @@ -284,6 +284,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} } + if strings.Contains(strings.TrimSpace(s), " ") { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } + if len(args) == 1 { return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } diff --git a/tests/go.mod b/tests/go.mod index 6502c1790..7e5ea8a5c 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,11 +4,13 @@ go 1.14 require ( github.com/google/uuid v1.3.0 + github.com/jackc/pgx/v4 v4.14.0 // indirect github.com/jinzhu/now v1.1.3 github.com/lib/pq v1.10.4 + golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 // indirect gorm.io/driver/mysql v1.2.0 gorm.io/driver/postgres v1.2.2 - gorm.io/driver/sqlite v1.2.4 + gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.2.1 gorm.io/gorm v1.22.3 ) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 94077d1d0..85671864f 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -44,6 +44,10 @@ func TestPostgres(t *testing.T) { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } + + if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } } type Post struct { From 9d5f315b6d5382dfbdaa20d46751e894b577d337 Mon Sep 17 00:00:00 2001 From: heige Date: Mon, 29 Nov 2021 09:33:20 +0800 Subject: [PATCH 48/83] feat: go code style adjust and optimize code for callbacks package (#4861) * feat: go code style adjust and optimize code for callbacks package * Update scan.go --- callbacks/associations.go | 26 +++++++++++++------------- callbacks/create.go | 21 +++++++++++---------- callbacks/delete.go | 15 +++++++++------ callbacks/preload.go | 5 +++-- callbacks/raw.go | 5 +++-- callbacks/row.go | 19 ++++++++++--------- callbacks/transaction.go | 7 ++++--- callbacks/update.go | 2 +- migrator/migrator.go | 4 +++- scan.go | 6 +++--- 10 files changed, 60 insertions(+), 50 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index d78bd9687..9d5b7c21c 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -39,7 +39,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) + rValLen = db.Statement.ReflectValue.Len() + objs = make([]reflect.Value, 0, rValLen) fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) @@ -49,21 +50,20 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() != reflect.Struct { + break + } - if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) - } + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) } - } else { - break } } diff --git a/callbacks/create.go b/callbacks/create.go index 36e165a01..df7743491 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -200,15 +200,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) - values.Values = make([][]interface{}, stmt.ReflectValue.Len()) - defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} - if stmt.ReflectValue.Len() == 0 { + rValLen := stmt.ReflectValue.Len() + stmt.SQL.Grow(rValLen * 18) + values.Values = make([][]interface{}, rValLen) + if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } - for i := 0; i < stmt.ReflectValue.Len(); i++ { + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) if !rv.IsValid() { stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) @@ -234,11 +235,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { - defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) + defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } - defaultValueFieldsHavingValue[field][i] = v + defaultValueFieldsHavingValue[field][i] = rvOfvalue } } } @@ -274,9 +275,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - values.Values[0] = append(values.Values[0], v) + values.Values[0] = append(values.Values[0], rvOfvalue) } } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 087375051..525c01456 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -156,16 +156,19 @@ func Delete(config *Config) func(db *gorm.DB) { } if !db.DryRun && db.Error == nil { - if ok, mode := hasReturning(db, supportReturning); ok { - if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { - gorm.Scan(rows, db, mode) - rows.Close() - } - } else { + ok, mode := hasReturning(db, supportReturning) + if !ok { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() } + + return + } + + if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { + gorm.Scan(rows, db, mode) + rows.Close() } } } diff --git a/callbacks/preload.go b/callbacks/preload.go index c887c6c0a..41405a22a 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -61,12 +61,13 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues := make([]interface{}, len(joinForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { + joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + fieldValues[idx], _ = field.ValueOf(joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { diff --git a/callbacks/raw.go b/callbacks/raw.go index d594ab391..013e638cb 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -9,8 +9,9 @@ func RawExec(db *gorm.DB) { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) - } else { - db.RowsAffected, _ = result.RowsAffected() + return } + + db.RowsAffected, _ = result.RowsAffected() } } diff --git a/callbacks/row.go b/callbacks/row.go index 407c32d71..56be742e8 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -7,16 +7,17 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) + if db.DryRun { + return + } - if !db.DryRun { - if isRows, ok := db.Get("rows"); ok && isRows.(bool) { - db.Statement.Settings.Delete("rows") - db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } else { - db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } - - db.RowsAffected = -1 + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } + + db.RowsAffected = -1 } } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index f116d19f0..50887ccce 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -20,11 +20,12 @@ func BeginTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) { if !db.Config.SkipDefaultTransaction { if _, ok := db.InstanceGet("gorm:started_transaction"); ok { - if db.Error == nil { - db.Commit() - } else { + if db.Error != nil { db.Rollback() + } else { + db.Commit() } + db.Statement.ConnPool = db.ConnPool } } diff --git a/callbacks/update.go b/callbacks/update.go index 8efc3983a..1f4960b54 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -157,7 +157,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: if size := stmt.ReflectValue.Len(); size > 0 { var primaryKeyExprs []clause.Expression - for i := 0; i < stmt.ReflectValue.Len(); i++ { + for i := 0; i < size; i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { diff --git a/migrator/migrator.go b/migrator/migrator.go index 95a708deb..af1385e24 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -156,7 +156,9 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } func (m Migrator) GetTables() (tableList []string, err error) { - return tableList, m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).Scan(&tableList).Error + err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). + Scan(&tableList).Error + return } func (m Migrator) CreateTable(values ...interface{}) error { diff --git a/scan.go b/scan.go index 2d0c8fc6b..b931aff4c 100644 --- a/scan.go +++ b/scan.go @@ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re type ScanMode uint8 const ( - ScanInitialized ScanMode = 1 << 0 - ScanUpdate = 1 << 1 - ScanOnConflictDoNothing = 1 << 2 + ScanInitialized ScanMode = 1 << 0 // 1 + ScanUpdate ScanMode = 1 << 1 // 2 + ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) func Scan(rows *sql.Rows, db *DB, mode ScanMode) { From e1b4c066a8bd3f8bca8d2f6fa141927776fca028 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 11:02:32 +0800 Subject: [PATCH 49/83] Fix FullSaveAssociations, close #4874 --- callbacks/create.go | 3 +++ clause/set.go | 4 ++-- tests/associations_test.go | 28 ++++++++++++++++++++++++++++ tests/go.mod | 1 + tests/migrate_test.go | 2 +- utils/tests/models.go | 12 ++++++++++++ 6 files changed, 47 insertions(+), 3 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index df7743491..c585fbe9f 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -317,6 +317,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) + if len(onConflict.DoUpdates) == 0 { + onConflict.DoNothing = true + } // use primary fields as default OnConflict columns if len(onConflict.Columns) == 0 { diff --git a/clause/set.go b/clause/set.go index 6a885711a..75eb6bdda 100644 --- a/clause/set.go +++ b/clause/set.go @@ -24,9 +24,9 @@ func (set Set) Build(builder Builder) { builder.AddVar(builder, assignment.Value) } } else { - builder.WriteQuoted(PrimaryColumn) + builder.WriteQuoted(Column{Name: PrimaryKey}) builder.WriteByte('=') - builder.WriteQuoted(PrimaryColumn) + builder.WriteQuoted(Column{Name: PrimaryKey}) } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 3b2706257..a8d478867 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -176,3 +176,31 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { t.Fatalf("Should not find deleted profile") } } + +func TestFullSaveAssociations(t *testing.T) { + err := DB. + Session(&gorm.Session{FullSaveAssociations: true}). + Create(&Coupon{ + ID: "full-save-association-coupon1", + AppliesToProduct: []*CouponProduct{ + { + CouponId: "full-save-association-coupon1", + ProductId: "full-save-association-product1", + }, + }, + AmountOff: 10, + PercentOff: 0.0, + }).Error + + if err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if DB.First(&Coupon{}, "id = ?", "full-save-association-coupon1").Error != nil { + t.Errorf("Failed to query saved coupon") + } + + if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", "full-save-association-coupon1", "full-save-association-product1").Error != nil { + t.Errorf("Failed to query saved association") + } +} diff --git a/tests/go.mod b/tests/go.mod index 7e5ea8a5c..36c7310cb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/google/uuid v1.3.0 + github.com/jackc/pgtype v1.9.1 // indirect github.com/jackc/pgx/v4 v4.14.0 // indirect github.com/jinzhu/now v1.1.3 github.com/lib/pq v1.10.4 diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 789a5e451..5cdf8e74c 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -11,7 +11,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") diff --git a/utils/tests/models.go b/utils/tests/models.go index 8e833c932..5eee84680 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -60,3 +60,15 @@ type Language struct { Code string `gorm:"primarykey"` Name string } + +type Coupon struct { + ID string `gorm:"primarykey; size:255"` + AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` + AmountOff uint32 `gorm:"amount_off"` + PercentOff float32 `gorm:"percent_off"` +} + +type CouponProduct struct { + CouponId string `gorm:"primarykey; size:255"` + ProductId string `gorm:"primarykey; size:255"` +} From 270e38c518260be891e227bbfd7521728aeb5309 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 14:23:10 +0800 Subject: [PATCH 50/83] Fix duplicated error when Scan, close #4525 --- finisher_api.go | 4 +--- scan.go | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 920ea739c..633a7fa06 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -454,9 +454,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() tx.Config = &config - if rows, err := tx.Rows(); err != nil { - tx.AddError(err) - } else { + if rows, err := tx.Rows(); err == nil { defer rows.Close() if rows.Next() { tx.ScanRows(rows, dest) diff --git a/scan.go b/scan.go index b931aff4c..b03b79b45 100644 --- a/scan.go +++ b/scan.go @@ -102,9 +102,9 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re type ScanMode uint8 const ( - ScanInitialized ScanMode = 1 << 0 // 1 - ScanUpdate ScanMode = 1 << 1 // 2 - ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 + ScanInitialized ScanMode = 1 << 0 // 1 + ScanUpdate ScanMode = 1 << 1 // 2 + ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) func Scan(rows *sql.Rows, db *DB, mode ScanMode) { From 92d5a959a02c64a12f017b205a997d515e459749 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 15:16:57 +0800 Subject: [PATCH 51/83] Fix tests --- tests/go.mod | 3 +-- tests/migrate_test.go | 2 +- tests/tests_test.go | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 36c7310cb..4fddb6628 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,8 +4,7 @@ go 1.14 require ( github.com/google/uuid v1.3.0 - github.com/jackc/pgtype v1.9.1 // indirect - github.com/jackc/pgx/v4 v4.14.0 // indirect + github.com/jackc/pgx/v4 v4.14.1 // indirect github.com/jinzhu/now v1.1.3 github.com/lib/pq v1.10.4 golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 // indirect diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5cdf8e74c..789a5e451 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -11,7 +11,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") diff --git a/tests/tests_test.go b/tests/tests_test.go index cb73d267f..5799662fe 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -87,7 +87,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) From 45e804dd3fa3ca11fc3db0945fc3c4b93e8b7e66 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 16:19:06 +0800 Subject: [PATCH 52/83] Fix call valuer interface when using nil value --- clause/expression.go | 2 +- statement.go | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index d04983068..dde00b1d7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -368,7 +368,7 @@ func (like Like) NegationBuild(builder Builder) { } func eqNil(value interface{}) bool { - if valuer, ok := value.(driver.Valuer); ok { + if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) { value, _ = valuer.Value() } diff --git a/statement.go b/statement.go index 453e485e6..5a948d3f1 100644 --- a/statement.go +++ b/statement.go @@ -173,7 +173,12 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case Valuer: - stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + reflectValue := reflect.ValueOf(v) + if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { + stmt.AddVar(writer, nil) + } else { + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + } case clause.Expr: v.Build(stmt) case *clause.Expr: From 27e2753c9dfbb7c4330ea14d5ff04fd672d341be Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 18:34:50 +0800 Subject: [PATCH 53/83] Fix create duplicated value when updating nested has many relationship, close #4796 --- callbacks/associations.go | 21 +++++++++++++++++---- tests/associations_test.go | 29 ++++++++++++++++++----------- tests/multi_primary_keys_test.go | 2 +- tests/tests_test.go | 2 +- utils/tests/models.go | 7 +++++++ 5 files changed, 44 insertions(+), 17 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 9d5b7c21c..38f21218a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func SaveBeforeAssociations(create bool) func(db *gorm.DB) { @@ -182,6 +183,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(v)) @@ -197,10 +199,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues) + if len(relPrimaryValues) == 0 || (len(relPrimaryValues) == len(rel.FieldSchema.PrimaryFields) && !identityMap[cacheKey]) { + identityMap[cacheKey] = true + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index a8d478867..a4b1f1f28 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -178,19 +178,21 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { } func TestFullSaveAssociations(t *testing.T) { + coupon := &Coupon{ + ID: "full-save-association-coupon1", + AppliesToProduct: []*CouponProduct{ + { + CouponId: "full-save-association-coupon1", + ProductId: "full-save-association-product1", + }, + }, + AmountOff: 10, + PercentOff: 0.0, + } + err := DB. Session(&gorm.Session{FullSaveAssociations: true}). - Create(&Coupon{ - ID: "full-save-association-coupon1", - AppliesToProduct: []*CouponProduct{ - { - CouponId: "full-save-association-coupon1", - ProductId: "full-save-association-product1", - }, - }, - AmountOff: 10, - PercentOff: 0.0, - }).Error + Create(coupon).Error if err != nil { t.Errorf("Failed, got error: %v", err) @@ -203,4 +205,9 @@ func TestFullSaveAssociations(t *testing.T) { if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", "full-save-association-coupon1", "full-save-association-product1").Error != nil { t.Errorf("Failed to query saved association") } + + orders := []Order{{Num: "order1", Coupon: coupon}, {Num: "order2", Coupon: coupon}} + if err := DB.Create(&orders).Error; err != nil { + t.Errorf("failed to create orders, got %v", err) + } } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index dcc90cd9a..3a8c08aa8 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -427,7 +427,7 @@ func TestCompositePrimaryKeysAssociations(t *testing.T) { DB.Migrator().DropTable(&Label{}, &Book{}) if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { - t.Fatalf("failed to migrate") + t.Fatalf("failed to migrate, got %v", err) } book := Book{ diff --git a/tests/tests_test.go b/tests/tests_test.go index 5799662fe..d1f19df30 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -87,7 +87,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index 5eee84680..337682d61 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -72,3 +72,10 @@ type CouponProduct struct { CouponId string `gorm:"primarykey; size:255"` ProductId string `gorm:"primarykey; size:255"` } + +type Order struct { + gorm.Model + Num string + Coupon *Coupon + CouponID string +} From d8a710cba23367a0e9adbaaf751c60041a1f7df6 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Mon, 29 Nov 2021 20:14:23 +0800 Subject: [PATCH 54/83] fix: count() when use group by and only find one record (#4885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 李龙 --- finisher_api.go | 4 +++- tests/count_test.go | 11 +++++++++++ tests/go.mod | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 633a7fa06..b3bdedc80 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -419,9 +419,11 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if tx.RowsAffected != 1 { + + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { *count = tx.RowsAffected } + return } diff --git a/tests/count_test.go b/tests/count_test.go index de06d0eb7..7cae890b5 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -134,4 +134,15 @@ func TestCount(t *testing.T) { t.Fatalf("Count should be 3, but got count: %v err %v", count10, err) } + var count11 int64 + sameUsers := make([]*User, 0) + for i := 0; i < 3; i++ { + sameUsers = append(sameUsers, GetUser("count-4", Config{})) + } + DB.Create(sameUsers) + + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) + } + } diff --git a/tests/go.mod b/tests/go.mod index 4fddb6628..6315c7f1a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/lib/pq v1.10.4 golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 // indirect gorm.io/driver/mysql v1.2.0 - gorm.io/driver/postgres v1.2.2 + gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.2.1 gorm.io/gorm v1.22.3 From 3a3b82263a2e6a3d19c2d669ce9d299b76c47f65 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 20:24:04 +0800 Subject: [PATCH 55/83] Fix auto migration always alert table, close #4198 --- migrator/migrator.go | 4 ++-- tests/migrate_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index af1385e24..91bf60a7a 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -390,7 +390,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn := false // check size - if length, _ := columnType.Length(); length != int64(field.Size) { + if length, ok := columnType.Length(); length != int64(field.Size) { if length > 0 && field.Size > 0 { alterColumn = true } else { @@ -399,7 +399,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { alterColumn = true } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 789a5e451..3d15bf2c7 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -90,7 +90,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) { } func TestSmartMigrateColumn(t *testing.T) { - fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] + fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()] type UserMigrateColumn struct { ID uint From 8627634959401e4126d12a6d18f3aa8249a036ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 Dec 2021 10:20:16 +0800 Subject: [PATCH 56/83] Fix create associations with zero primary key, close #4890 --- callbacks/associations.go | 2 +- tests/associations_test.go | 24 +++++++++++++++++------- utils/tests/models.go | 7 ++++--- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 38f21218a..75bd6c6a1 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -207,7 +207,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } cacheKey := utils.ToStringKey(relPrimaryValues) - if len(relPrimaryValues) == 0 || (len(relPrimaryValues) == len(rel.FieldSchema.PrimaryFields) && !identityMap[cacheKey]) { + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { identityMap[cacheKey] = true if isPtr { elems = reflect.Append(elems, elem) diff --git a/tests/associations_test.go b/tests/associations_test.go index a4b1f1f28..f88d1523e 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -179,12 +179,8 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { func TestFullSaveAssociations(t *testing.T) { coupon := &Coupon{ - ID: "full-save-association-coupon1", AppliesToProduct: []*CouponProduct{ - { - CouponId: "full-save-association-coupon1", - ProductId: "full-save-association-product1", - }, + {ProductId: "full-save-association-product1"}, }, AmountOff: 10, PercentOff: 0.0, @@ -198,11 +194,11 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed, got error: %v", err) } - if DB.First(&Coupon{}, "id = ?", "full-save-association-coupon1").Error != nil { + if DB.First(&Coupon{}, "id = ?", coupon.ID).Error != nil { t.Errorf("Failed to query saved coupon") } - if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", "full-save-association-coupon1", "full-save-association-product1").Error != nil { + if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", coupon.ID, "full-save-association-product1").Error != nil { t.Errorf("Failed to query saved association") } @@ -210,4 +206,18 @@ func TestFullSaveAssociations(t *testing.T) { if err := DB.Create(&orders).Error; err != nil { t.Errorf("failed to create orders, got %v", err) } + + coupon2 := Coupon{ + AppliesToProduct: []*CouponProduct{{Desc: "coupon-description"}}, + } + + DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&coupon2) + var result Coupon + if err := DB.Preload("AppliesToProduct").First(&result, "id = ?", coupon2.ID).Error; err != nil { + t.Errorf("Failed to create coupon w/o name, got error: %v", err) + } + + if len(result.AppliesToProduct) != 1 { + t.Errorf("Failed to preload AppliesToProduct") + } } diff --git a/utils/tests/models.go b/utils/tests/models.go index 337682d61..c84f9cae9 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -62,15 +62,16 @@ type Language struct { } type Coupon struct { - ID string `gorm:"primarykey; size:255"` + ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` AmountOff uint32 `gorm:"amount_off"` PercentOff float32 `gorm:"percent_off"` } type CouponProduct struct { - CouponId string `gorm:"primarykey; size:255"` - ProductId string `gorm:"primarykey; size:255"` + CouponId int `gorm:"primarykey;size:255"` + ProductId string `gorm:"primarykey;size:255"` + Desc string } type Order struct { From 300a23fc3137b947a3ce9bca97fa5c81cc605636 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 Dec 2021 10:39:24 +0800 Subject: [PATCH 57/83] Check rows.Close error, close #4891 --- callbacks/create.go | 2 +- callbacks/delete.go | 2 +- callbacks/query.go | 3 +-- callbacks/update.go | 2 +- finisher_api.go | 2 +- migrator/migrator.go | 8 +++++--- tests/associations_belongs_to_test.go | 7 +++++++ 7 files changed, 17 insertions(+), 9 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index c585fbe9f..9dc5b8b1a 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -83,7 +83,7 @@ func Create(config *Config) func(db *gorm.DB) { ) if db.AddError(err) == nil { gorm.Scan(rows, db, mode) - rows.Close() + db.AddError(rows.Close()) } return diff --git a/callbacks/delete.go b/callbacks/delete.go index 525c01456..b05a9d08f 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -168,7 +168,7 @@ func Delete(config *Config) func(db *gorm.DB) { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) - rows.Close() + db.AddError(rows.Close()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 6ca3a1fb0..2f98a4b6d 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -20,9 +20,8 @@ func Query(db *gorm.DB) { db.AddError(err) return } - defer rows.Close() - gorm.Scan(rows, db, 0) + db.AddError(rows.Close()) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 1f4960b54..fa7640de0 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -88,7 +88,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() gorm.Scan(rows, db, mode) db.Statement.Dest = dest - rows.Close() + db.AddError(rows.Close()) } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index b3bdedc80..d38d60b7e 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -457,12 +457,12 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.Config = &config if rows, err := tx.Rows(); err == nil { - defer rows.Close() if rows.Next() { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 } + tx.AddError(rows.Close()) } currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { diff --git a/migrator/migrator.go b/migrator/migrator.go index 91bf60a7a..18212dbb3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -430,13 +430,15 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) - execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error { + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err != nil { return err } - defer rows.Close() + defer func() { + err = rows.Close() + }() var rawColumnTypes []*sql.ColumnType rawColumnTypes, err = rows.ColumnTypes() @@ -448,7 +450,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes = append(columnTypes, c) } - return nil + return }) return columnTypes, execErr diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 3e4de7260..e37da7d3b 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -132,6 +132,13 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Company", 0, "after clear") AssertAssociationCount(t, user2, "Manager", 0, "after clear") + + // unexist company id + unexistCompanyID := company.ID + 9999999 + user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} + if err := DB.Create(&user).Error; err == nil { + t.Errorf("should have gotten foreign key violation error") + } } func TestBelongsToAssociationForSlice(t *testing.T) { From e5bdd610c36b0e65c957c53f8a4ffb0f11714615 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Wed, 8 Dec 2021 13:58:06 +0800 Subject: [PATCH 58/83] fix: save not use soft_delete (#4897) * fix: Save not use soft_delete * fix: save not use soft_delete * fix: save not use soft_delete * fix: save not use soft_delete Co-authored-by: kinggo <> --- callbacks/create.go | 2 +- callbacks/delete.go | 17 ++++++++++------- callbacks/query.go | 2 +- callbacks/update.go | 18 +++++++++++------- soft_delete.go | 4 ++-- tests/update_test.go | 8 +++++++- 6 files changed, 32 insertions(+), 19 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 9dc5b8b1a..291131283 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -57,7 +57,7 @@ func Create(config *Config) func(db *gorm.DB) { } } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) diff --git a/callbacks/delete.go b/callbacks/delete.go index b05a9d08f..7f1e09cee 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,13 +118,7 @@ func Delete(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -147,6 +141,15 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } diff --git a/callbacks/query.go b/callbacks/query.go index 2f98a4b6d..efb08609c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -33,7 +33,7 @@ func BuildQuerySQL(db *gorm.DB) { } } - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} diff --git a/callbacks/update.go b/callbacks/update.go index fa7640de0..b3eaaf117 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,13 +59,7 @@ func Update(config *Config) func(db *gorm.DB) { return } - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.String() == "" { + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) if set := ConvertToAssignments(db.Statement); len(set) != 0 { @@ -73,6 +67,16 @@ func Update(config *Config) func(db *gorm.DB) { } else if _, ok := db.Statement.Clauses["SET"]; !ok { return } + + } + + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } diff --git a/soft_delete.go b/soft_delete.go index 11c4fafc0..4e236fc44 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -103,7 +103,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { - if stmt.SQL.String() == "" { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { SoftDeleteQueryClause(sd).ModifyStatement(stmt) } @@ -129,7 +129,7 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { - if stmt.SQL.String() == "" { + if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { curTime := stmt.DB.NowFunc() stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) stmt.SetColumn(sd.Field.DBName, curTime, true) diff --git a/tests/update_test.go b/tests/update_test.go index 14ed98207..abe520db8 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,13 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } + + dryDB = DB.Session(&gorm.Session{DryRun: true}) + stmt = dryDB.Unscoped().Save(&user).Statement + if !regexp.MustCompile(`WHERE .id. = [^ ]+$`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } From 2a578d767f01af839c2e91fdeeb3bbb4caed4ae4 Mon Sep 17 00:00:00 2001 From: Matthieu MOREL Date: Fri, 10 Dec 2021 10:44:11 +0100 Subject: [PATCH 59/83] Use Golangci configuration file (#4896) --- .github/workflows/reviewdog.yml | 2 -- .golangci.yml | 11 +++++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 .golangci.yml diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index abfd57f3e..95b6fb048 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -9,5 +9,3 @@ jobs: uses: actions/checkout@v2 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 - with: - golangci_lint_flags: '-E cyclop,unconvert,misspell,unparam,ineffassign,gocritic,prealloc,exportloopref,gosec' diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..16903ed6c --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,11 @@ +linters: + enable: + - cyclop + - exportloopref + - gocritic + - gosec + - ineffassign + - misspell + - prealloc + - unconvert + - unparam From 380cc64ff5b3f5379a076b19b23ed0ddd1638ba7 Mon Sep 17 00:00:00 2001 From: piyongcai Date: Fri, 10 Dec 2021 17:45:36 +0800 Subject: [PATCH 60/83] =?UTF-8?q?fix=20type=20alias=20AutoMigrate=20bug?= =?UTF-8?q?=EF=BC=88Add=20Test=20Case=EF=BC=89=20(#4888)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix type alias AutoMigrate bug. eg ```go package main type IDer interface{ GetID() int64 } // ID will add some method to implement some interface eg: GetID type ID int64 func (z ID) GetID() int64 { return int64(z) } type Test struct { ID Code string `gorm:"size:50"` Name string `gorm:"size:50"` } func main() { db, err := gorm.Open(postgres.New(postgres.Config{ DSN: `dsn`, PreferSimpleProtocol: false, }), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), SkipDefaultTransaction: true, }) if err != nil { log.Fatal(err) } if err = db.AutoMigrate(&Test{}); err != nil { // invalid embedded struct for Test's field ID, should be struct, but got main.ID log.Fatal(err) } } ``` * fix type alias AutoMigrate bug. eg ```go package main type IDer interface{ GetID() int64 } // ID will add some method to implement some interface eg: GetID type ID int64 func (z ID) GetID() int64 { return int64(z) } type Test struct { ID Code string `gorm:"size:50"` Name string `gorm:"size:50"` } func main() { db, err := gorm.Open(postgres.New(postgres.Config{ DSN: `dsn`, PreferSimpleProtocol: false, }), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), SkipDefaultTransaction: true, }) if err != nil { log.Fatal(err) } if err = db.AutoMigrate(&Test{}); err != nil { // invalid embedded struct for Test's field ID, should be struct, but got main.ID log.Fatal(err) } } ``` * Add typealis test. * try to fix golangci-lint --- schema/field.go | 7 +++-- schema/field_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index f3189c7a3..c6c89cc10 100644 --- a/schema/field.go +++ b/schema/field.go @@ -347,7 +347,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { - if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + kind := reflect.Indirect(fieldValue).Kind() + switch kind { + case reflect.Struct: var err error field.Creatable = false field.Updatable = false @@ -396,7 +398,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - } else { + case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } } diff --git a/schema/field_test.go b/schema/field_test.go index 4be3e5ab9..8768a4c35 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -244,7 +244,7 @@ func TestParseFieldWithPermission(t *testing.T) { t.Fatalf("Failed to parse user with permission, got error %v", err) } - fields := []schema.Field{ + fields := []*schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, @@ -257,6 +257,68 @@ func TestParseFieldWithPermission(t *testing.T) { } for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) {}) + checkSchemaField(t, user, f, func(f *schema.Field) {}) } } + +type ID int64 +type INT int +type INT8 int8 +type INT16 int16 +type INT32 int32 +type INT64 int64 +type UINT uint +type UINT8 uint8 +type UINT16 uint16 +type UINT32 uint32 +type UINT64 uint64 +type FLOAT32 float32 +type FLOAT64 float64 +type BOOL bool +type STRING string +type TypeAlias struct { + ID + INT `gorm:"column:fint"` + INT8 `gorm:"column:fint8"` + INT16 `gorm:"column:fint16"` + INT32 `gorm:"column:fint32"` + INT64 `gorm:"column:fint64"` + UINT `gorm:"column:fuint"` + UINT8 `gorm:"column:fuint8"` + UINT16 `gorm:"column:fuint16"` + UINT32 `gorm:"column:fuint32"` + UINT64 `gorm:"column:fuint64"` + FLOAT32 `gorm:"column:ffloat32"` + FLOAT64 `gorm:"column:ffloat64"` + BOOL `gorm:"column:fbool"` + STRING `gorm:"column:fstring"` +} + +func TestTypeAliasField(t *testing.T){ + alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err) + } + + fields := []*schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true }, + {Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`}, + {Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`}, + {Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`}, + {Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`}, + {Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`}, + {Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`}, + {Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`}, + {Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`}, + {Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`}, + {Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`}, + {Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float , Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`}, + {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, + {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool , Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, + {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, + } + + for _, f := range fields { + checkSchemaField(t, alias, f, func(f *schema.Field) {}) + } +} \ No newline at end of file From adf8f70f06d905ce0ba6e5fb5dc7a1f7bb07ca23 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Dec 2021 17:50:19 +0800 Subject: [PATCH 61/83] Upgrade go.mod --- go.mod | 2 +- go.sum | 4 ++-- tests/go.mod | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 75662c80e..573627455 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.3 + github.com/jinzhu/now v1.1.4 ) diff --git a/go.sum b/go.sum index c17a1ceb5..50fbba2fc 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.3 h1:PlHq1bSCSZL9K0wUhbm2pGLoTWs2GwVhsP6emvGV/ZI= -github.com/jinzhu/now v1.1.3/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/tests/go.mod b/tests/go.mod index 6315c7f1a..c3133f38d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,14 +5,14 @@ go 1.14 require ( github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.14.1 // indirect - github.com/jinzhu/now v1.1.3 + github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 // indirect - gorm.io/driver/mysql v1.2.0 + golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect + gorm.io/driver/mysql v1.2.1 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.2.1 - gorm.io/gorm v1.22.3 + gorm.io/gorm v1.22.4 ) replace gorm.io/gorm => ../ From 24026bf1fedf588357d183025f4312a77bd1f911 Mon Sep 17 00:00:00 2001 From: liweitingwt <87644000+liweitingwt@users.noreply.github.com> Date: Thu, 16 Dec 2021 10:41:34 +0800 Subject: [PATCH 62/83] modify unscoped judge (#4929) * modify unscoped judge * modify unscoped judge Co-authored-by: liweiting --- callbacks/query.go | 2 +- soft_delete.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index efb08609c..c2bbf5f91 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -27,7 +27,7 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { + if db.Statement.Schema != nil { for _, c := range db.Statement.Schema.QueryClauses { db.Statement.AddClause(c) } diff --git a/soft_delete.go b/soft_delete.go index 4e236fc44..51e4c0d7e 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -63,7 +63,7 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { - if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { if c, ok := stmt.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { for _, expr := range where.Exprs { From 2c3fc2db28dc172bec0822b2851d6b1d67869015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emre=20G=C3=BCll=C3=BC?= <54181092+emregullu@users.noreply.github.com> Date: Tue, 21 Dec 2021 14:50:00 +0300 Subject: [PATCH 63/83] Fix: Where clauses with named arguments may cause generation of unintended queries (#4937) --- clause/where.go | 3 +++ tests/named_argument_test.go | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/clause/where.go b/clause/where.go index 00b1a40e9..61aa73a87 100644 --- a/clause/where.go +++ b/clause/where.go @@ -60,6 +60,9 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case Expr: sql := strings.ToLower(v.SQL) wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + case NamedExpr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") } } diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index d0a6f915f..a3a25f7ba 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -2,6 +2,7 @@ package tests_test import ( "database/sql" + "errors" "testing" "gorm.io/gorm" @@ -66,4 +67,16 @@ func TestNamedArg(t *testing.T) { } AssertEqual(t, result6, namedUser) + + var result7 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu-new")).Where("name3 = 'jinzhu-new3'").First(&result7).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } + + DB.Delete(&namedUser) + + var result8 NamedUser + if err := DB.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu-new"}).First(&result8).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should return record not found error, but got %v", err) + } } From b9667cb747341fbab197f9ccde1ddea864099171 Mon Sep 17 00:00:00 2001 From: "liweiting.wt" Date: Tue, 28 Dec 2021 18:22:17 +0800 Subject: [PATCH 64/83] fix: fix the error handle in tests_test --- tests/tests_test.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/tests_test.go b/tests/tests_test.go index d1f19df30..e26f358df 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -25,12 +25,15 @@ func init() { os.Exit(1) } else { sqlDB, err := DB.DB() - if err == nil { - err = sqlDB.Ping() + if err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) } + err = sqlDB.Ping() if err != nil { - log.Printf("failed to connect database, got error %v", err) + log.Printf("failed to ping sqlDB, got error %v", err) + os.Exit(1) } RunMigrations() @@ -76,6 +79,10 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) } + if err != nil { + return + } + if debug := os.Getenv("DEBUG"); debug == "true" { db.Logger = db.Logger.LogMode(logger.Info) } else if debug == "false" { From 8dde09e0becd383bc24c7bd7d17e5600644667a8 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Thu, 30 Dec 2021 11:47:14 +0800 Subject: [PATCH 65/83] fix: generate sql incorrect when use soft_delete and only one OR (#4969) * fix: generate sql incorrect when use soft_delete and only one OR --- clause/where.go | 9 +++++++-- soft_delete.go | 2 +- tests/soft_delete_test.go | 10 ++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/clause/where.go b/clause/where.go index 61aa73a87..20a011362 100644 --- a/clause/where.go +++ b/clause/where.go @@ -92,9 +92,14 @@ func (where Where) MergeClause(clause *Clause) { func And(exprs ...Expression) Expression { if len(exprs) == 0 { return nil - } else if len(exprs) == 1 { - return exprs[0] } + + if len(exprs) == 1 { + if _, ok := exprs[0].(OrConditions); !ok { + return exprs[0] + } + } + return AndConditions{Exprs: exprs} } diff --git a/soft_delete.go b/soft_delete.go index 51e4c0d7e..4582161dd 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -65,7 +65,7 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { if c, ok := stmt.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { for _, expr := range where.Exprs { if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { where.Exprs = []clause.Expression{clause.And(where.Exprs...)} diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 0dfe24d5a..9ac8da10d 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -83,3 +83,13 @@ func TestDeletedAtUnMarshal(t *testing.T) { t.Errorf("Failed, result.DeletedAt: %v is not same as expected.DeletedAt: %v", result.DeletedAt, expected.DeletedAt) } } + +func TestDeletedAtOneOr(t *testing.T) { + actualSQL := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Or("id = ?", 1).Find(&User{}) + }) + + if !regexp.MustCompile(` WHERE id = 1 AND .users.\..deleted_at. IS NULL`).MatchString(actualSQL) { + t.Fatalf("invalid sql generated, got %v", actualSQL) + } +} From b47cf57f5e01a4bf742d277c54658e798f1bb5c4 Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Thu, 6 Jan 2022 15:02:53 +0800 Subject: [PATCH 66/83] ci: add gofumpt check in reviewdog (#4973) --- .github/workflows/reviewdog.yml | 11 +++ callbacks/helper.go | 6 +- callbacks/update.go | 4 +- clause/benchmarks_test.go | 3 +- clause/group_by_test.go | 6 +- clause/order_by_test.go | 3 +- clause/set_test.go | 6 +- clause/values_test.go | 3 +- clause/where_test.go | 21 ++++-- clause/with.go | 3 +- logger/sql.go | 2 +- migrator/migrator.go | 2 +- schema/callbacks_test.go | 3 +- schema/check.go | 8 +-- schema/field.go | 4 +- schema/field_test.go | 100 +++++++++++++------------- schema/index.go | 2 +- schema/model_test.go | 8 ++- schema/naming_test.go | 10 +-- schema/relationship_test.go | 3 - schema/schema_test.go | 6 +- statement.go | 4 +- tests/associations_belongs_to_test.go | 12 ++-- tests/associations_has_many_test.go | 36 +++++----- tests/associations_has_one_test.go | 18 ++--- tests/associations_many2many_test.go | 42 +++++------ tests/associations_test.go | 3 +- tests/benchmark_test.go | 8 +-- tests/count_test.go | 7 +- tests/create_test.go | 18 ++--- tests/default_value_test.go | 2 +- tests/delete_test.go | 2 +- tests/distinct_test.go | 2 +- tests/group_by_test.go | 4 +- tests/joins_test.go | 8 +-- tests/migrate_test.go | 1 - tests/multi_primary_keys_test.go | 22 +++--- tests/non_std_test.go | 2 +- tests/preload_test.go | 14 ++-- tests/query_test.go | 18 ++--- tests/scan_test.go | 2 +- tests/scanner_valuer_test.go | 4 +- tests/scopes_test.go | 2 +- tests/sql_builder_test.go | 7 +- tests/update_belongs_to_test.go | 2 +- tests/update_has_many_test.go | 4 +- tests/update_has_one_test.go | 6 +- tests/update_many2many_test.go | 2 +- tests/update_test.go | 6 +- tests/upsert_test.go | 4 +- utils/tests/dummy_dialecter.go | 3 +- 51 files changed, 244 insertions(+), 235 deletions(-) diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index 95b6fb048..b252dd7ae 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -9,3 +9,14 @@ jobs: uses: actions/checkout@v2 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 + + - name: Setup reviewdog + uses: reviewdog/action-setup@v1 + + - name: gofumpt -s with reviewdog + env: + REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + go install mvdan.cc/gofumpt@v0.2.0 + gofumpt -e -d . | \ + reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review \ No newline at end of file diff --git a/callbacks/helper.go b/callbacks/helper.go index 1d96ab26b..a59e1880f 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -12,7 +12,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) - var keys = make([]string, 0, len(mapValue)) + keys := make([]string, 0, len(mapValue)) for k := range mapValue { keys = append(keys, k) } @@ -40,9 +40,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter // ConvertSliceOfMapToValuesForCreate convert slice of map to values func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { - var ( - columns = make([]string, 0, len(mapValues)) - ) + columns := make([]string, 0, len(mapValues)) // when the length of mapValues is zero,return directly here // no need to call stmt.SelectAndOmitColumns method diff --git a/callbacks/update.go b/callbacks/update.go index b3eaaf117..511e994e7 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -162,7 +162,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if size := stmt.ReflectValue.Len(); size > 0 { var primaryKeyExprs []clause.Expression for i := 0; i < size; i++ { - var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) @@ -242,7 +242,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: - var updatingSchema = stmt.Schema + updatingSchema := stmt.Schema if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { // different schema updatingStmt := &gorm.Statement{DB: stmt.DB} diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 88a238e36..e08677ac0 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -32,7 +32,8 @@ func BenchmarkComplexSelect(b *testing.B) { for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ - clause.Select{}, clause.From{}, + clause.Select{}, + clause.From{}, clause.Where{Exprs: []clause.Expression{ clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, diff --git a/clause/group_by_test.go b/clause/group_by_test.go index 589f96130..7c282cb96 100644 --- a/clause/group_by_test.go +++ b/clause/group_by_test.go @@ -18,7 +18,8 @@ func TestGroupBy(t *testing.T) { Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}, }}, - "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, + "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", + []interface{}{"admin"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ @@ -28,7 +29,8 @@ func TestGroupBy(t *testing.T) { Columns: []clause.Column{{Name: "gender"}}, Having: []clause.Expression{clause.Neq{"gender", "U"}}, }}, - "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, + "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", + []interface{}{"admin", "U"}, }, } diff --git a/clause/order_by_test.go b/clause/order_by_test.go index 8fd1e2a86..d8b5dfbf6 100644 --- a/clause/order_by_test.go +++ b/clause/order_by_test.go @@ -45,7 +45,8 @@ func TestOrderBy(t *testing.T) { Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, }, }, - "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3}, + "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", + []interface{}{1, 2, 3}, }, } diff --git a/clause/set_test.go b/clause/set_test.go index 56fac7060..7a9ee895a 100644 --- a/clause/set_test.go +++ b/clause/set_test.go @@ -20,7 +20,8 @@ func TestSet(t *testing.T) { clause.Update{}, clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), }, - "UPDATE `users` SET `users`.`id`=?", []interface{}{1}, + "UPDATE `users` SET `users`.`id`=?", + []interface{}{1}, }, { []clause.Interface{ @@ -28,7 +29,8 @@ func TestSet(t *testing.T) { clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), }, - "UPDATE `users` SET `name`=?", []interface{}{"jinzhu"}, + "UPDATE `users` SET `name`=?", + []interface{}{"jinzhu"}, }, } diff --git a/clause/values_test.go b/clause/values_test.go index 9c02c8a54..1eea8652c 100644 --- a/clause/values_test.go +++ b/clause/values_test.go @@ -21,7 +21,8 @@ func TestValues(t *testing.T) { Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, }, }, - "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, + "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", + []interface{}{"jinzhu", 18, "josh", 1}, }, } diff --git a/clause/where_test.go b/clause/where_test.go index 2fa11d76e..272c7b76b 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -17,25 +17,29 @@ func TestWhere(t *testing.T) { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", + []interface{}{"1", 18, "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", + []interface{}{"1", "jinzhu", 18}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", + []interface{}{"1", "jinzhu", 18}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", + []interface{}{"1", "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ @@ -43,7 +47,8 @@ func TestWhere(t *testing.T) { }, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, }}, - "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", + []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ @@ -51,13 +56,15 @@ func TestWhere(t *testing.T) { }, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, }}, - "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", + []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, }}, - "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, + "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", + []interface{}{18, "jinzhu"}, }, } diff --git a/clause/with.go b/clause/with.go index 7e9eaef17..0768488e5 100644 --- a/clause/with.go +++ b/clause/with.go @@ -1,4 +1,3 @@ package clause -type With struct { -} +type With struct{} diff --git a/logger/sql.go b/logger/sql.go index 3d31d23c6..5ecb0ae23 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -32,7 +32,7 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) - var vars = make([]string, len(avars)) + vars := make([]string, len(avars)) convertParams = func(v interface{}, idx int) { switch v := v.(type) { diff --git a/migrator/migrator.go b/migrator/migrator.go index 18212dbb3..2be15a7d1 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -541,7 +541,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { } if constraint != nil { - var vars = []interface{}{clause.Table{Name: table}} + vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { vars[0] = stmt.TableExpr } diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index dec41eba0..4583a2072 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -9,8 +9,7 @@ import ( "gorm.io/gorm/schema" ) -type UserWithCallback struct { -} +type UserWithCallback struct{} func (UserWithCallback) BeforeSave(*gorm.DB) error { return nil diff --git a/schema/check.go b/schema/check.go index 161a6ac6d..89e732d36 100644 --- a/schema/check.go +++ b/schema/check.go @@ -5,10 +5,8 @@ import ( "strings" ) -var ( - // reg match english letters and midline - regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") -) +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") type Check struct { Name string @@ -18,7 +16,7 @@ type Check struct { // ParseCheckConstraints parse schema check constraints func (schema *Schema) ParseCheckConstraints() map[string]Check { - var checks = map[string]Check{} + checks := map[string]Check{} for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") diff --git a/schema/field.go b/schema/field.go index c6c89cc10..d4f879c57 100644 --- a/schema/field.go +++ b/schema/field.go @@ -398,8 +398,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, - reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: + case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, + reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } } diff --git a/schema/field_test.go b/schema/field_test.go index 8768a4c35..2cf2d0838 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -261,64 +261,66 @@ func TestParseFieldWithPermission(t *testing.T) { } } -type ID int64 -type INT int -type INT8 int8 -type INT16 int16 -type INT32 int32 -type INT64 int64 -type UINT uint -type UINT8 uint8 -type UINT16 uint16 -type UINT32 uint32 -type UINT64 uint64 -type FLOAT32 float32 -type FLOAT64 float64 -type BOOL bool -type STRING string -type TypeAlias struct { - ID - INT `gorm:"column:fint"` - INT8 `gorm:"column:fint8"` - INT16 `gorm:"column:fint16"` - INT32 `gorm:"column:fint32"` - INT64 `gorm:"column:fint64"` - UINT `gorm:"column:fuint"` - UINT8 `gorm:"column:fuint8"` - UINT16 `gorm:"column:fuint16"` - UINT32 `gorm:"column:fuint32"` - UINT64 `gorm:"column:fuint64"` - FLOAT32 `gorm:"column:ffloat32"` - FLOAT64 `gorm:"column:ffloat64"` - BOOL `gorm:"column:fbool"` - STRING `gorm:"column:fstring"` -} +type ( + ID int64 + INT int + INT8 int8 + INT16 int16 + INT32 int32 + INT64 int64 + UINT uint + UINT8 uint8 + UINT16 uint16 + UINT32 uint32 + UINT64 uint64 + FLOAT32 float32 + FLOAT64 float64 + BOOL bool + STRING string + TypeAlias struct { + ID + INT `gorm:"column:fint"` + INT8 `gorm:"column:fint8"` + INT16 `gorm:"column:fint16"` + INT32 `gorm:"column:fint32"` + INT64 `gorm:"column:fint64"` + UINT `gorm:"column:fuint"` + UINT8 `gorm:"column:fuint8"` + UINT16 `gorm:"column:fuint16"` + UINT32 `gorm:"column:fuint32"` + UINT64 `gorm:"column:fuint64"` + FLOAT32 `gorm:"column:ffloat32"` + FLOAT64 `gorm:"column:ffloat64"` + BOOL `gorm:"column:fbool"` + STRING `gorm:"column:fstring"` + } +) -func TestTypeAliasField(t *testing.T){ +func TestTypeAliasField(t *testing.T) { alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err) } fields := []*schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true }, - {Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`}, - {Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`}, - {Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`}, - {Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`}, - {Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`}, - {Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`}, - {Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`}, - {Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`}, - {Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`}, - {Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`}, - {Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float , Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`}, - {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float , Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, - {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool , Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, - {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true}, + {Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`}, + {Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`}, + {Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`}, + {Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`}, + {Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`}, + {Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`}, + {Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`}, + {Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`}, + {Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`}, + {Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`}, + {Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`}, + {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, + {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, + {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, } for _, f := range fields { checkSchemaField(t, alias, f, func(f *schema.Field) {}) } -} \ No newline at end of file +} diff --git a/schema/index.go b/schema/index.go index b54e08ad2..5f775f30f 100644 --- a/schema/index.go +++ b/schema/index.go @@ -27,7 +27,7 @@ type IndexOption struct { // ParseIndexes parse schema indexes func (schema *Schema) ParseIndexes() map[string]Index { - var indexes = map[string]Index{} + indexes := map[string]Index{} for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { diff --git a/schema/model_test.go b/schema/model_test.go index 1f2b09481..9e6c3590f 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -26,9 +26,11 @@ type User struct { Active *bool } -type mytime time.Time -type myint int -type mybool = bool +type ( + mytime time.Time + myint int + mybool = bool +) type AdvancedDataTypeUser struct { ID sql.NullInt64 diff --git a/schema/naming_test.go b/schema/naming_test.go index 6add338e6..c3e6bf923 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -6,7 +6,7 @@ import ( ) func TestToDBName(t *testing.T) { - var maps = map[string]string{ + maps := map[string]string{ "": "", "x": "x", "X": "x", @@ -56,7 +56,7 @@ func TestToDBName(t *testing.T) { } func TestNamingStrategy(t *testing.T) { - var ns = NamingStrategy{ + ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, NameReplacer: strings.NewReplacer("CID", "Cid"), @@ -102,7 +102,7 @@ func (r CustomReplacer) Replace(name string) string { } func TestCustomReplacer(t *testing.T) { - var ns = NamingStrategy{ + ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, NameReplacer: CustomReplacer{ @@ -146,7 +146,7 @@ func TestCustomReplacer(t *testing.T) { } func TestCustomReplacerWithNoLowerCase(t *testing.T) { - var ns = NamingStrategy{ + ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, NameReplacer: CustomReplacer{ @@ -190,7 +190,7 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { } func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { - var ns = NamingStrategy{} + ns := NamingStrategy{} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index afa103b3d..e2cf11a91 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -105,7 +105,6 @@ func TestSelfReferentialBelongsTo(t *testing.T) { Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", References: []Reference{{"ID", "User", "CreatorID", "User", "", false}}, }) - } func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { @@ -160,7 +159,6 @@ func TestHasOneOverrideReferences(t *testing.T) { } func TestHasOneOverrideReferences2(t *testing.T) { - type Profile struct { gorm.Model Name string @@ -518,7 +516,6 @@ func TestSameForeignKey(t *testing.T) { } func TestBelongsToSameForeignKey(t *testing.T) { - type User struct { gorm.Model Name string diff --git a/schema/schema_test.go b/schema/schema_test.go index a426cd905..8a752fb7b 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -145,8 +145,7 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { } } -type CustomizeTable struct { -} +type CustomizeTable struct{} func (CustomizeTable) TableName() string { return "customize" @@ -165,7 +164,6 @@ func TestCustomizeTableName(t *testing.T) { func TestNestedModel(t *testing.T) { versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) - if err != nil { t.Fatalf("failed to parse nested user, got error %v", err) } @@ -204,7 +202,6 @@ func TestEmbeddedStruct(t *testing.T) { } cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) - if err != nil { t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) } @@ -273,7 +270,6 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { } cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) - if err != nil { t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) } diff --git a/statement.go b/statement.go index 5a948d3f1..f69339d4f 100644 --- a/statement.go +++ b/statement.go @@ -328,7 +328,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds = append(conds, clause.Eq{Column: i, Value: j}) } case map[string]string: - var keys = make([]string, 0, len(v)) + keys := make([]string, 0, len(v)) for i := range v { keys = append(keys, i) } @@ -338,7 +338,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } case map[string]interface{}: - var keys = make([]string, 0, len(v)) + keys := make([]string, 0, len(v)) for i := range v { keys = append(keys, i) } diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index e37da7d3b..f74799cea 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -7,7 +7,7 @@ import ( ) func TestBelongsToAssociation(t *testing.T) { - var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) + user := *GetUser("belongs-to", Config{Company: true, Manager: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -31,8 +31,8 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user, "Manager", 1, "") // Append - var company = Company{Name: "company-belongs-to-append"} - var manager = GetUser("manager-belongs-to-append", Config{}) + company := Company{Name: "company-belongs-to-append"} + manager := GetUser("manager-belongs-to-append", Config{}) if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { t.Fatalf("Error happened when append Company, got %v", err) @@ -60,8 +60,8 @@ func TestBelongsToAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") // Replace - var company2 = Company{Name: "company-belongs-to-replace"} - var manager2 = GetUser("manager-belongs-to-replace", Config{}) + company2 := Company{Name: "company-belongs-to-replace"} + manager2 := GetUser("manager-belongs-to-replace", Config{}) if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { t.Fatalf("Error happened when replace Company, got %v", err) @@ -142,7 +142,7 @@ func TestBelongsToAssociation(t *testing.T) { } func TestBelongsToAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 173e92319..002ae6364 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -7,7 +7,7 @@ import ( ) func TestHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Pets: 2}) + user := *GetUser("hasmany", Config{Pets: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -42,7 +42,7 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 2, "") // Append - var pet = Pet{Name: "pet-has-many-append"} + pet := Pet{Name: "pet-has-many-append"} if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -57,14 +57,14 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + pets2 := []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } for _, pet := range pets2 { - var pet = pet + pet := pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") } @@ -77,7 +77,7 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") // Replace - var pet2 = Pet{Name: "pet-has-many-replace"} + pet2 := Pet{Name: "pet-has-many-replace"} if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) @@ -119,7 +119,7 @@ func TestHasManyAssociation(t *testing.T) { } func TestSingleTableHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Team: 2}) + user := *GetUser("hasmany", Config{Team: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -137,7 +137,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Team", 2, "") // Append - var team = *GetUser("team", Config{}) + team := *GetUser("team", Config{}) if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -152,14 +152,14 @@ func TestSingleTableHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Team", 3, "AfterAppend") - var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} + teams := []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { t.Fatalf("Error happened when append team, got %v", err) } for _, team := range teams { - var team = team + team := team if team.ID == 0 { t.Fatalf("Team's ID should be created") } @@ -172,7 +172,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") // Replace - var team2 = *GetUser("team-replace", Config{}) + team2 := *GetUser("team-replace", Config{}) if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { t.Fatalf("Error happened when append team, got %v", err) @@ -214,7 +214,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) { } func TestHasManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasmany-1", Config{Pets: 2}), *GetUser("slice-hasmany-2", Config{Pets: 0}), *GetUser("slice-hasmany-3", Config{Pets: 4}), @@ -268,7 +268,7 @@ func TestHasManyAssociationForSlice(t *testing.T) { } func TestSingleTableHasManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasmany-1", Config{Team: 2}), *GetUser("slice-hasmany-2", Config{Team: 0}), *GetUser("slice-hasmany-3", Config{Team: 4}), @@ -324,7 +324,7 @@ func TestSingleTableHasManyAssociationForSlice(t *testing.T) { } func TestPolymorphicHasManyAssociation(t *testing.T) { - var user = *GetUser("hasmany", Config{Toys: 2}) + user := *GetUser("hasmany", Config{Toys: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -342,7 +342,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Toys", 2, "") // Append - var toy = Toy{Name: "toy-has-many-append"} + toy := Toy{Name: "toy-has-many-append"} if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -357,14 +357,14 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") - var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} + toys := []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { t.Fatalf("Error happened when append toy, got %v", err) } for _, toy := range toys { - var toy = toy + toy := toy if toy.ID == 0 { t.Fatalf("Toy's ID should be created") } @@ -377,7 +377,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") // Replace - var toy2 = Toy{Name: "toy-has-many-replace"} + toy2 := Toy{Name: "toy-has-many-replace"} if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { t.Fatalf("Error happened when append toy, got %v", err) @@ -419,7 +419,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { } func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), *GetUser("slice-hasmany-2", Config{Toys: 0}), *GetUser("slice-hasmany-3", Config{Toys: 4}), diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a4fc8c4fc..a2c075090 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -7,7 +7,7 @@ import ( ) func TestHasOneAssociation(t *testing.T) { - var user = *GetUser("hasone", Config{Account: true}) + user := *GetUser("hasone", Config{Account: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -25,7 +25,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user, "Account", 1, "") // Append - var account = Account{Number: "account-has-one-append"} + account := Account{Number: "account-has-one-append"} if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -41,7 +41,7 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user, "Account", 1, "AfterAppend") // Replace - var account2 = Account{Number: "account-has-one-replace"} + account2 := Account{Number: "account-has-one-replace"} if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { t.Fatalf("Error happened when append Account, got %v", err) @@ -84,7 +84,7 @@ func TestHasOneAssociation(t *testing.T) { } func TestHasOneAssociationWithSelect(t *testing.T) { - var user = *GetUser("hasone", Config{Account: true}) + user := *GetUser("hasone", Config{Account: true}) DB.Omit("Account.Number").Create(&user) @@ -98,7 +98,7 @@ func TestHasOneAssociationWithSelect(t *testing.T) { } func TestHasOneAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-hasone-1", Config{Account: true}), *GetUser("slice-hasone-2", Config{Account: false}), *GetUser("slice-hasone-3", Config{Account: true}), @@ -139,7 +139,7 @@ func TestHasOneAssociationForSlice(t *testing.T) { } func TestPolymorphicHasOneAssociation(t *testing.T) { - var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} + pet := Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -157,7 +157,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet, "Toy", 1, "") // Append - var toy = Toy{Name: "toy-has-one-append"} + toy := Toy{Name: "toy-has-one-append"} if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { t.Fatalf("Error happened when append toy, got %v", err) @@ -173,7 +173,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") // Replace - var toy2 = Toy{Name: "toy-has-one-replace"} + toy2 := Toy{Name: "toy-has-one-replace"} if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { t.Fatalf("Error happened when append Toy, got %v", err) @@ -216,7 +216,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) { } func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { - var pets = []Pet{ + pets := []Pet{ {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, {Name: "hasone-2", Toy: Toy{}}, {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 739d1682d..28b441bd8 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -7,7 +7,7 @@ import ( ) func TestMany2ManyAssociation(t *testing.T) { - var user = *GetUser("many2many", Config{Languages: 2}) + user := *GetUser("many2many", Config{Languages: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -26,7 +26,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 2, "") // Append - var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} + language := Language{Code: "language-many2many-append", Name: "language-many2many-append"} DB.Create(&language) if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { @@ -38,7 +38,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") - var languages = []Language{ + languages := []Language{ {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, } @@ -55,7 +55,7 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") // Replace - var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} + language2 := Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} DB.Create(&language2) if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { @@ -94,7 +94,7 @@ func TestMany2ManyAssociation(t *testing.T) { } func TestMany2ManyOmitAssociations(t *testing.T) { - var user = *GetUser("many2many_omit_associations", Config{Languages: 2}) + user := *GetUser("many2many_omit_associations", Config{Languages: 2}) if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { t.Fatalf("should raise error when create users without languages reference") @@ -114,14 +114,14 @@ func TestMany2ManyOmitAssociations(t *testing.T) { t.Errorf("languages count should be %v, but got %v", 2, len(languages)) } - var newLang = Language{Code: "omitmany2many", Name: "omitmany2many"} + newLang := Language{Code: "omitmany2many", Name: "omitmany2many"} if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) } } func TestMany2ManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-many2many-1", Config{Languages: 2}), *GetUser("slice-many2many-2", Config{Languages: 0}), *GetUser("slice-many2many-3", Config{Languages: 4}), @@ -139,11 +139,11 @@ func TestMany2ManyAssociationForSlice(t *testing.T) { } // Append - var languages1 = []Language{ + languages1 := []Language{ {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, } - var languages2 = []Language{} - var languages3 = []Language{ + languages2 := []Language{} + languages3 := []Language{ {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, } @@ -191,7 +191,7 @@ func TestMany2ManyAssociationForSlice(t *testing.T) { } func TestSingleTableMany2ManyAssociation(t *testing.T) { - var user = *GetUser("many2many", Config{Friends: 2}) + user := *GetUser("many2many", Config{Friends: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -210,7 +210,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Friends", 2, "") // Append - var friend = *GetUser("friend", Config{}) + friend := *GetUser("friend", Config{}) if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { t.Fatalf("Error happened when append account, got %v", err) @@ -221,7 +221,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") - var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} + friends := []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { t.Fatalf("Error happened when append friend, got %v", err) @@ -234,7 +234,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") // Replace - var friend2 = *GetUser("friend-replace-2", Config{}) + friend2 := *GetUser("friend-replace-2", Config{}) if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { t.Fatalf("Error happened when append friend, got %v", err) @@ -272,7 +272,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) { } func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice-many2many-1", Config{Team: 2}), *GetUser("slice-many2many-2", Config{Team: 0}), *GetUser("slice-many2many-3", Config{Team: 4}), @@ -290,17 +290,17 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { } // Append - var teams1 = []User{*GetUser("friend-append-1", Config{})} - var teams2 = []User{} - var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} + teams1 := []User{*GetUser("friend-append-1", Config{})} + teams2 := []User{} + teams3 := []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) AssertAssociationCount(t, users, "Team", 9, "After Append") - var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} - var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} - var teams2_3 = GetUser("friend-replace-3-1", Config{}) + teams2_1 := []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} + teams2_2 := []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} + teams2_3 := GetUser("friend-replace-3-1", Config{}) // Replace DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) diff --git a/tests/associations_test.go b/tests/associations_test.go index f88d1523e..5ce98c7dc 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -27,7 +27,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result } func TestInvalidAssociation(t *testing.T) { - var user = *GetUser("invalid", Config{Company: true, Manager: true}) + user := *GetUser("invalid", Config{Company: true, Manager: true}) if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { t.Fatalf("should return errors for invalid association, but got nil") } @@ -189,7 +189,6 @@ func TestFullSaveAssociations(t *testing.T) { err := DB. Session(&gorm.Session{FullSaveAssociations: true}). Create(coupon).Error - if err != nil { t.Errorf("Failed, got error: %v", err) } diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go index c6ce93a26..d897a6341 100644 --- a/tests/benchmark_test.go +++ b/tests/benchmark_test.go @@ -7,7 +7,7 @@ import ( ) func BenchmarkCreate(b *testing.B) { - var user = *GetUser("bench", Config{}) + user := *GetUser("bench", Config{}) for x := 0; x < b.N; x++ { user.ID = 0 @@ -16,7 +16,7 @@ func BenchmarkCreate(b *testing.B) { } func BenchmarkFind(b *testing.B) { - var user = *GetUser("find", Config{}) + user := *GetUser("find", Config{}) DB.Create(&user) for x := 0; x < b.N; x++ { @@ -25,7 +25,7 @@ func BenchmarkFind(b *testing.B) { } func BenchmarkUpdate(b *testing.B) { - var user = *GetUser("find", Config{}) + user := *GetUser("find", Config{}) DB.Create(&user) for x := 0; x < b.N; x++ { @@ -34,7 +34,7 @@ func BenchmarkUpdate(b *testing.B) { } func BenchmarkDelete(b *testing.B) { - var user = *GetUser("find", Config{}) + user := *GetUser("find", Config{}) for x := 0; x < b.N; x++ { user.ID = 0 diff --git a/tests/count_test.go b/tests/count_test.go index 7cae890b5..27d7ee607 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -87,7 +87,7 @@ func TestCount(t *testing.T) { t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) } - expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}} + expects := []User{{Name: "main"}, {Name: "other"}, {Name: "other"}} sort.SliceStable(users, func(i, j int) bool { return strings.Compare(users[i].Name, users[j].Name) < 0 }) @@ -101,7 +101,7 @@ func TestCount(t *testing.T) { t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) } - expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} + expects = []User{{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} sort.SliceStable(users, func(i, j int) bool { return strings.Compare(users[i].Name, users[j].Name) < 0 }) @@ -115,7 +115,7 @@ func TestCount(t *testing.T) { t.Fatalf("Count should work, but got err %v", err) } - expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} + expects = []User{{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} sort.SliceStable(users, func(i, j int) bool { return strings.Compare(users[i].Name, users[j].Name) < 0 }) @@ -144,5 +144,4 @@ func TestCount(t *testing.T) { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } - } diff --git a/tests/create_test.go b/tests/create_test.go index 060f78af2..af2abdb08 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -13,7 +13,7 @@ import ( ) func TestCreate(t *testing.T) { - var user = *GetUser("create", Config{}) + user := *GetUser("create", Config{}) if results := DB.Create(&user); results.Error != nil { t.Fatalf("errors happened when create: %v", results.Error) @@ -139,7 +139,7 @@ func TestCreateFromMap(t *testing.T) { } func TestCreateWithAssociations(t *testing.T) { - var user = *GetUser("create_with_associations", Config{ + user := *GetUser("create_with_associations", Config{ Account: true, Pets: 2, Toys: 3, @@ -223,7 +223,7 @@ func TestBulkCreatePtrDataWithAssociations(t *testing.T) { func TestPolymorphicHasOne(t *testing.T) { t.Run("Struct", func(t *testing.T) { - var pet = Pet{ + pet := Pet{ Name: "PolymorphicHasOne", Toy: Toy{Name: "Toy-PolymorphicHasOne"}, } @@ -240,7 +240,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("Slice", func(t *testing.T) { - var pets = []Pet{{ + pets := []Pet{{ Name: "PolymorphicHasOne-Slice-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, }, { @@ -269,7 +269,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("SliceOfPtr", func(t *testing.T) { - var pets = []*Pet{{ + pets := []*Pet{{ Name: "PolymorphicHasOne-Slice-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, }, { @@ -290,7 +290,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("Array", func(t *testing.T) { - var pets = [...]Pet{{ + pets := [...]Pet{{ Name: "PolymorphicHasOne-Array-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, }, { @@ -311,7 +311,7 @@ func TestPolymorphicHasOne(t *testing.T) { }) t.Run("ArrayPtr", func(t *testing.T) { - var pets = [...]*Pet{{ + pets := [...]*Pet{{ Name: "PolymorphicHasOne-Array-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, }, { @@ -348,12 +348,12 @@ func TestCreateEmptyStruct(t *testing.T) { } func TestCreateEmptySlice(t *testing.T) { - var data = []User{} + data := []User{} if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { t.Errorf("no data should be created, got %v", err) } - var sliceMap = []map[string]interface{}{} + sliceMap := []map[string]interface{}{} if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { t.Errorf("no data should be created, got %v", err) } diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 14a0a9774..5e00b1546 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -23,7 +23,7 @@ func TestDefaultValue(t *testing.T) { t.Fatalf("Failed to migrate with default value, got error: %v", err) } - var harumph = Harumph{Email: "hello@gorm.io"} + harumph := Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { diff --git a/tests/delete_test.go b/tests/delete_test.go index 049b2ac46..5cb4b91e6 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -10,7 +10,7 @@ import ( ) func TestDelete(t *testing.T) { - var users = []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} + users := []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} if err := DB.Create(&users).Error; err != nil { t.Errorf("errors happened when create: %v", err) diff --git a/tests/distinct_test.go b/tests/distinct_test.go index f97738a77..8c8298ae2 100644 --- a/tests/distinct_test.go +++ b/tests/distinct_test.go @@ -9,7 +9,7 @@ import ( ) func TestDistinct(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("distinct", Config{}), *GetUser("distinct", Config{}), *GetUser("distinct", Config{}), diff --git a/tests/group_by_test.go b/tests/group_by_test.go index 96dfc5477..5335fed14 100644 --- a/tests/group_by_test.go +++ b/tests/group_by_test.go @@ -7,7 +7,7 @@ import ( ) func TestGroupBy(t *testing.T) { - var users = []User{{ + users := []User{{ Name: "groupby", Age: 10, Birthday: Now(), @@ -67,7 +67,7 @@ func TestGroupBy(t *testing.T) { t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) } - var result = struct { + result := struct { Name string Total int64 }{} diff --git a/tests/joins_test.go b/tests/joins_test.go index ca8477dc9..e276a74a1 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -57,7 +57,7 @@ func TestJoinsForSlice(t *testing.T) { } func TestJoinConds(t *testing.T) { - var user = *GetUser("joins-conds", Config{Account: true, Pets: 3}) + user := *GetUser("joins-conds", Config{Account: true, Pets: 3}) DB.Save(&user) var users1 []User @@ -111,7 +111,7 @@ func TestJoinConds(t *testing.T) { } func TestJoinOn(t *testing.T) { - var user = *GetUser("joins-on", Config{Pets: 2}) + user := *GetUser("joins-on", Config{Pets: 2}) DB.Save(&user) var user1 User @@ -168,8 +168,8 @@ func TestJoinCount(t *testing.T) { DB.Create(&user) query := DB.Model(&User{}).Joins("Company") - //Bug happens when .Count is called on a query. - //Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + // Bug happens when .Count is called on a query. + // Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. var total int64 query.Count(&total) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 3d15bf2c7..15e851934 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -174,7 +174,6 @@ func TestSmartMigrateColumn(t *testing.T) { } } } - } func TestMigrateWithColumnComment(t *testing.T) { diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 3a8c08aa8..4a7ab9f61 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -71,7 +71,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { @@ -95,8 +95,8 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { } // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog).Association("Tags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("Tags").Find(&tags2) @@ -170,7 +170,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") @@ -201,7 +201,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { t.Fatalf("Preload many2many relations") } - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + tag4 := &Tag{Locale: "ZH", Value: "tag4"} DB.Model(&blog2).Association("SharedTags").Append(tag4) DB.Model(&blog).Association("SharedTags").Find(&tags) @@ -215,8 +215,8 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("SharedTags").Find(&tags2) @@ -291,7 +291,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { DB.Create(&blog2) // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") @@ -322,7 +322,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Fatalf("Preload many2many relations") } - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + tag4 := &Tag{Locale: "ZH", Value: "tag4"} DB.Model(&blog2).Association("LocaleTags").Append(tag4) DB.Model(&blog).Association("LocaleTags").Find(&tags) @@ -336,8 +336,8 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + tag5 := &Tag{Locale: "ZH", Value: "tag5"} + tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) var tags2 []Tag diff --git a/tests/non_std_test.go b/tests/non_std_test.go index d3561b11e..8ae426911 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -8,7 +8,7 @@ import ( type Animal struct { Counter uint64 `gorm:"primary_key:yes"` Name string `gorm:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name + From string // test reserved sql keyword as field name Age *time.Time unexported string // unexported value CreatedAt time.Time diff --git a/tests/preload_test.go b/tests/preload_test.go index a3e672003..adb54ee19 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -14,7 +14,7 @@ import ( ) func TestPreloadWithAssociations(t *testing.T) { - var user = *GetUser("preload_with_associations", Config{ + user := *GetUser("preload_with_associations", Config{ Account: true, Pets: 2, Toys: 3, @@ -35,7 +35,7 @@ func TestPreloadWithAssociations(t *testing.T) { DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) - var user3 = *GetUser("preload_with_associations_new", Config{ + user3 := *GetUser("preload_with_associations_new", Config{ Account: true, Pets: 2, Toys: 3, @@ -51,7 +51,7 @@ func TestPreloadWithAssociations(t *testing.T) { } func TestNestedPreload(t *testing.T) { - var user = *GetUser("nested_preload", Config{Pets: 2}) + user := *GetUser("nested_preload", Config{Pets: 2}) for idx, pet := range user.Pets { pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} @@ -75,7 +75,7 @@ func TestNestedPreload(t *testing.T) { } func TestNestedPreloadForSlice(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice_nested_preload_1", Config{Pets: 2}), *GetUser("slice_nested_preload_2", Config{Pets: 0}), *GetUser("slice_nested_preload_3", Config{Pets: 3}), @@ -105,7 +105,7 @@ func TestNestedPreloadForSlice(t *testing.T) { } func TestPreloadWithConds(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice_nested_preload_1", Config{Account: true}), *GetUser("slice_nested_preload_2", Config{Account: false}), *GetUser("slice_nested_preload_3", Config{Account: true}), @@ -163,7 +163,7 @@ func TestPreloadWithConds(t *testing.T) { } func TestNestedPreloadWithConds(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("slice_nested_preload_1", Config{Pets: 2}), *GetUser("slice_nested_preload_2", Config{Pets: 0}), *GetUser("slice_nested_preload_3", Config{Pets: 3}), @@ -213,7 +213,7 @@ func TestNestedPreloadWithConds(t *testing.T) { } func TestPreloadEmptyData(t *testing.T) { - var user = *GetUser("user_without_associations", Config{}) + user := *GetUser("user_without_associations", Config{}) DB.Create(&user) DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) diff --git a/tests/query_test.go b/tests/query_test.go index 8a476598b..c99214b60 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -17,7 +17,7 @@ import ( ) func TestFind(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("find", Config{}), *GetUser("find", Config{}), *GetUser("find", Config{}), @@ -57,7 +57,7 @@ func TestFind(t *testing.T) { } t.Run("FirstMap", func(t *testing.T) { - var first = map[string]interface{}{} + first := map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { @@ -88,7 +88,7 @@ func TestFind(t *testing.T) { }) t.Run("FirstMapWithTable", func(t *testing.T) { - var first = map[string]interface{}{} + first := map[string]interface{}{} if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { @@ -120,7 +120,7 @@ func TestFind(t *testing.T) { }) t.Run("FirstPtrMap", func(t *testing.T) { - var first = map[string]interface{}{} + first := map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { @@ -135,7 +135,7 @@ func TestFind(t *testing.T) { }) t.Run("FirstSliceOfMap", func(t *testing.T) { - var allMap = []map[string]interface{}{} + allMap := []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query find: %v", err) } else { @@ -170,7 +170,7 @@ func TestFind(t *testing.T) { }) t.Run("FindSliceOfMapWithTable", func(t *testing.T) { - var allMap = []map[string]interface{}{} + allMap := []map[string]interface{}{} if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query find: %v", err) } else { @@ -241,7 +241,7 @@ func TestQueryWithAssociation(t *testing.T) { } func TestFindInBatches(t *testing.T) { - var users = []User{ + users := []User{ *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), @@ -297,7 +297,7 @@ func TestFindInBatchesWithError(t *testing.T) { t.Skip("skip sqlserver due to it will raise data race for invalid sql") } - var users = []User{ + users := []User{ *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), @@ -440,7 +440,7 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } - + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) diff --git a/tests/scan_test.go b/tests/scan_test.go index 59fc6de5d..1a188facf 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -45,7 +45,7 @@ func TestScan(t *testing.T) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } - var doubleAgeRes = &result{} + doubleAgeRes := &result{} if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { t.Errorf("Scan to pointer of pointer") } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index fb1f57917..14121699e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -182,11 +182,11 @@ func (data *EncryptedData) Scan(value interface{}) error { func (data EncryptedData) Value() (driver.Value, error) { if len(data) > 0 && data[0] == 'x' { - //needed to test failures + // needed to test failures return nil, errors.New("Should not start with 'x'") } - //prepend asterisks + // prepend asterisks return append([]byte("***"), data...), nil } diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 94fff3082..ab3807ea2 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -23,7 +23,7 @@ func NameIn(names []string) func(d *gorm.DB) *gorm.DB { } func TestScopes(t *testing.T) { - var users = []*User{ + users := []*User{ GetUser("ScopeUser1", Config{}), GetUser("ScopeUser2", Config{}), GetUser("ScopeUser3", Config{}), diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 2f9fd8dad..237d807b7 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -4,12 +4,11 @@ import ( "regexp" "strings" "testing" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" - - "time" ) func TestRow(t *testing.T) { @@ -389,12 +388,12 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { actually = replaceQuoteInSQL(actually) // ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. - var updatedAtRe = regexp.MustCompile(`(?i)"updated_at"=".+?"`) + updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) // ignore RETURNING "id" (only in PostgreSQL) - var returningRe = regexp.MustCompile(`(?i)RETURNING "id"`) + returningRe := regexp.MustCompile(`(?i)RETURNING "id"`) actually = returningRe.ReplaceAllString(actually, ``) expected = returningRe.ReplaceAllString(expected, ``) diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 736dfc5b6..8fe0f2897 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -8,7 +8,7 @@ import ( ) func TestUpdateBelongsTo(t *testing.T) { - var user = *GetUser("update-belongs-to", Config{}) + user := *GetUser("update-belongs-to", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 9066cbacc..2ca93e2be 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -8,7 +8,7 @@ import ( ) func TestUpdateHasManyAssociations(t *testing.T) { - var user = *GetUser("update-has-many", Config{}) + user := *GetUser("update-has-many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -44,7 +44,7 @@ func TestUpdateHasManyAssociations(t *testing.T) { CheckUser(t, user4, user) t.Run("Polymorphic", func(t *testing.T) { - var user = *GetUser("update-has-many", Config{}) + user := *GetUser("update-has-many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 59d30e42a..c926fbcf8 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -10,7 +10,7 @@ import ( ) func TestUpdateHasOne(t *testing.T) { - var user = *GetUser("update-has-one", Config{}) + user := *GetUser("update-has-one", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) @@ -35,7 +35,7 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Account").Find(&user3, "id = ?", user.ID) CheckUser(t, user2, user3) - var lastUpdatedAt = user2.Account.UpdatedAt + lastUpdatedAt := user2.Account.UpdatedAt time.Sleep(time.Second) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { @@ -53,7 +53,7 @@ func TestUpdateHasOne(t *testing.T) { } t.Run("Polymorphic", func(t *testing.T) { - var pet = Pet{Name: "create"} + pet := Pet{Name: "create"} if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index d94ef4ab9..f1218cc0f 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -8,7 +8,7 @@ import ( ) func TestUpdateMany2ManyAssociations(t *testing.T) { - var user = *GetUser("update-many2many", Config{}) + user := *GetUser("update-many2many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) diff --git a/tests/update_test.go b/tests/update_test.go index abe520db8..b471ba9be 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -125,7 +125,7 @@ func TestUpdate(t *testing.T) { } func TestUpdates(t *testing.T) { - var users = []*User{ + users := []*User{ GetUser("updates_01", Config{}), GetUser("updates_02", Config{}), } @@ -178,7 +178,7 @@ func TestUpdates(t *testing.T) { } func TestUpdateColumn(t *testing.T) { - var users = []*User{ + users := []*User{ GetUser("update_column_01", Config{}), GetUser("update_column_02", Config{}), } @@ -622,7 +622,7 @@ func TestSave(t *testing.T) { time.Sleep(time.Second) user1UpdatedAt := result.UpdatedAt user2UpdatedAt := user2.UpdatedAt - var users = []*User{&result, &user2} + users := []*User{&result, &user2} DB.Save(&users) if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index a7b53ab7c..c5d196055 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -67,7 +67,7 @@ func TestUpsert(t *testing.T) { } } - var user = *GetUser("upsert_on_conflict", Config{}) + user := *GetUser("upsert_on_conflict", Config{}) user.Age = 20 if err := DB.Create(&user).Error; err != nil { t.Errorf("failed to create user, got error %v", err) @@ -320,11 +320,9 @@ func TestUpdateWithMissWhere(t *testing.T) { if err := tx.Error; err != nil { t.Fatalf("failed to update user,missing where condtion,err=%+v", err) - } if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", tx.Statement.SQL.String()) } - } diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 84fdd2b6e..9543f750a 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -7,8 +7,7 @@ import ( "gorm.io/gorm/schema" ) -type DummyDialector struct { -} +type DummyDialector struct{} func (DummyDialector) Name() string { return "dummy" From f757b8fdc9f9fd52a1d6454b13394fc5561fa299 Mon Sep 17 00:00:00 2001 From: halfcrazy Date: Thu, 6 Jan 2022 18:55:20 +0800 Subject: [PATCH 67/83] fix: auto migration column order unpredictable (#4980) --- migrator/migrator.go | 7 +++-- tests/migrate_test.go | 72 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 2be15a7d1..138917fb5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -97,11 +97,12 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { columnTypes, _ := m.DB.Migrator().ColumnTypes(value) - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] var foundColumn gorm.ColumnType for _, columnType := range columnTypes { - if columnType.Name() == field.DBName { + if columnType.Name() == dbName { foundColumn = columnType break } @@ -109,7 +110,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { + if err := tx.Migrator().AddColumn(value, dbName); err != nil { return err } } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 15e851934..aa0a84ab5 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -2,11 +2,13 @@ package tests_test import ( "math/rand" + "reflect" "strings" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -454,3 +456,73 @@ func TestMigrateIndexesWithDynamicTableName(t *testing.T) { } } } + +// check column order after migration, flaky test +// https://github.com/go-gorm/gorm/issues/4351 +func TestMigrateColumnOrder(t *testing.T) { + type UserMigrateColumn struct { + ID uint + } + DB.Migrator().DropTable(&UserMigrateColumn{}) + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + F1 string + F2 string + F3 string + F4 string + F5 string + F6 string + F7 string + F8 string + F9 string + F10 string + F11 string + F12 string + F13 string + F14 string + F15 string + F16 string + F17 string + F18 string + F19 string + F20 string + F21 string + F22 string + F23 string + F24 string + F25 string + F26 string + F27 string + F28 string + F29 string + F30 string + F31 string + F32 string + F33 string + F34 string + F35 string + } + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn2{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + typ := reflect.Indirect(reflect.ValueOf(&UserMigrateColumn2{})).Type() + numField := typ.NumField() + if numField != len(columnTypes) { + t.Fatalf("column's number not match struct and ddl, %d != %d", numField, len(columnTypes)) + } + namer := schema.NamingStrategy{} + for i := 0; i < numField; i++ { + expectName := namer.ColumnName("", typ.Field(i).Name) + if columnTypes[i].Name() != expectName { + t.Fatalf("column order not match struct and ddl, idx %d: %s != %s", + i, columnTypes[i].Name(), expectName) + } + } +} From 0df42e9afc15544a6927e4393b36f2ebd32a561e Mon Sep 17 00:00:00 2001 From: kinggo <30891428+longlihale@users.noreply.github.com> Date: Fri, 7 Jan 2022 09:49:56 +0800 Subject: [PATCH 68/83] feat: add `Connection` to execute multiple commands in a single connection; (#4982) --- finisher_api.go | 24 ++++++++++++++++++++ tests/connection_test.go | 48 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/connection_test.go diff --git a/finisher_api.go b/finisher_api.go index d38d60b7e..dd0eb83a8 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } +// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +func (db *DB) Connection(fc func(tx *DB) error) (err error) { + if db.Error != nil { + return db.Error + } + + tx := db.getInstance() + sqlDB, err := tx.DB() + if err != nil { + return + } + + conn, err := sqlDB.Conn(tx.Statement.Context) + if err != nil { + return + } + + defer conn.Close() + tx.Statement.ConnPool = conn + err = fc(tx) + + return +} + // Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true diff --git a/tests/connection_test.go b/tests/connection_test.go new file mode 100644 index 000000000..9b5dcd058 --- /dev/null +++ b/tests/connection_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "fmt" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "testing" +) + +func TestWithSingleConnection(t *testing.T) { + + var expectedName = "test" + var actualName string + + setSQL, getSQL := getSetSQL(DB.Dialector.Name()) + if len(setSQL) == 0 || len(getSQL) == 0 { + return + } + + err := DB.Connection(func(tx *gorm.DB) error { + if err := tx.Exec(setSQL, expectedName).Error; err != nil { + return err + } + + if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil { + return err + } + return nil + }) + + if err != nil { + t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) + } + + if actualName != expectedName { + t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) + } + +} + +func getSetSQL(driverName string) (string, string) { + switch driverName { + case mysql.Dialector{}.Name(): + return "SET @testName := ?", "SELECT @testName" + default: + return "", "" + } +} From eae73624ad43384d34ee0c9f85055b1fe48434b1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Jan 2022 10:04:35 +0800 Subject: [PATCH 69/83] Fix return failed to begin transaction error when failed to start a transaction --- finisher_api.go | 24 ++++++++++++------------ tests/connection_test.go | 5 ++--- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index dd0eb83a8..355d89bd9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -534,9 +534,7 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { defer conn.Close() tx.Statement.ConnPool = conn - err = fc(tx) - - return + return fc(tx) } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. @@ -547,6 +545,10 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er // nested transaction if !db.DisableNestedTransaction { err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + if err != nil { + return + } + defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { @@ -555,11 +557,12 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er }() } - if err == nil { - err = fc(db.Session(&Session{})) - } + err = fc(db.Session(&Session{})) } else { tx := db.Begin(opts...) + if tx.Error != nil { + return tx.Error + } defer func() { // Make sure to rollback when panic, Block error or Commit error @@ -568,12 +571,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - if err = tx.Error; err == nil { - err = fc(tx) - } - - if err == nil { - err = tx.Commit().Error + if err = fc(tx); err == nil { + panicked = false + return tx.Commit().Error } } diff --git a/tests/connection_test.go b/tests/connection_test.go index 9b5dcd058..92b13dd68 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -2,13 +2,13 @@ package tests_test import ( "fmt" + "testing" + "gorm.io/driver/mysql" "gorm.io/gorm" - "testing" ) func TestWithSingleConnection(t *testing.T) { - var expectedName = "test" var actualName string @@ -35,7 +35,6 @@ func TestWithSingleConnection(t *testing.T) { if actualName != expectedName { t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) } - } func getSetSQL(driverName string) (string, string) { From a0d6ff1feadcac2480af2b3cbc4db3d47b0a8f42 Mon Sep 17 00:00:00 2001 From: piyongcai Date: Wed, 12 Jan 2022 13:11:40 +0800 Subject: [PATCH 70/83] time.Time, []byte type add alias support. (rebase master) (#4992) * time.Time, []byte type add alias support * reformat --- schema/field.go | 3 ++- schema/field_test.go | 37 ++++++++++++++++++++++--------------- statement.go | 3 +++ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/schema/field.go b/schema/field.go index d4f879c57..485bbdf3d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -346,7 +346,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { + if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && + (ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { kind := reflect.Indirect(fieldValue).Kind() switch kind { case reflect.Struct: diff --git a/schema/field_test.go b/schema/field_test.go index 2cf2d0838..8fa46b876 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -262,21 +262,24 @@ func TestParseFieldWithPermission(t *testing.T) { } type ( - ID int64 - INT int - INT8 int8 - INT16 int16 - INT32 int32 - INT64 int64 - UINT uint - UINT8 uint8 - UINT16 uint16 - UINT32 uint32 - UINT64 uint64 - FLOAT32 float32 - FLOAT64 float64 - BOOL bool - STRING string + ID int64 + INT int + INT8 int8 + INT16 int16 + INT32 int32 + INT64 int64 + UINT uint + UINT8 uint8 + UINT16 uint16 + UINT32 uint32 + UINT64 uint64 + FLOAT32 float32 + FLOAT64 float64 + BOOL bool + STRING string + TIME time.Time + BYTES []byte + TypeAlias struct { ID INT `gorm:"column:fint"` @@ -293,6 +296,8 @@ type ( FLOAT64 `gorm:"column:ffloat64"` BOOL `gorm:"column:fbool"` STRING `gorm:"column:fstring"` + TIME `gorm:"column:ftime"` + BYTES `gorm:"column:fbytes"` } ) @@ -318,6 +323,8 @@ func TestTypeAliasField(t *testing.T) { {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, + {Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`}, + {Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`}, } for _, f := range fields { diff --git a/statement.go b/statement.go index f69339d4f..146722a9c 100644 --- a/statement.go +++ b/statement.go @@ -232,6 +232,9 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case reflect.Slice, reflect.Array: if rv.Len() == 0 { writer.WriteString("(NULL)") + } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) { + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) } else { writer.WriteByte('(') for i := 0; i < rv.Len(); i++ { From e5894ca44951fecc3b3f31f1aa46df7de6024b04 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Jan 2022 13:11:57 +0800 Subject: [PATCH 71/83] chore(deps): bump gorm.io/driver/mysql from 1.2.1 to 1.2.3 in /tests (#4987) Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.2.1 to 1.2.3. - [Release notes](https://github.com/go-gorm/mysql/releases) - [Commits](https://github.com/go-gorm/mysql/compare/v1.2.1...v1.2.3) --- updated-dependencies: - dependency-name: gorm.io/driver/mysql dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index c3133f38d..3233ea951 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect - gorm.io/driver/mysql v1.2.1 + gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.2.1 From cec0d32aecc8d5068873304abe7f85e9409d4b10 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Jan 2022 18:48:32 +0800 Subject: [PATCH 72/83] Support use clause.Expression as argument --- clause/select_test.go | 17 +++++++++++++++++ statement.go | 2 ++ tests/go.mod | 4 +++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/clause/select_test.go b/clause/select_test.go index 9fce0783f..18bc2693b 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -43,6 +43,23 @@ func TestSelect(t *testing.T) { }, clause.From{}}, "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: "? as name", + Vars: []interface{}{clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, + }, + }, + }, + }, + }, clause.From{}}, + "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + }, } for idx, result := range results { diff --git a/statement.go b/statement.go index 146722a9c..72359da29 100644 --- a/statement.go +++ b/statement.go @@ -183,6 +183,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { v.Build(stmt) case *clause.Expr: v.Build(stmt) + case clause.Expression: + v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) diff --git a/tests/go.mod b/tests/go.mod index 3233ea951..5415cf746 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,11 +3,13 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.14.1 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect + github.com/mattn/go-sqlite3 v1.14.10 // indirect + golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 From 98c4b78e4dcceea93eaaabd051f8c021e645e017 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Jan 2022 19:26:10 +0800 Subject: [PATCH 73/83] Add Session Initialized option --- gorm.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gorm.go b/gorm.go index fc70f6845..a982bee40 100644 --- a/gorm.go +++ b/gorm.go @@ -96,6 +96,7 @@ type Session struct { DryRun bool PrepareStmt bool NewDB bool + Initialized bool SkipHooks bool SkipDefaultTransaction bool DisableNestedTransaction bool @@ -282,6 +283,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.NowFunc = config.NowFunc } + if config.Initialized { + tx = tx.getInstance() + } + return tx } From c0bea447b9eb707cfc1712d2d423f43309e247a2 Mon Sep 17 00:00:00 2001 From: li-jin-gou <97824201+li-jin-gou@users.noreply.github.com> Date: Fri, 28 Jan 2022 22:16:42 +0800 Subject: [PATCH 74/83] fix: omit not work when use join (#5034) --- callbacks/query.go | 2 +- tests/connection_test.go | 3 +-- tests/joins_test.go | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index c2bbf5f91..490863549 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -100,7 +100,7 @@ func BuildQuerySQL(db *gorm.DB) { } if len(db.Statement.Joins) != 0 || len(joins) != 0 { - if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} diff --git a/tests/connection_test.go b/tests/connection_test.go index 92b13dd68..7bc23009d 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -9,7 +9,7 @@ import ( ) func TestWithSingleConnection(t *testing.T) { - var expectedName = "test" + expectedName := "test" var actualName string setSQL, getSQL := getSetSQL(DB.Dialector.Name()) @@ -27,7 +27,6 @@ func TestWithSingleConnection(t *testing.T) { } return nil }) - if err != nil { t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) } diff --git a/tests/joins_test.go b/tests/joins_test.go index e276a74a1..4c9cffae9 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -158,6 +158,22 @@ func TestJoinsWithSelect(t *testing.T) { } } +func TestJoinWithOmit(t *testing.T) { + user := *GetUser("joins_with_omit", Config{Pets: 2}) + DB.Save(&user) + + results := make([]*User, 0) + + if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil { + return + } + + if len(results) != 2 || results[0].Name != "" || results[1].Name != "" { + t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results) + return + } +} + func TestJoinCount(t *testing.T) { companyA := Company{Name: "A"} companyB := Company{Name: "B"} From 8c3673286dc6091967e2349687f0dbbaa55d66f8 Mon Sep 17 00:00:00 2001 From: Ning Date: Sun, 30 Jan 2022 18:17:06 +0800 Subject: [PATCH 75/83] preoload not allowd before count (#5023) Co-authored-by: ningfei --- errors.go | 2 ++ finisher_api.go | 4 ++++ tests/count_test.go | 10 ++++++++++ 3 files changed, 16 insertions(+) diff --git a/errors.go b/errors.go index 145614d94..49cbfe64a 100644 --- a/errors.go +++ b/errors.go @@ -39,4 +39,6 @@ var ( ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") // ErrInvalidValueOfLength invalid values do not match length ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") + // ErrPreloadNotAllowed preload is not allowed when count is used + ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") ) diff --git a/finisher_api.go b/finisher_api.go index 355d89bd9..cbbd48cbf 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -367,6 +367,10 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() + if len(tx.Statement.Preloads) > 0 { + tx.AddError(ErrPreloadNotAllowed) + return + } if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { diff --git a/tests/count_test.go b/tests/count_test.go index 27d7ee607..b63a55fcc 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -144,4 +144,14 @@ func TestCount(t *testing.T) { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } + + var count12 int64 + if err := DB.Table("users"). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { + t.Errorf("should returns preload not allowed error, but got %v", err) + } + } From 8d293d44dd7e4e6f61d759cb6c9a5be2c6523c5e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 30 Jan 2022 22:00:56 +0800 Subject: [PATCH 76/83] Fix docker-compose test env for Mac M1 --- tests/docker-compose.yml | 4 ++-- tests/go.mod | 6 +++--- tests/tests_all.sh | 17 +++++++++++++++++ tests/tests_test.go | 11 ++++++----- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 05e0956ee..9ab4ddb66 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: mysql: - image: 'mysql:latest' + image: 'mysql/mysql-server:latest' ports: - 9910:3306 environment: @@ -20,7 +20,7 @@ services: - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: - image: 'mcmoe/mssqldocker:latest' + image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' ports: - 9930:1433 environment: diff --git a/tests/go.mod b/tests/go.mod index 5415cf746..f2addaa15 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,13 +8,13 @@ require ( github.com/jackc/pgx/v4 v4.14.1 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - github.com/mattn/go-sqlite3 v1.14.10 // indirect - golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 // indirect + github.com/mattn/go-sqlite3 v1.14.11 // indirect + golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.2.1 - gorm.io/gorm v1.22.4 + gorm.io/gorm v1.22.5 ) replace gorm.io/gorm => ../ diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 79e0b5b71..e1f394e53 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -15,6 +15,23 @@ then cd .. fi +# SqlServer for Mac M1 +if [ -d tests ] +then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null + else + docker-compose start + fi + cd .. +fi + + for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then diff --git a/tests/tests_test.go b/tests/tests_test.go index e26f358df..11b6f0675 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -62,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) { PreferSimpleProtocol: true, }), &gorm.Config{}) case "sqlserver": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 // CREATE DATABASE gorm; - // USE gorm; + // GO + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - // npm install -g sql-cli - // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 + // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; + // GO log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" From f19b84d104a2659af7b32c1cacd92a35efa33d34 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 30 Jan 2022 22:32:34 +0800 Subject: [PATCH 77/83] Fix github action --- .github/workflows/tests.yml | 8 ++++---- tests/tests_all.sh | 26 ++++++++++++++------------ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 700af759d..91a0abc9f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlite ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh mysql: strategy: @@ -77,7 +77,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: strategy: @@ -120,7 +120,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: @@ -163,4 +163,4 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh diff --git a/tests/tests_all.sh b/tests/tests_all.sh index e1f394e53..5b9bae97a 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -16,19 +16,21 @@ then fi # SqlServer for Mac M1 -if [ -d tests ] -then - cd tests - if [[ $(uname -a) == *" arm64" ]]; then - MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start - go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest - SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null - SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null - SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null - else - docker-compose start +if [[ -z $GITHUB_ACTION ]]; then + if [ -d tests ] + then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true + else + docker-compose start + fi + cd .. fi - cd .. fi From 581a879bf1ff1af7fcb361f0c6e4b201dbed75f0 Mon Sep 17 00:00:00 2001 From: Saurabh Thakre Date: Mon, 31 Jan 2022 17:26:28 +0530 Subject: [PATCH 78/83] Added comments to existing methods Added two comments to describe FirstOrInit and FirstOrCreate methods. --- finisher_api.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index cbbd48cbf..3a1799778 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -255,7 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } - +// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -281,6 +281,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } +// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, From 416c4d0653ce6e0569e6c868963a6c3cc769c2fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Feb 2022 16:31:24 +0800 Subject: [PATCH 79/83] Test query with Or and soft delete --- tests/go.mod | 4 ++-- tests/query_test.go | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index f2addaa15..5488c17e9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,11 +5,11 @@ go 1.14 require ( github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v4 v4.14.1 // indirect + github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed // indirect + golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/query_test.go b/tests/query_test.go index c99214b60..d10df1807 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -512,7 +512,13 @@ func TestNotWithAllFields(t *testing.T) { func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + var count int64 + result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count) + if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } From d22215129ee4747f9a9dd5b089d9f6920efc91ad Mon Sep 17 00:00:00 2001 From: li-jin-gou <97824201+li-jin-gou@users.noreply.github.com> Date: Tue, 8 Feb 2022 17:06:10 +0800 Subject: [PATCH 80/83] fix: replace empty table name result in panic (#5048) * fix: replace empty name result in panic * fix: replace empty table name result in panic --- schema/naming.go | 8 +++++++- schema/naming_test.go | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/schema/naming.go b/schema/naming.go index 8407bffa1..a4e3a75b6 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -120,7 +120,13 @@ func (ns NamingStrategy) toDBName(name string) string { } if ns.NameReplacer != nil { - name = ns.NameReplacer.Replace(name) + tmpName := ns.NameReplacer.Replace(name) + + if tmpName == "" { + return name + } + + name = tmpName } if ns.NoLowerCase { diff --git a/schema/naming_test.go b/schema/naming_test.go index c3e6bf923..1fdab9a06 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -197,3 +197,14 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { t.Errorf("invalid formatted name generated, got %v", formattedName) } } + +func TestReplaceEmptyTableName(t *testing.T) { + ns := NamingStrategy{ + SingularTable: true, + NameReplacer: strings.NewReplacer("Model", ""), + } + tableName := ns.TableName("Model") + if tableName != "Model" { + t.Errorf("invalid table name generated, got %v", tableName) + } +} From 4eeb839ceabb983b634f9cf9fffa1dd773b6803d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 15:17:19 +0800 Subject: [PATCH 81/83] Better support Stringer when explain SQL --- logger/logger.go | 14 ++++++++++- logger/sql.go | 24 ++++++++++++++---- tests/go.mod | 2 +- tests/sql_builder_test.go | 53 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 7 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 0c4ca4a01..2ffd28d5a 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/utils" ) +// ErrRecordNotFound record not found error var ErrRecordNotFound = errors.New("record not found") // Colors @@ -30,13 +31,17 @@ const ( YellowBold = "\033[33;1m" ) -// LogLevel +// LogLevel log level type LogLevel int const ( + // Silent silent log level Silent LogLevel = iota + 1 + // Error error log level Error + // Warn warn log level Warn + // Info info log level Info ) @@ -45,6 +50,7 @@ type Writer interface { Printf(string, ...interface{}) } +// Config logger config type Config struct { SlowThreshold time.Duration Colorful bool @@ -62,16 +68,20 @@ type Interface interface { } var ( + // Discard Discard logger will print any log to ioutil.Discard Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, Colorful: true, }) + // Recorder Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) +// New initialize logger func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " @@ -179,10 +189,12 @@ type traceRecorder struct { Err error } +// New new trace recorder func (l traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } +// Trace implement logger interface func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { l.BeginAt = begin l.SQL, l.RowsAffected = fc() diff --git a/logger/sql.go b/logger/sql.go index 5ecb0ae23..e0be57c01 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,9 +30,12 @@ func isPrintable(s []byte) bool { var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { - var convertParams func(interface{}, int) - vars := make([]string, len(avars)) + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) convertParams = func(v interface{}, idx int) { switch v := v.(type) { @@ -64,10 +67,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case fmt.Stringer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper - } else { - vars[idx] = nullStr + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = nullStr + } } case []byte: if isPrintable(v) { diff --git a/tests/go.mod b/tests/go.mod index 5488c17e9..3453f77b0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect + golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 237d807b7..897f687f7 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) { } } +type ageInt int8 + +func (ageInt) String() string { + return "age" +} + +type ageBool bool + +func (ageBool) String() string { + return "age" +} + +type ageUint64 uint64 + +func (ageUint64) String() string { + return "age" +} + +type ageFloat float64 + +func (ageFloat) String() string { + return "age" +} + +func TestExplainSQL(t *testing.T) { + user := *GetUser("explain-sql", Config{}) + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement + sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } +} + func TestGroupConditions(t *testing.T) { type Pizza struct { ID uint From df2365057bb6c809b03d470323238262a93a9685 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 17:23:16 +0800 Subject: [PATCH 82/83] Remove uncessary switch case --- statement.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/statement.go b/statement.go index 72359da29..232126426 100644 --- a/statement.go +++ b/statement.go @@ -179,10 +179,6 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } - case clause.Expr: - v.Build(stmt) - case *clause.Expr: - v.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: From a0aceeb33e7eabbecae5b7fd2eef874b1a77b086 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 17:39:01 +0800 Subject: [PATCH 83/83] Migrator AlterColumn with full data type --- gorm.go | 6 ++++++ migrator/migrator.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index a982bee40..7967b0945 100644 --- a/gorm.go +++ b/gorm.go @@ -59,6 +59,7 @@ type Config struct { cacheStore *sync.Map } +// Apply update config to new config func (c *Config) Apply(config *Config) error { if config != c { *config = *c @@ -66,6 +67,7 @@ func (c *Config) Apply(config *Config) error { return nil } +// AfterInitialize initialize plugins after db connected func (c *Config) AfterInitialize(db *DB) error { if db != nil { for _, plugin := range c.Plugins { @@ -77,6 +79,7 @@ func (c *Config) AfterInitialize(db *DB) error { return nil } +// Option gorm option interface type Option interface { Apply(*Config) error AfterInitialize(*DB) error @@ -381,10 +384,12 @@ func (db *DB) getInstance() *DB { return db } +// Expr returns clause.Expr, which can be used to pass SQL expression as params func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } +// SetupJoinTable setup join table schema func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { var ( tx = db.getInstance() @@ -435,6 +440,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac return nil } +// Use use plugin func (db *DB) Use(plugin Plugin) error { name := plugin.Name() if _, ok := db.Plugins[name]; ok { diff --git a/migrator/migrator.go b/migrator/migrator.go index 138917fb5..80c4e2b3c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -337,7 +337,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - fileType := clause.Expr{SQL: m.DataTypeOf(field)} + fileType := m.FullDataTypeOf(field) return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,