diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index dcf1fb3ec4..f37bc55ff5 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -613,9 +613,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, return fmt.Errorf("failed to update start time for partition: %w", err) } - pullCtx, pullCancel := context.WithCancel(ctx) - defer pullCancel() - srcConn, err := connectors.GetQRepPullConnector(pullCtx, config.SourcePeer) + srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to get qrep source connector: %w", err) @@ -635,33 +633,42 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, }) defer shutdown() - var stream *model.QRecordStream + var rowsSynced int bufferSize := shared.FetchAndChannelSize - var wg sync.WaitGroup - - var goroutineErr error = nil if config.SourcePeer.Type == protos.DBType_POSTGRES { - stream = model.NewQRecordStream(bufferSize) - wg.Add(1) - - go func() { + errGroup, errCtx := errgroup.WithContext(ctx) + stream := model.NewQRecordStream(bufferSize) + errGroup.Go(func() error { pgConn := srcConn.(*connpostgres.PostgresConnector) - tmp, err := pgConn.PullQRepRecordStream(ctx, config, partition, stream) + tmp, err := pgConn.PullQRepRecordStream(errCtx, config, partition, stream) numRecords := int64(tmp) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - logger.Error("failed to pull records", slog.Any("error", err)) - goroutineErr = err + return fmt.Errorf("failed to pull records: %w", err) } else { - err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx, + err = monitoring.UpdatePullEndTimeAndRowsForPartition(errCtx, a.CatalogPool, runUUID, partition, numRecords) if err != nil { logger.Error(err.Error()) - goroutineErr = err } } - wg.Done() - }() + return nil + }) + + errGroup.Go(func() error { + rowsSynced, err = dstConn.SyncQRepRecords(ctx, config, partition, stream) + if err != nil { + a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + return fmt.Errorf("failed to sync records: %w", err) + } + return nil + }) + + err = errGroup.Wait() + if err != nil { + a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + return err + } } else { recordBatch, err := srcConn.PullQRepRecords(ctx, config, partition) if err != nil { @@ -675,35 +682,27 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, return err } - stream, err = recordBatch.ToQRecordStream(bufferSize) + stream, err := recordBatch.ToQRecordStream(bufferSize) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to convert to qrecord stream: %w", err) } - } - rowsSynced, err := dstConn.SyncQRepRecords(ctx, config, partition, stream) - if err != nil { - a.Alerter.LogFlowError(ctx, config.FlowJobName, err) - return fmt.Errorf("failed to sync records: %w", err) - } - - if rowsSynced == 0 { - logger.Info("no records to push for partition " + partition.PartitionId) - pullCancel() - } else { - wg.Wait() - if goroutineErr != nil { - a.Alerter.LogFlowError(ctx, config.FlowJobName, goroutineErr) - return goroutineErr + rowsSynced, err = dstConn.SyncQRepRecords(ctx, config, partition, stream) + if err != nil { + a.Alerter.LogFlowError(ctx, config.FlowJobName, err) + return fmt.Errorf("failed to sync records: %w", err) } + } + if rowsSynced > 0 { + logger.Info(fmt.Sprintf("pushed %d records", rowsSynced)) err := monitoring.UpdateRowsSyncedForPartition(ctx, a.CatalogPool, rowsSynced, runUUID, partition) if err != nil { return err } - - logger.Info(fmt.Sprintf("pushed %d records", rowsSynced)) + } else { + logger.Info("no records to push for partition " + partition.PartitionId) } err = monitoring.UpdateEndTimeForPartition(ctx, a.CatalogPool, runUUID, partition)