diff --git a/kafka/main.go b/kafka/main.go index 359da047..a111ce93 100644 --- a/kafka/main.go +++ b/kafka/main.go @@ -23,10 +23,12 @@ var brokers []string // Processor is a function that is used to process Kafka messages on type Processor func(context.Context, kafka.Message, *zerolog.Logger) error -// Subset of kafka.Reader methods that we use. This is used for testing. +// Subset of kafka.Reader methods that we use factored as an interface for unit +// testing support. type messageReader interface { FetchMessage(ctx context.Context) (kafka.Message, error) Stats() kafka.ReaderStats + CommitMessages(ctx context.Context, msgs ...kafka.Message) error } // TopicMapping represents a kafka topic, how to process it, and where to emit the result. @@ -43,8 +45,20 @@ type MessageContext struct { msg kafka.Message } -// StartConsumers reads configuration variables and starts the associated kafka consumers -func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error { +func closeWriter(writer io.Closer, logger *zerolog.Logger) { + err := writer.Close() + if err != nil { + logger.Error().Err(err).Msg("failed to close a writer") + } +} + +// RunConsumers reads configuration variables, creates the associated kafka +// readers and writer and run them until an error occurred. +func RunConsumers( + ctx context.Context, + providedServer *server.Server, + logger *zerolog.Logger, +) error { adsRequestRedeemV1Topic := os.Getenv("REDEEM_CONSUMER_TOPIC") adsResultRedeemV1Topic := os.Getenv("REDEEM_PRODUCER_TOPIC") adsRequestSignV1Topic := os.Getenv("SIGN_CONSUMER_TOPIC") @@ -59,11 +73,13 @@ func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error Topic: adsResultRedeemV1Topic, Dialer: getDialer(logger), }) + defer closeWriter(redeemWriter, logger) signWriter := kafka.NewWriter(kafka.WriterConfig{ Brokers: brokers, Topic: adsResultSignV1Topic, Dialer: getDialer(logger), }) + defer closeWriter(signWriter, logger) topicMappings := []TopicMapping{ { Topic: adsRequestRedeemV1Topic, @@ -86,86 +102,101 @@ func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error } reader := newConsumer(topics, adsConsumerGroupV1, logger) + defer reader.Close() // Each message in batchPipeline is associated with goroutine doing // CPU-intensive cryptography, so limit the channel capacity by CPU cores // plus some extra buffer to account for IO that a processor may potentially // do. batchPipeline := make(chan *MessageContext, runtime.NumCPU()+2) - ctx := context.Background() go processMessagesIntoBatchPipeline(ctx, topicMappings, reader, batchPipeline, logger) - for { - err := readAndCommitBatchPipelineResults(ctx, reader, batchPipeline, logger) - if err != nil { - // If readAndCommitBatchPipelineResults returns an error. - close(batchPipeline) - return err - } - } + return readAndCommitBatchPipelineResults(ctx, reader, batchPipeline, logger) } -// readAndCommitBatchPipelineResults does a blocking read of the batchPipeline channel and -// then does a blocking read of the done field in the MessageContext in the batchPipeline. -// When an error appears it means that the channel was closed or a temporary error was -// encountered. In the case of a temporary error, the application returns an error without -// committing so that the next reader gets the same message to try again. +// readAndCommitBatchPipelineResults receives messages from the batchPipeline +// channel and commits them until ctx is cancelled, batchPipeline is closed or a +// message error is received. func readAndCommitBatchPipelineResults( ctx context.Context, - reader *kafka.Reader, + reader messageReader, batchPipeline chan *MessageContext, logger *zerolog.Logger, ) error { - msgCtx := <-batchPipeline - <-msgCtx.done + for { + var msgCtx *MessageContext + select { + case <-ctx.Done(): + return ctx.Err() + case msgCtx = <-batchPipeline: + break + } + if msgCtx == nil { + // processMessagesIntoBatchPipeline has closed the channel. Report + // that to the caller as EOF. + return io.EOF + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-msgCtx.done: + break + } - if msgCtx.err != nil { - logger.Error().Err(msgCtx.err).Msg("temporary failure encountered") - return fmt.Errorf("temporary failure encountered: %w", msgCtx.err) - } - logger.Info().Msgf("Committing offset %d", msgCtx.msg.Offset) - if err := reader.CommitMessages(ctx, msgCtx.msg); err != nil { - logger.Error().Err(err).Msg("failed to commit") - return errors.New("failed to commit") + if msgCtx.err != nil { + return fmt.Errorf("temporary failure encountered: %w", msgCtx.err) + } + logger.Info().Msgf("Committing offset %d", msgCtx.msg.Offset) + if err := reader.CommitMessages(ctx, msgCtx.msg); err != nil { + return fmt.Errorf("failed to commit - %w", err) + } } - return nil } -// processMessagesIntoBatchPipeline fetches messages from Kafka indefinitely, -// pushes a MessageContext into the batchPipeline to maintain message order, and -// then spawns a goroutine that will process the message and push to errorResult -// of the MessageContext when the processing completes. +// processMessagesIntoBatchPipeline fetches messages from Kafka, pushes a +// MessageContext into the batchPipeline to maintain message order, and then +// spawns a goroutine that will process the message and closes the done channel +// of the MessageContext when the processing completes. This returns when the +// reader is closed or ctx is cancelled. func processMessagesIntoBatchPipeline(ctx context.Context, topicMappings []TopicMapping, reader messageReader, batchPipeline chan *MessageContext, logger *zerolog.Logger, ) { - // Loop forever + // Signal to runMessageProcessor() that processing stopped. + defer close(batchPipeline) for { msg, err := reader.FetchMessage(ctx) if err != nil { - // Indicates batch has no more messages. End the loop for - // this batch and fetch another. + if ctxErr := ctx.Err(); ctxErr != nil { + // cancelled context, log err if it is not related to the + // cancellation. + if !errors.Is(err, ctxErr) { + logger.Error().Err(err).Msg("FetchMessage error") + } + return + } if err == io.EOF { - logger.Info().Msg("Batch complete") - } else if errors.Is(err, context.DeadlineExceeded) { - logger.Error().Err(err).Msg("batch item error") - panic("failed to fetch kafka messages and closed channel") + logger.Info().Msg("Kafka reader closed") + return } // There are other possible errors, but the underlying consumer // group handler handle retryable failures well. If further // investigation is needed you can review the handler here: // https://github.com/segmentio/kafka-go/blob/main/consumergroup.go#L729 + logger.Error().Err(err).Msg("FetchMessage error") continue } msgCtx := &MessageContext{ done: make(chan struct{}), msg: msg, } - // If batchPipeline has been closed by an error in readAndCommitBatchPipelineResults, - // this write will panic, which is desired behavior, as the rest of the context - // will also have died and will be restarted from kafka/main.go - batchPipeline <- msgCtx + select { + case <-ctx.Done(): + return + case batchPipeline <- msgCtx: + break + } logger.Debug().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset) logger.Debug().Msgf("Reader Stats: %#v", reader.Stats()) logger.Debug().Msgf("topicMappings: %+v", topicMappings) diff --git a/kafka/main_test.go b/kafka/main_test.go index 86fc86ac..ee44e897 100644 --- a/kafka/main_test.go +++ b/kafka/main_test.go @@ -4,6 +4,7 @@ package kafka import ( "context" "errors" + "io" "sync/atomic" "testing" @@ -13,7 +14,8 @@ import ( ) type testMessageReader struct { - fetch func() (kafka.Message, error) + fetch func() (kafka.Message, error) + commit func(msgs []kafka.Message) error } func (r *testMessageReader) FetchMessage(ctx context.Context) (kafka.Message, error) { @@ -24,6 +26,10 @@ func (r *testMessageReader) Stats() kafka.ReaderStats { return kafka.ReaderStats{} } +func (r *testMessageReader) CommitMessages(ctx context.Context, msgs ...kafka.Message) error { + return r.commit(msgs) +} + func TestProcessMessagesIntoBatchPipeline(t *testing.T) { nopLog := zerolog.Nop() t.Run("AbsentTopicClosesMsg", func(t *testing.T) { @@ -38,9 +44,7 @@ func TestProcessMessagesIntoBatchPipeline(t *testing.T) { if messageCounter == 1 { return kafka.Message{Topic: "absent"}, nil } - // processMessagesIntoBatchPipeline never returns, so leak its - // goroutine via blocking here forever. - select {} + return kafka.Message{}, io.EOF } go processMessagesIntoBatchPipeline(context.Background(), nil, r, batchPipeline, &nopLog) @@ -52,6 +56,9 @@ func TestProcessMessagesIntoBatchPipeline(t *testing.T) { // Absent topic signals permanent error and the message should be // committed, so msg.err must be nil. assert.Nil(t, msg.err) + + _, ok := <-batchPipeline + assert.False(t, ok) }) t.Run("OrderPreserved", func(t *testing.T) { @@ -73,7 +80,7 @@ func TestProcessMessagesIntoBatchPipeline(t *testing.T) { // Processor below. return kafka.Message{Topic: "topicA", Partition: i}, nil } - select {} // block forever + return kafka.Message{}, io.EOF } atomicCounter := int32(N) topicMappings := []TopicMapping{{ @@ -108,5 +115,226 @@ func TestProcessMessagesIntoBatchPipeline(t *testing.T) { assert.Nil(t, msg.err) } } + _, ok := <-batchPipeline + assert.False(t, ok) + }) + + t.Run("ContextCancelStops", func(t *testing.T) { + t.Parallel() + + // generate two messages and cancel context when returning the second. + + ctx, cancel := context.WithCancel(context.Background()) + + batchPipeline := make(chan *MessageContext) + + r := &testMessageReader{} + messageCounter := 0 + r.fetch = func() (kafka.Message, error) { + i := messageCounter + messageCounter++ + if i > 1 { + panic("called more than once") + } + if i == 1 { + cancel() + } + return kafka.Message{Topic: "topicA", Partition: i}, nil + } + + topicMappings := []TopicMapping{{ + Topic: "topicA", + Processor: func(ctx context.Context, msg kafka.Message, logger *zerolog.Logger) error { + if msg.Partition > 0 { + panic("should only be called once") + } + return nil + }, + }} + + processFinished := make(chan struct{}) + go func() { + processMessagesIntoBatchPipeline(ctx, + topicMappings, r, batchPipeline, &nopLog) + close(processFinished) + }() + + msg := <-batchPipeline + assert.NotNil(t, msg) + <-msg.done + + <-processFinished + + // After processMessagesIntoBatchPipeline + assert.Error(t, ctx.Err()) + _, ok := <-batchPipeline + assert.False(t, ok) + }) +} + +func TestReadAndCommitBatchPipelineResults(t *testing.T) { + nopLog := zerolog.Nop() + + t.Run("WaitsForMessageDoneAfterReceiving", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + r := &testMessageReader{} + + r.commit = func(msgs []kafka.Message) error { + assert.Equal(t, 1, len(msgs)) + assert.Equal(t, "testA", msgs[0].Topic) + return nil + } + + batchPipeline := make(chan *MessageContext) + + readErr := make(chan error) + go func() { + readErr <- readAndCommitBatchPipelineResults(ctx, r, batchPipeline, &nopLog) + }() + + makeMsg := func() *MessageContext { + return &MessageContext{ + msg: kafka.Message{Topic: "testA"}, + done: make(chan struct{}), + } + } + + msg := makeMsg() + batchPipeline <- msg + + // Do not close, but write an empty struct to trigger deadlock if the + // read happens in the wrong order. For this to work all channels must + // be unbuffered. + var empty struct{} + msg.done <- empty + + msg = makeMsg() + batchPipeline <- msg + msg.done <- empty + + msg = makeMsg() + batchPipeline <- msg + msg.done <- empty + + close(batchPipeline) + + err := <-readErr + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("MessageWithErrorStopsReading", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + r := &testMessageReader{} + r.commit = func(msgs []kafka.Message) error { + panic("should not be called") + } + + batchPipeline := make(chan *MessageContext, 1) + + msg := &MessageContext{ + done: make(chan struct{}), + err: errors.New("New error"), + } + close(msg.done) + batchPipeline <- msg + + err := readAndCommitBatchPipelineResults(ctx, r, batchPipeline, &nopLog) + assert.ErrorIs(t, err, msg.err) + + close(batchPipeline) + err = readAndCommitBatchPipelineResults(ctx, r, batchPipeline, &nopLog) + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("CommitErrorStopsReading", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + r := &testMessageReader{} + + emitErr := errors.New("emit error") + + r.commit = func(msgs []kafka.Message) error { + assert.Equal(t, 1, len(msgs)) + assert.Equal(t, "testA", msgs[0].Topic) + return emitErr + } + + batchPipeline := make(chan *MessageContext, 1) + + msg := &MessageContext{ + msg: kafka.Message{Topic: "testA"}, + done: make(chan struct{}), + } + close(msg.done) + batchPipeline <- msg + + err := readAndCommitBatchPipelineResults(ctx, r, batchPipeline, &nopLog) + assert.ErrorIs(t, err, emitErr) + }) + + // check context cancel exits blocking read of batchPipeline arg. + t.Run("ContextCancelStops", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + r := &testMessageReader{} + + r.commit = func(msgs []kafka.Message) error { + panic("should not be called") + } + + batchPipeline := make(chan *MessageContext) + readErr := make(chan error) + go func() { + readErr <- readAndCommitBatchPipelineResults(ctx, r, batchPipeline, &nopLog) + }() + + cancel() + err := <-readErr + assert.Equal(t, ctx.Err(), err) }) + + // check context cancel exits blocking read of MessageContext.done + t.Run("ContextCancelStops2", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + r := &testMessageReader{} + + r.commit = func(msgs []kafka.Message) error { + panic("should not be called") + } + + batchPipeline := make(chan *MessageContext) + readErr := make(chan error) + go func() { + readErr <- readAndCommitBatchPipelineResults(ctx, r, batchPipeline, &nopLog) + }() + + msg := &MessageContext{ + msg: kafka.Message{Topic: "testA"}, + done: make(chan struct{}), + } + batchPipeline <- msg + + // As batchPipeline has zero capacity, we can be here only after + // readAndCommitBatchPipelineResults received from the channel. + // Cancelling context at this point should stop the blocking read from + // msg.done. + cancel() + + err := <-readErr + assert.Equal(t, ctx.Err(), err) + }) + } diff --git a/main.go b/main.go index 7c74bf39..badfbf04 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( _ "net/http/pprof" "os" "strconv" - "time" "github.com/brave-intl/bat-go/libs/logging" "github.com/brave-intl/challenge-bypass-server/kafka" @@ -88,7 +87,7 @@ func main() { if os.Getenv("KAFKA_ENABLED") != "false" { zeroLogger.Trace().Msg("Spawning Kafka goroutine") - go startKafka(srv, zeroLogger) + go runKafka(srv, zeroLogger) } zeroLogger.Trace().Msg("Initializing API server") @@ -103,14 +102,17 @@ func main() { } } -func startKafka(srv server.Server, zeroLogger *zerolog.Logger) { - zeroLogger.Trace().Msg("Initializing Kafka consumers") - err := kafka.StartConsumers(&srv, zeroLogger) +func runKafka(srv server.Server, zeroLogger *zerolog.Logger) { + ctx := context.Background() + zeroLogger.Trace().Msg("Running Kafka consumers") + err := kafka.RunConsumers(ctx, &srv, zeroLogger) + // For now if RunConsumer terminates due to temporary errors exit the + // process and let the container runtime to restart it. if err != nil { - zeroLogger.Error().Err(err).Msg("Failed to initialize Kafka consumers") - // If err is something then start consumer again - time.Sleep(10 * time.Second) - startKafka(srv, zeroLogger) + zeroLogger.Error().Err(err).Msg("Failed to run Kafka reader/writer") + } else { + zeroLogger.Error().Msg("kafka.RunConsumers() returned with no errors") } + os.Exit(10) }