Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[snowflake] Run merges in parallel during normalize flow #662

Merged
merged 12 commits into from
Nov 15, 2023
84 changes: 44 additions & 40 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"regexp"
"strings"
"sync/atomic"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
Expand All @@ -20,6 +21,7 @@ import (
"github.com/snowflakedb/gosnowflake"
"go.temporal.io/sdk/activity"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
)

//nolint:stylecheck
Expand Down Expand Up @@ -738,44 +740,39 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest
return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err)
}

// transaction for NormalizeRecords
normalizeRecordsTx, err := c.database.BeginTx(c.ctx, nil)
if err != nil {
return nil, fmt.Errorf("unable to begin transactions for NormalizeRecords: %w", err)
}
// in case we return after error, ensure transaction is rolled back
defer func() {
deferErr := normalizeRecordsTx.Rollback()
if deferErr != sql.ErrTxDone && deferErr != nil {
log.WithFields(log.Fields{
"flowName": req.FlowJobName,
}).Errorf("unexpected error while rolling back transaction for NormalizeRecords: %v", deferErr)
}
}()

var totalRowsAffected int64 = 0
// execute merge statements per table that uses CTEs to merge data into the normalized table
g, gCtx := errgroup.WithContext(c.ctx)
g.SetLimit(8) // limit parallel merges to 8

for _, destinationTableName := range destinationTableNames {
rowsAffected, err := c.generateAndExecuteMergeStatement(
destinationTableName,
tableNametoUnchangedToastCols[destinationTableName],
getRawTableIdentifier(req.FlowJobName),
syncBatchID, normalizeBatchID,
req,
normalizeRecordsTx)
if err != nil {
return nil, err
}
totalRowsAffected += rowsAffected
tableName := destinationTableName // local variable for the closure

g.Go(func() error {
rowsAffected, err := c.generateAndExecuteMergeStatement(
gCtx,
tableName,
tableNametoUnchangedToastCols[tableName],
getRawTableIdentifier(req.FlowJobName),
syncBatchID, normalizeBatchID,
req)
if err != nil {
log.WithFields(log.Fields{
"flowName": req.FlowJobName,
}).Errorf("[merge] error while normalizing records: %v", err)
return err
}

atomic.AddInt64(&totalRowsAffected, rowsAffected)
return nil
})
}

// updating metadata with new normalizeBatchID
err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID, normalizeRecordsTx)
if err != nil {
return nil, err
if err := g.Wait(); err != nil {
return nil, fmt.Errorf("error while normalizing records: %w", err)
}
// transaction commits
err = normalizeRecordsTx.Commit()

// updating metadata with new normalizeBatchID
err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -962,13 +959,13 @@ func (c *SnowflakeConnector) insertRecordsInRawTable(rawTableIdentifier string,
}

func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
ctx context.Context,
destinationTableIdentifier string,
unchangedToastColumns []string,
rawTableIdentifier string,
syncBatchID int64,
normalizeBatchID int64,
normalizeReq *model.NormalizeRecordsRequest,
normalizeRecordsTx *sql.Tx,
) (int64, error) {
normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier]
columnNames := maps.Keys(normalizedTableSchema.Columns)
Expand Down Expand Up @@ -1069,12 +1066,21 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

result, err := normalizeRecordsTx.ExecContext(c.ctx, mergeStatement, destinationTableIdentifier)
startTime := time.Now()
log.WithFields(log.Fields{
"flowName": destinationTableIdentifier,
}).Infof("[merge] merging records into %s...", destinationTableIdentifier)

result, err := c.database.ExecContext(ctx, mergeStatement, destinationTableIdentifier)
if err != nil {
return 0, fmt.Errorf("failed to merge records into %s (statement: %s): %w",
destinationTableIdentifier, mergeStatement, err)
}

endTime := time.Now()
log.Infof("[merge] merged records into %s, took: %d seconds",
destinationTableIdentifier, endTime.Sub(startTime)/time.Second)

return result.RowsAffected()
}

Expand Down Expand Up @@ -1133,8 +1139,7 @@ func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64
return nil
}

func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string,
normalizeBatchID int64, normalizeRecordsTx *sql.Tx) error {
func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
Expand All @@ -1143,9 +1148,8 @@ func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string,
return fmt.Errorf("job metadata does not exist, unable to update")
}

_, err = normalizeRecordsTx.ExecContext(c.ctx,
fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier),
normalizeBatchID, flowJobName)
stmt := fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier)
_, err = c.database.ExecContext(c.ctx, stmt, normalizeBatchID, flowJobName)
if err != nil {
return fmt.Errorf("failed to update metadata for NormalizeTables: %w", err)
}
Expand Down
Loading