Skip to content

Commit

Permalink
refactored the enum mirator to also be simpler to maintain
Browse files Browse the repository at this point in the history
  • Loading branch information
CommanderStorm committed Oct 23, 2023
1 parent c03dbc8 commit 24f45fd
Showing 1 changed file with 13 additions and 41 deletions.
54 changes: 13 additions & 41 deletions server/backend/migration/safe_enum_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,27 @@ import (
"fmt"
"strings"

log "github.com/sirupsen/logrus"

"gorm.io/gorm"
)

func SafeEnumAdd(tx *gorm.DB, table interface{}, column string, additionalTypes ...string) error {
enumTypes, err := getEnumTypesFromDB(tx, table, column)

if err != nil {
return err
}

enumTypes = append(enumTypes, additionalTypes...)
enumTypes = ensureUnique(enumTypes)

return alterEnumColumn(tx, table, column, enumTypes)
return alterEnumColumn(tx, table, column, ensureUnique(append(enumTypes, additionalTypes...)))
}

func SafeEnumRemove(tx *gorm.DB, table interface{}, column string, rollbackTypes ...string) error {
enumTypes, err := getEnumTypesFromDB(tx, table, column)

if err != nil {
return err
}

enumTypes = RemoveTypes(enumTypes, rollbackTypes...)

return alterEnumColumn(tx, table, column, enumTypes)
return alterEnumColumn(tx, table, column, RemoveTypes(enumTypes, rollbackTypes...))
}

func ensureUnique(types []string) []string {
Expand All @@ -46,44 +41,34 @@ func ensureUnique(types []string) []string {
}

func alterEnumColumn(tx *gorm.DB, table interface{}, column string, types []string) error {
enum := BuildEnum(types)

stmt := &gorm.Statement{DB: tx}
err := stmt.Parse(&table)
if err != nil {
return errors.New("could not parse enum table")
}
tableName := stmt.Schema.Table

rawQuery := fmt.Sprintf(
if err := tx.Exec(fmt.Sprintf(
"ALTER TABLE %s MODIFY %s %s;",
tableName,
stmt.Schema.Table,
column,
enum,
)

tx = tx.Exec(rawQuery)

if tx.Error != nil {
BuildEnum(types),
)).Error; err != nil {
log.WithError(err).Error("Error altering enum table")
return errors.New("could not alter enum table")
}

return nil
}

func getEnumTypesFromDB(tx *gorm.DB, table interface{}, column string) ([]string, error) {
columnType, err := tx.Migrator().ColumnTypes(&table)

if err != nil {
return nil, errors.New("could not get enum column types")
}

enumTypes, err := getEnumTypes(columnType, column)

if err != nil {
return nil, err
}

return enumTypes, nil
}

Expand All @@ -100,20 +85,16 @@ func RemoveTypes(types []string, rollbackTypes ...string) []string {
}

func getEnumTypes(columTypes []gorm.ColumnType, column string) ([]string, error) {
var cType string

for _, t := range columTypes {
if t.Name() == column {
if t, ok := t.ColumnType(); ok {
cType = t
break
return EnumTypesFromString(t)
} else {
return nil, errors.New("could not get column type")
}
}
}

return EnumTypesFromString(cType)
return nil, errors.New("column does not exist")
}

func EnumTypesFromString(enum string) ([]string, error) {
Expand All @@ -138,15 +119,6 @@ func EnumTypesFromString(enum string) ([]string, error) {
}

func BuildEnum(types []string) string {
str := "enum("

for _, t := range types {
str += fmt.Sprintf("'%s',", t)
}

str = strings.TrimRight(str, ",")

str += ")"

return str
enums := strings.Join(types, "','")
return fmt.Sprintf("enum('%s')", enums)
}

0 comments on commit 24f45fd

Please sign in to comment.