Skip to content

Commit

Permalink
composite keys for Snowflake and Postgres MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
heavycrystal committed Oct 10, 2023
1 parent d4e0997 commit 03632a1
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 54 deletions.
2 changes: 1 addition & 1 deletion flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ func (m *MergeStmtGenerator) generateDeDupedCTE() string {

// generateMergeStmt generates a merge statement.
func (m *MergeStmtGenerator) generateMergeStmt(tempTable string) string {
pkey := m.NormalizedTableSchema.PrimaryKeyColumn
pkey := m.NormalizedTableSchema.PrimaryKeyColumns[0]

// comma separated list of column names
backtickColNames := make([]string, 0)
Expand Down
1 change: 0 additions & 1 deletion flow/connectors/bigquery/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ func (s *QRepAvroSyncMethod) writeToStage(
return 0, fmt.Errorf("failed to write record to OCF file: %w", err)
}
numRecords++

}
activity.RecordHeartbeat(s.connector.ctx, fmt.Sprintf(
"Writing OCF contents to BigQuery for partition/batch ID %s",
Expand Down
50 changes: 32 additions & 18 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const (
)
MERGE INTO %s dst
USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src
ON dst.%s=src.%s
ON %s
WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN
INSERT (%s) VALUES (%s)
%s
Expand All @@ -69,7 +69,7 @@ const (
RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank
FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3
)
DELETE FROM %s USING src_rank WHERE %s.%s=%s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`
DELETE FROM %s USING src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2`

dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s"
deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1"
Expand Down Expand Up @@ -307,14 +307,19 @@ func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string,
sourceTableSchema *protos.TableSchema) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns))
for columnName, genericColumnType := range sourceTableSchema.Columns {
if sourceTableSchema.PrimaryKeyColumn == strings.ToLower(columnName) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s PRIMARY KEY,",
columnName, qValueKindToPostgresType(genericColumnType)))
} else {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName,
qValueKindToPostgresType(genericColumnType)))
}

// add composite primary key to the table
primaryKeyColsQuoted := make([]string, 0)
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsQuoted = append(primaryKeyColsQuoted,
fmt.Sprintf(`"%s"`, primaryKeyCol))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ",")))

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
}
Expand Down Expand Up @@ -507,13 +512,11 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie
fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema,
rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL,
strings.TrimSuffix(strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), ","), updateColumnsSQL)
strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL)
fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL,
strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema,
strings.Join(maps.Values(primaryKeyColumnCasts), ","), internalSchema,
rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL)

log.Errorln("fallbackUpsertStatement", fallbackUpsertStatement)
log.Errorln("fallbackDeleteStatement", fallbackDeleteStatement)
return []string{fallbackUpsertStatement, fallbackDeleteStatement}
}

Expand All @@ -527,6 +530,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st

flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
primaryKeyColumnCasts := make(map[string]string)
primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
pgType := qValueKindToPostgresType(genericColumnType)
if strings.Contains(genericColumnType, "array") {
Expand All @@ -537,15 +541,25 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"",
strings.Trim(columnName, "\""), pgType, columnName))
}
if normalizedTableSchema.PrimaryKeyColumn == columnName {
primaryKeyColumnCast = fmt.Sprintf("(_peerdb_data->>'%s')::%s", strings.Trim(columnName, "\""), pgType)
if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) {
primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType)
primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s",
columnName, columnName))
}
}
flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",")

return fmt.Sprintf(mergeStatementSQL, primaryKeyColumnCast, internalSchema, rawTableIdentifier,
destinationTableIdentifier, flattenedCastsSQL, normalizedTableSchema.PrimaryKeyColumn,
normalizedTableSchema.PrimaryKeyColumn, insertColumnsSQL, insertValuesSQL, updateStatements)
insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",")
insertValuesSQLArray := make([]string, 0, len(columnNames))
for _, columnName := range columnNames {
insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName))
}
insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",")
updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns)

return fmt.Sprintf(mergeStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","),
internalSchema, rawTableIdentifier, destinationTableIdentifier, flattenedCastsSQL,
strings.Join(primaryKeySelectSQLArray, " AND "), insertColumnsSQL, insertValuesSQL, updateStatements)
}

