diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 072044d787..efdb8d489a 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "sync/atomic" "time" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" @@ -123,7 +124,6 @@ func (c *EventHubConnector) SetLastOffset(jobName string, offset int64) error { func (c *EventHubConnector) processBatch( flowJobName string, batch *model.CDCRecordStream, - maxParallelism int64, ) (uint32, error) { ctx := context.Background() batchPerTopic := NewHubBatches(c.hubManager) @@ -137,22 +137,32 @@ func (c *EventHubConnector) processBatch( lastSeenLSN := int64(0) lastUpdatedOffset := int64(0) - numRecords := 0 + numRecords := atomic.Uint32{} + shutdown := utils.HeartbeatRoutine(c.ctx, 10*time.Second, func() string { + return fmt.Sprintf( + "processed %d records for flow %s", + numRecords.Load(), flowJobName, + ) + }) + defer shutdown() + for { select { case record, ok := <-batch.GetRecords(): if !ok { c.logger.Info("flushing batches because no more records") - err := batchPerTopic.flushAllBatches(ctx, maxParallelism, flowJobName) + err := batchPerTopic.flushAllBatches(ctx, flowJobName) if err != nil { return 0, err } - c.logger.Info("processBatch", slog.Int("Total records sent to event hub", numRecords)) - return uint32(numRecords), nil + currNumRecords := numRecords.Load() + + c.logger.Info("processBatch", slog.Int("Total records sent to event hub", int(currNumRecords))) + return currNumRecords, nil } - numRecords++ + numRecords.Add(1) recordLSN := record.GetCheckPointID() if recordLSN > lastSeenLSN { @@ -190,12 +200,13 @@ func (c *EventHubConnector) processBatch( return 0, err } - if numRecords%1000 == 0 { - c.logger.Error("processBatch", slog.Int("number of records processed for sending", numRecords)) + curNumRecords := numRecords.Load() + if curNumRecords%1000 == 0 { + c.logger.Error("processBatch", slog.Int("number of records processed for sending", int(curNumRecords))) } case <-ticker.C: - err := batchPerTopic.flushAllBatches(ctx, maxParallelism, flowJobName) + err := batchPerTopic.flushAllBatches(ctx, flowJobName) if err != nil { return 0, err } @@ -215,24 +226,9 @@ func (c *EventHubConnector) processBatch( } func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { - maxParallelism := req.PushParallelism - if maxParallelism <= 0 { - maxParallelism = 10 - } - - var err error batch := req.Records - var numRecords uint32 - - shutdown := utils.HeartbeatRoutine(c.ctx, 10*time.Second, func() string { - return fmt.Sprintf( - "processed %d records for flow %s", - numRecords, req.FlowJobName, - ) - }) - defer shutdown() - numRecords, err = c.processBatch(req.FlowJobName, batch, maxParallelism) + numRecords, err := c.processBatch(req.FlowJobName, batch) if err != nil { c.logger.Error("failed to process batch", slog.Any("error", err)) return nil, err diff --git a/flow/connectors/eventhub/hub_batches.go b/flow/connectors/eventhub/hub_batches.go index b4d8997bb1..5634173faf 100644 --- a/flow/connectors/eventhub/hub_batches.go +++ b/flow/connectors/eventhub/hub_batches.go @@ -10,7 +10,6 @@ import ( azeventhubs "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs" "github.com/PeerDB-io/peer-flow/shared" - "golang.org/x/sync/errgroup" ) // multimap from ScopedEventhub to *azeventhubs.EventDataBatch @@ -76,10 +75,14 @@ func (h *HubBatches) Len() int { } // ForEach calls the given function for each ScopedEventhub and batch pair -func (h *HubBatches) ForEach(fn func(ScopedEventhub, *azeventhubs.EventDataBatch)) { - for destination, batch := range h.batch { - fn(destination, batch) +func (h *HubBatches) ForEach(fn func(ScopedEventhub, *azeventhubs.EventDataBatch) error) error { + for name, batch := range h.batch { + err := fn(name, batch) + if err != nil { + return err + } } + return nil } func (h *HubBatches) sendBatch( @@ -108,7 +111,6 @@ func (h *HubBatches) sendBatch( func (h *HubBatches) flushAllBatches( ctx context.Context, - maxParallelism int64, flowName string, ) error { if h.Len() == 0 { @@ -116,34 +118,36 @@ func (h *HubBatches) flushAllBatches( return nil } - var numEventsPushed int32 - g, gCtx := errgroup.WithContext(ctx) - g.SetLimit(int(maxParallelism)) - h.ForEach(func(destination ScopedEventhub, eventBatch *azeventhubs.EventDataBatch) { - g.Go(func() error { + var numEventsPushed atomic.Int32 + err := h.ForEach( + func( + destination ScopedEventhub, + eventBatch *azeventhubs.EventDataBatch, + ) error { numEvents := eventBatch.NumEvents() - err := h.sendBatch(gCtx, destination, eventBatch) + err := h.sendBatch(ctx, destination, eventBatch) if err != nil { return err } - atomic.AddInt32(&numEventsPushed, numEvents) + numEventsPushed.Add(numEvents) slog.Info("flushAllBatches", slog.String(string(shared.FlowNameKey), flowName), slog.Int("events sent", int(numEvents)), slog.String("event hub topic ", destination.ToString())) return nil }) - }) - err := g.Wait() + h.Clear() + + if err != nil { + return fmt.Errorf("failed to flushAllBatches: %v", err) + } slog.Info("hub batches flush", slog.String(string(shared.FlowNameKey), flowName), - slog.Int("events sent", int(numEventsPushed))) + slog.Int("events sent", int(numEventsPushed.Load()))) // clear the batches after flushing them. - h.Clear() - return err }