From 25bab3783a2b75d4e550b32f554325a1cf54795c Mon Sep 17 00:00:00 2001 From: Kevin Biju Date: Tue, 9 Jan 2024 02:13:48 +0530 Subject: [PATCH] refactor to PG as well, fixed review comments --- .../bigquery/merge_stmt_generator.go | 13 +- .../bigquery/merge_stmt_generator_test.go | 156 +++++++++--- flow/connectors/postgres/client.go | 212 ---------------- .../postgres/normalize_stmt_generator.go | 235 ++++++++++++++++++ .../postgres/normalize_stmt_generator_test.go | 146 +++++++++++ flow/connectors/postgres/postgres.go | 20 +- .../snowflake/merge_stmt_generator.go | 59 +++-- .../snowflake/merge_stmt_generator_test.go | 55 ++-- flow/connectors/utils/identifiers.go | 7 + 9 files changed, 601 insertions(+), 302 deletions(-) create mode 100644 flow/connectors/postgres/normalize_stmt_generator.go create mode 100644 flow/connectors/postgres/normalize_stmt_generator_test.go diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go index 4efd0c9c64..b2ac9c1942 100644 --- a/flow/connectors/bigquery/merge_stmt_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -196,7 +196,14 @@ and updating the other columns (not the unchanged toast columns) 7. Return the list of generated update statements. */ func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { - updateStmts := make([]string, 0, len(m.unchangedToastColumns)) + handleSoftDelete := m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(m.unchangedToastColumns) + } + return len(m.unchangedToastColumns) + }()) for _, cols := range m.unchangedToastColumns { unchangedColsArray := strings.Split(cols, ",") @@ -212,7 +219,7 @@ func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string m.peerdbCols.SyncedAtColName)) } // set soft-deleted to false, tackles insert after soft-delete - if m.peerdbCols.SoftDeleteColName != "" { + if handleSoftDelete { tmpArray = append(tmpArray, fmt.Sprintf("`%s`=FALSE", m.peerdbCols.SoftDeleteColName)) } @@ -226,7 +233,7 @@ func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string // generates update statements for the case where updates and deletes happen in the same branch // the backfill has happened from the pull side already, so treat the DeleteRecord as an update // and then set soft-delete to true. - if m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") { + if handleSoftDelete { tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf("`%s`=TRUE", m.peerdbCols.SoftDeleteColName)) ssep := strings.Join(tmpArray, ",") diff --git a/flow/connectors/bigquery/merge_stmt_generator_test.go b/flow/connectors/bigquery/merge_stmt_generator_test.go index 1857b41d76..cc49b17cbd 100644 --- a/flow/connectors/bigquery/merge_stmt_generator_test.go +++ b/flow/connectors/bigquery/merge_stmt_generator_test.go @@ -2,70 +2,52 @@ package connbigquery import ( "reflect" - "strings" "testing" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) -func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { +func TestGenerateUpdateStatement(t *testing.T) { allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + unchangedToastCols := []string{""} m := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, shortColumn: map[string]string{ "col1": "_c0", "col2": "_c1", "col3": "_c2", }, - unchangedToastColumns: unchangedToastCols, peerdbCols: &protos.PeerDBColumns{ - SoftDelete: true, + SoftDelete: false, SoftDeleteColName: "deleted", SyncedAtColName: "synced_at", }, } expected := []string{ - "WHEN MATCHED AND _rt!=2 AND _ut=''" + - " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", - "WHEN MATCHED AND _rt=2 " + - "AND _ut='' " + - "THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1," + - "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", - "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + - "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", - "WHEN MATCHED AND _rt=2 AND _ut='col2,col3' " + - "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", "WHEN MATCHED AND _rt!=2 " + - "AND _ut='col2' " + - "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", - "WHEN MATCHED AND _rt=2 " + - "AND _ut='col2' " + - "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE ", - "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + - "THEN UPDATE SET `col1`=_d._c0," + - "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", - "WHEN MATCHED AND _rt=2 AND _ut='col3' " + - "THEN UPDATE SET `col1`=_d._c0," + - "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "AND _ut=''" + + "THEN UPDATE SET " + + "`col1`=_d._c0," + + "`col2`=_d._c1," + + "`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", } result := m.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { - t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) } } -func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { +func TestGenerateUpdateStatement_WithSoftDelete(t *testing.T) { allCols := []string{"col1", "col2", "col3"} unchangedToastCols := []string{""} m := &mergeStmtGenerator{ @@ -100,8 +82,8 @@ func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { result := m.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { @@ -109,9 +91,103 @@ func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { } } -func removeSpacesTabsNewlines(s string) string { - s = strings.ReplaceAll(s, " ", "") - s = strings.ReplaceAll(s, "\t", "") - s = strings.ReplaceAll(s, "\n", "") - return s +func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + m := &mergeStmtGenerator{ + shortColumn: map[string]string{ + "col1": "_c0", + "col2": "_c1", + "col3": "_c2", + }, + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, + } + + expected := []string{ + "WHEN MATCHED AND _rt!=2 AND _ut=''" + + " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP", + } + + result := m.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + } +} + +func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + m := &mergeStmtGenerator{ + shortColumn: map[string]string{ + "col1": "_c0", + "col2": "_c1", + "col3": "_c2", + }, + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, + } + + expected := []string{ + "WHEN MATCHED AND _rt!=2 AND _ut=''" + + " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", + "WHEN MATCHED AND _rt=2 " + + "AND _ut='' " + + "THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1," + + "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", + "WHEN MATCHED AND _rt=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "WHEN MATCHED AND _rt!=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", + "WHEN MATCHED AND _rt=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE ", + "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", + "WHEN MATCHED AND _rt=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + } + + result := m.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + } } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index b51741aff6..0a99bce668 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -5,18 +5,15 @@ import ( "fmt" "log" "regexp" - "slices" "strings" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" - "golang.org/x/exp/maps" ) const ( @@ -599,215 +596,6 @@ func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, sync return resultMap, nil } -func (c *PostgresConnector) generateNormalizeStatements(destinationTableIdentifier string, - unchangedToastColumns []string, rawTableIdentifier string, supportsMerge bool, - peerdbCols *protos.PeerDBColumns, -) []string { - if supportsMerge { - return []string{c.generateMergeStatement(destinationTableIdentifier, unchangedToastColumns, - rawTableIdentifier, peerdbCols)} - } - c.logger.Warn("Postgres version is not high enough to support MERGE, falling back to UPSERT + DELETE") - c.logger.Warn("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres") - return c.generateFallbackStatements(destinationTableIdentifier, rawTableIdentifier, peerdbCols) -} - -func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifier string, - rawTableIdentifier string, peerdbCols *protos.PeerDBColumns, -) []string { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - columnCount := utils.TableSchemaColumns(normalizedTableSchema) - columnNames := make([]string, 0, columnCount) - flattenedCastsSQLArray := make([]string, 0, columnCount) - primaryKeyColumnCasts := make(map[string]string) - utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) { - columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) - pgType := qValueKindToPostgresType(genericColumnType) - if qvalue.QValueKind(genericColumnType).IsArray() { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } - if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) - } - }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) - - insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") - updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - utils.IterColumns(normalizedTableSchema, func(columnName, _ string) { - updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) - }) - updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") - deleteWhereClauseArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for columnName, columnCast := range primaryKeyColumnCasts { - deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, - parsedDstTable.String(), columnName, columnCast)) - } - deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") - deletePart := fmt.Sprintf( - "DELETE FROM %s USING", - parsedDstTable.String()) - - if peerdbCols.SoftDelete { - deletePart = fmt.Sprintf(`UPDATE %s SET "%s" = TRUE`, - parsedDstTable.String(), peerdbCols.SoftDeleteColName) - if peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s, "%s" = CURRENT_TIMESTAMP`, - deletePart, peerdbCols.SyncedAtColName) - } - deletePart += " FROM" - } - fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, - strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), c.metadataSchema, - rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, - strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) - fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, - strings.Join(maps.Values(primaryKeyColumnCasts), ","), c.metadataSchema, - rawTableIdentifier, deletePart, deleteWhereClauseSQL) - - return []string{fallbackUpsertStatement, fallbackDeleteStatement} -} - -func (c *PostgresConnector) generateMergeStatement( - destinationTableIdentifier string, - unchangedToastColumns []string, - rawTableIdentifier string, - peerdbCols *protos.PeerDBColumns, -) string { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - columnNames := utils.TableSchemaColumnNames(normalizedTableSchema) - for i, columnName := range columnNames { - columnNames[i] = fmt.Sprintf("\"%s\"", columnName) - } - - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) - - primaryKeyColumnCasts := make(map[string]string) - primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) { - pgType := qValueKindToPostgresType(genericColumnType) - if qvalue.QValueKind(genericColumnType).IsArray() { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } - if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) - primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", - columnName, columnName)) - } - }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) - } - - updateStatementsforToastCols := c.generateUpdateStatement(columnNames, unchangedToastColumns, peerdbCols) - // append synced_at column - columnNames = append(columnNames, fmt.Sprintf(`"%s"`, peerdbCols.SyncedAtColName)) - insertColumnsSQL := strings.Join(columnNames, ",") - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") - - if peerdbCols.SoftDelete { - softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, - fmt.Sprintf(`"%s"`, peerdbCols.SoftDeleteColName)), ","), ",") - softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") - - updateStatementsforToastCols = append(updateStatementsforToastCols, - fmt.Sprintf("WHEN NOT MATCHED AND (src._peerdb_record_type = 2) THEN INSERT (%s) VALUES(%s)", - softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) - } - updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") - - deletePart := "DELETE" - if peerdbCols.SoftDelete { - colName := peerdbCols.SoftDeleteColName - deletePart = fmt.Sprintf(`UPDATE SET "%s" = TRUE`, colName) - if peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s, "%s" = CURRENT_TIMESTAMP`, - deletePart, peerdbCols.SyncedAtColName) - } - } - - mergeStmt := fmt.Sprintf( - mergeStatementSQL, - strings.Join(maps.Values(primaryKeyColumnCasts), ","), - c.metadataSchema, - rawTableIdentifier, - parsedDstTable.String(), - flattenedCastsSQL, - strings.Join(primaryKeySelectSQLArray, " AND "), - insertColumnsSQL, - insertValuesSQL, - updateStringToastCols, - deletePart, - ) - - return mergeStmt -} - -func (c *PostgresConnector) generateUpdateStatement(allCols []string, - unchangedToastColsLists []string, peerdbCols *protos.PeerDBColumns, -) []string { - updateStmts := make([]string, 0, len(unchangedToastColsLists)) - - for _, cols := range unchangedToastColsLists { - unquotedUnchangedColsArray := strings.Split(cols, ",") - unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) - for _, unchangedToastCol := range unquotedUnchangedColsArray { - unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) - } - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0, len(otherCols)) - for _, colName := range otherCols { - tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) - } - // set the synced at column to the current timestamp - if peerdbCols.SyncedAtColName != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, - peerdbCols.SyncedAtColName)) - } - // set soft-deleted to false, tackles insert after soft-delete - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, - peerdbCols.SoftDeleteColName)) - } - - ssep := strings.Join(tmpArray, ",") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - - // generates update statements for the case where updates and deletes happen in the same branch - // the backfill has happened from the pull side already, so treat the DeleteRecord as an update - // and then set soft-delete to true. - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { - tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf(`"%s" = TRUE`, peerdbCols.SoftDeleteColName)) - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - } - } - return updateStmts -} - func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { row := c.pool.QueryRow(c.ctx, "SELECT CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END") diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go new file mode 100644 index 0000000000..e3af2f48f0 --- /dev/null +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -0,0 +1,235 @@ +package connpostgres + +import ( + "fmt" + "log/slog" + "slices" + "strings" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "golang.org/x/exp/maps" +) + +type normalizeStmtGenerator struct { + rawTableName string + // destination table name, used to retrieve records from raw table + dstTableName string + // the schema of the table to merge into + normalizedTableSchema *protos.TableSchema + // array of toast column combinations that are unchanged + unchangedToastColumns []string + // _PEERDB_IS_DELETED and _SYNCED_AT columns + peerdbCols *protos.PeerDBColumns + // Postgres version 15 introduced MERGE, fallback statements before that + supportsMerge bool + // Postgres metadata schema + metadataSchema string + // to log fallback statement selection + logger slog.Logger +} + +func (n *normalizeStmtGenerator) generateNormalizeStatements() []string { + if n.supportsMerge { + return []string{n.generateMergeStatement()} + } + n.logger.Warn("Postgres version is not high enough to support MERGE, falling back to UPSERT+DELETE") + n.logger.Warn("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres") + if n.peerdbCols.SoftDelete { + n.logger.Warn("soft delete enabled with fallback statements! this combination is unsupported") + } + return n.generateFallbackStatements() +} + +func (n *normalizeStmtGenerator) generateFallbackStatements() []string { + columnCount := utils.TableSchemaColumns(n.normalizedTableSchema) + columnNames := make([]string, 0, columnCount) + flattenedCastsSQLArray := make([]string, 0, columnCount) + primaryKeyColumnCasts := make(map[string]string) + utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) + pgType := qValueKindToPostgresType(genericColumnType) + if qvalue.QValueKind(genericColumnType).IsArray() { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } + if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + } + }) + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) + + insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") + updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { + updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) + }) + updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") + deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) + for columnName, columnCast := range primaryKeyColumnCasts { + deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, + parsedDstTable.String(), columnName, columnCast)) + } + deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") + deletePart := fmt.Sprintf( + "DELETE FROM %s USING", + parsedDstTable.String()) + + if n.peerdbCols.SoftDelete { + deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`, + parsedDstTable.String(), n.peerdbCols.SoftDeleteColName) + if n.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, + deletePart, n.peerdbCols.SyncedAtColName) + } + deletePart += " FROM" + } + fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, + strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), n.metadataSchema, + n.rawTableName, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, + strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) + fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, + n.rawTableName, deletePart, deleteWhereClauseSQL) + + return []string{fallbackUpsertStatement, fallbackDeleteStatement} +} + +func (n *normalizeStmtGenerator) generateMergeStatement() string { + columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) + for i, columnName := range columnNames { + columnNames[i] = fmt.Sprintf("\"%s\"", columnName) + } + + flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) + + primaryKeyColumnCasts := make(map[string]string) + primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) + utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + pgType := qValueKindToPostgresType(genericColumnType) + if qvalue.QValueKind(genericColumnType).IsArray() { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } + if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", + columnName, columnName)) + } + }) + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + insertValuesSQLArray := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) + } + + updateStatementsforToastCols := n.generateUpdateStatements(columnNames) + // append synced_at column + columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName)) + insertColumnsSQL := strings.Join(columnNames, ",") + // fill in synced_at column + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") + + if n.peerdbCols.SoftDelete { + softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, + fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",") + softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") + + updateStatementsforToastCols = append(updateStatementsforToastCols, + fmt.Sprintf("WHEN NOT MATCHED AND (src._peerdb_record_type=2) THEN INSERT (%s) VALUES(%s)", + softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) + } + updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") + + deletePart := "DELETE" + if n.peerdbCols.SoftDelete { + colName := n.peerdbCols.SoftDeleteColName + deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName) + if n.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, + deletePart, n.peerdbCols.SyncedAtColName) + } + } + + mergeStmt := fmt.Sprintf( + mergeStatementSQL, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), + n.metadataSchema, + n.rawTableName, + parsedDstTable.String(), + flattenedCastsSQL, + strings.Join(primaryKeySelectSQLArray, " AND "), + insertColumnsSQL, + insertValuesSQL, + updateStringToastCols, + deletePart, + ) + + return mergeStmt +} + +func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := n.peerdbCols.SoftDelete && (n.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(n.unchangedToastColumns) + } + return len(n.unchangedToastColumns) + }()) + + for _, cols := range n.unchangedToastColumns { + unquotedUnchangedColsArray := strings.Split(cols, ",") + unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) + for _, unchangedToastCol := range unquotedUnchangedColsArray { + unchangedColsArray = append(unchangedColsArray, unchangedToastCol) + } + otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + tmpArray := make([]string, 0, len(otherCols)) + for _, colName := range otherCols { + tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) + } + // set the synced at column to the current timestamp + if n.peerdbCols.SyncedAtColName != "" { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`, + n.peerdbCols.SyncedAtColName)) + } + // set soft-deleted to false, tackles insert after soft-delete + if handleSoftDelete { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`, + n.peerdbCols.SoftDeleteColName)) + } + + ssep := strings.Join(tmpArray, ",") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s`, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + + // generates update statements for the case where updates and deletes happen in the same branch + // the backfill has happened from the pull side already, so treat the DeleteRecord as an update + // and then set soft-delete to true. + if handleSoftDelete { + tmpArray = append(tmpArray[:len(tmpArray)-1], + fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName)) + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + } + } + return updateStmts +} diff --git a/flow/connectors/postgres/normalize_stmt_generator_test.go b/flow/connectors/postgres/normalize_stmt_generator_test.go new file mode 100644 index 0000000000..46ffd058df --- /dev/null +++ b/flow/connectors/postgres/normalize_stmt_generator_test.go @@ -0,0 +1,146 @@ +package connpostgres + +import ( + "reflect" + "testing" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" +) + +func TestGenerateMergeUpdateStatement(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET col1=src.col1,col2=src.col2,col3=src.col3, + "_peerdb_synced_at"=CURRENT_TIMESTAMP`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET col1=src.col1,col2=src.col2,col3=src.col3, + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET col1=src.col1,col2=src.col2,col3=src.col3, + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithUnchangedToastCols(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET col1=src.col1,col2=src.col2,col3=src.col3,"_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET col1=src.col1,"_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET col1=src.col1,col3=src.col3,"_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET col1=src.col1,col2=src.col2,"_peerdb_synced_at"=CURRENT_TIMESTAMP`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET col1=src.col1,col2=src.col2,col3=src.col3,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET col1=src.col1,col2=src.col2,col3=src.col3,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET col1=src.col1,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET col1=src.col1,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET col1=src.col1,col3=src.col3,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET col1=src.col1,col3=src.col3,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET col1=src.col1,col2=src.col2,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET col1=src.col1,col2=src.col2,"_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 43a11a2e72..71ede101c5 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -461,13 +461,21 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) mergeStatementsBatch := &pgx.Batch{} totalRowsAffected := 0 for _, destinationTableName := range destinationTableNames { - peerdbCols := protos.PeerDBColumns{ - SoftDeleteColName: req.SoftDeleteColName, - SyncedAtColName: req.SyncedAtColName, - SoftDelete: req.SoftDelete, + normalizeStmtGen := &normalizeStmtGenerator{ + rawTableName: rawTableIdentifier, + dstTableName: destinationTableName, + normalizedTableSchema: c.tableSchemaMapping[destinationTableName], + unchangedToastColumns: unchangedToastColsMap[destinationTableName], + peerdbCols: &protos.PeerDBColumns{ + SoftDeleteColName: req.SoftDeleteColName, + SyncedAtColName: req.SyncedAtColName, + SoftDelete: req.SoftDelete, + }, + supportsMerge: supportsMerge, + metadataSchema: c.metadataSchema, + logger: c.logger, } - normalizeStatements := c.generateNormalizeStatements(destinationTableName, unchangedToastColsMap[destinationTableName], - rawTableIdentifier, supportsMerge, &peerdbCols) + normalizeStatements := normalizeStmtGen.generateNormalizeStatements() for _, normalizeStatement := range normalizeStatements { mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec( func(ct pgconn.CommandTag) error { diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 58983fe048..684f922dfd 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -25,12 +25,12 @@ type mergeStmtGenerator struct { peerdbCols *protos.PeerDBColumns } -func (c *mergeStmtGenerator) generateMergeStmt() (string, error) { - parsedDstTable, _ := utils.ParseSchemaTable(c.dstTableName) - columnNames := utils.TableSchemaColumnNames(c.normalizedTableSchema) +func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { + parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName) + columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema) - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(c.normalizedTableSchema)) - err := utils.IterColumnsError(c.normalizedTableSchema, func(columnName, genericColumnType string) error { + flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)) + err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error { qvKind := qvalue.QValueKind(genericColumnType) sfType, err := qValueKindToSnowflakeType(qvKind) if err != nil { @@ -82,7 +82,7 @@ func (c *mergeStmtGenerator) generateMergeStmt() (string, error) { } // append synced_at column quotedUpperColNames = append(quotedUpperColNames, - fmt.Sprintf(`"%s"`, strings.ToUpper(c.peerdbCols.SyncedAtColName)), + fmt.Sprintf(`"%s"`, strings.ToUpper(m.peerdbCols.SyncedAtColName)), ) insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",") @@ -95,14 +95,14 @@ func (c *mergeStmtGenerator) generateMergeStmt() (string, error) { // fill in synced_at column insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") insertValuesSQL := strings.Join(insertValuesSQLArray, ",") - updateStatementsforToastCols := c.generateUpdateStatements(columnNames) + updateStatementsforToastCols := m.generateUpdateStatements(columnNames) // handling the case when an insert and delete happen in the same batch, with updates in the middle // with soft-delete, we want the row to be in the destination with SOFT_DELETE true // the current merge statement doesn't do that, so we add another case to insert the DeleteRecord - if c.peerdbCols.SoftDelete && (c.peerdbCols.SoftDeleteColName != "") { + if m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") { softDeleteInsertColumnsSQL := strings.Join(append(quotedUpperColNames, - c.peerdbCols.SoftDeleteColName), ",") + m.peerdbCols.SoftDeleteColName), ",") softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" updateStatementsforToastCols = append(updateStatementsforToastCols, fmt.Sprintf("WHEN NOT MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN INSERT (%s) VALUES(%s)", @@ -110,9 +110,9 @@ func (c *mergeStmtGenerator) generateMergeStmt() (string, error) { } updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - normalizedpkeyColsArray := make([]string, 0, len(c.normalizedTableSchema.PrimaryKeyColumns)) - pkeySelectSQLArray := make([]string, 0, len(c.normalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range c.normalizedTableSchema.PrimaryKeyColumns { + normalizedpkeyColsArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + pkeySelectSQLArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range m.normalizedTableSchema.PrimaryKeyColumns { normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName) normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName) pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s", @@ -122,16 +122,16 @@ func (c *mergeStmtGenerator) generateMergeStmt() (string, error) { pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") deletePart := "DELETE" - if c.peerdbCols.SoftDelete { - colName := c.peerdbCols.SoftDeleteColName + if m.peerdbCols.SoftDelete { + colName := m.peerdbCols.SoftDeleteColName deletePart = fmt.Sprintf("UPDATE SET %s = TRUE", colName) - if c.peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, c.peerdbCols.SyncedAtColName) + if m.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, m.peerdbCols.SyncedAtColName) } } mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable), - toVariantColumnName, c.rawTableName, c.normalizeBatchID, c.syncBatchID, flattenedCastsSQL, + toVariantColumnName, m.rawTableName, m.normalizeBatchID, m.syncBatchID, flattenedCastsSQL, fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")), pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) @@ -162,10 +162,17 @@ and updating the other columns. 6. Repeat steps 1-5 for each unique set of unchanged toast column groups. 7. Return the list of generated update statements. */ -func (c *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { - updateStmts := make([]string, 0, len(c.unchangedToastColumns)) +func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(m.unchangedToastColumns) + } + return len(m.unchangedToastColumns) + }()) - for _, cols := range c.unchangedToastColumns { + for _, cols := range m.unchangedToastColumns { unchangedColsArray := strings.Split(cols, ",") otherCols := utils.ArrayMinus(allCols, unchangedColsArray) tmpArray := make([]string, 0, len(otherCols)+2) @@ -175,14 +182,14 @@ func (c *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string } // set the synced at column to the current timestamp - if c.peerdbCols.SyncedAtColName != "" { + if m.peerdbCols.SyncedAtColName != "" { tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, - c.peerdbCols.SyncedAtColName)) + m.peerdbCols.SyncedAtColName)) } // set soft-deleted to false, tackles insert after soft-delete - if c.peerdbCols.SoftDelete && (c.peerdbCols.SoftDeleteColName != "") { + if handleSoftDelete { tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, - c.peerdbCols.SoftDeleteColName)) + m.peerdbCols.SoftDeleteColName)) } ssep := strings.Join(tmpArray, ", ") @@ -194,9 +201,9 @@ func (c *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string // generates update statements for the case where updates and deletes happen in the same branch // the backfill has happened from the pull side already, so treat the DeleteRecord as an update // and then set soft-delete to true. - if c.peerdbCols.SoftDelete && (c.peerdbCols.SoftDeleteColName != "") { + if handleSoftDelete { tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf(`"%s" = TRUE`, - c.peerdbCols.SoftDeleteColName)) + m.peerdbCols.SoftDeleteColName)) ssep := strings.Join(tmpArray, ", ") updateStmt := fmt.Sprintf(`WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' diff --git a/flow/connectors/snowflake/merge_stmt_generator_test.go b/flow/connectors/snowflake/merge_stmt_generator_test.go index d3c7d93285..f8b70f566a 100644 --- a/flow/connectors/snowflake/merge_stmt_generator_test.go +++ b/flow/connectors/snowflake/merge_stmt_generator_test.go @@ -2,13 +2,13 @@ package connsnowflake import ( "reflect" - "strings" "testing" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) -func TestGenerateUpdateStatement_EmptyColumns(t *testing.T) { +func TestGenerateUpdateStatement(t *testing.T) { allCols := []string{"col1", "col2", "col3"} unchangedToastCols := []string{""} @@ -28,8 +28,40 @@ func TestGenerateUpdateStatement_EmptyColumns(t *testing.T) { result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + } + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { @@ -66,8 +98,8 @@ func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { @@ -116,18 +148,11 @@ func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing. result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) } } - -func removeSpacesTabsNewlines(s string) string { - s = strings.ReplaceAll(s, " ", "") - s = strings.ReplaceAll(s, "\t", "") - s = strings.ReplaceAll(s, "\n", "") - return s -} diff --git a/flow/connectors/utils/identifiers.go b/flow/connectors/utils/identifiers.go index 5318605a93..19867971a9 100644 --- a/flow/connectors/utils/identifiers.go +++ b/flow/connectors/utils/identifiers.go @@ -49,3 +49,10 @@ func IsLower(s string) bool { } return true } + +func RemoveSpacesTabsNewlines(s string) string { + s = strings.ReplaceAll(s, " ", "") + s = strings.ReplaceAll(s, "\t", "") + s = strings.ReplaceAll(s, "\n", "") + return s +}