From d758e4b8c9aa55e001c93fa525a86a5b9bf72b93 Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Mon, 22 Jul 2024 16:56:08 +0800 Subject: [PATCH] fix: MemoryQueue --- mq/memamq/queue.go | 111 ++++++++++++++++++++++++++++++++-------- mq/memamq/queue_test.go | 55 ++++++++++++++------ 2 files changed, 127 insertions(+), 39 deletions(-) diff --git a/mq/memamq/queue.go b/mq/memamq/queue.go index fa57c605..161ed3f0 100644 --- a/mq/memamq/queue.go +++ b/mq/memamq/queue.go @@ -1,72 +1,139 @@ package memamq import ( + "context" "errors" "sync" + "sync/atomic" "time" ) +var ( + ErrStop = errors.New("push failed: queue is stopped") + ErrFull = errors.New("push failed: queue is full") +) + // AsyncQueue is the interface responsible for asynchronous processing of functions. -type AsyncQueue interface { - Initialize(processFunc func(), workerCount int, bufferSize int) - Push(task func()) error -} +//type AsyncQueue interface { +// Initialize(processFunc func(), workerCount int, bufferSize int) +// Push(task func()) error +//} // MemoryQueue is an implementation of the AsyncQueue interface using a channel to process functions. type MemoryQueue struct { taskChan chan func() wg sync.WaitGroup - isStopped bool - stopMutex sync.Mutex // Mutex to protect access to isStopped + isStopped atomic.Bool + count atomic.Int64 + //stopMutex sync.Mutex // Mutex to protect access to isStopped } func NewMemoryQueue(workerCount int, bufferSize int) *MemoryQueue { + if workerCount < 1 || bufferSize < 1 { + panic("workerCount and bufferSize must be greater than 0") + } mq := &MemoryQueue{} // Create a new instance of MemoryQueue - mq.Initialize(workerCount, bufferSize) // Initialize it with specified parameters + mq.initialize(workerCount, bufferSize) // Initialize it with specified parameters return mq } // Initialize sets up the worker nodes and the buffer size of the channel, // starting internal goroutines to handle tasks from the channel. -func (mq *MemoryQueue) Initialize(workerCount int, bufferSize int) { +func (mq *MemoryQueue) initialize(workerCount int, bufferSize int) { mq.taskChan = make(chan func(), bufferSize) // Initialize the channel with the provided buffer size. - mq.isStopped = false - // Start multiple goroutines based on the specified workerCount. for i := 0; i < workerCount; i++ { mq.wg.Add(1) - go func(workerID int) { + go func() { defer mq.wg.Done() for task := range mq.taskChan { task() // Execute the function } - }(i) + }() } } // Push submits a function to the queue. // Returns an error if the queue is stopped or if the queue is full. func (mq *MemoryQueue) Push(task func()) error { - mq.stopMutex.Lock() - if mq.isStopped { - mq.stopMutex.Unlock() - return errors.New("push failed: queue is stopped") + mq.count.Add(1) + defer mq.count.Add(-1) + if mq.isStopped.Load() { + return ErrStop + } + timer := time.NewTimer(time.Millisecond * 100) + defer timer.Stop() + select { + case mq.taskChan <- task: + return nil + case <-timer.C: // Timeout to prevent deadlock/blocking + return ErrFull } - mq.stopMutex.Unlock() +} +func (mq *MemoryQueue) PushCtx(ctx context.Context, task func()) error { + mq.count.Add(1) + defer mq.count.Add(-1) + if mq.isStopped.Load() { + return ErrStop + } select { case mq.taskChan <- task: return nil - case <-time.After(time.Millisecond * 100): // Timeout to prevent deadlock/blocking - return errors.New("push failed: queue is full") + case <-ctx.Done(): + return context.Cause(ctx) + } +} + +func (mq *MemoryQueue) BatchPushCtx(ctx context.Context, tasks ...func()) (int, error) { + mq.count.Add(1) + defer mq.count.Add(-1) + if mq.isStopped.Load() { + return 0, ErrStop + } + for i := range tasks { + select { + case <-ctx.Done(): + return i, context.Cause(ctx) + case mq.taskChan <- tasks[i]: + } + } + return len(tasks), nil +} + +func (mq *MemoryQueue) NotWaitPush(task func()) error { + mq.count.Add(1) + defer mq.count.Add(-1) + if mq.isStopped.Load() { + return ErrStop + } + select { + case mq.taskChan <- task: + return nil + default: + return ErrFull } } // Stop is used to terminate the internal goroutines and close the channel. func (mq *MemoryQueue) Stop() { - mq.stopMutex.Lock() - mq.isStopped = true + if !mq.isStopped.CompareAndSwap(false, true) { + return + } + mq.waitSafeClose() close(mq.taskChan) - mq.stopMutex.Unlock() mq.wg.Wait() } + +func (mq *MemoryQueue) waitSafeClose() { + if mq.count.Load() == 0 { + return + } + ticker := time.NewTicker(time.Second / 10) + defer ticker.Stop() + for range ticker.C { + if mq.count.Load() == 0 { + return + } + } +} diff --git a/mq/memamq/queue_test.go b/mq/memamq/queue_test.go index 76526f00..9a0f3cff 100644 --- a/mq/memamq/queue_test.go +++ b/mq/memamq/queue_test.go @@ -2,27 +2,28 @@ package memamq import ( "sync" + "sync/atomic" "testing" "time" ) -func TestNewMemoryQueue(t *testing.T) { - workerCount := 3 - bufferSize := 10 - queue := NewMemoryQueue(workerCount, bufferSize) - - if cap(queue.taskChan) != bufferSize { - t.Errorf("Expected buffer size %d, got %d", bufferSize, cap(queue.taskChan)) - } - - if queue.isStopped { - t.Errorf("New queue is prematurely stopped") - } - - if len(queue.taskChan) != 0 { - t.Errorf("New queue should be empty, found %d items", len(queue.taskChan)) - } -} +//func TestNewMemoryQueue(t *testing.T) { +// workerCount := 3 +// bufferSize := 10 +// queue := NewMemoryQueue(workerCount, bufferSize) +// +// if cap(queue.taskChan) != bufferSize { +// t.Errorf("Expected buffer size %d, got %d", bufferSize, cap(queue.taskChan)) +// } +// +// if queue.isStopped { +// t.Errorf("New queue is prematurely stopped") +// } +// +// if len(queue.taskChan) != 0 { +// t.Errorf("New queue should be empty, found %d items", len(queue.taskChan)) +// } +//} func TestPushAndStop(t *testing.T) { queue := NewMemoryQueue(1, 5) @@ -59,3 +60,23 @@ func TestPushTimeout(t *testing.T) { t.Error("Expected timeout error, got nil") } } + +func TestName(t *testing.T) { + queue := NewMemoryQueue(16, 1024) + var count atomic.Int64 + for i := 0; i < 128; i++ { + go func() { + for { + queue.Push(func() { + count.Add(1) + }) + } + }() + } + + <-time.After(time.Second * 2) + t.Log("stop 1", time.Now()) + queue.Stop() + t.Log("stop 2", time.Now()) + t.Log(count.Load(), time.Now()) +}