diff --git a/pool/node.go b/pool/node.go index 84f24cf..8d2a4d2 100644 --- a/pool/node.go +++ b/pool/node.go @@ -338,7 +338,7 @@ func (node *Node) Shutdown(ctx context.Context) error { } if node.clientOnly { node.lock.Unlock() - return fmt.Errorf("Shutdown: pool %q is client-only", node.Name) + return fmt.Errorf("Shutdown: client-only node cannot shutdown worker pool") } node.lock.Unlock() node.logger.Info("shutting down") @@ -401,31 +401,32 @@ func (node *Node) close(ctx context.Context, requeue bool) error { node.logger.Info("closing") node.closing = true - // Need to stop workers before requeueing jobs to prevent - // requeued jobs from being handled by this node. - var wg sync.WaitGroup - node.logger.Debug("stopping workers", "count", len(node.localWorkers)) - for _, w := range node.localWorkers { - wg.Add(1) - go func(w *Worker) { - defer wg.Done() - w.stopAndWait(ctx) - }(w) - } - wg.Wait() - node.logger.Debug("workers stopped") - - for _, w := range node.localWorkers { - if requeue { - if err := w.requeueJobs(ctx); err != nil { - node.logger.Error(fmt.Errorf("Close: failed to requeue jobs for worker %q: %w", w.ID, err)) - continue + if len(node.localWorkers) > 0 { + // Need to stop workers before requeueing jobs to prevent + // requeued jobs from being handled by this node. + var wg sync.WaitGroup + node.logger.Debug("stopping workers", "count", len(node.localWorkers)) + for _, w := range node.localWorkers { + wg.Add(1) + go func(w *Worker) { + defer wg.Done() + w.stopAndWait(ctx) + }(w) + } + wg.Wait() + node.logger.Debug("workers stopped") + for _, w := range node.localWorkers { + if requeue { + if err := w.requeueJobs(ctx); err != nil { + node.logger.Error(fmt.Errorf("Close: failed to requeue jobs for worker %q: %w", w.ID, err)) + continue + } } + w.cleanup(ctx) } - w.cleanup(ctx) + node.localWorkers = nil } - node.localWorkers = nil if !node.clientOnly { node.poolSink.Close() node.tickerMap.Close() @@ -532,7 +533,10 @@ func (node *Node) ackWorkerEvent(ctx context.Context, ev *streaming.Event) { key := workerID + ":" + ack.EventID pending, ok := node.pendingEvents[key] if !ok { - node.logger.Error(fmt.Errorf("ackWorkerEvent: received event %s from worker %s that was not dispatched", ack.EventID, workerID)) + node.logger.Error(fmt.Errorf("ackWorkerEvent: received unknown event %s from worker %s", ack.EventID, workerID)) + if err := node.poolSink.Ack(ctx, pending); err != nil { + node.logger.Error(fmt.Errorf("ackWorkerEvent: failed to ack unknown event: %w", err), "event", pending.EventName, "id", pending.ID) + } return } diff --git a/pool/node_test.go b/pool/node_test.go index 1aa5c08..1c8a34e 100644 --- a/pool/node_test.go +++ b/pool/node_test.go @@ -1,13 +1,17 @@ package pool import ( + "context" + "fmt" "strings" "testing" "time" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "goa.design/pulse/streaming" ptesting "goa.design/pulse/testing" ) @@ -90,6 +94,82 @@ func TestDispatchJobTwoWorkers(t *testing.T) { assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node") } +func TestNotifyWorker(t *testing.T) { + testName := strings.Replace(t.Name(), "/", "_", -1) + ctx := ptesting.NewTestContext(t) + rdb := ptesting.NewRedisClient(t) + node := newTestNode(t, ctx, rdb, testName) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + // Create a worker + worker := newTestWorker(t, ctx, node) + + // Set up notification handling + jobKey := "test-job" + jobPayload := []byte("job payload") + notificationPayload := []byte("test notification") + ch := make(chan []byte, 1) + worker.handler.(*mockHandler).notifyFunc = func(key string, payload []byte) error { + assert.Equal(t, jobKey, key, "Received notification for the wrong key") + assert.Equal(t, notificationPayload, payload, "Received notification for the wrong payload") + close(ch) + return nil + } + + // Dispatch a job to ensure the worker is assigned + require.NoError(t, node.DispatchJob(ctx, jobKey, jobPayload)) + + // Send a notification + err := node.NotifyWorker(ctx, jobKey, notificationPayload) + require.NoError(t, err, "Failed to send notification") + + // Wait for the notification to be received + select { + case <-ch: + case <-time.After(max): + t.Fatal("Timeout waiting for notification to be received") + } + + // Shutdown node + assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node") +} + +func TestNotifyWorkerNoHandler(t *testing.T) { + testName := strings.Replace(t.Name(), "/", "_", -1) + ctx, buf := ptesting.NewBufferedLogContext(t) + rdb := ptesting.NewRedisClient(t) + node := newTestNode(t, ctx, rdb, testName) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + // Create a worker without NotificationHandler implementation + worker := newTestWorkerWithoutNotify(t, ctx, node) + + // Dispatch a job to ensure the worker is assigned + jobKey := "test-job" + jobPayload := []byte("job payload") + require.NoError(t, node.DispatchJob(ctx, jobKey, jobPayload)) + + // Wait for the job to be received by the worker + require.Eventually(t, func() bool { + return len(worker.Jobs()) == 1 + }, max, delay, "Job was not received by the worker") + + // Send a notification + notificationPayload := []byte("test notification") + assert.NoError(t, node.NotifyWorker(ctx, jobKey, notificationPayload), "Failed to send notification") + + // Check that an error was logged + assert.Eventually(t, func() bool { + return strings.Contains(buf.String(), "worker does not implement NotificationHandler, ignoring notification") + }, max, delay, "Expected error message was not logged within the timeout period") + + // Ensure the worker is still functioning + assert.Len(t, worker.Jobs(), 1, "Worker should still have the job") + + // Shutdown node + assert.NoError(t, node.Shutdown(ctx), "Failed to shutdown node") +} + func TestRemoveWorkerThenShutdown(t *testing.T) { ctx := ptesting.NewTestContext(t) testName := strings.Replace(t.Name(), "/", "_", -1) @@ -225,3 +305,79 @@ func TestNodeCloseAndRequeue(t *testing.T) { // Clean up require.NoError(t, node2.Shutdown(ctx), "Failed to shutdown node2") } + +func TestStaleEventsAreRemoved(t *testing.T) { + // Setup + ctx := ptesting.NewTestContext(t) + testName := strings.Replace(t.Name(), "/", "_", -1) + rdb := ptesting.NewRedisClient(t) + defer ptesting.CleanupRedis(t, rdb, true, testName) + node := newTestNode(t, ctx, rdb, testName) + defer func() { assert.NoError(t, node.Shutdown(ctx)) }() + + // Add a stale event manually + staleEventID := fmt.Sprintf("%d-0", time.Now().Add(-2*node.pendingJobTTL).UnixNano()/int64(time.Millisecond)) + staleEvent := &streaming.Event{ + ID: staleEventID, + EventName: "test-event", + Payload: []byte("test-payload"), + Acker: &mockAcker{ + XAckFunc: func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd { + return redis.NewIntCmd(ctx, 0) + }, + }, + } + node.pendingEvents["worker:stale-event-id"] = staleEvent + + // Add a fresh event + freshEventID := fmt.Sprintf("%d-0", time.Now().Add(-time.Second).UnixNano()/int64(time.Millisecond)) + freshEvent := &streaming.Event{ + ID: freshEventID, + EventName: "test-event", + Payload: []byte("test-payload"), + Acker: &mockAcker{ + XAckFunc: func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd { + return redis.NewIntCmd(ctx, 0) + }, + }, + } + node.pendingEvents["worker:fresh-event-id"] = freshEvent + + // Create a mock event to trigger the ackWorkerEvent function + mockEvent := &streaming.Event{ + ID: "mock-event-id", + EventName: evAck, + Payload: marshalEnvelope("worker", marshalAck(&ack{EventID: "mock-event-id"})), + Acker: &mockAcker{ + XAckFunc: func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd { + return redis.NewIntCmd(ctx, 0) + }, + }, + } + node.pendingEvents["worker:mock-event-id"] = mockEvent + + // Call ackWorkerEvent to trigger the stale event cleanup + node.ackWorkerEvent(ctx, mockEvent) + + assert.Eventually(t, func() bool { + node.lock.Lock() + defer node.lock.Unlock() + _, exists := node.pendingEvents["worker:stale-event-id"] + return !exists + }, max, delay, "Stale event should have been removed") + + assert.Eventually(t, func() bool { + node.lock.Lock() + defer node.lock.Unlock() + _, exists := node.pendingEvents["worker:fresh-event-id"] + return exists + }, max, delay, "Fresh event should still be present") +} + +type mockAcker struct { + XAckFunc func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd +} + +func (m *mockAcker) XAck(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd { + return m.XAckFunc(ctx, streamKey, sinkName, ids...) +} diff --git a/pool/testing.go b/pool/testing.go index 63d0a95..75476d7 100644 --- a/pool/testing.go +++ b/pool/testing.go @@ -14,7 +14,13 @@ import ( type mockHandler struct { startFunc func(job *Job) error stopFunc func(key string) error - notifyFunc func(payload []byte) error + notifyFunc func(key string, payload []byte) error +} + +// mockHandlerWithoutNotify is a mock handler that doesn't implement NotificationHandler +type mockHandlerWithoutNotify struct { + startFunc func(job *Job) error + stopFunc func(key string) error } const ( @@ -47,7 +53,20 @@ func newTestWorker(t *testing.T, ctx context.Context, node *Node) *Worker { handler := &mockHandler{ startFunc: func(job *Job) error { return nil }, stopFunc: func(key string) error { return nil }, - notifyFunc: func(payload []byte) error { return nil }, + notifyFunc: func(key string, payload []byte) error { return nil }, + } + worker, err := node.AddWorker(ctx, handler) + require.NoError(t, err) + return worker +} + +// newTestWorkerWithoutNotify creates a new Worker instance for testing purposes. +// It sets up a mock handler without NotificationHandler for testing. +func newTestWorkerWithoutNotify(t *testing.T, ctx context.Context, node *Node) *Worker { + t.Helper() + handler := &mockHandlerWithoutNotify{ + startFunc: func(job *Job) error { return nil }, + stopFunc: func(key string) error { return nil }, } worker, err := node.AddWorker(ctx, handler) require.NoError(t, err) @@ -56,4 +75,9 @@ func newTestWorker(t *testing.T, ctx context.Context, node *Node) *Worker { func (w *mockHandler) Start(job *Job) error { return w.startFunc(job) } func (w *mockHandler) Stop(key string) error { return w.stopFunc(key) } -func (w *mockHandler) Notify(p []byte) error { return w.notifyFunc(p) } +func (w *mockHandler) HandleNotification(key string, payload []byte) error { + return w.notifyFunc(key, payload) +} + +func (h *mockHandlerWithoutNotify) Start(job *Job) error { return h.startFunc(job) } +func (h *mockHandlerWithoutNotify) Stop(key string) error { return h.stopFunc(key) } diff --git a/pool/worker.go b/pool/worker.go index 1c8f969..b5cfcd7 100644 --- a/pool/worker.go +++ b/pool/worker.go @@ -285,7 +285,7 @@ func (w *Worker) notify(_ context.Context, key string, payload []byte) error { } nh, ok := w.handler.(NotificationHandler) if !ok { - w.logger.Debug("worker does not implement NotificationHandler, ignoring notification") + w.logger.Error(fmt.Errorf("worker does not implement NotificationHandler, ignoring notification"), "worker", w.ID) return nil } w.logger.Debug("handled notification", "payload", string(payload)) diff --git a/streaming/reader.go b/streaming/reader.go index 6f37b78..1afa402 100644 --- a/streaming/reader.go +++ b/streaming/reader.go @@ -59,6 +59,11 @@ type ( rdb *redis.Client } + // Acker is the interface used by events to acknowledge themselves. + Acker interface { + XAck(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd + } + // Event is a stream event. Event struct { // ID is the unique event ID. @@ -73,10 +78,10 @@ type ( Topic string // Payload is the event payload. Payload []byte + // Acker is the redis client used to acknowledge events. + Acker Acker // streamKey is the Redis key of the stream. streamKey string - // rdb is the redis client. - rdb *redis.Client } ) @@ -314,7 +319,7 @@ func streamEvents( Topic: topic, Payload: []byte(event.Values[payloadKey].(string)), streamKey: streamKey, - rdb: rdb, + Acker: rdb, } if eventFilter != nil && !eventFilter(ev) { logger.Debug("event filtered", "event", ev.EventName, "id", ev.ID, "stream", streamName) diff --git a/streaming/sink.go b/streaming/sink.go index d241c4d..035c6fb 100644 --- a/streaming/sink.go +++ b/streaming/sink.go @@ -197,7 +197,7 @@ func (s *Sink) Unsubscribe(c <-chan *Event) { // Ack acknowledges the event. func (s *Sink) Ack(ctx context.Context, e *Event) error { - err := e.rdb.XAck(ctx, e.streamKey, e.SinkName, e.ID).Err() + err := e.Acker.XAck(ctx, e.streamKey, e.SinkName, e.ID).Err() if err != nil { s.logger.Error(err, "ack", e.ID, "stream", e.StreamName) return err