diff --git a/internal/dbutil/testing.go b/internal/dbutil/testing.go index 0b81687..378f913 100644 --- a/internal/dbutil/testing.go +++ b/internal/dbutil/testing.go @@ -34,7 +34,7 @@ func NewTestDB(t *testing.T, migrationTables ...interface{}) (testDB *gorm.DB, c dsn = os.ExpandEnv("$DB_USER:$DB_PASSWORD@tcp($DB_HOST:$DB_PORT)/$DB_DATABASE?charset=utf8mb4&parseTime=True&loc=Local") dialectFunc = mysql.Open case "postgres": - dsn = os.ExpandEnv("host=$DB_HOST user=$DB_USER password=$DB_PASSWORD port=$DB_PORT sslmode=disable TimeZone=Asia/Shanghai") + dsn = os.ExpandEnv("host=$DB_HOST user=$DB_USER password=$DB_PASSWORD dbname=$DB_DATABASE port=$DB_PORT sslmode=disable TimeZone=Asia/Shanghai") dialectFunc = postgres.Open default: t.Fatalf("Unknown database type: %q", dbType) @@ -111,7 +111,15 @@ func NewTestDB(t *testing.T, migrationTables ...interface{}) (testDB *gorm.DB, c } for _, table := range tables { - err := testDB.WithContext(ctx).Exec(`TRUNCATE TABLE ` + QuoteIdentifier(dbType, table)).Error + var query string + switch dbType { + case "mysql": + query = `TRUNCATE TABLE ` + QuoteIdentifier(dbType, table) + case "postgres": + query = `TRUNCATE TABLE ` + QuoteIdentifier(dbType, table) + ` RESTART IDENTITY CASCADE` + } + + err := testDB.WithContext(ctx).Exec(query).Error if err != nil { return err }