Skip to content

Commit

Permalink
Add Before/After step
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jun 30, 2022
1 parent d757836 commit d727315
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
18 changes: 18 additions & 0 deletions dbump.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ type Config struct {
// Going up does apply-revert-apply of each migration.
// Going down does revert-apply-revert of each migration.
ZigZag bool

// BeforeStep function will be invoked right before the DoStep for each step.
// Default is nil and means no-op.
BeforeStep func(ctx context.Context, step Step)
// AfterStep function will be invoked right after the DoStep for each step.
// Default is nil and means no-op.
AfterStep func(ctx context.Context, step Step)
}

// Migrator represents database over which we will run migrations.
Expand Down Expand Up @@ -120,6 +127,13 @@ func Run(ctx context.Context, config Config) error {
return fmt.Errorf("incorrect mode provided: %d", config.Mode)
}

if config.BeforeStep == nil {
config.BeforeStep = func(ctx context.Context, step Step) {}
}
if config.AfterStep == nil {
config.AfterStep = func(ctx context.Context, step Step) {}
}

m := mig{
Config: config,
Migrator: config.Migrator,
Expand Down Expand Up @@ -203,9 +217,13 @@ func (m *mig) runMigrationsLocked(ctx context.Context, ms []*Migration) error {
}

for _, step := range m.prepareSteps(curr, target, ms) {
m.BeforeStep(ctx, step)

if err := m.DoStep(ctx, step); err != nil {
return err
}

m.AfterStep(ctx, step)
}
return nil
}
Expand Down
35 changes: 35 additions & 0 deletions dbump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dbump
import (
"context"
"errors"
"fmt"
"reflect"
"testing"
)
Expand Down Expand Up @@ -230,6 +231,40 @@ func TestMigrateDrop(t *testing.T) {
mustEqual(t, mm.log, wantLog)
}

func TestBeforeAfterStep(t *testing.T) {
currVersion := 3
wantLog := []string{
"lockdb", "init", "getversion",
"before", "{v:4 q:'SELECT 4;' notx:false}",
"dostep", "{v:4 q:'SELECT 4;' notx:false}",
"after", "{v:4 q:'SELECT 4;' notx:false}",
"before", "{v:5 q:'SELECT 5;' notx:false}",
"dostep", "{v:5 q:'SELECT 5;' notx:false}",
"after", "{v:5 q:'SELECT 5;' notx:false}",
"unlockdb",
}

mm := &MockMigrator{
VersionFn: func(ctx context.Context) (version int, err error) {
return currVersion, nil
},
}
cfg := Config{
Migrator: mm,
Loader: NewSliceLoader(testdataMigrations),
Mode: ModeUp,
BeforeStep: func(ctx context.Context, step Step) {
mm.log = append(mm.log, "before", fmt.Sprintf("{v:%d q:'%s' notx:%v}", step.Version, step.Query, step.DisableTx))
},
AfterStep: func(ctx context.Context, step Step) {
mm.log = append(mm.log, "after", fmt.Sprintf("{v:%d q:'%s' notx:%v}", step.Version, step.Query, step.DisableTx))
},
}

failIfErr(t, Run(context.Background(), cfg))
mustEqual(t, mm.log, wantLog)
}

func TestLockless(t *testing.T) {
wantLog := []string{
"init",
Expand Down

0 comments on commit d727315

Please sign in to comment.