diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 3c47f4462b..0794db478c 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -1301,7 +1301,7 @@ func (m *MergeStmtGenerator) generateDeDupedCTE() string { // generateMergeStmt generates a merge statement. func (m *MergeStmtGenerator) generateMergeStmt(tempTable string) string { - pkey := m.NormalizedTableSchema.PrimaryKeyColumn + pkey := m.NormalizedTableSchema.PrimaryKeyColumns[0] // comma separated list of column names backtickColNames := make([]string, 0) diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index a4cd92a82c..db9c6ef6c0 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -387,7 +387,6 @@ func (s *QRepAvroSyncMethod) writeToStage( return 0, fmt.Errorf("failed to write record to OCF file: %w", err) } numRecords++ - } activity.RecordHeartbeat(s.connector.ctx, fmt.Sprintf( "Writing OCF contents to BigQuery for partition/batch ID %s", diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 34d0926fea..e9eb565840 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -51,7 +51,7 @@ const ( ) MERGE INTO %s dst USING (SELECT %s,_peerdb_record_type,_peerdb_unchanged_toast_columns FROM src_rank WHERE _peerdb_rank=1) src - ON dst.%s=src.%s + ON %s WHEN NOT MATCHED AND src._peerdb_record_type!=2 THEN INSERT (%s) VALUES (%s) %s @@ -69,7 +69,7 @@ const ( RANK() OVER (PARTITION BY %s ORDER BY _peerdb_timestamp DESC) AS _peerdb_rank FROM %s.%s WHERE _peerdb_batch_id>$1 AND _peerdb_batch_id<=$2 AND _peerdb_destination_table_name=$3 ) - DELETE FROM %s USING src_rank WHERE %s.%s=%s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` + DELETE FROM %s USING src_rank WHERE %s AND src_rank._peerdb_rank=1 AND src_rank._peerdb_record_type=2` dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=$1" @@ -307,14 +307,19 @@ func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string, sourceTableSchema *protos.TableSchema) string { createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)) for columnName, genericColumnType := range sourceTableSchema.Columns { - if sourceTableSchema.PrimaryKeyColumn == strings.ToLower(columnName) { - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s PRIMARY KEY,", - columnName, qValueKindToPostgresType(genericColumnType))) - } else { - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName, - qValueKindToPostgresType(genericColumnType))) - } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("\"%s\" %s,", columnName, + qValueKindToPostgresType(genericColumnType))) + } + + // add composite primary key to the table + primaryKeyColsQuoted := make([]string, 0) + for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { + primaryKeyColsQuoted = append(primaryKeyColsQuoted, + fmt.Sprintf(`"%s"`, primaryKeyCol)) } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", + strings.TrimSuffix(strings.Join(primaryKeyColsQuoted, ","), ","))) + return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")) } @@ -507,13 +512,11 @@ func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifie fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema, rawTableIdentifier, destinationTableIdentifier, insertColumnsSQL, flattenedCastsSQL, - strings.TrimSuffix(strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), ","), updateColumnsSQL) + strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, - strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), internalSchema, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), internalSchema, rawTableIdentifier, destinationTableIdentifier, deleteWhereClauseSQL) - log.Errorln("fallbackUpsertStatement", fallbackUpsertStatement) - log.Errorln("fallbackDeleteStatement", fallbackDeleteStatement) return []string{fallbackUpsertStatement, fallbackDeleteStatement} } @@ -527,6 +530,7 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns)) primaryKeyColumnCasts := make(map[string]string) + primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) for columnName, genericColumnType := range normalizedTableSchema.Columns { pgType := qValueKindToPostgresType(genericColumnType) if strings.Contains(genericColumnType, "array") { @@ -537,15 +541,25 @@ func (c *PostgresConnector) generateMergeStatement(destinationTableIdentifier st flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", strings.Trim(columnName, "\""), pgType, columnName)) } - if normalizedTableSchema.PrimaryKeyColumn == columnName { - primaryKeyColumnCast = fmt.Sprintf("(_peerdb_data->>'%s')::%s", strings.Trim(columnName, "\""), pgType) + if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", + columnName, columnName)) } } flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - return fmt.Sprintf(mergeStatementSQL, primaryKeyColumnCast, internalSchema, rawTableIdentifier, - destinationTableIdentifier, flattenedCastsSQL, normalizedTableSchema.PrimaryKeyColumn, - normalizedTableSchema.PrimaryKeyColumn, insertColumnsSQL, insertValuesSQL, updateStatements) + insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") + insertValuesSQLArray := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) + } + insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") + updateStatements := c.generateUpdateStatement(columnNames, unchangedToastColumns) + + return fmt.Sprintf(mergeStatementSQL, strings.Join(maps.Values(primaryKeyColumnCasts), ","), + internalSchema, rawTableIdentifier, destinationTableIdentifier, flattenedCastsSQL, + strings.Join(primaryKeySelectSQLArray, " AND "), insertColumnsSQL, insertValuesSQL, updateStatements) } func (c *PostgresConnector) generateUpdateStatement(allCols []string, unchangedToastColsLists []string) string { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 4436365cec..b43d51ff77 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -533,7 +533,7 @@ func (c *PostgresConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr _, err = createRawTableTx.Exec(c.ctx, fmt.Sprintf(createRawTableDstTableIndexSQL, rawTableIdentifier, internalSchema, rawTableIdentifier)) if err != nil { - return nil, fmt.Errorf("error creating batch ID index on raw table: %w", err) + return nil, fmt.Errorf("error creating destion table index on raw table: %w", err) } err = createRawTableTx.Commit(c.ctx) diff --git a/flow/connectors/postgres/postgres_cdc_test.go b/flow/connectors/postgres/postgres_cdc_test.go index d7db24fef7..46a3e78464 100644 --- a/flow/connectors/postgres/postgres_cdc_test.go +++ b/flow/connectors/postgres/postgres_cdc_test.go @@ -505,7 +505,7 @@ func (suite *PostgresCDCTestSuite) TestSimpleHappyFlow() { "id": string(qvalue.QValueKindInt32), "name": string(qvalue.QValueKindString), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, }}, tableNameSchema) tableNameSchemaMapping[simpleHappyFlowDstTableName] = @@ -666,7 +666,7 @@ func (suite *PostgresCDCTestSuite) TestAllTypesHappyFlow() { "c40": string(qvalue.QValueKindUUID), "c41": string(qvalue.QValueKindString), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, }, }, tableNameSchema) @@ -765,14 +765,14 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() { "n_b": string(qvalue.QValueKindBytes), "lz4_b": string(qvalue.QValueKindBytes), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, }}, tableNameSchema) tableNameSchemaMapping[toastHappyFlowDstTableName] = tableNameSchema.TableNameSchemaMapping[toastHappyFlowSrcTableName] suite.insertToastRecords(toastHappyFlowSrcTableName) - recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{ + _, err = suite.connector.PullRecords(&model.PullRecordsRequest{ FlowJobName: toastHappyFlowName, LastSyncState: nil, IdleTimeout: 10 * time.Second, @@ -783,7 +783,7 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() { RelationMessageMapping: relationMessageMapping, }) suite.failTestError(err) - recordsWithSchemaDelta, err = suite.connector.PullRecords(&model.PullRecordsRequest{ + recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{ FlowJobName: toastHappyFlowName, LastSyncState: nil, IdleTimeout: 10 * time.Second, diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 243f597864..6d53533378 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -99,7 +99,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { "id": string(qvalue.QValueKindInt32), "hi": string(qvalue.QValueKindInt64), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, output.TableNameSchemaMapping[tableName]) } @@ -125,7 +125,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleDropColumn() { Columns: map[string]string{ "id": string(qvalue.QValueKindInt32), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, output.TableNameSchemaMapping[tableName]) } @@ -156,7 +156,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestSimpleAddDropColumn() { "id": string(qvalue.QValueKindInt32), "hi": string(qvalue.QValueKindInt64), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, output.TableNameSchemaMapping[tableName]) } @@ -187,7 +187,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropSameColumn() { "id": string(qvalue.QValueKindInt32), "bye": string(qvalue.QValueKindJSON), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, output.TableNameSchemaMapping[tableName]) } @@ -219,7 +219,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropAllColumnTypes() { "c15": string(qvalue.QValueKindTimestampTZ), "c16": string(qvalue.QValueKindUUID), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) for columnName, columnType := range expectedTableSchema.Columns { @@ -270,7 +270,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropAllColumnTypes() { Columns: map[string]string{ "id": string(qvalue.QValueKindInt32), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, output.TableNameSchemaMapping[tableName]) } @@ -294,7 +294,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropTrickyColumnNames() { "±ªþ³§": string(qvalue.QValueKindString), "カラム": string(qvalue.QValueKindString), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) for columnName, columnType := range expectedTableSchema.Columns { @@ -342,7 +342,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropTrickyColumnNames() { Columns: map[string]string{ "id": string(qvalue.QValueKindInt32), }, - PrimaryKeyColumn: "id", + PrimaryKeyColumns: []string{"id"}, }, output.TableNameSchemaMapping[tableName]) } @@ -360,7 +360,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { " ": string(qvalue.QValueKindInt64), " ": string(qvalue.QValueKindDate), }, - PrimaryKeyColumn: " ", + PrimaryKeyColumns: []string{" "}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) for columnName, columnType := range expectedTableSchema.Columns { @@ -372,7 +372,6 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { } } - fmt.Println(addedColumns) err = suite.connector.ReplayTableSchemaDelta("schema_delta_flow", &protos.TableSchemaDelta{ SrcTableName: tableName, DstTableName: tableName, @@ -409,7 +408,7 @@ func (suite *PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { Columns: map[string]string{ " ": string(qvalue.QValueKindInt32), }, - PrimaryKeyColumn: " ", + PrimaryKeyColumns: []string{" "}, }, output.TableNameSchemaMapping[tableName]) } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index bef026dd4e..c78b32a3d7 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -968,7 +968,8 @@ func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*p if err != nil { return nil, err } - // there is no easy way to check if a table has the same schema in Snowflake, so just executing the CREATE TABLE IF NOT EXISTS blindly. + // there is no easy way to check if a table has the same schema in Snowflake, + // so just executing the CREATE TABLE IF NOT EXISTS blindly. _, err = createRawTableTx.ExecContext(c.ctx, fmt.Sprintf(createRawTableSQL, peerDBInternalSchema, rawTableIdentifier)) if err != nil { @@ -1060,16 +1061,10 @@ func generateCreateTableSQLForNormalizedTable( sourceTableSchema *protos.TableSchema, ) string { createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)) - primaryColUpper := strings.ToUpper(sourceTableSchema.PrimaryKeyColumns[0]) for columnName, genericColumnType := range sourceTableSchema.Columns { columnNameUpper := strings.ToUpper(columnName) - if primaryColUpper == columnNameUpper { - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s PRIMARY KEY,`, - columnNameUpper, qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)))) - } else { - createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, - qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)))) - } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" %s,`, columnNameUpper, + qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)))) } // add a _peerdb_is_deleted column to the normalized table @@ -1077,6 +1072,15 @@ func generateCreateTableSQLForNormalizedTable( createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`"%s" BOOLEAN DEFAULT FALSE,`, isDeletedColumnName)) + // add composite primary key to the table + primaryKeyColsUpperQuoted := make([]string, 0) + for _, primaryKeyCol := range sourceTableSchema.PrimaryKeyColumns { + primaryKeyColsUpperQuoted = append(primaryKeyColsUpperQuoted, + fmt.Sprintf(`"%s"`, strings.ToUpper(primaryKeyCol))) + } + createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("PRIMARY KEY(%s),", + strings.TrimSuffix(strings.Join(primaryKeyColsUpperQuoted, ","), ","))) + return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.TrimSuffix(strings.Join(createTableSQLArray, ""), ",")) } @@ -1159,9 +1163,13 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( updateStatementsforToastCols := c.generateUpdateStatement(columnNames, unchangedToastColumns) updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - // TARGET. = SOURCE. - pkeyColStr := fmt.Sprintf("TARGET.%s = SOURCE.%s", - normalizedTableSchema.PrimaryKeyColumns[0], normalizedTableSchema.PrimaryKeyColumns[0]) + pkeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range normalizedTableSchema.PrimaryKeyColumns { + pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s", + pkeyColName, pkeyColName)) + } + // TARGET. = SOURCE. AND TARGET. = SOURCE. ... + pkeyColStr := strings.Join(pkeySelectSQLArray, " AND ") deletePart := "DELETE" if softDelete { @@ -1170,8 +1178,8 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement( mergeStatement := fmt.Sprintf(mergeStatementSQL, destinationTableIdentifier, toVariantColumnName, rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL, - normalizedTableSchema.PrimaryKeyColumn, pkeyColStr, insertColumnsSQL, insertValuesSQL, - updateStringToastCols, deletePart) + fmt.Sprintf("(%s)", strings.Join(normalizedTableSchema.PrimaryKeyColumns, ",")), + pkeyColStr, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) result, err := normalizeRecordsTx.ExecContext(c.ctx, mergeStatement, destinationTableIdentifier) if err != nil { @@ -1237,7 +1245,8 @@ func (c *SnowflakeConnector) updateSyncMetadata(flowJobName string, lastCP int64 return nil } -func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64, normalizeRecordsTx *sql.Tx) error { +func (c *SnowflakeConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64, + normalizeRecordsTx *sql.Tx) error { jobMetadataExists, err := c.jobMetadataExists(flowJobName) if err != nil { return fmt.Errorf("failed to get sync status for flow job: %w", err) diff --git a/nexus/analyzer/src/lib.rs b/nexus/analyzer/src/lib.rs index a73d02a687..b9d01e1739 100644 --- a/nexus/analyzer/src/lib.rs +++ b/nexus/analyzer/src/lib.rs @@ -152,14 +152,14 @@ impl<'a> StatementAnalyzer for PeerDDLAnalyzer<'a> { match create_mirror { CDC(cdc) => { let mut flow_job_table_mappings = vec![]; - for table_mapping in &cdc.mappings { + for table_mapping in &cdc.table_mappings { flow_job_table_mappings.push(FlowJobTableMapping { source_table_identifier: table_mapping .source_table_identifier .to_string() .to_lowercase(), target_table_identifier: table_mapping - .target_identifier + .target_table_identifier .to_string() .to_lowercase(), });