From 96ca6282b521a64293b2c6c30ac1fca9c0d08942 Mon Sep 17 00:00:00 2001
From: Amogh Bharadwaj <amogh@peerdb.io>
Date: Wed, 28 Feb 2024 22:21:57 +0530
Subject: [PATCH] Flowable: Use errgroup in replicateQRep (#1390)

Replaces a usage waitgroup with errgroup in replicateQRepPartitions in
flowable.go
This is a preceding PR to #1368
---
 flow/activities/flowable.go | 71 ++++++++++++++++++-------------------
 1 file changed, 35 insertions(+), 36 deletions(-)

diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go
index dbe9459642..62e8fe89f2 100644
--- a/flow/activities/flowable.go
+++ b/flow/activities/flowable.go
@@ -617,9 +617,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
 		return fmt.Errorf("failed to update start time for partition: %w", err)
 	}
 
-	pullCtx, pullCancel := context.WithCancel(ctx)
-	defer pullCancel()
-	srcConn, err := connectors.GetQRepPullConnector(pullCtx, config.SourcePeer)
+	srcConn, err := connectors.GetQRepPullConnector(ctx, config.SourcePeer)
 	if err != nil {
 		a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
 		return fmt.Errorf("failed to get qrep source connector: %w", err)
@@ -639,33 +637,42 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
 	})
 	defer shutdown()
 
-	var stream *model.QRecordStream
+	var rowsSynced int
 	bufferSize := shared.FetchAndChannelSize
-	var wg sync.WaitGroup
-
-	var goroutineErr error = nil
 	if config.SourcePeer.Type == protos.DBType_POSTGRES {
-		stream = model.NewQRecordStream(bufferSize)
-		wg.Add(1)
-
-		go func() {
+		errGroup, errCtx := errgroup.WithContext(ctx)
+		stream := model.NewQRecordStream(bufferSize)
+		errGroup.Go(func() error {
 			pgConn := srcConn.(*connpostgres.PostgresConnector)
-			tmp, err := pgConn.PullQRepRecordStream(ctx, config, partition, stream)
+			tmp, err := pgConn.PullQRepRecordStream(errCtx, config, partition, stream)
 			numRecords := int64(tmp)
 			if err != nil {
 				a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
-				logger.Error("failed to pull records", slog.Any("error", err))
-				goroutineErr = err
+				return fmt.Errorf("failed to pull records: %w", err)
 			} else {
-				err = monitoring.UpdatePullEndTimeAndRowsForPartition(ctx,
+				err = monitoring.UpdatePullEndTimeAndRowsForPartition(errCtx,
 					a.CatalogPool, runUUID, partition, numRecords)
 				if err != nil {
 					logger.Error(err.Error())
-					goroutineErr = err
 				}
 			}
-			wg.Done()
-		}()
+			return nil
+		})
+
+		errGroup.Go(func() error {
+			rowsSynced, err = dstConn.SyncQRepRecords(ctx, config, partition, stream)
+			if err != nil {
+				a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
+				return fmt.Errorf("failed to sync records: %w", err)
+			}
+			return nil
+		})
+
+		err = errGroup.Wait()
+		if err != nil {
+			a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
+			return err
+		}
 	} else {
 		recordBatch, err := srcConn.PullQRepRecords(ctx, config, partition)
 		if err != nil {
@@ -679,35 +686,27 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context,
 			return err
 		}
 
-		stream, err = recordBatch.ToQRecordStream(bufferSize)
+		stream, err := recordBatch.ToQRecordStream(bufferSize)
 		if err != nil {
 			a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
 			return fmt.Errorf("failed to convert to qrecord stream: %w", err)
 		}
-	}
 
-	rowsSynced, err := dstConn.SyncQRepRecords(ctx, config, partition, stream)
-	if err != nil {
-		a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
-		return fmt.Errorf("failed to sync records: %w", err)
-	}
-
-	if rowsSynced == 0 {
-		logger.Info("no records to push for partition " + partition.PartitionId)
-		pullCancel()
-	} else {
-		wg.Wait()
-		if goroutineErr != nil {
-			a.Alerter.LogFlowError(ctx, config.FlowJobName, goroutineErr)
-			return goroutineErr
+		rowsSynced, err = dstConn.SyncQRepRecords(ctx, config, partition, stream)
+		if err != nil {
+			a.Alerter.LogFlowError(ctx, config.FlowJobName, err)
+			return fmt.Errorf("failed to sync records: %w", err)
 		}
+	}
 
+	if rowsSynced > 0 {
+		logger.Info(fmt.Sprintf("pushed %d records", rowsSynced))
 		err := monitoring.UpdateRowsSyncedForPartition(ctx, a.CatalogPool, rowsSynced, runUUID, partition)
 		if err != nil {
 			return err
 		}
-
-		logger.Info(fmt.Sprintf("pushed %d records", rowsSynced))
+	} else {
+		logger.Info("no records to push for partition " + partition.PartitionId)
 	}
 
 	err = monitoring.UpdateEndTimeForPartition(ctx, a.CatalogPool, runUUID, partition)