diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 9cefbecc97..1cfadd9727 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -8,6 +8,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/connectors" + connbigquery "github.com/PeerDB-io/peer-flow/connectors/bigquery" connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -731,17 +732,20 @@ func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.Rena } defer connectors.CloseConnector(dstConn) - // check if destination is snowflake, if not error out - if config.Peer.Type != protos.DBType_SNOWFLAKE { - return nil, fmt.Errorf("rename tables is only supported for snowflake") - } - - sfConn, ok := dstConn.(*connsnowflake.SnowflakeConnector) - if !ok { - return nil, fmt.Errorf("failed to cast connector to snowflake connector") + if config.Peer.Type == protos.DBType_SNOWFLAKE { + sfConn, ok := dstConn.(*connsnowflake.SnowflakeConnector) + if !ok { + return nil, fmt.Errorf("failed to cast connector to snowflake connector") + } + return sfConn.RenameTables(config) + } else if config.Peer.Type == protos.DBType_BIGQUERY { + bqConn, ok := dstConn.(*connbigquery.BigQueryConnector) + if !ok { + return nil, fmt.Errorf("failed to cast connector to bigquery connector") + } + return bqConn.RenameTables(config) } - - return sfConn.RenameTables(config) + return nil, fmt.Errorf("rename tables is only supported on snowflake and bigquery") } func (a *FlowableActivity) CreateTablesFromExisting(ctx context.Context, req *protos.CreateTablesFromExistingInput) ( @@ -752,15 +756,18 @@ func (a *FlowableActivity) CreateTablesFromExisting(ctx context.Context, req *pr } defer connectors.CloseConnector(dstConn) - // check if destination is snowflake, if not error out - if req.Peer.Type != protos.DBType_SNOWFLAKE { - return nil, fmt.Errorf("create tables from existing is only supported on snowflake") - } - - sfConn, ok := dstConn.(*connsnowflake.SnowflakeConnector) - if !ok { - return nil, fmt.Errorf("failed to cast connector to snowflake connector") + if req.Peer.Type == protos.DBType_SNOWFLAKE { + sfConn, ok := dstConn.(*connsnowflake.SnowflakeConnector) + if !ok { + return nil, fmt.Errorf("failed to cast connector to snowflake connector") + } + return sfConn.CreateTablesFromExisting(req) + } else if req.Peer.Type == protos.DBType_BIGQUERY { + bqConn, ok := dstConn.(*connbigquery.BigQueryConnector) + if !ok { + return nil, fmt.Errorf("failed to cast connector to bigquery connector") + } + return bqConn.CreateTablesFromExisting(req) } - - return sfConn.CreateTablesFromExisting(req) + return nil, fmt.Errorf("create tables from existing is only supported on snowflake and bigquery") } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 1a3b7fab6d..4f6eb35ab4 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -16,9 +16,9 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" - util "github.com/PeerDB-io/peer-flow/utils" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "go.temporal.io/sdk/activity" "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -479,6 +479,7 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, // separate staging batchID which is random/unique // to handle the case where ingestion into staging passes but raw fails // helps avoid duplicates in the raw table + //nolint:gosec stagingBatchID := rand.Int63() records := make([]StagingBQRecord, 0) tableNameRowsMapping := make(map[string]uint32) @@ -913,7 +914,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) stmts = append(stmts, "BEGIN TRANSACTION;") for _, tableName := range distinctTableNames { - mergeGen := &MergeStmtGenerator{ + mergeGen := &mergeStmtGenerator{ Dataset: c.datasetID, NormalizedTable: tableName, RawTable: rawTableName, @@ -923,7 +924,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) UnchangedToastColumns: tableNametoUnchangedToastCols[tableName], } // normalize anything between last normalized batch id to last sync batchid - mergeStmts := mergeGen.GenerateMergeStmts() + mergeStmts := mergeGen.generateMergeStmts() stmts = append(stmts, mergeStmts...) } //update metadata to make the last normalized batch id to the recent last sync batch id. @@ -1179,181 +1180,62 @@ func (c *BigQueryConnector) truncateTable(tableIdentifier string) error { return nil } -type MergeStmtGenerator struct { - // dataset of all the tables - Dataset string - // the table to merge into - NormalizedTable string - // the table where the data is currently staged. - RawTable string - // last synced batchID. - SyncBatchID int64 - // last normalized batchID. - NormalizeBatchID int64 - // the schema of the table to merge into - NormalizedTableSchema *protos.TableSchema - // array of toast column combinations that are unchanged - UnchangedToastColumns []string -} - -// GenerateMergeStmt generates a merge statements. -func (m *MergeStmtGenerator) GenerateMergeStmts() []string { - // return an empty array for now - flattenedCTE := m.generateFlattenedCTE() - deDupedCTE := m.generateDeDupedCTE() - tempTable := fmt.Sprintf("_peerdb_de_duplicated_data_%s", util.RandomString(5)) - // create temp table stmt - createTempTableStmt := fmt.Sprintf( - "CREATE TEMP TABLE %s AS (%s, %s);", - tempTable, flattenedCTE, deDupedCTE) +func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { + for _, renameRequest := range req.RenameTableOptions { + src := renameRequest.CurrentName + dst := renameRequest.NewName - mergeStmt := m.generateMergeStmt(tempTable) + log.WithFields(log.Fields{ + "flowName": req.FlowJobName, + }).Infof("renaming table '%s' to '%s'...", src, dst) - dropTempTableStmt := fmt.Sprintf("DROP TABLE %s;", tempTable) + activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) - return []string{createTempTableStmt, mergeStmt, dropTempTableStmt} -} - -// generateFlattenedCTE generates a flattened CTE. -func (m *MergeStmtGenerator) generateFlattenedCTE() string { - // for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR - // statement. - flattenedProjs := make([]string, 0) - for colName, colType := range m.NormalizedTableSchema.Columns { - bqType := qValueKindToBigQueryType(colType) - // CAST doesn't work for FLOAT, so rewrite it to FLOAT64. - if bqType == bigquery.FloatFieldType { - bqType = "FLOAT64" + // drop the dst table if exists + _, err := c.client.Query(fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", c.datasetID, dst)).Run(c.ctx) + if err != nil { + return nil, fmt.Errorf("unable to drop table %s: %w", dst, err) } - var castStmt string - - switch qvalue.QValueKind(colType) { - case qvalue.QValueKindJSON: - //if the type is JSON, then just extract JSON - castStmt = fmt.Sprintf("CAST(JSON_EXTRACT(_peerdb_data, '$.%s') AS %s) AS `%s`", - colName, bqType, colName) - // expecting data in BASE64 format - case qvalue.QValueKindBytes, qvalue.QValueKindBit: - castStmt = fmt.Sprintf("FROM_BASE64(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s')) AS `%s`", - colName, colName) - case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64, - qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64, qvalue.QValueKindArrayString: - castStmt = fmt.Sprintf("ARRAY(SELECT CAST(element AS %s) FROM "+ - "UNNEST(CAST(JSON_EXTRACT_ARRAY(_peerdb_data, '$.%s') AS ARRAY)) AS element) AS `%s`", - bqType, colName, colName) - // MAKE_INTERVAL(years INT64, months INT64, days INT64, hours INT64, minutes INT64, seconds INT64) - // Expecting interval to be in the format of {"Microseconds":2000000,"Days":0,"Months":0,"Valid":true} - // json.Marshal in SyncRecords for Postgres already does this - once new data-stores are added, - // this needs to be handled again - // TODO add interval types again - // case model.ColumnTypeInterval: - // castStmt = fmt.Sprintf("MAKE_INTERVAL(0,CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Months') AS INT64),"+ - // "CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Days') AS INT64),0,0,"+ - // "CAST(CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Microseconds') AS INT64)/1000000 AS INT64)) AS %s", - // colName, colName, colName, colName) - // TODO add proper granularity for time types, then restore this - // case model.ColumnTypeTime: - // castStmt = fmt.Sprintf("time(timestamp_micros(CAST(JSON_EXTRACT(_peerdb_data, '$.%s.Microseconds')"+ - // " AS int64))) AS %s", - // colName, colName) - default: - castStmt = fmt.Sprintf("CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s') AS %s) AS `%s`", - colName, bqType, colName) + + // rename the src table to dst + _, err = c.client.Query(fmt.Sprintf("ALTER TABLE %s.%s RENAME TO %s", + c.datasetID, src, dst)).Run(c.ctx) + if err != nil { + return nil, fmt.Errorf("unable to rename table %s to %s: %w", src, dst, err) } - flattenedProjs = append(flattenedProjs, castStmt) - } - flattenedProjs = append(flattenedProjs, "_peerdb_timestamp") - flattenedProjs = append(flattenedProjs, "_peerdb_timestamp_nanos") - flattenedProjs = append(flattenedProjs, "_peerdb_record_type") - flattenedProjs = append(flattenedProjs, "_peerdb_unchanged_toast_columns") - - // normalize anything between last normalized batch id to last sync batchid - return fmt.Sprintf(`WITH _peerdb_flattened AS - (SELECT %s FROM %s.%s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d and - _peerdb_destination_table_name='%s')`, - strings.Join(flattenedProjs, ", "), m.Dataset, m.RawTable, m.NormalizeBatchID, - m.SyncBatchID, m.NormalizedTable) -} -// generateDeDupedCTE generates a de-duped CTE. -func (m *MergeStmtGenerator) generateDeDupedCTE() string { - const cte = `_peerdb_de_duplicated_data_res AS ( - SELECT _peerdb_ranked.* - FROM ( - SELECT RANK() OVER ( - PARTITION BY %s ORDER BY _peerdb_timestamp_nanos DESC - ) as _peerdb_rank, * FROM _peerdb_flattened - ) _peerdb_ranked - WHERE _peerdb_rank = 1 - ) SELECT * FROM _peerdb_de_duplicated_data_res` - pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(m.NormalizedTableSchema.PrimaryKeyColumns, - ", '_peerdb_concat_', ")) - return fmt.Sprintf(cte, pkeyColsStr) -} + log.WithFields(log.Fields{ + "flowName": req.FlowJobName, + }).Infof("successfully renamed table '%s' to '%s'", src, dst) + } -// generateMergeStmt generates a merge statement. -func (m *MergeStmtGenerator) generateMergeStmt(tempTable string) string { - // comma separated list of column names - backtickColNames := make([]string, 0) - pureColNames := make([]string, 0) - for colName := range m.NormalizedTableSchema.Columns { - backtickColNames = append(backtickColNames, fmt.Sprintf("`%s`", colName)) - pureColNames = append(pureColNames, colName) - } - csep := strings.Join(backtickColNames, ", ") - - updateStatementsforToastCols := m.generateUpdateStatements(pureColNames, m.UnchangedToastColumns) - updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - - pkeySelectSQLArray := make([]string, 0, len(m.NormalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range m.NormalizedTableSchema.PrimaryKeyColumns { - pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("_peerdb_target.%s = _peerdb_deduped.%s", - pkeyColName, pkeyColName)) - } - // _peerdb_target. = _peerdb_deduped. AND _peerdb_target. = _peerdb_deduped. ... - pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") - - return fmt.Sprintf(` - MERGE %s.%s _peerdb_target USING %s _peerdb_deduped - ON %s - WHEN NOT MATCHED and (_peerdb_deduped._peerdb_record_type != 2) THEN - INSERT (%s) VALUES (%s) - %s - WHEN MATCHED AND (_peerdb_deduped._peerdb_record_type = 2) THEN - DELETE; - `, m.Dataset, m.NormalizedTable, tempTable, pkeySelectSQL, csep, csep, updateStringToastCols) + return &protos.RenameTablesOutput{ + FlowJobName: req.FlowJobName, + }, nil } -/* -This function takes an array of unique unchanged toast column groups and an array of all column names, -and returns suitable UPDATE statements as part of a MERGE operation. - -Algorithm: -1. Iterate over each unique unchanged toast column group. -2. Split the group into individual column names. -3. Calculate the other columns by finding the set difference between all column names -and the unchanged columns. -4. Generate an update statement for the current group by setting the appropriate conditions -and updating the other columns (not the unchanged toast columns) -5. Append the update statement to the list of generated statements. -6. Repeat steps 1-5 for each unique unchanged toast column group. -7. Return the list of generated update statements. -*/ -func (m *MergeStmtGenerator) generateUpdateStatements(allCols []string, unchangedToastCols []string) []string { - updateStmts := make([]string, 0) - - for _, cols := range unchangedToastCols { - unchangedColsArray := strings.Split(cols, ", ") - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0) - for _, colName := range otherCols { - tmpArray = append(tmpArray, fmt.Sprintf("`%s` = _peerdb_deduped.%s", colName, colName)) +func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFromExistingInput) ( + *protos.CreateTablesFromExistingOutput, error) { + for newTable, existingTable := range req.NewToExistingTableMapping { + log.WithFields(log.Fields{ + "flowName": req.FlowJobName, + }).Infof("creating table '%s' similar to '%s'", newTable, existingTable) + + activity.RecordHeartbeat(c.ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) + + // rename the src table to dst + _, err := c.client.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s LIKE %s.%s", + c.datasetID, newTable, c.datasetID, existingTable)).Run(c.ctx) + if err != nil { + return nil, fmt.Errorf("unable to create table %s: %w", newTable, err) } - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - (_peerdb_deduped._peerdb_record_type != 2) AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) + + log.WithFields(log.Fields{ + "flowName": req.FlowJobName, + }).Infof("successfully created table '%s'", newTable) } - return updateStmts + + return &protos.CreateTablesFromExistingOutput{ + FlowJobName: req.FlowJobName, + }, nil } diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_statement_generator.go new file mode 100644 index 0000000000..1bef877a67 --- /dev/null +++ b/flow/connectors/bigquery/merge_statement_generator.go @@ -0,0 +1,191 @@ +package connbigquery + +import ( + "fmt" + "strings" + + "cloud.google.com/go/bigquery" + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" + util "github.com/PeerDB-io/peer-flow/utils" +) + +type mergeStmtGenerator struct { + // dataset of all the tables + Dataset string + // the table to merge into + NormalizedTable string + // the table where the data is currently staged. + RawTable string + // last synced batchID. + SyncBatchID int64 + // last normalized batchID. + NormalizeBatchID int64 + // the schema of the table to merge into + NormalizedTableSchema *protos.TableSchema + // array of toast column combinations that are unchanged + UnchangedToastColumns []string +} + +// GenerateMergeStmt generates a merge statements. +func (m *mergeStmtGenerator) generateMergeStmts() []string { + // return an empty array for now + flattenedCTE := m.generateFlattenedCTE() + deDupedCTE := m.generateDeDupedCTE() + tempTable := fmt.Sprintf("_peerdb_de_duplicated_data_%s", util.RandomString(5)) + // create temp table stmt + createTempTableStmt := fmt.Sprintf( + "CREATE TEMP TABLE %s AS (%s, %s);", + tempTable, flattenedCTE, deDupedCTE) + + mergeStmt := m.generateMergeStmt(tempTable) + + dropTempTableStmt := fmt.Sprintf("DROP TABLE %s;", tempTable) + + return []string{createTempTableStmt, mergeStmt, dropTempTableStmt} +} + +// generateFlattenedCTE generates a flattened CTE. +func (m *mergeStmtGenerator) generateFlattenedCTE() string { + // for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR + // statement. + flattenedProjs := make([]string, 0) + for colName, colType := range m.NormalizedTableSchema.Columns { + bqType := qValueKindToBigQueryType(colType) + // CAST doesn't work for FLOAT, so rewrite it to FLOAT64. + if bqType == bigquery.FloatFieldType { + bqType = "FLOAT64" + } + var castStmt string + + switch qvalue.QValueKind(colType) { + case qvalue.QValueKindJSON: + //if the type is JSON, then just extract JSON + castStmt = fmt.Sprintf("CAST(JSON_EXTRACT(_peerdb_data, '$.%s') AS %s) AS `%s`", + colName, bqType, colName) + // expecting data in BASE64 format + case qvalue.QValueKindBytes, qvalue.QValueKindBit: + castStmt = fmt.Sprintf("FROM_BASE64(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s')) AS `%s`", + colName, colName) + case qvalue.QValueKindArrayFloat32, qvalue.QValueKindArrayFloat64, + qvalue.QValueKindArrayInt32, qvalue.QValueKindArrayInt64, qvalue.QValueKindArrayString: + castStmt = fmt.Sprintf("ARRAY(SELECT CAST(element AS %s) FROM "+ + "UNNEST(CAST(JSON_EXTRACT_ARRAY(_peerdb_data, '$.%s') AS ARRAY)) AS element) AS `%s`", + bqType, colName, colName) + // MAKE_INTERVAL(years INT64, months INT64, days INT64, hours INT64, minutes INT64, seconds INT64) + // Expecting interval to be in the format of {"Microseconds":2000000,"Days":0,"Months":0,"Valid":true} + // json.Marshal in SyncRecords for Postgres already does this - once new data-stores are added, + // this needs to be handled again + // TODO add interval types again + // case model.ColumnTypeInterval: + // castStmt = fmt.Sprintf("MAKE_INTERVAL(0,CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Months') AS INT64),"+ + // "CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Days') AS INT64),0,0,"+ + // "CAST(CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Microseconds') AS INT64)/1000000 AS INT64)) AS %s", + // colName, colName, colName, colName) + // TODO add proper granularity for time types, then restore this + // case model.ColumnTypeTime: + // castStmt = fmt.Sprintf("time(timestamp_micros(CAST(JSON_EXTRACT(_peerdb_data, '$.%s.Microseconds')"+ + // " AS int64))) AS %s", + // colName, colName) + default: + castStmt = fmt.Sprintf("CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s') AS %s) AS `%s`", + colName, bqType, colName) + } + flattenedProjs = append(flattenedProjs, castStmt) + } + flattenedProjs = append(flattenedProjs, "_peerdb_timestamp") + flattenedProjs = append(flattenedProjs, "_peerdb_timestamp_nanos") + flattenedProjs = append(flattenedProjs, "_peerdb_record_type") + flattenedProjs = append(flattenedProjs, "_peerdb_unchanged_toast_columns") + + // normalize anything between last normalized batch id to last sync batchid + return fmt.Sprintf(`WITH _peerdb_flattened AS + (SELECT %s FROM %s.%s WHERE _peerdb_batch_id > %d and _peerdb_batch_id <= %d and + _peerdb_destination_table_name='%s')`, + strings.Join(flattenedProjs, ", "), m.Dataset, m.RawTable, m.NormalizeBatchID, + m.SyncBatchID, m.NormalizedTable) +} + +// generateDeDupedCTE generates a de-duped CTE. +func (m *mergeStmtGenerator) generateDeDupedCTE() string { + const cte = `_peerdb_de_duplicated_data_res AS ( + SELECT _peerdb_ranked.* + FROM ( + SELECT RANK() OVER ( + PARTITION BY %s ORDER BY _peerdb_timestamp_nanos DESC + ) as _peerdb_rank, * FROM _peerdb_flattened + ) _peerdb_ranked + WHERE _peerdb_rank = 1 + ) SELECT * FROM _peerdb_de_duplicated_data_res` + pkeyColsStr := fmt.Sprintf("(CONCAT(%s))", strings.Join(m.NormalizedTableSchema.PrimaryKeyColumns, + ", '_peerdb_concat_', ")) + return fmt.Sprintf(cte, pkeyColsStr) +} + +// generateMergeStmt generates a merge statement. +func (m *mergeStmtGenerator) generateMergeStmt(tempTable string) string { + // comma separated list of column names + backtickColNames := make([]string, 0) + pureColNames := make([]string, 0) + for colName := range m.NormalizedTableSchema.Columns { + backtickColNames = append(backtickColNames, fmt.Sprintf("`%s`", colName)) + pureColNames = append(pureColNames, colName) + } + csep := strings.Join(backtickColNames, ", ") + + updateStatementsforToastCols := m.generateUpdateStatements(pureColNames, m.UnchangedToastColumns) + updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") + + pkeySelectSQLArray := make([]string, 0, len(m.NormalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range m.NormalizedTableSchema.PrimaryKeyColumns { + pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("_peerdb_target.%s = _peerdb_deduped.%s", + pkeyColName, pkeyColName)) + } + // _peerdb_target. = _peerdb_deduped. AND _peerdb_target. = _peerdb_deduped. ... + pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") + + return fmt.Sprintf(` + MERGE %s.%s _peerdb_target USING %s _peerdb_deduped + ON %s + WHEN NOT MATCHED and (_peerdb_deduped._peerdb_record_type != 2) THEN + INSERT (%s) VALUES (%s) + %s + WHEN MATCHED AND (_peerdb_deduped._peerdb_record_type = 2) THEN + DELETE; + `, m.Dataset, m.NormalizedTable, tempTable, pkeySelectSQL, csep, csep, updateStringToastCols) +} + +/* +This function takes an array of unique unchanged toast column groups and an array of all column names, +and returns suitable UPDATE statements as part of a MERGE operation. + +Algorithm: +1. Iterate over each unique unchanged toast column group. +2. Split the group into individual column names. +3. Calculate the other columns by finding the set difference between all column names +and the unchanged columns. +4. Generate an update statement for the current group by setting the appropriate conditions +and updating the other columns (not the unchanged toast columns) +5. Append the update statement to the list of generated statements. +6. Repeat steps 1-5 for each unique unchanged toast column group. +7. Return the list of generated update statements. +*/ +func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string, unchangedToastCols []string) []string { + updateStmts := make([]string, 0) + + for _, cols := range unchangedToastCols { + unchangedColsArray := strings.Split(cols, ", ") + otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + tmpArray := make([]string, 0) + for _, colName := range otherCols { + tmpArray = append(tmpArray, fmt.Sprintf("`%s` = _peerdb_deduped.%s", colName, colName)) + } + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + (_peerdb_deduped._peerdb_record_type != 2) AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + } + return updateStmts +} diff --git a/flow/connectors/bigquery/merge_stmt_generator_test.go b/flow/connectors/bigquery/merge_stmt_generator_test.go index 3d8892d4c5..41e54114e6 100644 --- a/flow/connectors/bigquery/merge_stmt_generator_test.go +++ b/flow/connectors/bigquery/merge_stmt_generator_test.go @@ -7,7 +7,7 @@ import ( ) func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { - m := &MergeStmtGenerator{} + m := &mergeStmtGenerator{} allCols := []string{"col1", "col2", "col3"} unchangedToastCols := []string{"", "col2, col3", "col2", "col3"} @@ -43,7 +43,7 @@ func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { } func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { - m := &MergeStmtGenerator{} + m := &mergeStmtGenerator{} allCols := []string{"col1", "col2", "col3"} unchangedToastCols := []string{""} diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 4023027949..326b44c109 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -1201,8 +1201,6 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto log.WithFields(log.Fields{ "flowName": req.FlowJobName, }).Infof("successfully renamed table '%s' to '%s'", src, dst) - - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("successfully renamed table '%s' to '%s'", src, dst)) } err = renameTablesTx.Commit()