From c0061a28444fd302badfe5245547410129810ef9 Mon Sep 17 00:00:00 2001 From: Kevin Biju <52661649+heavycrystal@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:27:52 +0530 Subject: [PATCH 1/3] refactored SF and PG to standalone merge generator (#1026) to allow better testing in the future, also soft-delete only happens when SoftDelete is set to true and a SoftDeleteColName is set. Added merge statement tests for PG, SF and BQ. Removed some whitespace in Postgres merge statements. --- ...t_generator.go => merge_stmt_generator.go} | 34 +-- .../bigquery/merge_stmt_generator_test.go | 181 +++++++++---- flow/connectors/postgres/client.go | 212 --------------- .../postgres/normalize_stmt_generator.go | 235 +++++++++++++++++ .../postgres/normalize_stmt_generator_test.go | 148 +++++++++++ flow/connectors/postgres/postgres.go | 20 +- .../snowflake/merge_stmt_generator.go | 215 +++++++++++++++ .../snowflake/merge_stmt_generator_test.go | 137 ++++++++-- flow/connectors/snowflake/snowflake.go | 248 +++--------------- flow/connectors/utils/identifiers.go | 7 + 10 files changed, 913 insertions(+), 524 deletions(-) rename flow/connectors/bigquery/{merge_statement_generator.go => merge_stmt_generator.go} (93%) create mode 100644 flow/connectors/postgres/normalize_stmt_generator.go create mode 100644 flow/connectors/postgres/normalize_stmt_generator_test.go create mode 100644 flow/connectors/snowflake/merge_stmt_generator.go diff --git a/flow/connectors/bigquery/merge_statement_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go similarity index 93% rename from flow/connectors/bigquery/merge_statement_generator.go rename to flow/connectors/bigquery/merge_stmt_generator.go index 01affcddfd..3af5af179f 100644 --- a/flow/connectors/bigquery/merge_statement_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -139,8 +139,7 @@ func (m *mergeStmtGenerator) generateMergeStmt() string { insertColumnsSQL := csep + fmt.Sprintf(", `%s`", m.peerdbCols.SyncedAtColName) insertValuesSQL := shortCsep + ",CURRENT_TIMESTAMP" - updateStatementsforToastCols := m.generateUpdateStatements(pureColNames, - m.unchangedToastColumns, m.peerdbCols) + updateStatementsforToastCols := m.generateUpdateStatements(pureColNames) if m.peerdbCols.SoftDelete { softDeleteInsertColumnsSQL := insertColumnsSQL + fmt.Sprintf(",`%s`", m.peerdbCols.SoftDeleteColName) softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" @@ -196,14 +195,17 @@ and updating the other columns (not the unchanged toast columns) 6. Repeat steps 1-5 for each unique unchanged toast column group. 7. Return the list of generated update statements. */ -func (m *mergeStmtGenerator) generateUpdateStatements( - allCols []string, - unchangedToastCols []string, - peerdbCols *protos.PeerDBColumns, -) []string { - updateStmts := make([]string, 0, len(unchangedToastCols)) - - for _, cols := range unchangedToastCols { +func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(m.unchangedToastColumns) + } + return len(m.unchangedToastColumns) + }()) + + for _, cols := range m.unchangedToastColumns { unchangedColsArray := strings.Split(cols, ",") otherCols := utils.ArrayMinus(allCols, unchangedColsArray) tmpArray := make([]string, 0, len(otherCols)) @@ -212,14 +214,14 @@ func (m *mergeStmtGenerator) generateUpdateStatements( } // set the synced at column to the current timestamp - if peerdbCols.SyncedAtColName != "" { + if m.peerdbCols.SyncedAtColName != "" { tmpArray = append(tmpArray, fmt.Sprintf("`%s`=CURRENT_TIMESTAMP", - peerdbCols.SyncedAtColName)) + m.peerdbCols.SyncedAtColName)) } // set soft-deleted to false, tackles insert after soft-delete - if peerdbCols.SoftDeleteColName != "" { + if handleSoftDelete { tmpArray = append(tmpArray, fmt.Sprintf("`%s`=FALSE", - peerdbCols.SoftDeleteColName)) + m.peerdbCols.SoftDeleteColName)) } ssep := strings.Join(tmpArray, ",") @@ -231,9 +233,9 @@ func (m *mergeStmtGenerator) generateUpdateStatements( // generates update statements for the case where updates and deletes happen in the same branch // the backfill has happened from the pull side already, so treat the DeleteRecord as an update // and then set soft-delete to true. - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { + if handleSoftDelete { tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf("`%s`=TRUE", peerdbCols.SoftDeleteColName)) + fmt.Sprintf("`%s`=TRUE", m.peerdbCols.SoftDeleteColName)) ssep := strings.Join(tmpArray, ",") updateStmt := fmt.Sprintf(`WHEN MATCHED AND _rt=2 AND _ut='%s' diff --git a/flow/connectors/bigquery/merge_stmt_generator_test.go b/flow/connectors/bigquery/merge_stmt_generator_test.go index 141b3999b7..cc49b17cbd 100644 --- a/flow/connectors/bigquery/merge_stmt_generator_test.go +++ b/flow/connectors/bigquery/merge_stmt_generator_test.go @@ -2,77 +2,67 @@ package connbigquery import ( "reflect" - "strings" "testing" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) -func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { +func TestGenerateUpdateStatement(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} m := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, shortColumn: map[string]string{ "col1": "_c0", "col2": "_c1", "col3": "_c2", }, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, } - allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} expected := []string{ - "WHEN MATCHED AND _rt!=2 AND _ut=''" + - " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", - "WHEN MATCHED AND _rt=2 " + - "AND _ut='' " + - "THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1," + - "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", - "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + - "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", - "WHEN MATCHED AND _rt=2 AND _ut='col2,col3' " + - "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", "WHEN MATCHED AND _rt!=2 " + - "AND _ut='col2' " + - "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", - "WHEN MATCHED AND _rt=2 " + - "AND _ut='col2' " + - "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + - "`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE ", - "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + - "THEN UPDATE SET `col1`=_d._c0," + - "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", - "WHEN MATCHED AND _rt=2 AND _ut='col3' " + - "THEN UPDATE SET `col1`=_d._c0," + - "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "AND _ut=''" + + "THEN UPDATE SET " + + "`col1`=_d._c0," + + "`col2`=_d._c1," + + "`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", } - result := m.generateUpdateStatements(allCols, unchangedToastCols, &protos.PeerDBColumns{ - SoftDelete: true, - SoftDeleteColName: "deleted", - SyncedAtColName: "synced_at", - }) + result := m.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { - t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) } } -func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { +func TestGenerateUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} m := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, shortColumn: map[string]string{ "col1": "_c0", "col2": "_c1", "col3": "_c2", }, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, } - allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{""} expected := []string{ "WHEN MATCHED AND _rt!=2 " + @@ -89,26 +79,115 @@ func TestGenerateUpdateStatement_NoUnchangedToastCols(t *testing.T) { "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", } - result := m.generateUpdateStatements(allCols, unchangedToastCols, - &protos.PeerDBColumns{ - SoftDelete: true, + result := m.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + m := &mergeStmtGenerator{ + shortColumn: map[string]string{ + "col1": "_c0", + "col2": "_c1", + "col3": "_c2", + }, + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, SoftDeleteColName: "deleted", SyncedAtColName: "synced_at", - }) + }, + } + + expected := []string{ + "WHEN MATCHED AND _rt!=2 AND _ut=''" + + " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP", + "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP", + } + + result := m.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { - t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) } } -func removeSpacesTabsNewlines(s string) string { - s = strings.ReplaceAll(s, " ", "") - s = strings.ReplaceAll(s, "\t", "") - s = strings.ReplaceAll(s, "\n", "") - return s +func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + m := &mergeStmtGenerator{ + shortColumn: map[string]string{ + "col1": "_c0", + "col2": "_c1", + "col3": "_c2", + }, + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SoftDeleteColName: "deleted", + SyncedAtColName: "synced_at", + }, + } + + expected := []string{ + "WHEN MATCHED AND _rt!=2 AND _ut=''" + + " THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", + "WHEN MATCHED AND _rt=2 " + + "AND _ut='' " + + "THEN UPDATE SET `col1`=_d._c0,`col2`=_d._c1," + + "`col3`=_d._c2,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "WHEN MATCHED AND _rt!=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", + "WHEN MATCHED AND _rt=2 AND _ut='col2,col3' " + + "THEN UPDATE SET `col1`=_d._c0,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + "WHEN MATCHED AND _rt!=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE", + "WHEN MATCHED AND _rt=2 " + + "AND _ut='col2' " + + "THEN UPDATE SET `col1`=_d._c0,`col3`=_d._c2," + + "`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE ", + "WHEN MATCHED AND _rt!=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=FALSE ", + "WHEN MATCHED AND _rt=2 AND _ut='col3' " + + "THEN UPDATE SET `col1`=_d._c0," + + "`col2`=_d._c1,`synced_at`=CURRENT_TIMESTAMP,`deleted`=TRUE", + } + + result := m.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v,\nbut got: %v", expected, result) + } } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index b51741aff6..0a99bce668 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -5,18 +5,15 @@ import ( "fmt" "log" "regexp" - "slices" "strings" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq/oid" - "golang.org/x/exp/maps" ) const ( @@ -599,215 +596,6 @@ func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, sync return resultMap, nil } -func (c *PostgresConnector) generateNormalizeStatements(destinationTableIdentifier string, - unchangedToastColumns []string, rawTableIdentifier string, supportsMerge bool, - peerdbCols *protos.PeerDBColumns, -) []string { - if supportsMerge { - return []string{c.generateMergeStatement(destinationTableIdentifier, unchangedToastColumns, - rawTableIdentifier, peerdbCols)} - } - c.logger.Warn("Postgres version is not high enough to support MERGE, falling back to UPSERT + DELETE") - c.logger.Warn("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres") - return c.generateFallbackStatements(destinationTableIdentifier, rawTableIdentifier, peerdbCols) -} - -func (c *PostgresConnector) generateFallbackStatements(destinationTableIdentifier string, - rawTableIdentifier string, peerdbCols *protos.PeerDBColumns, -) []string { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - columnCount := utils.TableSchemaColumns(normalizedTableSchema) - columnNames := make([]string, 0, columnCount) - flattenedCastsSQLArray := make([]string, 0, columnCount) - primaryKeyColumnCasts := make(map[string]string) - utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) { - columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) - pgType := qValueKindToPostgresType(genericColumnType) - if qvalue.QValueKind(genericColumnType).IsArray() { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } - if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) - } - }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) - - insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") - updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - utils.IterColumns(normalizedTableSchema, func(columnName, _ string) { - updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) - }) - updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") - deleteWhereClauseArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for columnName, columnCast := range primaryKeyColumnCasts { - deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, - parsedDstTable.String(), columnName, columnCast)) - } - deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") - deletePart := fmt.Sprintf( - "DELETE FROM %s USING", - parsedDstTable.String()) - - if peerdbCols.SoftDelete { - deletePart = fmt.Sprintf(`UPDATE %s SET "%s" = TRUE`, - parsedDstTable.String(), peerdbCols.SoftDeleteColName) - if peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s, "%s" = CURRENT_TIMESTAMP`, - deletePart, peerdbCols.SyncedAtColName) - } - deletePart += " FROM" - } - fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, - strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), c.metadataSchema, - rawTableIdentifier, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, - strings.Join(normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) - fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, - strings.Join(maps.Values(primaryKeyColumnCasts), ","), c.metadataSchema, - rawTableIdentifier, deletePart, deleteWhereClauseSQL) - - return []string{fallbackUpsertStatement, fallbackDeleteStatement} -} - -func (c *PostgresConnector) generateMergeStatement( - destinationTableIdentifier string, - unchangedToastColumns []string, - rawTableIdentifier string, - peerdbCols *protos.PeerDBColumns, -) string { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - columnNames := utils.TableSchemaColumnNames(normalizedTableSchema) - for i, columnName := range columnNames { - columnNames[i] = fmt.Sprintf("\"%s\"", columnName) - } - - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - parsedDstTable, _ := utils.ParseSchemaTable(destinationTableIdentifier) - - primaryKeyColumnCasts := make(map[string]string) - primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(normalizedTableSchema, func(columnName, genericColumnType string) { - pgType := qValueKindToPostgresType(genericColumnType) - if qvalue.QValueKind(genericColumnType).IsArray() { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", - strings.Trim(columnName, "\""), pgType, columnName)) - } - if slices.Contains(normalizedTableSchema.PrimaryKeyColumns, columnName) { - primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) - primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", - columnName, columnName)) - } - }) - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) - } - - updateStatementsforToastCols := c.generateUpdateStatement(columnNames, unchangedToastColumns, peerdbCols) - // append synced_at column - columnNames = append(columnNames, fmt.Sprintf(`"%s"`, peerdbCols.SyncedAtColName)) - insertColumnsSQL := strings.Join(columnNames, ",") - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") - - if peerdbCols.SoftDelete { - softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, - fmt.Sprintf(`"%s"`, peerdbCols.SoftDeleteColName)), ","), ",") - softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") - - updateStatementsforToastCols = append(updateStatementsforToastCols, - fmt.Sprintf("WHEN NOT MATCHED AND (src._peerdb_record_type = 2) THEN INSERT (%s) VALUES(%s)", - softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) - } - updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") - - deletePart := "DELETE" - if peerdbCols.SoftDelete { - colName := peerdbCols.SoftDeleteColName - deletePart = fmt.Sprintf(`UPDATE SET "%s" = TRUE`, colName) - if peerdbCols.SyncedAtColName != "" { - deletePart = fmt.Sprintf(`%s, "%s" = CURRENT_TIMESTAMP`, - deletePart, peerdbCols.SyncedAtColName) - } - } - - mergeStmt := fmt.Sprintf( - mergeStatementSQL, - strings.Join(maps.Values(primaryKeyColumnCasts), ","), - c.metadataSchema, - rawTableIdentifier, - parsedDstTable.String(), - flattenedCastsSQL, - strings.Join(primaryKeySelectSQLArray, " AND "), - insertColumnsSQL, - insertValuesSQL, - updateStringToastCols, - deletePart, - ) - - return mergeStmt -} - -func (c *PostgresConnector) generateUpdateStatement(allCols []string, - unchangedToastColsLists []string, peerdbCols *protos.PeerDBColumns, -) []string { - updateStmts := make([]string, 0, len(unchangedToastColsLists)) - - for _, cols := range unchangedToastColsLists { - unquotedUnchangedColsArray := strings.Split(cols, ",") - unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) - for _, unchangedToastCol := range unquotedUnchangedColsArray { - unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) - } - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0, len(otherCols)) - for _, colName := range otherCols { - tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) - } - // set the synced at column to the current timestamp - if peerdbCols.SyncedAtColName != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, - peerdbCols.SyncedAtColName)) - } - // set soft-deleted to false, tackles insert after soft-delete - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, - peerdbCols.SoftDeleteColName)) - } - - ssep := strings.Join(tmpArray, ",") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - - // generates update statements for the case where updates and deletes happen in the same branch - // the backfill has happened from the pull side already, so treat the DeleteRecord as an update - // and then set soft-delete to true. - if peerdbCols.SoftDelete && (peerdbCols.SoftDeleteColName != "") { - tmpArray = append(tmpArray[:len(tmpArray)-1], - fmt.Sprintf(`"%s" = TRUE`, peerdbCols.SoftDeleteColName)) - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - } - } - return updateStmts -} - func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { row := c.pool.QueryRow(c.ctx, "SELECT CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END") diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go new file mode 100644 index 0000000000..b541543fe2 --- /dev/null +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -0,0 +1,235 @@ +package connpostgres + +import ( + "fmt" + "log/slog" + "slices" + "strings" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" + "golang.org/x/exp/maps" +) + +type normalizeStmtGenerator struct { + rawTableName string + // destination table name, used to retrieve records from raw table + dstTableName string + // the schema of the table to merge into + normalizedTableSchema *protos.TableSchema + // array of toast column combinations that are unchanged + unchangedToastColumns []string + // _PEERDB_IS_DELETED and _SYNCED_AT columns + peerdbCols *protos.PeerDBColumns + // Postgres version 15 introduced MERGE, fallback statements before that + supportsMerge bool + // Postgres metadata schema + metadataSchema string + // to log fallback statement selection + logger slog.Logger +} + +func (n *normalizeStmtGenerator) generateNormalizeStatements() []string { + if n.supportsMerge { + return []string{n.generateMergeStatement()} + } + n.logger.Warn("Postgres version is not high enough to support MERGE, falling back to UPSERT+DELETE") + n.logger.Warn("TOAST columns will not be updated properly, use REPLICA IDENTITY FULL or upgrade Postgres") + if n.peerdbCols.SoftDelete { + n.logger.Warn("soft delete enabled with fallback statements! this combination is unsupported") + } + return n.generateFallbackStatements() +} + +func (n *normalizeStmtGenerator) generateFallbackStatements() []string { + columnCount := utils.TableSchemaColumns(n.normalizedTableSchema) + columnNames := make([]string, 0, columnCount) + flattenedCastsSQLArray := make([]string, 0, columnCount) + primaryKeyColumnCasts := make(map[string]string) + utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + columnNames = append(columnNames, fmt.Sprintf("\"%s\"", columnName)) + pgType := qValueKindToPostgresType(genericColumnType) + if qvalue.QValueKind(genericColumnType).IsArray() { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } + if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + } + }) + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) + + insertColumnsSQL := strings.TrimSuffix(strings.Join(columnNames, ","), ",") + updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { + updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`"%s"=EXCLUDED."%s"`, columnName, columnName)) + }) + updateColumnsSQL := strings.TrimSuffix(strings.Join(updateColumnsSQLArray, ","), ",") + deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) + for columnName, columnCast := range primaryKeyColumnCasts { + deleteWhereClauseArray = append(deleteWhereClauseArray, fmt.Sprintf(`%s."%s"=%s AND `, + parsedDstTable.String(), columnName, columnCast)) + } + deleteWhereClauseSQL := strings.TrimSuffix(strings.Join(deleteWhereClauseArray, ""), "AND ") + deletePart := fmt.Sprintf( + "DELETE FROM %s USING", + parsedDstTable.String()) + + if n.peerdbCols.SoftDelete { + deletePart = fmt.Sprintf(`UPDATE %s SET "%s"=TRUE`, + parsedDstTable.String(), n.peerdbCols.SoftDeleteColName) + if n.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, + deletePart, n.peerdbCols.SyncedAtColName) + } + deletePart += " FROM" + } + fallbackUpsertStatement := fmt.Sprintf(fallbackUpsertStatementSQL, + strings.TrimSuffix(strings.Join(maps.Values(primaryKeyColumnCasts), ","), ","), n.metadataSchema, + n.rawTableName, parsedDstTable.String(), insertColumnsSQL, flattenedCastsSQL, + strings.Join(n.normalizedTableSchema.PrimaryKeyColumns, ","), updateColumnsSQL) + fallbackDeleteStatement := fmt.Sprintf(fallbackDeleteStatementSQL, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), n.metadataSchema, + n.rawTableName, deletePart, deleteWhereClauseSQL) + + return []string{fallbackUpsertStatement, fallbackDeleteStatement} +} + +func (n *normalizeStmtGenerator) generateMergeStatement() string { + columnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) + for i, columnName := range columnNames { + columnNames[i] = fmt.Sprintf("\"%s\"", columnName) + } + + flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) + + primaryKeyColumnCasts := make(map[string]string) + primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) + utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + pgType := qValueKindToPostgresType(genericColumnType) + if qvalue.QValueKind(genericColumnType).IsArray() { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("ARRAY(SELECT * FROM JSON_ARRAY_ELEMENTS_TEXT((_peerdb_data->>'%s')::JSON))::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("(_peerdb_data->>'%s')::%s AS \"%s\"", + strings.Trim(columnName, "\""), pgType, columnName)) + } + if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { + primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>'%s')::%s", columnName, pgType) + primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", + columnName, columnName)) + } + }) + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ","), ",") + insertValuesSQLArray := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", columnName)) + } + + updateStatementsforToastCols := n.generateUpdateStatements(columnNames) + // append synced_at column + columnNames = append(columnNames, fmt.Sprintf(`"%s"`, n.peerdbCols.SyncedAtColName)) + insertColumnsSQL := strings.Join(columnNames, ",") + // fill in synced_at column + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + insertValuesSQL := strings.TrimSuffix(strings.Join(insertValuesSQLArray, ","), ",") + + if n.peerdbCols.SoftDelete { + softDeleteInsertColumnsSQL := strings.TrimSuffix(strings.Join(append(columnNames, + fmt.Sprintf(`"%s"`, n.peerdbCols.SoftDeleteColName)), ","), ",") + softDeleteInsertValuesSQL := strings.Join(append(insertValuesSQLArray, "TRUE"), ",") + + updateStatementsforToastCols = append(updateStatementsforToastCols, + fmt.Sprintf("WHEN NOT MATCHED AND (src._peerdb_record_type=2) THEN INSERT (%s) VALUES(%s)", + softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) + } + updateStringToastCols := strings.Join(updateStatementsforToastCols, "\n") + + deletePart := "DELETE" + if n.peerdbCols.SoftDelete { + colName := n.peerdbCols.SoftDeleteColName + deletePart = fmt.Sprintf(`UPDATE SET "%s"=TRUE`, colName) + if n.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf(`%s,"%s"=CURRENT_TIMESTAMP`, + deletePart, n.peerdbCols.SyncedAtColName) + } + } + + mergeStmt := fmt.Sprintf( + mergeStatementSQL, + strings.Join(maps.Values(primaryKeyColumnCasts), ","), + n.metadataSchema, + n.rawTableName, + parsedDstTable.String(), + flattenedCastsSQL, + strings.Join(primaryKeySelectSQLArray, " AND "), + insertColumnsSQL, + insertValuesSQL, + updateStringToastCols, + deletePart, + ) + + return mergeStmt +} + +func (n *normalizeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := n.peerdbCols.SoftDelete && (n.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(n.unchangedToastColumns) + } + return len(n.unchangedToastColumns) + }()) + + for _, cols := range n.unchangedToastColumns { + unquotedUnchangedColsArray := strings.Split(cols, ",") + unchangedColsArray := make([]string, 0, len(unquotedUnchangedColsArray)) + for _, unchangedToastCol := range unquotedUnchangedColsArray { + unchangedColsArray = append(unchangedColsArray, fmt.Sprintf(`"%s"`, unchangedToastCol)) + } + otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + tmpArray := make([]string, 0, len(otherCols)) + for _, colName := range otherCols { + tmpArray = append(tmpArray, fmt.Sprintf("%s=src.%s", colName, colName)) + } + // set the synced at column to the current timestamp + if n.peerdbCols.SyncedAtColName != "" { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=CURRENT_TIMESTAMP`, + n.peerdbCols.SyncedAtColName)) + } + // set soft-deleted to false, tackles insert after soft-delete + if handleSoftDelete { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s"=FALSE`, + n.peerdbCols.SoftDeleteColName)) + } + + ssep := strings.Join(tmpArray, ",") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s`, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + + // generates update statements for the case where updates and deletes happen in the same branch + // the backfill has happened from the pull side already, so treat the DeleteRecord as an update + // and then set soft-delete to true. + if handleSoftDelete { + tmpArray = append(tmpArray[:len(tmpArray)-1], + fmt.Sprintf(`"%s"=TRUE`, n.peerdbCols.SoftDeleteColName)) + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + } + } + return updateStmts +} diff --git a/flow/connectors/postgres/normalize_stmt_generator_test.go b/flow/connectors/postgres/normalize_stmt_generator_test.go new file mode 100644 index 0000000000..637c69f2bc --- /dev/null +++ b/flow/connectors/postgres/normalize_stmt_generator_test.go @@ -0,0 +1,148 @@ +package connpostgres + +import ( + "reflect" + "testing" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" +) + +func TestGenerateMergeUpdateStatement(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithUnchangedToastCols(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET "col1"=src."col1","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET "col1"=src."col1","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","_peerdb_synced_at"=CURRENT_TIMESTAMP`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateMergeUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { + allCols := []string{`"col1"`, `"col2"`, `"col3"`} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} + + expected := []string{ + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","col3"=src."col3", + "_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET "col1"=src."col1","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col2,col3' + THEN UPDATE SET "col1"=src."col1","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET "col1"=src."col1","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col2' + THEN UPDATE SET "col1"=src."col1","col3"=src."col3","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + `WHEN MATCHED AND src._peerdb_record_type!=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=FALSE`, + `WHEN MATCHED AND src._peerdb_record_type=2 AND _peerdb_unchanged_toast_columns='col3' + THEN UPDATE SET "col1"=src."col1","col2"=src."col2","_peerdb_synced_at"=CURRENT_TIMESTAMP,"_peerdb_soft_delete"=TRUE`, + } + normalizeGen := &normalizeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_peerdb_synced_at", + SoftDeleteColName: "_peerdb_soft_delete", + }, + } + result := normalizeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 43a11a2e72..71ede101c5 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -461,13 +461,21 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) mergeStatementsBatch := &pgx.Batch{} totalRowsAffected := 0 for _, destinationTableName := range destinationTableNames { - peerdbCols := protos.PeerDBColumns{ - SoftDeleteColName: req.SoftDeleteColName, - SyncedAtColName: req.SyncedAtColName, - SoftDelete: req.SoftDelete, + normalizeStmtGen := &normalizeStmtGenerator{ + rawTableName: rawTableIdentifier, + dstTableName: destinationTableName, + normalizedTableSchema: c.tableSchemaMapping[destinationTableName], + unchangedToastColumns: unchangedToastColsMap[destinationTableName], + peerdbCols: &protos.PeerDBColumns{ + SoftDeleteColName: req.SoftDeleteColName, + SyncedAtColName: req.SyncedAtColName, + SoftDelete: req.SoftDelete, + }, + supportsMerge: supportsMerge, + metadataSchema: c.metadataSchema, + logger: c.logger, } - normalizeStatements := c.generateNormalizeStatements(destinationTableName, unchangedToastColsMap[destinationTableName], - rawTableIdentifier, supportsMerge, &peerdbCols) + normalizeStatements := normalizeStmtGen.generateNormalizeStatements() for _, normalizeStatement := range normalizeStatements { mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec( func(ct pgconn.CommandTag) error { diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go new file mode 100644 index 0000000000..684f922dfd --- /dev/null +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -0,0 +1,215 @@ +package connsnowflake + +import ( + "fmt" + "strings" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" +) + +type mergeStmtGenerator struct { + rawTableName string + // destination table name, used to retrieve records from raw table + dstTableName string + // last synced batchID. + syncBatchID int64 + // last normalized batchID. + normalizeBatchID int64 + // the schema of the table to merge into + normalizedTableSchema *protos.TableSchema + // array of toast column combinations that are unchanged + unchangedToastColumns []string + // _PEERDB_IS_DELETED and _SYNCED_AT columns + peerdbCols *protos.PeerDBColumns +} + +func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { + parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName) + columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema) + + flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)) + err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error { + qvKind := qvalue.QValueKind(genericColumnType) + sfType, err := qValueKindToSnowflakeType(qvKind) + if err != nil { + return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) + } + + targetColumnName := SnowflakeIdentifierNormalize(columnName) + switch qvalue.QValueKind(genericColumnType) { + case qvalue.QValueKindBytes, qvalue.QValueKindBit: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ + "AS %s,", toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindGeography: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,", + toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindGeometry: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,", + toVariantColumnName, columnName, targetColumnName)) + case qvalue.QValueKindJSON: + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,", + toVariantColumnName, columnName, targetColumnName)) + // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types + // case model.ColumnTypeTime: + // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ + // "Microseconds*1000) "+ + // "AS %s,", toVariantColumnName, columnName, columnName)) + default: + if qvKind == qvalue.QValueKindNumeric { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, + fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,", + toVariantColumnName, columnName, sfType, targetColumnName)) + } else { + flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,", + toVariantColumnName, columnName, sfType, targetColumnName)) + } + } + return nil + }) + if err != nil { + return "", err + } + flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",") + + quotedUpperColNames := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName)) + } + // append synced_at column + quotedUpperColNames = append(quotedUpperColNames, + fmt.Sprintf(`"%s"`, strings.ToUpper(m.peerdbCols.SyncedAtColName)), + ) + + insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",") + + insertValuesSQLArray := make([]string, 0, len(columnNames)) + for _, columnName := range columnNames { + normalizedColName := SnowflakeIdentifierNormalize(columnName) + insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName)) + } + // fill in synced_at column + insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") + insertValuesSQL := strings.Join(insertValuesSQLArray, ",") + updateStatementsforToastCols := m.generateUpdateStatements(columnNames) + + // handling the case when an insert and delete happen in the same batch, with updates in the middle + // with soft-delete, we want the row to be in the destination with SOFT_DELETE true + // the current merge statement doesn't do that, so we add another case to insert the DeleteRecord + if m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") { + softDeleteInsertColumnsSQL := strings.Join(append(quotedUpperColNames, + m.peerdbCols.SoftDeleteColName), ",") + softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" + updateStatementsforToastCols = append(updateStatementsforToastCols, + fmt.Sprintf("WHEN NOT MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN INSERT (%s) VALUES(%s)", + softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) + } + updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") + + normalizedpkeyColsArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + pkeySelectSQLArray := make([]string, 0, len(m.normalizedTableSchema.PrimaryKeyColumns)) + for _, pkeyColName := range m.normalizedTableSchema.PrimaryKeyColumns { + normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName) + normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName) + pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s", + normalizedPkeyColName, normalizedPkeyColName)) + } + // TARGET. = SOURCE. AND TARGET. = SOURCE. ... + pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") + + deletePart := "DELETE" + if m.peerdbCols.SoftDelete { + colName := m.peerdbCols.SoftDeleteColName + deletePart = fmt.Sprintf("UPDATE SET %s = TRUE", colName) + if m.peerdbCols.SyncedAtColName != "" { + deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, m.peerdbCols.SyncedAtColName) + } + } + + mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable), + toVariantColumnName, m.rawTableName, m.normalizeBatchID, m.syncBatchID, flattenedCastsSQL, + fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")), + pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) + + return mergeStatement, nil +} + +/* +This function generates UPDATE statements for a MERGE operation based on the provided inputs. + +Inputs: +1. allCols: An array of all column names. +2. unchangedToastCols: An array capturing unique sets of unchanged toast column groups. +3. softDeleteCol: just set to false in the case we see an insert after a soft-deleted column +4. syncedAtCol: set to the CURRENT_TIMESTAMP + +Algorithm: +1. Iterate over each unique set of unchanged toast column groups. +2. For each group, split it into individual column names. +3. Calculate the other columns by finding the set difference between allCols and the unchanged columns. +4. Generate an update statement for the current group by setting the appropriate conditions +and updating the other columns. + - The condition includes checking if the _PEERDB_RECORD_TYPE is not 2 (not a DELETE) and if the + _PEERDB_UNCHANGED_TOAST_COLUMNS match the current group. + - The update sets the other columns to their corresponding values + from the SOURCE table. It doesn't set (make null the Unchanged toast columns. + +5. Append the update statement to the list of generated statements. +6. Repeat steps 1-5 for each unique set of unchanged toast column groups. +7. Return the list of generated update statements. +*/ +func (m *mergeStmtGenerator) generateUpdateStatements(allCols []string) []string { + handleSoftDelete := m.peerdbCols.SoftDelete && (m.peerdbCols.SoftDeleteColName != "") + // weird way of doing it but avoids prealloc lint + updateStmts := make([]string, 0, func() int { + if handleSoftDelete { + return 2 * len(m.unchangedToastColumns) + } + return len(m.unchangedToastColumns) + }()) + + for _, cols := range m.unchangedToastColumns { + unchangedColsArray := strings.Split(cols, ",") + otherCols := utils.ArrayMinus(allCols, unchangedColsArray) + tmpArray := make([]string, 0, len(otherCols)+2) + for _, colName := range otherCols { + normalizedColName := SnowflakeIdentifierNormalize(colName) + tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName)) + } + + // set the synced at column to the current timestamp + if m.peerdbCols.SyncedAtColName != "" { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, + m.peerdbCols.SyncedAtColName)) + } + // set soft-deleted to false, tackles insert after soft-delete + if handleSoftDelete { + tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, + m.peerdbCols.SoftDeleteColName)) + } + + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + + // generates update statements for the case where updates and deletes happen in the same branch + // the backfill has happened from the pull side already, so treat the DeleteRecord as an update + // and then set soft-delete to true. + if handleSoftDelete { + tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf(`"%s" = TRUE`, + m.peerdbCols.SoftDeleteColName)) + ssep := strings.Join(tmpArray, ", ") + updateStmt := fmt.Sprintf(`WHEN MATCHED AND + (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' + THEN UPDATE SET %s `, cols, ssep) + updateStmts = append(updateStmts, updateStmt) + } + } + return updateStmts +} diff --git a/flow/connectors/snowflake/merge_stmt_generator_test.go b/flow/connectors/snowflake/merge_stmt_generator_test.go index c4eb2e973e..f8b70f566a 100644 --- a/flow/connectors/snowflake/merge_stmt_generator_test.go +++ b/flow/connectors/snowflake/merge_stmt_generator_test.go @@ -2,34 +2,104 @@ package connsnowflake import ( "reflect" - "strings" "testing" + + "github.com/PeerDB-io/peer-flow/connectors/utils" + "github.com/PeerDB-io/peer-flow/generated/protos" ) +func TestGenerateUpdateStatement(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, + } + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + +func TestGenerateUpdateStatement_WithSoftDelete(t *testing.T) { + allCols := []string{"col1", "col2", "col3"} + unchangedToastCols := []string{""} + + expected := []string{ + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + } + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) + + for i := range expected { + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) + } + + if !reflect.DeepEqual(result, expected) { + t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) + } +} + func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { - c := &SnowflakeConnector{} allCols := []string{"col1", "col2", "col3"} unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} expected := []string{ `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3' THEN UPDATE SET "COL1" = SOURCE."COL1", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2' THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3' THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", - "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP`, } - result := c.generateUpdateStatements("_PEERDB_SYNCED_AT", "_PEERDB_SOFT_DELETE", false, allCols, unchangedToastCols) + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: false, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { @@ -37,31 +107,52 @@ func TestGenerateUpdateStatement_WithUnchangedToastCols(t *testing.T) { } } -func TestGenerateUpdateStatement_EmptyColumns(t *testing.T) { - c := &SnowflakeConnector{} +func TestGenerateUpdateStatement_WithUnchangedToastColsAndSoftDelete(t *testing.T) { allCols := []string{"col1", "col2", "col3"} - unchangedToastCols := []string{""} + unchangedToastCols := []string{"", "col2,col3", "col2", "col3"} expected := []string{ `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' - THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2,col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col2' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL3" = SOURCE."COL3", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = FALSE`, + `WHEN MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='col3' + THEN UPDATE SET "COL1" = SOURCE."COL1", "COL2" = SOURCE."COL2", + "_PEERDB_SYNCED_AT" = CURRENT_TIMESTAMP, "_PEERDB_SOFT_DELETE" = TRUE`, } - result := c.generateUpdateStatements("_PEERDB_SYNCED_AT", "_PEERDB_SOFT_DELETE", false, allCols, unchangedToastCols) + mergeGen := &mergeStmtGenerator{ + unchangedToastColumns: unchangedToastCols, + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: true, + SyncedAtColName: "_PEERDB_SYNCED_AT", + SoftDeleteColName: "_PEERDB_SOFT_DELETE", + }, + } + result := mergeGen.generateUpdateStatements(allCols) for i := range expected { - expected[i] = removeSpacesTabsNewlines(expected[i]) - result[i] = removeSpacesTabsNewlines(result[i]) + expected[i] = utils.RemoveSpacesTabsNewlines(expected[i]) + result[i] = utils.RemoveSpacesTabsNewlines(result[i]) } if !reflect.DeepEqual(result, expected) { t.Errorf("Unexpected result. Expected: %v, but got: %v", expected, result) } } - -func removeSpacesTabsNewlines(s string) string { - s = strings.ReplaceAll(s, " ", "") - s = strings.ReplaceAll(s, "\t", "") - s = strings.ReplaceAll(s, "\n", "") - return s -} diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index bb1eb4240c..fe1326ebee 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -638,18 +638,43 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest tableName := destinationTableName // local variable for the closure g.Go(func() error { - rowsAffected, err := c.generateAndExecuteMergeStatement( - gCtx, - tableName, - tableNametoUnchangedToastCols[tableName], - getRawTableIdentifier(req.FlowJobName), - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID, - req) + mergeGen := &mergeStmtGenerator{ + rawTableName: getRawTableIdentifier(req.FlowJobName), + dstTableName: tableName, + syncBatchID: batchIDs.SyncBatchID, + normalizeBatchID: batchIDs.NormalizeBatchID, + normalizedTableSchema: c.tableSchemaMapping[tableName], + unchangedToastColumns: tableNametoUnchangedToastCols[tableName], + peerdbCols: &protos.PeerDBColumns{ + SoftDelete: req.SoftDelete, + SoftDeleteColName: req.SoftDeleteColName, + SyncedAtColName: req.SyncedAtColName, + }, + } + mergeStatement, err := mergeGen.generateMergeStmt() + + startTime := time.Now() + c.logger.Info("[merge] merging records...", slog.String("destTable", tableName)) + + result, err := c.database.ExecContext(gCtx, mergeStatement, tableName) + if err != nil { + return fmt.Errorf("failed to merge records into %s (statement: %s): %w", + tableName, mergeStatement, err) + } + + endTime := time.Now() + c.logger.Info(fmt.Sprintf("[merge] merged records into %s, took: %d seconds", + tableName, endTime.Sub(startTime)/time.Second)) if err != nil { c.logger.Error("[merge] error while normalizing records", slog.Any("error", err)) return err } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected by merge statement for table %s: %w", tableName, err) + } + atomic.AddInt64(&totalRowsAffected, rowsAffected) return nil }) @@ -813,146 +838,6 @@ func getRawTableIdentifier(jobName string) string { return fmt.Sprintf("%s_%s", rawTablePrefix, jobName) } -func (c *SnowflakeConnector) generateAndExecuteMergeStatement( - ctx context.Context, - destinationTableIdentifier string, - unchangedToastColumns []string, - rawTableIdentifier string, - syncBatchID int64, - normalizeBatchID int64, - normalizeReq *model.NormalizeRecordsRequest, -) (int64, error) { - normalizedTableSchema := c.tableSchemaMapping[destinationTableIdentifier] - parsedDstTable, err := utils.ParseSchemaTable(destinationTableIdentifier) - if err != nil { - return 0, fmt.Errorf("unable to parse destination table '%s'", parsedDstTable) - } - columnNames := utils.TableSchemaColumnNames(normalizedTableSchema) - - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(normalizedTableSchema)) - err = utils.IterColumnsError(normalizedTableSchema, func(columnName, genericColumnType string) error { - qvKind := qvalue.QValueKind(genericColumnType) - sfType, err := qValueKindToSnowflakeType(qvKind) - if err != nil { - return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) - } - - targetColumnName := SnowflakeIdentifierNormalize(columnName) - switch qvalue.QValueKind(genericColumnType) { - case qvalue.QValueKindBytes, qvalue.QValueKindBit: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:\"%s\") "+ - "AS %s,", toVariantColumnName, columnName, targetColumnName)) - case qvalue.QValueKindGeography: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TO_GEOGRAPHY(CAST(%s:\"%s\" AS STRING),true) AS %s,", - toVariantColumnName, columnName, targetColumnName)) - case qvalue.QValueKindGeometry: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TO_GEOMETRY(CAST(%s:\"%s\" AS STRING),true) AS %s,", - toVariantColumnName, columnName, targetColumnName)) - case qvalue.QValueKindJSON: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("PARSE_JSON(CAST(%s:\"%s\" AS STRING)) AS %s,", - toVariantColumnName, columnName, targetColumnName)) - // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types - // case model.ColumnTypeTime: - // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ - // "Microseconds*1000) "+ - // "AS %s,", toVariantColumnName, columnName, columnName)) - default: - if qvKind == qvalue.QValueKindNumeric { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, - fmt.Sprintf("TRY_CAST((%s:\"%s\")::text AS %s) AS %s,", - toVariantColumnName, columnName, sfType, targetColumnName)) - } else { - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:\"%s\" AS %s) AS %s,", - toVariantColumnName, columnName, sfType, targetColumnName)) - } - } - return nil - }) - if err != nil { - return 0, err - } - flattenedCastsSQL := strings.TrimSuffix(strings.Join(flattenedCastsSQLArray, ""), ",") - - quotedUpperColNames := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - quotedUpperColNames = append(quotedUpperColNames, SnowflakeIdentifierNormalize(columnName)) - } - // append synced_at column - quotedUpperColNames = append(quotedUpperColNames, - fmt.Sprintf(`"%s"`, strings.ToUpper(normalizeReq.SyncedAtColName)), - ) - - insertColumnsSQL := strings.TrimSuffix(strings.Join(quotedUpperColNames, ","), ",") - - insertValuesSQLArray := make([]string, 0, len(columnNames)) - for _, columnName := range columnNames { - normalizedColName := SnowflakeIdentifierNormalize(columnName) - insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("SOURCE.%s", normalizedColName)) - } - // fill in synced_at column - insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") - insertValuesSQL := strings.Join(insertValuesSQLArray, ",") - updateStatementsforToastCols := c.generateUpdateStatements(normalizeReq.SyncedAtColName, - normalizeReq.SoftDeleteColName, normalizeReq.SoftDelete, - columnNames, unchangedToastColumns) - - // handling the case when an insert and delete happen in the same batch, with updates in the middle - // with soft-delete, we want the row to be in the destination with SOFT_DELETE true - // the current merge statement doesn't do that, so we add another case to insert the DeleteRecord - if normalizeReq.SoftDelete { - softDeleteInsertColumnsSQL := strings.Join(append(quotedUpperColNames, - normalizeReq.SoftDeleteColName), ",") - softDeleteInsertValuesSQL := insertValuesSQL + ",TRUE" - updateStatementsforToastCols = append(updateStatementsforToastCols, - fmt.Sprintf("WHEN NOT MATCHED AND (SOURCE._PEERDB_RECORD_TYPE = 2) THEN INSERT (%s) VALUES(%s)", - softDeleteInsertColumnsSQL, softDeleteInsertValuesSQL)) - } - updateStringToastCols := strings.Join(updateStatementsforToastCols, " ") - - normalizedpkeyColsArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - pkeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for _, pkeyColName := range normalizedTableSchema.PrimaryKeyColumns { - normalizedPkeyColName := SnowflakeIdentifierNormalize(pkeyColName) - normalizedpkeyColsArray = append(normalizedpkeyColsArray, normalizedPkeyColName) - pkeySelectSQLArray = append(pkeySelectSQLArray, fmt.Sprintf("TARGET.%s = SOURCE.%s", - normalizedPkeyColName, normalizedPkeyColName)) - } - // TARGET. = SOURCE. AND TARGET. = SOURCE. ... - pkeySelectSQL := strings.Join(pkeySelectSQLArray, " AND ") - - deletePart := "DELETE" - if normalizeReq.SoftDelete { - colName := normalizeReq.SoftDeleteColName - deletePart = fmt.Sprintf("UPDATE SET %s = TRUE", colName) - if normalizeReq.SyncedAtColName != "" { - deletePart = fmt.Sprintf("%s, %s = CURRENT_TIMESTAMP", deletePart, normalizeReq.SyncedAtColName) - } - } - - mergeStatement := fmt.Sprintf(mergeStatementSQL, snowflakeSchemaTableNormalize(parsedDstTable), - toVariantColumnName, rawTableIdentifier, normalizeBatchID, syncBatchID, flattenedCastsSQL, - fmt.Sprintf("(%s)", strings.Join(normalizedpkeyColsArray, ",")), - pkeySelectSQL, insertColumnsSQL, insertValuesSQL, updateStringToastCols, deletePart) - - startTime := time.Now() - c.logger.Info("[merge] merging records...", slog.String("destTable", destinationTableIdentifier)) - - result, err := c.database.ExecContext(ctx, mergeStatement, destinationTableIdentifier) - if err != nil { - return 0, fmt.Errorf("failed to merge records into %s (statement: %s): %w", - destinationTableIdentifier, mergeStatement, err) - } - - endTime := time.Now() - c.logger.Info(fmt.Sprintf("[merge] merged records into %s, took: %d seconds", - destinationTableIdentifier, endTime.Sub(startTime)/time.Second)) - - return result.RowsAffected() -} - func (c *SnowflakeConnector) jobMetadataExists(jobName string) (bool, error) { var result pgtype.Bool err := c.database.QueryRowContext(c.ctx, @@ -1039,75 +924,6 @@ func (c *SnowflakeConnector) createPeerDBInternalSchema(createSchemaTx *sql.Tx) return nil } -/* -This function generates UPDATE statements for a MERGE operation based on the provided inputs. - -Inputs: -1. allCols: An array of all column names. -2. unchangedToastCols: An array capturing unique sets of unchanged toast column groups. -3. softDeleteCol: just set to false in the case we see an insert after a soft-deleted column -4. syncedAtCol: set to the CURRENT_TIMESTAMP - -Algorithm: -1. Iterate over each unique set of unchanged toast column groups. -2. For each group, split it into individual column names. -3. Calculate the other columns by finding the set difference between allCols and the unchanged columns. -4. Generate an update statement for the current group by setting the appropriate conditions -and updating the other columns. - - The condition includes checking if the _PEERDB_RECORD_TYPE is not 2 (not a DELETE) and if the - _PEERDB_UNCHANGED_TOAST_COLUMNS match the current group. - - The update sets the other columns to their corresponding values - from the SOURCE table. It doesn't set (make null the Unchanged toast columns. - -5. Append the update statement to the list of generated statements. -6. Repeat steps 1-5 for each unique set of unchanged toast column groups. -7. Return the list of generated update statements. -*/ -func (c *SnowflakeConnector) generateUpdateStatements( - syncedAtCol string, softDeleteCol string, softDelete bool, - allCols []string, unchangedToastCols []string, -) []string { - updateStmts := make([]string, 0, len(unchangedToastCols)) - - for _, cols := range unchangedToastCols { - unchangedColsArray := strings.Split(cols, ",") - otherCols := utils.ArrayMinus(allCols, unchangedColsArray) - tmpArray := make([]string, 0, len(otherCols)+2) - for _, colName := range otherCols { - normalizedColName := SnowflakeIdentifierNormalize(colName) - tmpArray = append(tmpArray, fmt.Sprintf("%s = SOURCE.%s", normalizedColName, normalizedColName)) - } - - // set the synced at column to the current timestamp - if syncedAtCol != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = CURRENT_TIMESTAMP`, syncedAtCol)) - } - // set soft-deleted to false, tackles insert after soft-delete - if softDeleteCol != "" { - tmpArray = append(tmpArray, fmt.Sprintf(`"%s" = FALSE`, softDeleteCol)) - } - - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - (SOURCE._PEERDB_RECORD_TYPE != 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - - // generates update statements for the case where updates and deletes happen in the same branch - // the backfill has happened from the pull side already, so treat the DeleteRecord as an update - // and then set soft-delete to true. - if softDelete && (softDeleteCol != "") { - tmpArray = append(tmpArray[:len(tmpArray)-1], fmt.Sprintf(`"%s" = TRUE`, softDeleteCol)) - ssep := strings.Join(tmpArray, ", ") - updateStmt := fmt.Sprintf(`WHEN MATCHED AND - (SOURCE._PEERDB_RECORD_TYPE = 2) AND _PEERDB_UNCHANGED_TOAST_COLUMNS='%s' - THEN UPDATE SET %s `, cols, ssep) - updateStmts = append(updateStmts, updateStmt) - } - } - return updateStmts -} - func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { renameTablesTx, err := c.database.BeginTx(c.ctx, nil) if err != nil { diff --git a/flow/connectors/utils/identifiers.go b/flow/connectors/utils/identifiers.go index 5318605a93..19867971a9 100644 --- a/flow/connectors/utils/identifiers.go +++ b/flow/connectors/utils/identifiers.go @@ -49,3 +49,10 @@ func IsLower(s string) bool { } return true } + +func RemoveSpacesTabsNewlines(s string) string { + s = strings.ReplaceAll(s, " ", "") + s = strings.ReplaceAll(s, "\t", "") + s = strings.ReplaceAll(s, "\n", "") + return s +} From 921203570a050142e5019a0b8b619688fe5f42cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 9 Jan 2024 19:21:56 +0000 Subject: [PATCH 2/3] SendableStream: remove mutex, require sync (#1039) Cursor is mutable so type system already knows it has exclusive access --- nexus/peer-bigquery/src/cursor.rs | 8 +++----- nexus/peer-cursor/src/lib.rs | 2 +- nexus/peer-snowflake/src/cursor.rs | 8 +++----- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/nexus/peer-bigquery/src/cursor.rs b/nexus/peer-bigquery/src/cursor.rs index 23812a382a..6b8e758a5e 100644 --- a/nexus/peer-bigquery/src/cursor.rs +++ b/nexus/peer-bigquery/src/cursor.rs @@ -1,5 +1,4 @@ use dashmap::DashMap; -use tokio::sync::Mutex; use futures::StreamExt; use peer_cursor::{QueryExecutor, QueryOutput, Records, SchemaRef, SendableStream}; @@ -10,7 +9,7 @@ use crate::BigQueryQueryExecutor; pub struct BigQueryCursor { position: usize, - stream: Mutex, + stream: SendableStream, schema: SchemaRef, } @@ -42,7 +41,7 @@ impl BigQueryCursorManager { // Create a new cursor let cursor = BigQueryCursor { position: 0, - stream: Mutex::new(stream), + stream, schema, }; @@ -75,9 +74,8 @@ impl BigQueryCursorManager { let prev_end = cursor.position; let mut cursor_position = cursor.position; { - let mut stream = cursor.stream.lock().await; while cursor_position - prev_end < count { - match stream.next().await { + match cursor.stream.next().await { Some(Ok(record)) => { records.push(record); cursor_position += 1; diff --git a/nexus/peer-cursor/src/lib.rs b/nexus/peer-cursor/src/lib.rs index 7d2525a7df..e4029ab003 100644 --- a/nexus/peer-cursor/src/lib.rs +++ b/nexus/peer-cursor/src/lib.rs @@ -23,7 +23,7 @@ pub trait RecordStream: Stream> { fn schema(&self) -> SchemaRef; } -pub type SendableStream = Pin>; +pub type SendableStream = Pin>; pub struct Records { pub records: Vec, diff --git a/nexus/peer-snowflake/src/cursor.rs b/nexus/peer-snowflake/src/cursor.rs index 475a2d7f35..ef247d8243 100644 --- a/nexus/peer-snowflake/src/cursor.rs +++ b/nexus/peer-snowflake/src/cursor.rs @@ -4,11 +4,10 @@ use futures::StreamExt; use peer_cursor::{QueryExecutor, QueryOutput, Records, SchemaRef, SendableStream}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use sqlparser::ast::Statement; -use tokio::sync::Mutex; pub struct SnowflakeCursor { position: usize, - stream: Mutex, + stream: SendableStream, schema: SchemaRef, } @@ -39,7 +38,7 @@ impl SnowflakeCursorManager { // Create a new cursor let cursor = SnowflakeCursor { position: 0, - stream: Mutex::new(stream), + stream, schema, }; @@ -72,9 +71,8 @@ impl SnowflakeCursorManager { let prev_end = cursor.position; let mut cursor_position = cursor.position; { - let mut stream = cursor.stream.lock().await; while cursor_position - prev_end < count { - match stream.next().await { + match cursor.stream.next().await { Some(Ok(record)) => { records.push(record); cursor_position += 1; From 9124df237a296982c28c314fea83e079e0c3eb1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 9 Jan 2024 19:22:29 +0000 Subject: [PATCH 3/3] nexus: don't wrap Catalog in Mutex (#1040) Catalog's methods already handle synchronization, besides when running migrations, which already uses exclusive connections Also query source/destination peer in parallel --- nexus/parser/src/lib.rs | 8 +-- nexus/server/src/main.rs | 123 +++++++++++++++++---------------------- 2 files changed, 58 insertions(+), 73 deletions(-) diff --git a/nexus/parser/src/lib.rs b/nexus/parser/src/lib.rs index f99dbe8751..4a305c7899 100644 --- a/nexus/parser/src/lib.rs +++ b/nexus/parser/src/lib.rs @@ -11,13 +11,12 @@ use pgwire::{ error::{ErrorInfo, PgWireError, PgWireResult}, }; use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; -use tokio::sync::Mutex; const DIALECT: PostgreSqlDialect = PostgreSqlDialect {}; #[derive(Clone)] pub struct NexusQueryParser { - catalog: Arc>, + catalog: Arc, } #[derive(Debug, Clone)] @@ -93,13 +92,12 @@ pub struct NexusParsedStatement { } impl NexusQueryParser { - pub fn new(catalog: Arc>) -> Self { + pub fn new(catalog: Arc) -> Self { Self { catalog } } pub async fn get_peers_bridge(&self) -> PgWireResult> { - let catalog = self.catalog.lock().await; - let peers = catalog.get_peers().await; + let peers = self.catalog.get_peers().await; peers.map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index bb2219512e..55e096cb7f 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -13,6 +13,7 @@ use clap::Parser; use cursor::PeerCursors; use dashmap::{mapref::entry::Entry as DashEntry, DashMap}; use flow_rs::grpc::{FlowGrpcClient, PeerValidationResult}; +use futures::join; use peer_bigquery::BigQueryQueryExecutor; use peer_connections::{PeerConnectionTracker, PeerConnections}; use peer_cursor::{ @@ -40,7 +41,7 @@ use pt::{ }; use rand::Rng; use tokio::signal::unix::{signal, SignalKind}; -use tokio::sync::{Mutex, MutexGuard}; +use tokio::sync::Mutex; use tokio::{io::AsyncWriteExt, net::TcpListener}; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{fmt, prelude::*, EnvFilter}; @@ -78,7 +79,7 @@ impl AuthSource for FixedPasswordAuthSource { } pub struct NexusBackend { - catalog: Arc>, + catalog: Arc, peer_connections: PeerConnectionTracker, query_parser: NexusQueryParser, peer_cursors: Mutex, @@ -89,7 +90,7 @@ pub struct NexusBackend { impl NexusBackend { pub fn new( - catalog: Arc>, + catalog: Arc, peer_connections: PeerConnectionTracker, flow_handler: Option>>, peerdb_fdw_mode: bool, @@ -161,7 +162,7 @@ impl NexusBackend { } async fn check_for_mirror( - catalog: &MutexGuard<'_, Catalog>, + catalog: &Catalog, flow_name: &str, ) -> PgWireResult> { let workflow_details = catalog @@ -175,10 +176,7 @@ impl NexusBackend { Ok(workflow_details) } - async fn get_peer_of_mirror( - catalog: &MutexGuard<'_, Catalog>, - peer_name: &str, - ) -> PgWireResult { + async fn get_peer_of_mirror(catalog: &Catalog, peer_name: &str) -> PgWireResult { let peer = catalog.get_peer(peer_name).await.map_err(|err| { PgWireError::ApiError(format!("unable to get peer {:?}: {:?}", peer_name, err).into()) })?; @@ -251,13 +249,13 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "DROP MIRROR: mirror_name: {}, if_exists: {}", flow_job_name, if_exists ); - let workflow_details = catalog + let workflow_details = self + .catalog .get_workflow_details_for_flow_job(flow_job_name) .await .map_err(|err| { @@ -284,7 +282,7 @@ impl NexusBackend { format!("unable to shutdown flow job: {:?}", err).into(), ) })?; - catalog + self.catalog .delete_flow_job_entry(flow_job_name) .await .map_err(|err| { @@ -334,14 +332,13 @@ impl NexusBackend { } let mirror_details; { - let catalog = self.catalog.lock().await; mirror_details = - Self::check_for_mirror(&catalog, &qrep_flow_job.name).await?; + Self::check_for_mirror(self.catalog.as_ref(), &qrep_flow_job.name) + .await?; } if mirror_details.is_none() { { - let catalog = self.catalog.lock().await; - catalog + self.catalog .create_qrep_flow_job_entry(qrep_flow_job) .await .map_err(|err| { @@ -399,8 +396,7 @@ impl NexusBackend { })?; } - let catalog = self.catalog.lock().await; - catalog.create_peer(peer.as_ref()).await.map_err(|e| { + self.catalog.create_peer(peer.as_ref()).await.map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "internal_error".to_owned(), @@ -420,8 +416,8 @@ impl NexusBackend { "flow service is not configured".into(), )); } - let catalog = self.catalog.lock().await; - let mirror_details = Self::check_for_mirror(&catalog, &flow_job.name).await?; + let mirror_details = + Self::check_for_mirror(self.catalog.as_ref(), &flow_job.name).await?; if mirror_details.is_none() { // reject duplicate source tables or duplicate target tables let table_mappings_count = flow_job.table_mappings.len(); @@ -450,7 +446,7 @@ impl NexusBackend { } } - catalog + self.catalog .create_cdc_flow_job_entry(flow_job) .await .map_err(|err| { @@ -460,10 +456,12 @@ impl NexusBackend { })?; // get source and destination peers - let src_peer = - Self::get_peer_of_mirror(&catalog, &flow_job.source_peer).await?; - let dst_peer = - Self::get_peer_of_mirror(&catalog, &flow_job.target_peer).await?; + let (src_peer, dst_peer) = join!( + Self::get_peer_of_mirror(self.catalog.as_ref(), &flow_job.source_peer), + Self::get_peer_of_mirror(self.catalog.as_ref(), &flow_job.target_peer), + ); + let src_peer = src_peer?; + let dst_peer = dst_peer?; // make a request to the flow service to start the job. let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await; @@ -476,7 +474,7 @@ impl NexusBackend { ) })?; - catalog + self.catalog .update_workflow_id_for_flow_job(&flow_job.name, &workflow_id) .await .map_err(|err| { @@ -505,8 +503,7 @@ impl NexusBackend { } if let Some(job) = { - let catalog = self.catalog.lock().await; - catalog + self.catalog .get_qrep_flow_job_by_name(flow_job_name) .await .map_err(|err| { @@ -540,17 +537,21 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "DROP PEER: peer_name: {}, if_exists: {}", peer_name, if_exists ); - let peer_exists = catalog.check_peer_entry(peer_name).await.map_err(|err| { - PgWireError::ApiError( - format!("unable to query catalog for peer metadata: {:?}", err).into(), - ) - })?; + let peer_exists = + self.catalog + .check_peer_entry(peer_name) + .await + .map_err(|err| { + PgWireError::ApiError( + format!("unable to query catalog for peer metadata: {:?}", err) + .into(), + ) + })?; tracing::info!("peer exist count: {}", peer_exists); if peer_exists != 0 { let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await; @@ -590,8 +591,7 @@ impl NexusBackend { let qrep_config = { // retrieve the mirror job since DROP MIRROR will delete the row later. - let catalog = self.catalog.lock().await; - catalog + self.catalog .get_qrep_config_proto(mirror_name) .await .map_err(|err| { @@ -632,8 +632,7 @@ impl NexusBackend { ) })?; // relock catalog, DROP MIRROR is done with it now - let catalog = self.catalog.lock().await; - catalog + self.catalog .update_workflow_id_for_flow_job( &qrep_config.flow_job_name, &workflow_id, @@ -674,13 +673,13 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "[PAUSE MIRROR] mirror_name: {}, if_exists: {}", flow_job_name, if_exists ); - let workflow_details = catalog + let workflow_details = self + .catalog .get_workflow_details_for_flow_job(flow_job_name) .await .map_err(|err| { @@ -737,13 +736,13 @@ impl NexusBackend { )); } - let catalog = self.catalog.lock().await; tracing::info!( "[RESUME MIRROR] mirror_name: {}, if_exists: {}", flow_job_name, if_exists ); - let workflow_details = catalog + let workflow_details = self + .catalog .get_workflow_details_for_flow_job(flow_job_name) .await .map_err(|err| { @@ -805,8 +804,7 @@ impl NexusBackend { } QueryAssociation::Catalog => { tracing::info!("handling catalog query: {}", stmt); - let catalog = self.catalog.lock().await; - Arc::clone(catalog.get_executor()) + Arc::clone(self.catalog.get_executor()) } }; @@ -829,10 +827,7 @@ impl NexusBackend { analyzer::CursorEvent::Close(c) => peer_cursors.get_peer(&c), }; match peer { - None => { - let catalog = self.catalog.lock().await; - Arc::clone(catalog.get_executor()) - } + None => Arc::clone(self.catalog.get_executor()), Some(peer) => self.get_peer_executor(peer).await.map_err(|err| { PgWireError::ApiError( format!("unable to get peer executor: {:?}", err).into(), @@ -850,22 +845,18 @@ impl NexusBackend { } async fn run_qrep_mirror(&self, qrep_flow_job: &QRepFlowJob) -> PgWireResult { - let catalog = self.catalog.lock().await; - + let (src_peer, dst_peer) = join!( + self.catalog.get_peer(&qrep_flow_job.source_peer), + self.catalog.get_peer(&qrep_flow_job.target_peer), + ); // get source and destination peers - let src_peer = catalog - .get_peer(&qrep_flow_job.source_peer) - .await - .map_err(|err| { - PgWireError::ApiError(format!("unable to get source peer: {:?}", err).into()) - })?; + let src_peer = src_peer.map_err(|err| { + PgWireError::ApiError(format!("unable to get source peer: {:?}", err).into()) + })?; - let dst_peer = catalog - .get_peer(&qrep_flow_job.target_peer) - .await - .map_err(|err| { - PgWireError::ApiError(format!("unable to get destination peer: {:?}", err).into()) - })?; + let dst_peer = dst_peer.map_err(|err| { + PgWireError::ApiError(format!("unable to get destination peer: {:?}", err).into()) + })?; // make a request to the flow service to start the job. let mut flow_handler = self.flow_handler.as_ref().unwrap().lock().await; @@ -876,7 +867,7 @@ impl NexusBackend { PgWireError::ApiError(format!("unable to submit job: {:?}", err).into()) })?; - catalog + self.catalog .update_workflow_id_for_flow_job(&qrep_flow_job.name, &workflow_id) .await .map_err(|err| { @@ -1087,11 +1078,7 @@ impl ExtendedQueryHandler for NexusBackend { } } } - QueryAssociation::Catalog => { - let catalog = self.catalog.lock().await; - let executor = catalog.get_executor(); - executor.describe(stmt).await? - } + QueryAssociation::Catalog => self.catalog.get_executor().describe(stmt).await?, }; if let Some(described_schema) = schema { if self.peerdb_fdw_mode { @@ -1320,7 +1307,7 @@ pub async fn main() -> anyhow::Result<()> { let tracker = PeerConnectionTracker::new(conn_uuid, conn_peer_conns); let processor = Arc::new(NexusBackend::new( - Arc::new(Mutex::new(catalog)), + Arc::new(catalog), tracker, conn_flow_handler, peerdb_fdw_mode,