From d72731512bdb5a5e60ce432360f0abee0dc86482 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Fri, 1 Jul 2022 00:02:35 +0200 Subject: [PATCH] Add Before/After step --- dbump.go | 18 ++++++++++++++++++ dbump_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/dbump.go b/dbump.go index 8307cfa..7065959 100644 --- a/dbump.go +++ b/dbump.go @@ -43,6 +43,13 @@ type Config struct { // Going up does apply-revert-apply of each migration. // Going down does revert-apply-revert of each migration. ZigZag bool + + // BeforeStep function will be invoked right before the DoStep for each step. + // Default is nil and means no-op. + BeforeStep func(ctx context.Context, step Step) + // AfterStep function will be invoked right after the DoStep for each step. + // Default is nil and means no-op. + AfterStep func(ctx context.Context, step Step) } // Migrator represents database over which we will run migrations. @@ -120,6 +127,13 @@ func Run(ctx context.Context, config Config) error { return fmt.Errorf("incorrect mode provided: %d", config.Mode) } + if config.BeforeStep == nil { + config.BeforeStep = func(ctx context.Context, step Step) {} + } + if config.AfterStep == nil { + config.AfterStep = func(ctx context.Context, step Step) {} + } + m := mig{ Config: config, Migrator: config.Migrator, @@ -203,9 +217,13 @@ func (m *mig) runMigrationsLocked(ctx context.Context, ms []*Migration) error { } for _, step := range m.prepareSteps(curr, target, ms) { + m.BeforeStep(ctx, step) + if err := m.DoStep(ctx, step); err != nil { return err } + + m.AfterStep(ctx, step) } return nil } diff --git a/dbump_test.go b/dbump_test.go index a0a64e7..70f89aa 100644 --- a/dbump_test.go +++ b/dbump_test.go @@ -3,6 +3,7 @@ package dbump import ( "context" "errors" + "fmt" "reflect" "testing" ) @@ -230,6 +231,40 @@ func TestMigrateDrop(t *testing.T) { mustEqual(t, mm.log, wantLog) } +func TestBeforeAfterStep(t *testing.T) { + currVersion := 3 + wantLog := []string{ + "lockdb", "init", "getversion", + "before", "{v:4 q:'SELECT 4;' notx:false}", + "dostep", "{v:4 q:'SELECT 4;' notx:false}", + "after", "{v:4 q:'SELECT 4;' notx:false}", + "before", "{v:5 q:'SELECT 5;' notx:false}", + "dostep", "{v:5 q:'SELECT 5;' notx:false}", + "after", "{v:5 q:'SELECT 5;' notx:false}", + "unlockdb", + } + + mm := &MockMigrator{ + VersionFn: func(ctx context.Context) (version int, err error) { + return currVersion, nil + }, + } + cfg := Config{ + Migrator: mm, + Loader: NewSliceLoader(testdataMigrations), + Mode: ModeUp, + BeforeStep: func(ctx context.Context, step Step) { + mm.log = append(mm.log, "before", fmt.Sprintf("{v:%d q:'%s' notx:%v}", step.Version, step.Query, step.DisableTx)) + }, + AfterStep: func(ctx context.Context, step Step) { + mm.log = append(mm.log, "after", fmt.Sprintf("{v:%d q:'%s' notx:%v}", step.Version, step.Query, step.DisableTx)) + }, + } + + failIfErr(t, Run(context.Background(), cfg)) + mustEqual(t, mm.log, wantLog) +} + func TestLockless(t *testing.T) { wantLog := []string{ "init",