From 35b1eff15249ea39b669da28a8f35545ac4bdb77 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Wed, 11 Dec 2024 18:02:00 -0800 Subject: [PATCH] :rotating_light: fix race conditions Signed-off-by: Salim Afiune Maya --- internal/workerpool/collector.go | 27 +++++++++++++++++++--- internal/workerpool/pool.go | 39 ++++++++++++++------------------ internal/workerpool/worker.go | 2 +- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go index 4c5257afd..2d105501b 100644 --- a/internal/workerpool/collector.go +++ b/internal/workerpool/collector.go @@ -3,9 +3,15 @@ package workerpool +import ( + "sync" + "sync/atomic" +) + type collector[R any] struct { resultsCh <-chan R results []R + read sync.Mutex errorsCh <-chan error errors []error @@ -13,22 +19,37 @@ type collector[R any] struct { requestsRead int64 } -func (c *collector[R]) Start() { +func (c *collector[R]) start() { go func() { for { select { case result := <-c.resultsCh: + c.read.Lock() c.results = append(c.results, result) + c.read.Unlock() case err := <-c.errorsCh: + c.read.Lock() c.errors = append(c.errors, err) + c.read.Unlock() } - c.requestsRead++ + atomic.AddInt64(&c.requestsRead, 1) } }() } +func (c *collector[R]) GetResults() []R { + c.read.Lock() + defer c.read.Unlock() + return c.results +} + +func (c *collector[R]) GetErrors() []error { + c.read.Lock() + defer c.read.Unlock() + return c.errors +} func (c *collector[R]) RequestsRead() int64 { - return c.requestsRead + return atomic.LoadInt64(&c.requestsRead) } diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index d407543cf..8553ca25d 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -4,6 +4,7 @@ package workerpool import ( + "sync" "sync/atomic" "time" @@ -14,10 +15,12 @@ type Task[R any] func() (result R, err error) // Pool is a generic pool of workers. type Pool[R any] struct { - queueCh chan Task[R] - resultsCh chan R - errorsCh chan error + queueCh chan Task[R] + resultsCh chan R + errorsCh chan error + requestsSent int64 + once sync.Once workers []*worker[R] workerCount int @@ -51,13 +54,15 @@ func New[R any](count int) *Pool[R] { // pool.Start() // defer pool.Close() func (p *Pool[R]) Start() { - for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} - w.Start() - p.workers = append(p.workers, &w) - } + p.once.Do(func() { + for i := 0; i < p.workerCount; i++ { + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} + w.start() + p.workers = append(p.workers, &w) + } - p.collector.Start() + p.collector.start() + }) } // Submit sends a task to the workers @@ -68,14 +73,14 @@ func (p *Pool[R]) Submit(t Task[R]) { // GetErrors returns any error from a processed task func (p *Pool[R]) GetErrors() error { - return errors.Join(p.collector.errors...) + return errors.Join(p.collector.GetErrors()...) } // GetResults returns the tasks results. // // It is recommended to call `Wait()` before reading the results. func (p *Pool[R]) GetResults() []R { - return p.collector.results + return p.collector.GetResults() } // Close waits for workers and collector to process all the requests, and then closes @@ -98,20 +103,10 @@ func (p *Pool[R]) Wait() { // PendingRequests returns the number of pending requests. func (p *Pool[R]) PendingRequests() int64 { - return p.requestsSent - p.collector.RequestsRead() + return atomic.LoadInt64(&p.requestsSent) - p.collector.RequestsRead() } // Processing return true if tasks are being processed. func (p *Pool[R]) Processing() bool { - if !p.empty() { - return false - } - return p.PendingRequests() != 0 } - -func (p *Pool[R]) empty() bool { - return len(p.queueCh) == 0 && - len(p.resultsCh) == 0 && - len(p.errorsCh) == 0 -} diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 19b21de1e..77b5c81f1 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -10,7 +10,7 @@ type worker[R any] struct { errorsCh chan<- error } -func (w *worker[R]) Start() { +func (w *worker[R]) start() { go func() { for task := range w.queueCh { if task == nil {