diff --git a/queue.go b/queue.go deleted file mode 100644 index d9dfa12..0000000 --- a/queue.go +++ /dev/null @@ -1,49 +0,0 @@ -package concurrency - -import "sync" - -type Queue struct { - sync.Mutex - data []interface{} -} - -func NewQueue() *Queue { - return &Queue{ - data: make([]interface{}, 0), - } -} - -func (c *Queue) Push(eles ...interface{}) { - c.Lock() - c.data = append(c.data, eles...) - c.Unlock() -} - -func (c *Queue) Front() interface{} { - c.Lock() - defer c.Unlock() - - if n := len(c.data); n == 0 { - return nil - } else { - var result = c.data[0] - c.data = c.data[1:] - return result - } -} - -func (c *Queue) Len() int { - c.Lock() - length := len(c.data) - c.Unlock() - return length -} - -// All 返回所有数据并清空队列 -func (c *Queue) All() []interface{} { - c.Lock() - data := c.data - c.data = make([]interface{}, 0) - c.Unlock() - return data -} diff --git a/worker_group.go b/worker_group.go index c9b6099..e19d701 100644 --- a/worker_group.go +++ b/worker_group.go @@ -3,7 +3,6 @@ package concurrency import ( "github.com/hashicorp/go-multierror" "sync" - "sync/atomic" ) type WorkerGroup struct { @@ -11,7 +10,7 @@ type WorkerGroup struct { err error config *Config done chan bool - q *Queue + q []Job taskDone int64 taskTotal int64 } @@ -25,24 +24,23 @@ func NewWorkerGroup(options ...Option) *WorkerGroup { o := &WorkerGroup{ mu: &sync.Mutex{}, config: config.init(), - q: NewQueue(), + q: make([]Job, 0), taskDone: 0, done: make(chan bool), } return o } -// Len 获取队列中剩余任务数量 -func (c *WorkerGroup) Len() int { - return c.q.Len() -} +func (c *WorkerGroup) getJob() interface{} { + c.mu.Lock() + defer c.mu.Unlock() -// AddJob 往任务队列中追加任务 -func (c *WorkerGroup) AddJob(jobs ...Job) { - atomic.AddInt64(&c.taskTotal, int64(len(jobs))) - for i, _ := range jobs { - c.q.Push(jobs[i]) + if n := len(c.q); n == 0 { + return nil } + var result = c.q[0] + c.q = c.q[1:] + return result } func (c *WorkerGroup) appendError(err error) { @@ -54,33 +52,55 @@ func (c *WorkerGroup) appendError(err error) { c.mu.Unlock() } -func (c *WorkerGroup) do() { - if atomic.LoadInt64(&c.taskDone) == atomic.LoadInt64(&c.taskTotal) { +func (c *WorkerGroup) incr(d int64) bool { + c.mu.Lock() + c.taskDone += d + ok := c.taskDone == c.taskTotal + c.mu.Unlock() + return ok +} + +func (c *WorkerGroup) do(job Job) { + if !isCanceled(c.config.Context) { + c.appendError(c.config.Caller(job)) + } + if c.incr(1) { c.done <- true return } - - if item := c.q.Front(); item != nil { - go func(job Job) { - if !isCanceled(c.config.Context) { - c.appendError(c.config.Caller(job)) - } - atomic.AddInt64(&c.taskDone, 1) - c.do() - }(item.(Job)) + if nextJob := c.getJob(); nextJob != nil { + c.do(nextJob.(Job)) } } +// Len 获取队列中剩余任务数量 +func (c *WorkerGroup) Len() int { + c.mu.Lock() + x := len(c.q) + c.mu.Unlock() + return x +} + +// AddJob 往任务队列中追加任务 +func (c *WorkerGroup) AddJob(jobs ...Job) { + c.mu.Lock() + c.taskTotal += int64(len(jobs)) + c.q = append(c.q, jobs...) + c.mu.Unlock() +} + // StartAndWait 启动并等待所有任务执行完成 func (c *WorkerGroup) StartAndWait() { - var taskTotal = atomic.LoadInt64(&c.taskTotal) + var taskTotal = int64(c.Len()) if taskTotal == 0 { return } var co = min(c.config.Concurrency, taskTotal) for i := int64(0); i < co; i++ { - c.do() + if item := c.getJob(); item != nil { + go c.do(item.(Job)) + } } <-c.done diff --git a/worker_queue.go b/worker_queue.go index 53209b1..468ac27 100644 --- a/worker_queue.go +++ b/worker_queue.go @@ -30,18 +30,6 @@ func NewWorkerQueue(options ...Option) *WorkerQueue { } } -// AddJob 追加任务, 有资源空闲的话会立即执行 -func (c *WorkerQueue) AddJob(jobs ...Job) { - c.mu.Lock() - c.q = append(c.q, jobs...) - c.mu.Unlock() - - var n = len(jobs) - for i := 0; i < n; i++ { - c.do() - } -} - func (c *WorkerQueue) getJob() interface{} { c.mu.Lock() defer c.mu.Unlock() @@ -65,15 +53,13 @@ func (c *WorkerQueue) incr(d int64) { c.mu.Unlock() } -func (c *WorkerQueue) do() { - if item := c.getJob(); item != nil { - go func(job Job) { - if !isCanceled(c.config.Context) { - c.callOnError(c.config.Caller(job)) - } - c.incr(-1) - c.do() - }(item.(Job)) +func (c *WorkerQueue) do(job Job) { + if !isCanceled(c.config.Context) { + c.callOnError(c.config.Caller(job)) + } + c.incr(-1) + if nextJob := c.getJob(); nextJob != nil { + c.do(nextJob.(Job)) } } @@ -93,6 +79,28 @@ func (c *WorkerQueue) getCurConcurrency() int64 { return x } +// Len 获取队列中剩余任务数量 +func (c *WorkerQueue) Len() int { + c.mu.Lock() + x := len(c.q) + c.mu.Unlock() + return x +} + +// AddJob 追加任务, 有资源空闲的话会立即执行 +func (c *WorkerQueue) AddJob(jobs ...Job) { + c.mu.Lock() + c.q = append(c.q, jobs...) + c.mu.Unlock() + + var n = len(jobs) + for i := 0; i < n; i++ { + if item := c.getJob(); item != nil { + go c.do(item.(Job)) + } + } +} + // Stop 优雅退出 // timeout 超时时间 func (c *WorkerQueue) StopAndWait(timeout time.Duration) {