Skip to content

Commit

Permalink
Revert "[snowflake] Run merges in parallel during normalize flow (#662)"
Browse files Browse the repository at this point in the history
This reverts commit d0b4f20.
  • Loading branch information
iskakaushik committed Dec 15, 2023
1 parent de62091 commit 7d152c9
Showing 1 changed file with 40 additions and 44 deletions.
84 changes: 40 additions & 44 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"regexp"
"strings"
"sync/atomic"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand All @@ -21,7 +20,6 @@ import (
"github.com/snowflakedb/gosnowflake"
"go.temporal.io/sdk/activity"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
)

//nolint:stylecheck
Expand Down Expand Up @@ -740,39 +738,44 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest
return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err)
}

var totalRowsAffected int64 = 0
g, gCtx := errgroup.WithContext(c.ctx)
g.SetLimit(8) // limit parallel merges to 8

for _, destinationTableName := range destinationTableNames {
tableName := destinationTableName // local variable for the closure

g.Go(func() error {
rowsAffected, err := c.generateAndExecuteMergeStatement(
gCtx,
tableName,
tableNametoUnchangedToastCols[tableName],
getRawTableIdentifier(req.FlowJobName),
syncBatchID, normalizeBatchID,
req)
if err != nil {
log.WithFields(log.Fields{
"flowName": req.FlowJobName,
}).Errorf("[merge] error while normalizing records: %v", err)
return err
}

atomic.AddInt64(&totalRowsAffected, rowsAffected)
return nil
})
// transaction for NormalizeRecords
normalizeRecordsTx, err := c.database.BeginTx(c.ctx, nil)
if err != nil {
return nil, fmt.Errorf("unable to begin transactions for NormalizeRecords: %w", err)
}
// in case we return after error, ensure transaction is rolled back
defer func() {
deferErr := normalizeRecordsTx.Rollback()
if deferErr != sql.ErrTxDone && deferErr != nil {
log.WithFields(log.Fields{
"flowName": req.FlowJobName,
}).Errorf("unexpected error while rolling back transaction for NormalizeRecords: %v", deferErr)
}
}()

if err := g.Wait(); err != nil {
return nil, fmt.Errorf("error while normalizing records: %w", err)
var totalRowsAffected int64 = 0
// execute merge statements per table that uses CTEs to merge data into the normalized table
for _, destinationTableName := range destinationTableNames {
rowsAffected, err := c.generateAndExecuteMergeStatement(
destinationTableName,
tableNametoUnchangedToastCols[destinationTableName],
getRawTableIdentifier(req.FlowJobName),
syncBatchID, normalizeBatchID,
req,
normalizeRecordsTx)
if err != nil {
return nil, err
}
totalRowsAffected += rowsAffected
}

// updating metadata with new normalizeBatchID
err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID)
err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID, normalizeRecordsTx)
if err != nil {
return nil, err
}
// transaction commits
err = normalizeRecordsTx.Commit()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -959,13 +962,13 @@ func (c *SnowflakeConnector) insertRecordsInRawTable(rawTableIdentifier string,
}

func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
ctx context.Context,
destinationTableIdentifier string,
unchangedToastColumns []string,
rawTableIdentifier string,
syncBatchID int64,
normalizeBatchID int64,
normalizeReq *model.NormalizeRecordsRequest,
normalizeRecordsTx *sql.Tx,
) (int64, error) {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := maps.Keys(normalizedTableSchema.Columns)
Expand Down Expand Up @@ -1066,21 +1069,12 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

startTime := time.Now()
log.WithFields(log.Fields{
"flowName": destinationTableIdentifier,
}).Infof("[merge] merging records into %s...", destinationTableIdentifier)

result, err := c.database.ExecContext(ctx, mergeStatement, destinationTableIdentifier)
result, err := normalizeRecordsTx.ExecContext(c.ctx, mergeStatement, destinationTableIdentifier)
if err != nil {
return 0, fmt.Errorf("failed to merge records into %s (statement: %s): %w",
destinationTableIdentifier, mergeStatement, err)
}

endTime := time.Now()
log.Infof("[merge] merged records into %s, took: %d seconds",
destinationTableIdentifier, endTime.Sub(startTime)/time.Second)

return result.RowsAffected()
}

Expand Down Expand Up @@ -1139,7 +1133,8 @@ func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64
return nil
}

func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64) error {
func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string,
normalizeBatchID int64, normalizeRecordsTx *sql.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
Expand All @@ -1148,8 +1143,9 @@ func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normali
return fmt.Errorf("job metadata does not exist, unable to update")
}

stmt := fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier)
_, err = c.database.ExecContext(c.ctx, stmt, normalizeBatchID, flowJobName)
_, err = normalizeRecordsTx.ExecContext(c.ctx,
fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier),
normalizeBatchID, flowJobName)
if err != nil {
return fmt.Errorf("failed to update metadata for NormalizeTables: %w", err)
}
Expand Down

0 comments on commit 7d152c9

Please sign in to comment.