diff --git a/updater/group.go b/updater/group.go index c6fae68..d874d39 100644 --- a/updater/group.go +++ b/updater/group.go @@ -19,8 +19,10 @@ type Group struct { // Parameters that apply to members: // Range is a comma separated list of allowed semver ranges - Range string `yaml:"range"` - CoolDown string `yaml:"cooldown"` + Range string `yaml:"range"` + CoolDown string `yaml:"cooldown"` + PreScript string `yaml:"pre-script"` + PostScript string `yaml:"post-script"` compiledPattern *regexp.Regexp } @@ -48,7 +50,7 @@ func (g *Group) Validate() error { return nil } -func (g Group) InRange(v string) bool { +func (g *Group) InRange(v string) bool { for _, rangeCond := range strings.Split(g.Range, ",") { rangeCond = strings.TrimSpace(rangeCond) switch { @@ -90,7 +92,7 @@ const ( oneDay = 24 * time.Hour ) -func (g Group) CoolDownDuration() time.Duration { +func (g *Group) CoolDownDuration() time.Duration { m := durPattern.FindStringSubmatch(g.CoolDown) var ret time.Duration diff --git a/updater/group_test.go b/updater/group_test.go index c5dbdb2..9c97583 100644 --- a/updater/group_test.go +++ b/updater/group_test.go @@ -65,14 +65,15 @@ func TestGroup_InRange(t *testing.T) { for r, tc := range cases { t.Run(r, func(t *testing.T) { + u := &updater.Group{Range: r} for _, v := range tc.included { t.Run(fmt.Sprintf("includes %s", v), func(t *testing.T) { - assert.True(t, updater.Group{Range: r}.InRange(v)) + assert.True(t, u.InRange(v)) }) } for _, v := range tc.excluded { t.Run(fmt.Sprintf("excludes %q", v), func(t *testing.T) { - assert.False(t, updater.Group{Range: r}.InRange(v)) + assert.False(t, u.InRange(v)) }) } }) diff --git a/updater/updater.go b/updater/updater.go index 42b4202..6211158 100644 --- a/updater/updater.go +++ b/updater/updater.go @@ -3,6 +3,8 @@ package updater import ( "context" "fmt" + "os" + "os/exec" "github.com/sirupsen/logrus" ) @@ -214,14 +216,39 @@ func (u *RepoUpdater) groupedUpdate(ctx context.Context, log logrus.FieldLogger, return 0, fmt.Errorf("switching to target branch: %w", err) } + if err := u.updateScript(ctx, "pre", group.PreScript); err != nil { + return 0, fmt.Errorf("executing pre-update script: %w", err) + } + for _, update := range updates { if err := u.updater.ApplyUpdate(ctx, update); err != nil { return 0, fmt.Errorf("applying batched update: %w", err) } } + if err := u.updateScript(ctx, "post", group.PostScript); err != nil { + return 0, fmt.Errorf("executing pre-update script: %w", err) + } + if err := u.repo.Push(ctx, updates...); err != nil { return 0, fmt.Errorf("pushing update: %w", err) } return len(updates), nil } + +func (u *RepoUpdater) updateScript(ctx context.Context, label, script string) error { + if script == "" { + return nil + } + cmd := exec.CommandContext(ctx, "/bin/sh", "-c", script) + cmd.Dir = u.repo.Root() + out := os.Stdout + _, _ = fmt.Fprintf(out, "--- start %s update script ---\n", label) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Run(); err != nil { + return err + } + _, _ = fmt.Fprintf(out, "--- end %s update script ---\n", label) + return nil +} diff --git a/updater/updater_test.go b/updater/updater_test.go index 4a3461e..bed985c 100644 --- a/updater/updater_test.go +++ b/updater/updater_test.go @@ -3,6 +3,8 @@ package updater_test import ( "context" "fmt" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/mock" @@ -120,3 +122,47 @@ func TestRepoUpdater_UpdateAll_MultipleGrouped(t *testing.T) { r.AssertExpectations(t) u.AssertExpectations(t) } + +func TestRepoUpdater_UpdateAll_Scripts(t *testing.T) { + cases := []*updater.Group{ + { + Name: groupName, + Pattern: "github.com/foo", + PreScript: `echo "sup" && touch token`, + }, + { + Name: groupName, + Pattern: "github.com/foo", + PostScript: `echo "sup" && touch token`, + }, + } + + for _, group := range cases { + err := group.Validate() + require.NoError(t, err) + + tmpDir := t.TempDir() + tokenPath := filepath.Join(tmpDir, "token") + r := &mockRepo{} + u := &mockUpdater{} + ru := updater.NewRepoUpdater(r, u, updater.WithGroups(group)) + ctx := context.Background() + + r.On("SetBranch", baseBranch).Return(nil) + dep := updater.Dependency{Path: mockUpdate.Path, Version: mockUpdate.Previous} + u.On("Dependencies", ctx).Return([]updater.Dependency{dep}, nil) + availableUpdate := mockUpdate // avoid pointer to shared reference + u.On("Check", ctx, dep, mock.Anything).Return(&availableUpdate, nil) + r.On("NewBranch", baseBranch, "action-update-go/main/foo").Times(1).Return(nil) + u.On("ApplyUpdate", ctx, mock.Anything).Times(1).Return(nil) + r.On("Push", ctx, mock.Anything, mock.Anything).Times(1).Return(nil) + r.On("Root").Return(tmpDir) + + err = ru.UpdateAll(ctx, baseBranch) + require.NoError(t, err) + r.AssertExpectations(t) + u.AssertExpectations(t) + _, err = os.Stat(tokenPath) + require.NoError(t, err) + } +}