Skip to content

Commit

Permalink
fix: MemoryQueue
Browse files Browse the repository at this point in the history
  • Loading branch information
withchao committed Jul 22, 2024
1 parent 828da30 commit d758e4b
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 39 deletions.
111 changes: 89 additions & 22 deletions mq/memamq/queue.go
Original file line number Diff line number Diff line change
@@ -1,72 +1,139 @@
package memamq

import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
)

var (
ErrStop = errors.New("push failed: queue is stopped")
ErrFull = errors.New("push failed: queue is full")
)

// AsyncQueue is the interface responsible for asynchronous processing of functions.
type AsyncQueue interface {
Initialize(processFunc func(), workerCount int, bufferSize int)
Push(task func()) error
}
//type AsyncQueue interface {
// Initialize(processFunc func(), workerCount int, bufferSize int)
// Push(task func()) error
//}

// MemoryQueue is an implementation of the AsyncQueue interface using a channel to process functions.
type MemoryQueue struct {
taskChan chan func()
wg sync.WaitGroup
isStopped bool
stopMutex sync.Mutex // Mutex to protect access to isStopped
isStopped atomic.Bool
count atomic.Int64
//stopMutex sync.Mutex // Mutex to protect access to isStopped
}

func NewMemoryQueue(workerCount int, bufferSize int) *MemoryQueue {
if workerCount < 1 || bufferSize < 1 {
panic("workerCount and bufferSize must be greater than 0")
}
mq := &MemoryQueue{} // Create a new instance of MemoryQueue
mq.Initialize(workerCount, bufferSize) // Initialize it with specified parameters
mq.initialize(workerCount, bufferSize) // Initialize it with specified parameters
return mq
}

// Initialize sets up the worker nodes and the buffer size of the channel,
// starting internal goroutines to handle tasks from the channel.
func (mq *MemoryQueue) Initialize(workerCount int, bufferSize int) {
func (mq *MemoryQueue) initialize(workerCount int, bufferSize int) {
mq.taskChan = make(chan func(), bufferSize) // Initialize the channel with the provided buffer size.
mq.isStopped = false

// Start multiple goroutines based on the specified workerCount.
for i := 0; i < workerCount; i++ {
mq.wg.Add(1)
go func(workerID int) {
go func() {
defer mq.wg.Done()
for task := range mq.taskChan {
task() // Execute the function
}
}(i)
}()
}
}

// Push submits a function to the queue.
// Returns an error if the queue is stopped or if the queue is full.
func (mq *MemoryQueue) Push(task func()) error {
mq.stopMutex.Lock()
if mq.isStopped {
mq.stopMutex.Unlock()
return errors.New("push failed: queue is stopped")
mq.count.Add(1)
defer mq.count.Add(-1)
if mq.isStopped.Load() {
return ErrStop
}
timer := time.NewTimer(time.Millisecond * 100)
defer timer.Stop()
select {
case mq.taskChan <- task:
return nil
case <-timer.C: // Timeout to prevent deadlock/blocking
return ErrFull
}
mq.stopMutex.Unlock()
}

func (mq *MemoryQueue) PushCtx(ctx context.Context, task func()) error {
mq.count.Add(1)
defer mq.count.Add(-1)
if mq.isStopped.Load() {
return ErrStop
}
select {
case mq.taskChan <- task:
return nil
case <-time.After(time.Millisecond * 100): // Timeout to prevent deadlock/blocking
return errors.New("push failed: queue is full")
case <-ctx.Done():
return context.Cause(ctx)
}
}

func (mq *MemoryQueue) BatchPushCtx(ctx context.Context, tasks ...func()) (int, error) {
mq.count.Add(1)
defer mq.count.Add(-1)
if mq.isStopped.Load() {
return 0, ErrStop
}
for i := range tasks {
select {
case <-ctx.Done():
return i, context.Cause(ctx)
case mq.taskChan <- tasks[i]:
}
}
return len(tasks), nil
}

func (mq *MemoryQueue) NotWaitPush(task func()) error {
mq.count.Add(1)
defer mq.count.Add(-1)
if mq.isStopped.Load() {
return ErrStop
}
select {
case mq.taskChan <- task:
return nil
default:
return ErrFull
}
}

// Stop is used to terminate the internal goroutines and close the channel.
func (mq *MemoryQueue) Stop() {
mq.stopMutex.Lock()
mq.isStopped = true
if !mq.isStopped.CompareAndSwap(false, true) {
return
}
mq.waitSafeClose()
close(mq.taskChan)
mq.stopMutex.Unlock()
mq.wg.Wait()
}

func (mq *MemoryQueue) waitSafeClose() {
if mq.count.Load() == 0 {
return
}
ticker := time.NewTicker(time.Second / 10)
defer ticker.Stop()
for range ticker.C {
if mq.count.Load() == 0 {
return
}
}
}
55 changes: 38 additions & 17 deletions mq/memamq/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,28 @@ package memamq

import (
"sync"
"sync/atomic"
"testing"
"time"
)

func TestNewMemoryQueue(t *testing.T) {
workerCount := 3
bufferSize := 10
queue := NewMemoryQueue(workerCount, bufferSize)

if cap(queue.taskChan) != bufferSize {
t.Errorf("Expected buffer size %d, got %d", bufferSize, cap(queue.taskChan))
}

if queue.isStopped {
t.Errorf("New queue is prematurely stopped")
}

if len(queue.taskChan) != 0 {
t.Errorf("New queue should be empty, found %d items", len(queue.taskChan))
}
}
//func TestNewMemoryQueue(t *testing.T) {
// workerCount := 3
// bufferSize := 10
// queue := NewMemoryQueue(workerCount, bufferSize)
//
// if cap(queue.taskChan) != bufferSize {
// t.Errorf("Expected buffer size %d, got %d", bufferSize, cap(queue.taskChan))
// }
//
// if queue.isStopped {
// t.Errorf("New queue is prematurely stopped")
// }
//
// if len(queue.taskChan) != 0 {
// t.Errorf("New queue should be empty, found %d items", len(queue.taskChan))
// }
//}

func TestPushAndStop(t *testing.T) {
queue := NewMemoryQueue(1, 5)
Expand Down Expand Up @@ -59,3 +60,23 @@ func TestPushTimeout(t *testing.T) {
t.Error("Expected timeout error, got nil")
}
}

func TestName(t *testing.T) {
queue := NewMemoryQueue(16, 1024)
var count atomic.Int64
for i := 0; i < 128; i++ {
go func() {
for {
queue.Push(func() {
count.Add(1)
})
}
}()
}

<-time.After(time.Second * 2)
t.Log("stop 1", time.Now())
queue.Stop()
t.Log("stop 2", time.Now())
t.Log(count.Load(), time.Now())
}

0 comments on commit d758e4b

Please sign in to comment.