func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedToastColsLists []string) string {
Expand Down
2 changes: 1 addition & 1 deletion flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func (c *PostgresConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr
_, err = createRawTableTx.Exec(c.ctx, fmt.Sprintf(createRawTableDstTableIndexSQL, rawTableIdentifier,
internalSchema, rawTableIdentifier))
if err != nil {
return nil, fmt.Errorf("error creating batch ID index on raw table: %w", err)
return nil, fmt.Errorf("error creating destion table index on raw table: %w", err)
}

err = createRawTableTx.Commit(c.ctx)
Expand Down
10 changes: 5 additions & 5 deletions flow/connectors/postgres/postgres_cdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ func (suite *PostgresCDCTestSuite) TestSimpleHappyFlow() {
"id": string(qvalue.QValueKindInt32),
"name": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
},
}}, tableNameSchema)
tableNameSchemaMapping[simpleHappyFlowDstTableName] =
Expand Down Expand Up @@ -666,7 +666,7 @@ func (suite *PostgresCDCTestSuite) TestAllTypesHappyFlow() {
"c40": string(qvalue.QValueKindUUID),
"c41": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
},
},
}, tableNameSchema)
Expand Down Expand Up @@ -765,14 +765,14 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() {
"n_b": string(qvalue.QValueKindBytes),
"lz4_b": string(qvalue.QValueKindBytes),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
},
}}, tableNameSchema)
tableNameSchemaMapping[toastHappyFlowDstTableName] =
tableNameSchema.TableNameSchemaMapping[toastHappyFlowSrcTableName]

