diff --git a/server/backend/migration/safe_enum_migrate.go b/server/backend/migration/safe_enum_migrate.go index c034f861..a5f91a6e 100644 --- a/server/backend/migration/safe_enum_migrate.go +++ b/server/backend/migration/safe_enum_migrate.go @@ -5,32 +5,27 @@ import ( "fmt" "strings" + log "github.com/sirupsen/logrus" + "gorm.io/gorm" ) func SafeEnumAdd(tx *gorm.DB, table interface{}, column string, additionalTypes ...string) error { enumTypes, err := getEnumTypesFromDB(tx, table, column) - if err != nil { return err } - enumTypes = append(enumTypes, additionalTypes...) - enumTypes = ensureUnique(enumTypes) - - return alterEnumColumn(tx, table, column, enumTypes) + return alterEnumColumn(tx, table, column, ensureUnique(append(enumTypes, additionalTypes...))) } func SafeEnumRemove(tx *gorm.DB, table interface{}, column string, rollbackTypes ...string) error { enumTypes, err := getEnumTypesFromDB(tx, table, column) - if err != nil { return err } - enumTypes = RemoveTypes(enumTypes, rollbackTypes...) - - return alterEnumColumn(tx, table, column, enumTypes) + return alterEnumColumn(tx, table, column, RemoveTypes(enumTypes, rollbackTypes...)) } func ensureUnique(types []string) []string { @@ -46,44 +41,34 @@ func ensureUnique(types []string) []string { } func alterEnumColumn(tx *gorm.DB, table interface{}, column string, types []string) error { - enum := BuildEnum(types) - stmt := &gorm.Statement{DB: tx} err := stmt.Parse(&table) if err != nil { return errors.New("could not parse enum table") } - tableName := stmt.Schema.Table - rawQuery := fmt.Sprintf( + if err := tx.Exec(fmt.Sprintf( "ALTER TABLE %s MODIFY %s %s;", - tableName, + stmt.Schema.Table, column, - enum, - ) - - tx = tx.Exec(rawQuery) - - if tx.Error != nil { + BuildEnum(types), + )).Error; err != nil { + log.WithError(err).Error("Error altering enum table") return errors.New("could not alter enum table") } - return nil } func getEnumTypesFromDB(tx *gorm.DB, table interface{}, column string) ([]string, error) { columnType, err := tx.Migrator().ColumnTypes(&table) - if err != nil { return nil, errors.New("could not get enum column types") } enumTypes, err := getEnumTypes(columnType, column) - if err != nil { return nil, err } - return enumTypes, nil } @@ -100,20 +85,16 @@ func RemoveTypes(types []string, rollbackTypes ...string) []string { } func getEnumTypes(columTypes []gorm.ColumnType, column string) ([]string, error) { - var cType string - for _, t := range columTypes { if t.Name() == column { if t, ok := t.ColumnType(); ok { - cType = t - break + return EnumTypesFromString(t) } else { return nil, errors.New("could not get column type") } } } - - return EnumTypesFromString(cType) + return nil, errors.New("column does not exist") } func EnumTypesFromString(enum string) ([]string, error) { @@ -138,15 +119,6 @@ func EnumTypesFromString(enum string) ([]string, error) { } func BuildEnum(types []string) string { - str := "enum(" - - for _, t := range types { - str += fmt.Sprintf("'%s',", t) - } - - str = strings.TrimRight(str, ",") - - str += ")" - - return str + enums := strings.Join(types, "','") + return fmt.Sprintf("enum('%s')", enums) }