Skip to content

Commit

Permalink
sf: merge one batch at a time
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed Mar 9, 2024
1 parent facfe8b commit 6e252f7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 44 deletions.
8 changes: 3 additions & 5 deletions flow/connectors/snowflake/merge_stmt_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ type mergeStmtGenerator struct {
rawTableName string
// destination table name, used to retrieve records from raw table
dstTableName string
// last synced batchID.
syncBatchID int64
// last normalized batchID.
normalizeBatchID int64
// Id of the currently merging batch
batchIdForThisMerge int64
// the schema of the table to merge into
normalizedTableSchema *protos.TableSchema
// array of toast column combinations that are unchanged
Expand Down Expand Up @@ -136,7 +134,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) {
}

mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable),
toVariantColumnName, m.rawTableName, m.normalizeBatchID, m.syncBatchID, flattenedCastsSQL,
toVariantColumnName, m.rawTableName, m.batchIdForThisMerge, flattenedCastsSQL,
fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")),
pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

Expand Down
90 changes: 51 additions & 39 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const (
mergeStatementSQL = `MERGE INTO %s TARGET USING (WITH VARIANT_CONVERTED AS (
SELECT _PEERDB_UID,_PEERDB_TIMESTAMP,TO_VARIANT(PARSE_JSON(_PEERDB_DATA)) %s,_PEERDB_RECORD_TYPE,
_PEERDB_MATCH_DATA,_PEERDB_BATCH_ID,_PEERDB_UNCHANGED_TOAST_COLUMNS
FROM _PEERDB_INTERNAL.%s WHERE _PEERDB_BATCH_ID > %d AND _PEERDB_BATCH_ID <= %d AND
FROM _PEERDB_INTERNAL.%s WHERE _PEERDB_BATCH_ID = %d AND
_PEERDB_DESTINATION_TABLE_NAME = ? ), FLATTENED AS
(SELECT _PEERDB_UID,_PEERDB_TIMESTAMP,_PEERDB_RECORD_TYPE,_PEERDB_MATCH_DATA,_PEERDB_BATCH_ID,
_PEERDB_UNCHANGED_TOAST_COLUMNS,%s
Expand All @@ -55,10 +55,10 @@ const (
%s
WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN %s`
getDistinctDestinationTableNames = `SELECT DISTINCT _PEERDB_DESTINATION_TABLE_NAME FROM %s.%s WHERE
_PEERDB_BATCH_ID > %d AND _PEERDB_BATCH_ID <= %d`
_PEERDB_BATCH_ID = %d`
getTableNameToUnchangedColsSQL = `SELECT _PEERDB_DESTINATION_TABLE_NAME,
ARRAY_AGG(DISTINCT _PEERDB_UNCHANGED_TOAST_COLUMNS) FROM %s.%s WHERE
_PEERDB_BATCH_ID > %d AND _PEERDB_BATCH_ID <= %d AND _PEERDB_RECORD_TYPE != 2
_PEERDB_BATCH_ID = %d AND _PEERDB_RECORD_TYPE != 2
GROUP BY _PEERDB_DESTINATION_TABLE_NAME`
getTableSchemaSQL = `SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS
WHERE UPPER(TABLE_SCHEMA)=? AND UPPER(TABLE_NAME)=? ORDER BY ORDINAL_POSITION`
Expand Down Expand Up @@ -265,13 +265,12 @@ func (c *SnowflakeConnector) GetLastNormalizeBatchID(ctx context.Context, jobNam
func (c *SnowflakeConnector) getDistinctTableNamesInBatch(
ctx context.Context,
flowJobName string,
syncBatchID int64,
normalizeBatchID int64,
batchId int64,
) ([]string, error) {
rawTableIdentifier := getRawTableIdentifier(flowJobName)

rows, err := c.database.QueryContext(ctx, fmt.Sprintf(getDistinctDestinationTableNames, c.rawSchema,
rawTableIdentifier, normalizeBatchID, syncBatchID))
rawTableIdentifier, batchId))
if err != nil {
return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err)
}
Expand All @@ -297,13 +296,12 @@ func (c *SnowflakeConnector) getDistinctTableNamesInBatch(
func (c *SnowflakeConnector) getTableNameToUnchangedCols(
ctx context.Context,
flowJobName string,
syncBatchID int64,
normalizeBatchID int64,
batchId int64,
) (map[string][]string, error) {
rawTableIdentifier := getRawTableIdentifier(flowJobName)

rows, err := c.database.QueryContext(ctx, fmt.Sprintf(getTableNameToUnchangedColsSQL, c.rawSchema,
rawTableIdentifier, normalizeBatchID, syncBatchID))
rawTableIdentifier, batchId))
if err != nil {
return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err)
}
Expand Down Expand Up @@ -500,19 +498,47 @@ func (c *SnowflakeConnector) NormalizeRecords(ctx context.Context, req *model.No
}, nil
}

destinationTableNames, err := c.getDistinctTableNamesInBatch(
ctx,
req.FlowJobName,
req.SyncBatchID,
normBatchID,
)
for batchId := normBatchID + 1; batchId <= req.SyncBatchID; batchId++ {
mergeErr := c.MergeTablesInBatch(ctx, batchId,
req.FlowJobName, req.TableNameSchemaMapping,
&protos.PeerDBColumns{
SoftDelete: req.SoftDelete,
SoftDeleteColName: req.SoftDeleteColName,
SyncedAtColName: req.SyncedAtColName,
},
)
if mergeErr != nil {
return nil, mergeErr
}

err = c.pgMetadata.UpdateNormalizeBatchID(ctx, req.FlowJobName, batchId)
if err != nil {
return nil, err
}
}

return &model.NormalizeResponse{
Done: true,
StartBatchID: normBatchID + 1,
EndBatchID: req.SyncBatchID,
}, nil
}

func (c *SnowflakeConnector) MergeTablesInBatch(
ctx context.Context,
batchId int64,
flowName string,
tableToSchema map[string]*protos.TableSchema,
peerdbCols *protos.PeerDBColumns,
) error {
destinationTableNames, err := c.getDistinctTableNamesInBatch(ctx, flowName, batchId)
if err != nil {
return nil, err
return err
}

tableNameToUnchangedToastCols, err := c.getTableNameToUnchangedCols(ctx, req.FlowJobName, req.SyncBatchID, normBatchID)
tableNameToUnchangedToastCols, err := c.getTableNameToUnchangedCols(ctx, flowName, batchId)
if err != nil {
return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err)
return fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err)
}

var totalRowsAffected int64 = 0
Expand All @@ -521,22 +547,17 @@ func (c *SnowflakeConnector) NormalizeRecords(ctx context.Context, req *model.No

for _, tableName := range destinationTableNames {
if err := gCtx.Err(); err != nil {
return nil, fmt.Errorf("canceled while normalizing records: %w", err)
return fmt.Errorf("canceled while normalizing records: %w", err)
}

g.Go(func() error {
mergeGen := &mergeStmtGenerator{
rawTableName: getRawTableIdentifier(req.FlowJobName),
rawTableName: getRawTableIdentifier(flowName),
dstTableName: tableName,
syncBatchID: req.SyncBatchID,
normalizeBatchID: normBatchID,
normalizedTableSchema: req.TableNameSchemaMapping[tableName],
batchIdForThisMerge: batchId,
normalizedTableSchema: tableToSchema[tableName],
unchangedToastColumns: tableNameToUnchangedToastCols[tableName],
peerdbCols: &protos.PeerDBColumns{
SoftDelete: req.SoftDelete,
SoftDeleteColName: req.SoftDeleteColName,
SyncedAtColName: req.SyncedAtColName,
},
peerdbCols: peerdbCols,
}
mergeStatement, err := mergeGen.generateMergeStmt()
if err != nil {
Expand Down Expand Up @@ -567,19 +588,10 @@ func (c *SnowflakeConnector) NormalizeRecords(ctx context.Context, req *model.No
}

if err := g.Wait(); err != nil {
return nil, fmt.Errorf("error while normalizing records: %w", err)
return fmt.Errorf("error while normalizing records: %w", err)
}

err = c.pgMetadata.UpdateNormalizeBatchID(ctx, req.FlowJobName, req.SyncBatchID)
if err != nil {
return nil, err
}

return &model.NormalizeResponse{
Done: true,
StartBatchID: normBatchID + 1,
EndBatchID: req.SyncBatchID,
}, nil
return nil
}

func (c *SnowflakeConnector) CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) {
Expand Down

0 comments on commit 6e252f7

Please sign in to comment.