diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index b2065a39c1..4737a6def6 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -8,7 +8,6 @@ import ( "fmt" "regexp" "strings" - "sync/atomic" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -21,7 +20,6 @@ import ( "github.com/snowflakedb/gosnowflake" "go.temporal.io/sdk/activity" "golang.org/x/exp/maps" - "golang.org/x/sync/errgroup" ) //nolint:stylecheck @@ -740,39 +738,44 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } - var totalRowsAffected int64 = 0 - g, gCtx := errgroup.WithContext(c.ctx) - g.SetLimit(8) // limit parallel merges to 8 - - for _, destinationTableName := range destinationTableNames { - tableName := destinationTableName // local variable for the closure - - g.Go(func() error { - rowsAffected, err := c.generateAndExecuteMergeStatement( - gCtx, - tableName, - tableNametoUnchangedToastCols[tableName], - getRawTableIdentifier(req.FlowJobName), - syncBatchID, normalizeBatchID, - req) - if err != nil { - log.WithFields(log.Fields{ - "flowName": req.FlowJobName, - }).Errorf("[merge] error while normalizing records: %v", err) - return err - } - - atomic.AddInt64(&totalRowsAffected, rowsAffected) - return nil - }) + // transaction for NormalizeRecords + normalizeRecordsTx, err := c.database.BeginTx(c.ctx, nil) + if err != nil { + return nil, fmt.Errorf("unable to begin transactions for NormalizeRecords: %w", err) } + // in case we return after error, ensure transaction is rolled back + defer func() { + deferErr := normalizeRecordsTx.Rollback() + if deferErr != sql.ErrTxDone && deferErr != nil { + log.WithFields(log.Fields{ + "flowName": req.FlowJobName, + }).Errorf("unexpected error while rolling back transaction for NormalizeRecords: %v", deferErr) + } + }() - if err := g.Wait(); err != nil { - return nil, fmt.Errorf("error while normalizing records: %w", err) + var totalRowsAffected int64 = 0 + // execute merge statements per table that uses CTEs to merge data into the normalized table + for _, destinationTableName := range destinationTableNames { + rowsAffected, err := c.generateAndExecuteMergeStatement( + destinationTableName, + tableNametoUnchangedToastCols[destinationTableName], + getRawTableIdentifier(req.FlowJobName), + syncBatchID, normalizeBatchID, + req, + normalizeRecordsTx) + if err != nil { + return nil, err + } + totalRowsAffected += rowsAffected } // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID) + err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID, normalizeRecordsTx) + if err != nil { + return nil, err + } + // transaction commits + err = normalizeRecordsTx.Commit() if err != nil { return nil, err } @@ -959,13 +962,13 @@ func (c *SnowflakeConnector) insertRecordsInRawTable(rawTableIdentifier string, } func (c *SnowflakeConnector) generateAndExecuteMergeStatement( - ctx context.Context, destinationTableIdentifier string, unchangedToastColumns []string, rawTableIdentifier string, syncBatchID int64, normalizeBatchID int64, normalizeReq *model.NormalizeRecordsRequest, + normalizeRecordsTx *sql.Tx, ) (int64, error) { normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] columnNames := maps.Keys(normalizedTableSchema.Columns) @@ -1066,21 +1069,12 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")), pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) - startTime := time.Now() - log.WithFields(log.Fields{ - "flowName": destinationTableIdentifier, - }).Infof("[merge] merging records into %s...", destinationTableIdentifier) - - result, err := c.database.ExecContext(ctx, mergeStatement, destinationTableIdentifier) + result, err := normalizeRecordsTx.ExecContext(c.ctx, mergeStatement, destinationTableIdentifier) if err != nil { return 0, fmt.Errorf("failed to merge records into %s (statement: %s): %w", destinationTableIdentifier, mergeStatement, err) } - endTime := time.Now() - log.Infof("[merge] merged records into %s, took: %d seconds", - destinationTableIdentifier, endTime.Sub(startTime)/time.Second) - return result.RowsAffected() } @@ -1139,7 +1133,8 @@ func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64 return nil } -func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64) error { +func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, + normalizeBatchID int64, normalizeRecordsTx *sql.Tx) error { jobMetadataExists, err := c.jobMetadataExists(flowJobName) if err != nil { return fmt.Errorf("failed to get sync status for flow job: %w", err) @@ -1148,8 +1143,9 @@ func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normali return fmt.Errorf("job metadata does not exist, unable to update") } - stmt := fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier) - _, err = c.database.ExecContext(c.ctx, stmt, normalizeBatchID, flowJobName) + _, err = normalizeRecordsTx.ExecContext(c.ctx, + fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier), + normalizeBatchID, flowJobName) if err != nil { return fmt.Errorf("failed to update metadata for NormalizeTables: %w", err) }