suite.insertToastRecords(toastHappyFlowSrcTableName)
recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{
_, err = suite.connector.PullRecords(&model.PullRecordsRequest{
FlowJobName: toastHappyFlowName,
LastSyncState: nil,
IdleTimeout: 10 * time.Second,
Expand All @@ -783,7 +783,7 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() {
RelationMessageMapping: relationMessageMapping,
})
suite.failTestError(err)
recordsWithSchemaDelta, err = suite.connector.PullRecords(&model.PullRecordsRequest{
recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{
FlowJobName: toastHappyFlowName,
LastSyncState: nil,
IdleTimeout: 10 * time.Second,
Expand Down
21 changes: 10 additions & 11 deletions flow/connectors/postgres/postgres_schema_delta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() {
"id": string(qvalue.QValueKindInt32),
"hi": string(qvalue.QValueKindInt64),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}

Expand All @@ -125,7 +125,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleDropColumn() {
Columns: map[string]string{
"id": string(qvalue.QValueKindInt32),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}

Expand Down Expand Up @@ -156,7 +156,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddDropColumn() {
"id": string(qvalue.QValueKindInt32),
"hi": string(qvalue.QValueKindInt64),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}

Expand Down Expand Up @@ -187,7 +187,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropSameColumn() {
"id": string(qvalue.QValueKindInt32),
"bye": string(qvalue.QValueKindJSON),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}

Expand Down Expand Up @@ -219,7 +219,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropAllColumnTypes() {
"c15": string(qvalue.QValueKindTimestampTZ),
"c16": string(qvalue.QValueKindUUID),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
Expand Down Expand Up @@ -270,7 +270,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropAllColumnTypes() {
Columns: map[string]string{
"id": string(qvalue.QValueKindInt32),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}

Expand All @@ -294,7 +294,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropTrickyColumnNames() {
"±ªþ³§": string(qvalue.QValueKindString),
"カラム": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
Expand Down Expand Up @@ -342,7 +342,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropTrickyColumnNames() {
Columns: map[string]string{
"id": string(qvalue.QValueKindInt32),
},
PrimaryKeyColumn: "id",
PrimaryKeyColumns: []string{"id"},
}, output.TableNameSchemaMapping[tableName])
}

Expand All @@ -360,7 +360,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
" ": string(qvalue.QValueKindInt64),
" ": string(qvalue.QValueKindDate),
},
PrimaryKeyColumn: " ",
PrimaryKeyColumns: []string{" "},
}
addedColumns := make([]*protos.DeltaAddedColumn, 0)
for columnName, columnType := range expectedTableSchema.Columns {
Expand All @@ -372,7 +372,6 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
}
}

fmt.Println(addedColumns)
err = suite.connector.ReplayTableSchemaDelta("schema_delta_flow", &protos.TableSchemaDelta{
SrcTableName: tableName,
DstTableName: tableName,
Expand Down Expand Up @@ -409,7 +408,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() {
Columns: map[string]string{
" ": string(qvalue.QValueKindInt32),
},
PrimaryKeyColumn: " ",
PrimaryKeyColumns: []string{" "},
}, output.TableNameSchemaMapping[tableName])
}

Expand Down
39 changes: 24 additions & 15 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,8 @@ func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*p
if err != nil {
return nil, err
}
// there is no easy way to check if a table has the same schema in Snowflake, so just executing the CREATE TABLE IF NOT EXISTS blindly.
// there is no easy way to check if a table has the same schema in Snowflake,
// so just executing the CREATE TABLE IF NOT EXISTS blindly.
_, err = createRawTableTx.ExecContext(c.ctx,
fmt.Sprintf(createRawTableSQL, peerDBInternalSchema, rawTableIdentifier))
if err != nil {
Expand Down Expand Up @@ -1060,23 +1061,26 @@ func generateCreateTableSQLForNormalizedTable(
sourceTableSchema *protos.TableSchema,
) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns))
primaryColUpper := strings.ToUpper(sourceTableSchema.PrimaryKeyColumns[0])
for columnName, genericColumnType := range sourceTableSchema.Columns {
columnNameUpper := strings.ToUpper(columnName)
if primaryColUpper == columnNameUpper {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s PRIMARY KEY,`,
columnNameUpper, qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))))
} else {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper,
qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper,
qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType))))
}

// add a _peerdb_is_deleted column to the normalized table
// this is boolean default false, and is used to mark records as deleted
createTableSQLArray = append(createTableSQLArray,
fmt.Sprintf(`"%s" BOOLEAN DEFAULT FALSE,`, isDeletedColumnName))

// add composite primary key to the table
primaryKeyColsUpperQuoted := make([]string, 0)
for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns {
primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted,
fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol)))
}
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),",
strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ",")))

return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ","))
}
Expand Down Expand Up @@ -1159,9 +1163,13 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(
updateStatementsforToastCols := c.generateUpdateStatement(columnNames, unchangedToastColumns)
updateStringToastCols := strings.Join(updateStatementsforToastCols, " ")

// TARGET.<pkey> = SOURCE.<pkey>
pkeyColStr := fmt.Sprintf("TARGET.%s = SOURCE.%s",
normalizedTableSchema.PrimaryKeyColumns[0], normalizedTableSchema.PrimaryKeyColumns[0])
pkeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns))
for _, pkeyColName := range normalizedTableSchema.PrimaryKeyColumns {
pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s",
pkeyColName, pkeyColName))
}
// TARGET.<pkey1> = SOURCE.<pkey1> AND TARGET.<pkey2> = SOURCE.<pkey2> ...
pkeyColStr := strings.Join(pkeySelectSQLArray, " AND ")

deletePart := "DELETE"
if softDelete {
Expand All @@ -1170,8 +1178,8 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(

mergeStatement := fmt.Sprintf(mergeStatementSQL, destinationTableIdentifier, toVariantColumnName,
rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL,
normalizedTableSchema.PrimaryKeyColumn, pkeyColStr, insertColumnsSQL, insertValuesSQL,
updateStringToastCols, deletePart)
fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")),
pkeyColStr, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart)

result, err := normalizeRecordsTx.ExecContext(c.ctx, mergeStatement, destinationTableIdentifier)
if err != nil {
Expand Down Expand Up @@ -1237,7 +1245,8 @@ func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64
return nil
}

func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64, normalizeRecordsTx *sql.Tx) error {
func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64,
normalizeRecordsTx *sql.Tx) error {
jobMetadataExists, err := c.jobMetadataExists(flowJobName)
if err != nil {
return fmt.Errorf("failed to get sync status for flow job: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions nexus/analyzer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ impl<'a> StatementAnalyzer for PeerDDLAnalyzer<'a> {
match create_mirror {
CDC(cdc) => {
let mut flow_job_table_mappings = vec![];
for table_mapping in &cdc.mappings {
for table_mapping in &cdc.table_mappings {
flow_job_table_mappings.push(FlowJobTableMapping {
source_table_identifier: table_mapping
.source_table_identifier
.to_string()
.to_lowercase(),
target_table_identifier: table_mapping
.target_identifier
.target_table_identifier
.to_string()
.to_lowercase(),
});
Expand Down

0 comments on commit 03632a1

Please sign in to comment.