Skip to content

Commit

Permalink
Add migration functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jun 4, 2022
1 parent eeca5c5 commit 883eb18
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions dbump.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,20 @@ type Loader interface {

// Migration represents migration step that will be runned on DB.
type Migration struct {
ID int // ID of the migration, unique, positive, starts from 1.
Name string // Name of the migration
Apply string // Apply query
Rollback string // Rollback query
ID int // ID of the migration, unique, positive, starts from 1.
Name string // Name of the migration
Apply string // Apply query
Rollback string // Rollback query
ApplyFn MigrationFn // Apply func
RollbackFn MigrationFn // Rollback func

isQuery bool // shortcut for the type of migration (query or func)
}

type MigrationFn func(db DB) error

type DB interface {
Exec(ctx context.Context, query string, args ...interface{}) error
}

// Run the Migrator with migration queries provided by the Loader.
Expand Down Expand Up @@ -62,7 +72,10 @@ func loadMigrations(ms []*Migration, err error) ([]*Migration, error) {
case m.ID > want:
return nil, fmt.Errorf("missing migration number: %d (have %d)", want, m.ID)
default:
// pass
if (m.Apply != "" || m.Rollback != "") && (m.ApplyFn != nil || m.RollbackFn != nil) {
return nil, fmt.Errorf("mixing queries and functions is not allowed (migration %d)", m.ID)
}
m.isQuery = m.Apply != ""
}
}
return ms, nil
Expand Down Expand Up @@ -115,15 +128,20 @@ func runMigrationExclusive(ctx context.Context, m Migrator, ms []*Migration) err
for currentVersion != targetVersion {
current := ms[currentVersion]
sequence := current.ID
query := current.Apply
query, queryFn := current.Apply, current.ApplyFn

if direction == -1 {
current = ms[currentVersion-1]
sequence = current.ID - 1
query = current.Rollback
query, queryFn = current.Rollback, current.RollbackFn
}

if err := m.Exec(ctx, query); err != nil {
if current.isQuery {
err = m.Exec(ctx, query)
} else {
err = queryFn(m)
}
if err != nil {
return fmt.Errorf("exec: %w", err)
}

Expand Down

0 comments on commit 883eb18

Please sign in to comment.