-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
127 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,139 @@ | ||
package memamq | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
var ( | ||
ErrStop = errors.New("push failed: queue is stopped") | ||
ErrFull = errors.New("push failed: queue is full") | ||
) | ||
|
||
// AsyncQueue is the interface responsible for asynchronous processing of functions. | ||
type AsyncQueue interface { | ||
Initialize(processFunc func(), workerCount int, bufferSize int) | ||
Push(task func()) error | ||
} | ||
//type AsyncQueue interface { | ||
// Initialize(processFunc func(), workerCount int, bufferSize int) | ||
// Push(task func()) error | ||
//} | ||
|
||
// MemoryQueue is an implementation of the AsyncQueue interface using a channel to process functions. | ||
type MemoryQueue struct { | ||
taskChan chan func() | ||
wg sync.WaitGroup | ||
isStopped bool | ||
stopMutex sync.Mutex // Mutex to protect access to isStopped | ||
isStopped atomic.Bool | ||
count atomic.Int64 | ||
//stopMutex sync.Mutex // Mutex to protect access to isStopped | ||
} | ||
|
||
func NewMemoryQueue(workerCount int, bufferSize int) *MemoryQueue { | ||
if workerCount < 1 || bufferSize < 1 { | ||
panic("workerCount and bufferSize must be greater than 0") | ||
} | ||
mq := &MemoryQueue{} // Create a new instance of MemoryQueue | ||
mq.Initialize(workerCount, bufferSize) // Initialize it with specified parameters | ||
mq.initialize(workerCount, bufferSize) // Initialize it with specified parameters | ||
return mq | ||
} | ||
|
||
// Initialize sets up the worker nodes and the buffer size of the channel, | ||
// starting internal goroutines to handle tasks from the channel. | ||
func (mq *MemoryQueue) Initialize(workerCount int, bufferSize int) { | ||
func (mq *MemoryQueue) initialize(workerCount int, bufferSize int) { | ||
mq.taskChan = make(chan func(), bufferSize) // Initialize the channel with the provided buffer size. | ||
mq.isStopped = false | ||
|
||
// Start multiple goroutines based on the specified workerCount. | ||
for i := 0; i < workerCount; i++ { | ||
mq.wg.Add(1) | ||
go func(workerID int) { | ||
go func() { | ||
defer mq.wg.Done() | ||
for task := range mq.taskChan { | ||
task() // Execute the function | ||
} | ||
}(i) | ||
}() | ||
} | ||
} | ||
|
||
// Push submits a function to the queue. | ||
// Returns an error if the queue is stopped or if the queue is full. | ||
func (mq *MemoryQueue) Push(task func()) error { | ||
mq.stopMutex.Lock() | ||
if mq.isStopped { | ||
mq.stopMutex.Unlock() | ||
return errors.New("push failed: queue is stopped") | ||
mq.count.Add(1) | ||
defer mq.count.Add(-1) | ||
if mq.isStopped.Load() { | ||
return ErrStop | ||
} | ||
timer := time.NewTimer(time.Millisecond * 100) | ||
defer timer.Stop() | ||
select { | ||
case mq.taskChan <- task: | ||
return nil | ||
case <-timer.C: // Timeout to prevent deadlock/blocking | ||
return ErrFull | ||
} | ||
mq.stopMutex.Unlock() | ||
} | ||
|
||
func (mq *MemoryQueue) PushCtx(ctx context.Context, task func()) error { | ||
mq.count.Add(1) | ||
defer mq.count.Add(-1) | ||
if mq.isStopped.Load() { | ||
return ErrStop | ||
} | ||
select { | ||
case mq.taskChan <- task: | ||
return nil | ||
case <-time.After(time.Millisecond * 100): // Timeout to prevent deadlock/blocking | ||
return errors.New("push failed: queue is full") | ||
case <-ctx.Done(): | ||
return context.Cause(ctx) | ||
} | ||
} | ||
|
||
func (mq *MemoryQueue) BatchPushCtx(ctx context.Context, tasks ...func()) (int, error) { | ||
mq.count.Add(1) | ||
defer mq.count.Add(-1) | ||
if mq.isStopped.Load() { | ||
return 0, ErrStop | ||
} | ||
for i := range tasks { | ||
select { | ||
case <-ctx.Done(): | ||
return i, context.Cause(ctx) | ||
case mq.taskChan <- tasks[i]: | ||
} | ||
} | ||
return len(tasks), nil | ||
} | ||
|
||
func (mq *MemoryQueue) NotWaitPush(task func()) error { | ||
mq.count.Add(1) | ||
defer mq.count.Add(-1) | ||
if mq.isStopped.Load() { | ||
return ErrStop | ||
} | ||
select { | ||
case mq.taskChan <- task: | ||
return nil | ||
default: | ||
return ErrFull | ||
} | ||
} | ||
|
||
// Stop is used to terminate the internal goroutines and close the channel. | ||
func (mq *MemoryQueue) Stop() { | ||
mq.stopMutex.Lock() | ||
mq.isStopped = true | ||
if !mq.isStopped.CompareAndSwap(false, true) { | ||
return | ||
} | ||
mq.waitSafeClose() | ||
close(mq.taskChan) | ||
mq.stopMutex.Unlock() | ||
mq.wg.Wait() | ||
} | ||
|
||
func (mq *MemoryQueue) waitSafeClose() { | ||
if mq.count.Load() == 0 { | ||
return | ||
} | ||
ticker := time.NewTicker(time.Second / 10) | ||
defer ticker.Stop() | ||
for range ticker.C { | ||
if mq.count.Load() == 0 { | ||
return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters