From a2847f9fe1ddf95bcbbc74714873d39e448fdd95 Mon Sep 17 00:00:00 2001 From: Kaushik Iska Date: Wed, 15 Nov 2023 09:01:59 -0500 Subject: [PATCH] [snowflake] Run merges in parallel during normalize flow --- flow/connectors/snowflake/snowflake.go | 41 ++++++++++++++++++-------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 4737a6def6..cbc1501230 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -8,6 +8,7 @@ import ( "fmt" "regexp" "strings" + "sync/atomic" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -20,6 +21,7 @@ import ( "github.com/snowflakedb/gosnowflake" "go.temporal.io/sdk/activity" "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" ) //nolint:stylecheck @@ -754,19 +756,34 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest }() var totalRowsAffected int64 = 0 - // execute merge statements per table that uses CTEs to merge data into the normalized table + g, _ := errgroup.WithContext(context.Background()) + sem := make(chan struct{}, 8) // semaphore to limit parallel merges + for _, destinationTableName := range destinationTableNames { - rowsAffected, err := c.generateAndExecuteMergeStatement( - destinationTableName, - tableNametoUnchangedToastCols[destinationTableName], - getRawTableIdentifier(req.FlowJobName), - syncBatchID, normalizeBatchID, - req, - normalizeRecordsTx) - if err != nil { - return nil, err - } - totalRowsAffected += rowsAffected + sem <- struct{}{} // block if semaphore is full + tableName := destinationTableName // local variable for the closure + + g.Go(func() error { + defer func() { <-sem }() // release semaphore + + rowsAffected, err := c.generateAndExecuteMergeStatement( + tableName, + tableNametoUnchangedToastCols[tableName], + getRawTableIdentifier(req.FlowJobName), + syncBatchID, normalizeBatchID, + req, + normalizeRecordsTx) + if err != nil { + return err + } + + atomic.AddInt64(&totalRowsAffected, rowsAffected) + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("error while normalizing records: %w", err) } // updating metadata with new normalizeBatchID