diff --git a/background.go b/background.go index 1af3617..6452949 100644 --- a/background.go +++ b/background.go @@ -2,16 +2,32 @@ package background import ( "context" - "sync" - "sync/atomic" + "errors" "time" "github.com/kamilsk/retry/v5" "github.com/kamilsk/retry/v5/strategy" ) -// Manager keeps track of scheduled goroutines and provides mechanisms to wait for them to finish. `Meta` is whatever -// you wish to associate with this task, usually something that will help you keep track of the tasks. +// TaskType determines how the task will be executed by the manager. +type TaskType int + +const ( + // TaskTypeOneOff is the default task type. It will be executed only once. + TaskTypeOneOff TaskType = iota + // TaskTypeLoop will be executed in an infinite loop until the manager's Cancel() method is called. The task will + // restart immediately after the previous iteration returns. + TaskTypeLoop +) + +var ( + // ErrUnknownTaskType is returned when the task type is not a valid value of TaskType. + ErrUnknownTaskType = errors.New("unknown task type") +) + +// Manager keeps track of scheduled goroutines and provides mechanisms to wait for them to finish or cancel their +// execution. `Meta` is whatever you wish to associate with this task, usually something that will help you keep track +// of the tasks in the observer. // // This is useful in context of HTTP servers, where a customer request may result in some kind of background processing // activity that should not block the response and you schedule a goroutine to handle it. However, if your server @@ -20,11 +36,11 @@ import ( // package to schedule the queue jobs without the customer waiting for that to happen while at the same time being able // to wait for all those goroutines to finish before allowing the process to exit. type Manager struct { - wg sync.WaitGroup - len atomic.Int32 stalledThreshold time.Duration observer Observer retry Retry + taskmgr taskmgr + loopmgr loopmgr } // Options provides a means for configuring the background manager and attaching hooks to it. @@ -44,6 +60,8 @@ type Options struct { type Task struct { // Fn is the function to be executed in a goroutine. Fn Fn + // Type is the type of the task. It determines how the task will be executed by the manager. Default is TaskTypeOneOff. + Type TaskType // Meta is whatever custom information you wish to associate with the task. This will be passed to the observer's // functions. Meta Metadata @@ -66,7 +84,7 @@ type Metadata map[string]string // NewManager creates a new instance of Manager with default options and no observer. func NewManager() *Manager { - return &Manager{} + return NewManagerWithOptions(Options{}) } // NewManagerWithOptions creates a new instance of Manager with the provided options and observer. @@ -75,36 +93,60 @@ func NewManagerWithOptions(options Options) *Manager { stalledThreshold: options.StalledThreshold, observer: options.Observer, retry: options.Retry, + loopmgr: mkloopmgr(), } } -// Run schedules the provided function to be executed in a goroutine. +// Run schedules the provided function to be executed once in a goroutine. func (m *Manager) Run(ctx context.Context, fn Fn) { task := Task{Fn: fn} m.RunTask(ctx, task) } -// RunTask schedules the provided task to be executed in a goroutine. +// RunTask schedules the provided task to be executed in a goroutine. The task will be executed according to its type. +// By default, the task will be executed only once (TaskTypeOneOff). func (m *Manager) RunTask(ctx context.Context, task Task) { - m.observer.callOnTaskAdded(ctx, task) - m.wg.Add(1) - m.len.Add(1) - ctx = context.WithoutCancel(ctx) done := make(chan error, 1) + m.observer.callOnTaskAdded(ctx, task) + + switch task.Type { + case TaskTypeOneOff: + m.taskmgr.start() + go m.observe(ctx, task, done) + go m.run(ctx, task, done) - go m.monitor(ctx, task, done) - go m.run(ctx, task, done) + case TaskTypeLoop: + m.loopmgr.start() + go m.loop(ctx, task, done) + + default: + m.observer.callOnTaskFailed(ctx, task, ErrUnknownTaskType) + } } -// Wait blocks until all scheduled tasks have finished. +// Wait blocks until all scheduled one-off tasks have finished. Adding more one-off tasks will prolong the wait time. func (m *Manager) Wait() { - m.wg.Wait() + m.taskmgr.group.Wait() } -// Len returns the number of currently running tasks. -func (m *Manager) Len() int32 { - return m.len.Load() +// Cancel blocks until all loop tasks finish their current loop and stops looping further. The tasks' context is not +// cancelled. Adding a new loop task after calling Cancel() will cause the task to be ignored and not run. +func (m *Manager) Cancel() { + m.loopmgr.cancel() +} + +// CountOf returns the number of tasks of the specified type that are currently running. When the TaskType is invalid it +// returns 0. +func (m *Manager) CountOf(t TaskType) int { + switch t { + case TaskTypeOneOff: + return int(m.taskmgr.count.Load()) + case TaskTypeLoop: + return int(m.loopmgr.count.Load()) + default: + return 0 + } } func (m *Manager) run(ctx context.Context, task Task, done chan<- error) { @@ -112,8 +154,25 @@ func (m *Manager) run(ctx context.Context, task Task, done chan<- error) { done <- retry.Do(ctx, task.Fn, strategies...) } -func (m *Manager) monitor(ctx context.Context, task Task, done <-chan error) { +func (m *Manager) loop(ctx context.Context, task Task, done chan error) { + defer m.loopmgr.finish() + + for { + if m.loopmgr.ctx.Err() != nil { + return + } + + m.run(ctx, task, done) + err := <-done + if err != nil { + m.observer.callOnTaskFailed(ctx, task, err) + } + } +} + +func (m *Manager) observe(ctx context.Context, task Task, done <-chan error) { timeout := mktimeout(m.stalledThreshold) + defer m.taskmgr.finish() for { select { @@ -126,8 +185,6 @@ func (m *Manager) monitor(ctx context.Context, task Task, done <-chan error) { m.observer.callOnTaskSucceeded(ctx, task) } - m.wg.Done() - m.len.Add(-1) return } } diff --git a/background_test.go b/background_test.go index 236e9ff..6b9c256 100644 --- a/background_test.go +++ b/background_test.go @@ -16,10 +16,11 @@ func Test_NewManager(t *testing.T) { m := background.NewManager() assert.NotNil(t, m) assert.IsType(t, &background.Manager{}, m) - assert.EqualValues(t, 0, m.Len()) + assert.EqualValues(t, 0, m.CountOf(background.TaskTypeOneOff)) + assert.EqualValues(t, 0, m.CountOf(background.TaskTypeLoop)) } -func Test_RunExecutesInGoroutine(t *testing.T) { +func Test_RunTaskExecutesInGoroutine(t *testing.T) { m := background.NewManager() proceed := make(chan bool, 1) @@ -61,7 +62,7 @@ func Test_WaitWaitsForPendingTasks(t *testing.T) { assert.True(t, waited) } -func Test_CancelledParentContext(t *testing.T) { +func Test_RunTaskCancelledParentContext(t *testing.T) { m := background.NewManager() ctx, cancel := context.WithCancel(context.Background()) proceed := make(chan bool, 1) @@ -77,30 +78,6 @@ func Test_CancelledParentContext(t *testing.T) { m.Wait() } -func Test_Len(t *testing.T) { - proceed := make(chan bool, 1) - remaining := 10 - m := background.NewManagerWithOptions(background.Options{ - Observer: background.Observer{ - OnTaskSucceeded: func(ctx context.Context, task background.Task) { - remaining-- - proceed <- true - }, - }, - }) - - for range 10 { - m.Run(context.Background(), func(ctx context.Context) error { - <-proceed - return nil - }) - } - - proceed <- true - m.Wait() - assert.EqualValues(t, 0, m.Len()) -} - func Test_OnTaskAdded(t *testing.T) { metadata := background.Metadata{"test": "value"} executed := false @@ -183,7 +160,7 @@ func Test_OnTaskFailed(t *testing.T) { assert.True(t, executed) } -func Test_OnGoroutineStalled(t *testing.T) { +func Test_OnTaskStalled(t *testing.T) { tests := []struct { duration time.Duration shouldExecute bool @@ -232,7 +209,7 @@ func Test_OnGoroutineStalled(t *testing.T) { } } -func Test_StalledGoroutineStillCallsOnTaskSucceeded(t *testing.T) { +func Test_StalledTaskStillCallsOnTaskSucceeded(t *testing.T) { executed := false var wg sync.WaitGroup m := background.NewManagerWithOptions(background.Options{ @@ -255,7 +232,7 @@ func Test_StalledGoroutineStillCallsOnTaskSucceeded(t *testing.T) { assert.True(t, executed) } -func Test_TaskDefinitionRetryStrategies(t *testing.T) { +func Test_TaskRetryStrategies(t *testing.T) { var limit uint = 5 var count uint = 0 m := background.NewManager() @@ -275,7 +252,7 @@ func Test_TaskDefinitionRetryStrategies(t *testing.T) { assert.Equal(t, limit, count) } -func Test_ManagerDefaultRetryStrategies(t *testing.T) { +func Test_ManagerRetryStrategies(t *testing.T) { var limit uint = 5 var count uint = 0 m := background.NewManagerWithOptions(background.Options{ @@ -292,3 +269,123 @@ func Test_ManagerDefaultRetryStrategies(t *testing.T) { assert.Equal(t, limit, count) } + +func Test_RunTaskTypeLoop(t *testing.T) { + loops := 0 + m := background.NewManager() + def := background.Task{ + Type: background.TaskTypeLoop, + Fn: func(ctx context.Context) error { + loops++ + return nil + }, + } + + m.RunTask(context.Background(), def) + <-time.After(time.Microsecond * 500) + + m.Cancel() + assert.GreaterOrEqual(t, loops, 100) +} + +func Test_RunTaskTypeLoop_RetryStrategies(t *testing.T) { + done := make(chan error, 1) + count := 0 + + m := background.NewManagerWithOptions(background.Options{ + Observer: background.Observer{ + OnTaskFailed: func(ctx context.Context, task background.Task, err error) { + done <- err + }, + }, + }) + def := background.Task{ + Type: background.TaskTypeLoop, + Fn: func(ctx context.Context) error { + count++ + // TODO: Figure out why we need to wait here to avoid test timeout + <-time.After(time.Millisecond) + return assert.AnError + }, + Retry: background.Retry{ + strategy.Limit(2), + }, + } + + m.RunTask(context.Background(), def) + err := <-done + m.Cancel() + + assert.Equal(t, assert.AnError, err) + // We cannot guarantee exact count of executions because by the time we cancel the task the loop might have made + // several additional iterations. + assert.GreaterOrEqual(t, count, 2) +} + +func Test_RunTaskTypeLoop_CancelledParentContext(t *testing.T) { + m := background.NewManager() + cancellable, cancel := context.WithCancel(context.Background()) + proceed := make(chan bool, 1) + done := make(chan error, 1) + var once sync.Once + + def := background.Task{ + Type: background.TaskTypeLoop, + Fn: func(ctx context.Context) error { + once.Do(func() { + proceed <- true + // Cancel the parent context and send the child context's error out to the test + // The expectation is that the child context will not be cancelled + cancel() + done <- ctx.Err() + }) + + return nil + }, + } + + m.RunTask(cancellable, def) + // Make sure we wait for the loop to run at least one iteration before cancelling it + <-proceed + m.Cancel() + err := <-done + + assert.Equal(t, nil, err) +} + +func Test_CountOf(t *testing.T) { + m := background.NewManager() + + assert.Equal(t, 0, m.CountOf(background.TaskTypeOneOff)) + assert.Equal(t, 0, m.CountOf(background.TaskTypeLoop)) + assert.Equal(t, 0, m.CountOf(background.TaskType(3))) + + def := background.Task{ + Type: background.TaskTypeOneOff, + Fn: func(ctx context.Context) error { + return nil + }, + } + m.RunTask(context.Background(), def) + assert.Equal(t, 1, m.CountOf(background.TaskTypeOneOff)) + assert.Equal(t, 0, m.CountOf(background.TaskTypeLoop)) + assert.Equal(t, 0, m.CountOf(background.TaskType(3))) + m.Wait() + + def = background.Task{ + Type: background.TaskTypeLoop, + Fn: func(ctx context.Context) error { + return nil + }, + } + + m.RunTask(context.Background(), def) + assert.Equal(t, 0, m.CountOf(background.TaskTypeOneOff)) + assert.Equal(t, 1, m.CountOf(background.TaskTypeLoop)) + assert.Equal(t, 0, m.CountOf(background.TaskType(3))) + m.Cancel() + + assert.Equal(t, 0, m.CountOf(background.TaskTypeOneOff)) + assert.Equal(t, 0, m.CountOf(background.TaskTypeLoop)) + assert.Equal(t, 0, m.CountOf(background.TaskType(3))) +} diff --git a/internal.go b/internal.go index 0149162..852960f 100644 --- a/internal.go +++ b/internal.go @@ -1,11 +1,65 @@ package background import ( + "context" + "sync" + "sync/atomic" "time" "github.com/kamilsk/retry/v5/strategy" ) +// taskmgr is used internally for task tracking and synchronization. +type taskmgr struct { + group sync.WaitGroup + count atomic.Int32 +} + +// start tells the taskmgr that a new task has started. +func (m *taskmgr) start() { + m.group.Add(1) + m.count.Add(1) +} + +// finish tells the taskmgr that a task has finished. +func (m *taskmgr) finish() { + m.group.Done() + m.count.Add(-1) +} + +// loopmgr is used internally for loop tracking and synchronization and cancellation of the loops. +type loopmgr struct { + group sync.WaitGroup + count atomic.Int32 + ctx context.Context + cancelfn context.CancelFunc +} + +func mkloopmgr() loopmgr { + ctx, cancelfn := context.WithCancel(context.Background()) + return loopmgr{ + ctx: ctx, + cancelfn: cancelfn, + } +} + +// start tells the loopmgr that a new loop has started. +func (m *loopmgr) start() { + m.group.Add(1) + m.count.Add(1) +} + +// cancel tells the loopmgr that a loop has finished. +func (m *loopmgr) finish() { + m.group.Done() + m.count.Add(-1) +} + +func (m *loopmgr) cancel() { + m.cancelfn() + m.group.Wait() +} + // mktimeout returns a channel that will receive the current time after the specified duration. If the duration is 0, // the channel will never receive any message. func mktimeout(duration time.Duration) <-chan time.Time {