From 883eb1862d5839546eff65e5ba74a38087770e27 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Sat, 4 Jun 2022 13:13:18 +0200 Subject: [PATCH] Add migration functions --- dbump.go | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/dbump.go b/dbump.go index 424e572..77ef8f4 100644 --- a/dbump.go +++ b/dbump.go @@ -31,10 +31,20 @@ type Loader interface { // Migration represents migration step that will be runned on DB. type Migration struct { - ID int // ID of the migration, unique, positive, starts from 1. - Name string // Name of the migration - Apply string // Apply query - Rollback string // Rollback query + ID int // ID of the migration, unique, positive, starts from 1. + Name string // Name of the migration + Apply string // Apply query + Rollback string // Rollback query + ApplyFn MigrationFn // Apply func + RollbackFn MigrationFn // Rollback func + + isQuery bool // shortcut for the type of migration (query or func) +} + +type MigrationFn func(db DB) error + +type DB interface { + Exec(ctx context.Context, query string, args ...interface{}) error } // Run the Migrator with migration queries provided by the Loader. @@ -62,7 +72,10 @@ func loadMigrations(ms []*Migration, err error) ([]*Migration, error) { case m.ID > want: return nil, fmt.Errorf("missing migration number: %d (have %d)", want, m.ID) default: - // pass + if (m.Apply != "" || m.Rollback != "") && (m.ApplyFn != nil || m.RollbackFn != nil) { + return nil, fmt.Errorf("mixing queries and functions is not allowed (migration %d)", m.ID) + } + m.isQuery = m.Apply != "" } } return ms, nil @@ -115,15 +128,20 @@ func runMigrationExclusive(ctx context.Context, m Migrator, ms []*Migration) err for currentVersion != targetVersion { current := ms[currentVersion] sequence := current.ID - query := current.Apply + query, queryFn := current.Apply, current.ApplyFn if direction == -1 { current = ms[currentVersion-1] sequence = current.ID - 1 - query = current.Rollback + query, queryFn = current.Rollback, current.RollbackFn } - if err := m.Exec(ctx, query); err != nil { + if current.isQuery { + err = m.Exec(ctx, query) + } else { + err = queryFn(m) + } + if err != nil { return fmt.Errorf("exec: %w", err) }