diff --git a/flow/cmd/validate_mirror.go b/flow/cmd/validate_mirror.go index 7f10020e58..8240f7df01 100644 --- a/flow/cmd/validate_mirror.go +++ b/flow/cmd/validate_mirror.go @@ -12,21 +12,20 @@ import ( func (h *FlowRequestHandler) ValidateCDCMirror( ctx context.Context, req *protos.CreateCDCFlowRequest, ) (*protos.ValidateCDCMirrorResponse, error) { - pgPeer, err := connpostgres.NewPostgresConnector(ctx, req.ConnectionConfigs.Source.GetPostgresConfig()) + sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig() + if sourcePeerConfig == nil { + slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source)) + return nil, fmt.Errorf("source peer config is nil") + } + + pgPeer, err := connpostgres.NewPostgresConnector(ctx, sourcePeerConfig) if err != nil { return &protos.ValidateCDCMirrorResponse{ Ok: false, }, fmt.Errorf("failed to create postgres connector: %v", err) } - defer pgPeer.Close() - sourcePeerConfig := req.ConnectionConfigs.Source.GetPostgresConfig() - if sourcePeerConfig == nil { - slog.Error("/validatecdc source peer config is nil", slog.Any("peer", req.ConnectionConfigs.Source)) - return nil, fmt.Errorf("source peer config is nil") - } - // Check permissions of postgres peer err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User) if err != nil { diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 59498d927d..26d36df3d5 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -810,13 +810,14 @@ func (c *BigQueryConnector) SetupNormalizedTables( // convert the column names and types to bigquery types columns := make([]*bigquery.FieldSchema, 0, len(tableSchema.ColumnNames)+2) - utils.IterColumns(tableSchema, func(colName, genericColType string) { + for i, colName := range tableSchema.ColumnNames { + genericColType := tableSchema.ColumnTypes[i] columns = append(columns, &bigquery.FieldSchema{ Name: colName, Type: qValueKindToBigQueryType(genericColType), Repeated: qvalue.QValueKind(genericColType).IsArray(), }) - }) + } if req.SoftDeleteColName != "" { columns = append(columns, &bigquery.FieldSchema{ @@ -908,7 +909,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos dstDatasetTable.string())) if req.SoftDeleteColName != nil { - allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",") + allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",") pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",") c.logger.InfoContext(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string())) diff --git a/flow/connectors/bigquery/merge_stmt_generator.go b/flow/connectors/bigquery/merge_stmt_generator.go index d87a83a290..798cb88cbb 100644 --- a/flow/connectors/bigquery/merge_stmt_generator.go +++ b/flow/connectors/bigquery/merge_stmt_generator.go @@ -34,7 +34,7 @@ type mergeStmtGenerator struct { func (m *mergeStmtGenerator) generateFlattenedCTE() string { // for each column in the normalized table, generate CAST + JSON_EXTRACT_SCALAR // statement. - flattenedProjs := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)+3) + flattenedProjs := make([]string, 0, len(m.normalizedTableSchema.ColumnNames)+3) for i, colName := range m.normalizedTableSchema.ColumnNames { colType := m.normalizedTableSchema.ColumnTypes[i] @@ -124,7 +124,7 @@ func (m *mergeStmtGenerator) generateDeDupedCTE() string { // generateMergeStmt generates a merge statement. func (m *mergeStmtGenerator) generateMergeStmt(unchangedToastColumns []string) string { // comma separated list of column names - columnCount := utils.TableSchemaColumns(m.normalizedTableSchema) + columnCount := len(m.normalizedTableSchema.ColumnNames) backtickColNames := make([]string, 0, columnCount) shortBacktickColNames := make([]string, 0, columnCount) pureColNames := make([]string, 0, columnCount) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 5f1efd046f..c4c9a879dc 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -424,11 +424,12 @@ func generateCreateTableSQLForNormalizedTable( softDeleteColName string, syncedAtColName string, ) string { - createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2) - utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) { + createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2) + for i, columnName := range sourceTableSchema.ColumnNames { + genericColumnType := sourceTableSchema.ColumnTypes[i] createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("%s %s", QuoteIdentifier(columnName), qValueKindToPostgresType(genericColumnType))) - }) + } if softDeleteColName != "" { createTableSQLArray = append(createTableSQLArray, diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index 932312073f..47da11f0c4 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -44,11 +44,12 @@ func (n *normalizeStmtGenerator) generateNormalizeStatements() []string { } func (n *normalizeStmtGenerator) generateFallbackStatements() []string { - columnCount := utils.TableSchemaColumns(n.normalizedTableSchema) + columnCount := len(n.normalizedTableSchema.ColumnNames) columnNames := make([]string, 0, columnCount) flattenedCastsSQLArray := make([]string, 0, columnCount) primaryKeyColumnCasts := make(map[string]string, len(n.normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + for i, columnName := range n.normalizedTableSchema.ColumnNames { + genericColumnType := n.normalizedTableSchema.ColumnTypes[i] quotedCol := QuoteIdentifier(columnName) stringCol := QuoteLiteral(columnName) columnNames = append(columnNames, quotedCol) @@ -64,16 +65,16 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { if slices.Contains(n.normalizedTableSchema.PrimaryKeyColumns, columnName) { primaryKeyColumnCasts[columnName] = fmt.Sprintf("(_peerdb_data->>%s)::%s", stringCol, pgType) } - }) + } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) insertColumnsSQL := strings.Join(columnNames, ",") - updateColumnsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) - utils.IterColumns(n.normalizedTableSchema, func(columnName, _ string) { + updateColumnsSQLArray := make([]string, 0, columnCount) + for _, columnName := range n.normalizedTableSchema.ColumnNames { quotedCol := QuoteIdentifier(columnName) updateColumnsSQLArray = append(updateColumnsSQLArray, fmt.Sprintf(`%s=EXCLUDED.%s`, quotedCol, quotedCol)) - }) + } updateColumnsSQL := strings.Join(updateColumnsSQLArray, ",") deleteWhereClauseArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) for columnName, columnCast := range primaryKeyColumnCasts { @@ -104,19 +105,20 @@ func (n *normalizeStmtGenerator) generateFallbackStatements() []string { } func (n *normalizeStmtGenerator) generateMergeStatement() string { - quotedColumnNames := utils.TableSchemaColumnNames(n.normalizedTableSchema) - for i, columnName := range quotedColumnNames { - quotedColumnNames[i] = QuoteIdentifier(columnName) - } + columnCount := len(n.normalizedTableSchema.ColumnNames) + quotedColumnNames := make([]string, columnCount) - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(n.normalizedTableSchema)) + flattenedCastsSQLArray := make([]string, 0, columnCount) parsedDstTable, _ := utils.ParseSchemaTable(n.dstTableName) primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(n.normalizedTableSchema.PrimaryKeyColumns)) - utils.IterColumns(n.normalizedTableSchema, func(columnName, genericColumnType string) { + for i, columnName := range n.normalizedTableSchema.ColumnNames { + genericColumnType := n.normalizedTableSchema.ColumnTypes[i] quotedCol := QuoteIdentifier(columnName) stringCol := QuoteLiteral(columnName) + quotedColumnNames[i] = quotedCol + pgType := qValueKindToPostgresType(genericColumnType) if qvalue.QValueKind(genericColumnType).IsArray() { flattenedCastsSQLArray = append(flattenedCastsSQLArray, @@ -131,9 +133,9 @@ func (n *normalizeStmtGenerator) generateMergeStatement() string { primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", quotedCol, quotedCol)) } - }) + } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") - insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2) + insertValuesSQLArray := make([]string, 0, columnCount+2) for _, quotedCol := range quotedColumnNames { insertValuesSQLArray = append(insertValuesSQLArray, fmt.Sprintf("src.%s", quotedCol)) } diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 4c3b012243..450303da42 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -9,7 +9,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -121,14 +120,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { PrimaryKeyColumns: []string{"id"}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "id" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -171,14 +170,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { PrimaryKeyColumns: []string{"id"}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "id" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -212,14 +211,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { PrimaryKeyColumns: []string{" "}, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != " " { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, diff --git a/flow/connectors/snowflake/merge_stmt_generator.go b/flow/connectors/snowflake/merge_stmt_generator.go index 291b3314d9..d849fe5a58 100644 --- a/flow/connectors/snowflake/merge_stmt_generator.go +++ b/flow/connectors/snowflake/merge_stmt_generator.go @@ -27,14 +27,15 @@ type mergeStmtGenerator struct { func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { parsedDstTable, _ := utils.ParseSchemaTable(m.dstTableName) - columnNames := utils.TableSchemaColumnNames(m.normalizedTableSchema) + columnNames := m.normalizedTableSchema.ColumnNames - flattenedCastsSQLArray := make([]string, 0, utils.TableSchemaColumns(m.normalizedTableSchema)) - err := utils.IterColumnsError(m.normalizedTableSchema, func(columnName, genericColumnType string) error { + flattenedCastsSQLArray := make([]string, 0, len(columnNames)) + for i, columnName := range columnNames { + genericColumnType := m.normalizedTableSchema.ColumnTypes[i] qvKind := qvalue.QValueKind(genericColumnType) sfType, err := qValueKindToSnowflakeType(qvKind) if err != nil { - return fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) + return "", fmt.Errorf("failed to convert column type %s to snowflake type: %w", genericColumnType, err) } targetColumnName := SnowflakeIdentifierNormalize(columnName) @@ -69,10 +70,6 @@ func (m *mergeStmtGenerator) generateMergeStmt() (string, error) { toVariantColumnName, columnName, sfType, targetColumnName)) } } - return nil - }) - if err != nil { - return "", err } flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 0507ac954d..763de043dd 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -834,17 +834,18 @@ func generateCreateTableSQLForNormalizedTable( softDeleteColName string, syncedAtColName string, ) string { - createTableSQLArray := make([]string, 0, utils.TableSchemaColumns(sourceTableSchema)+2) - utils.IterColumns(sourceTableSchema, func(columnName, genericColumnType string) { + createTableSQLArray := make([]string, 0, len(sourceTableSchema.ColumnNames)+2) + for i, columnName := range sourceTableSchema.ColumnNames { + genericColumnType := sourceTableSchema.ColumnTypes[i] normalizedColName := SnowflakeIdentifierNormalize(columnName) sfColType, err := qValueKindToSnowflakeType(qvalue.QValueKind(genericColumnType)) if err != nil { slog.Warn(fmt.Sprintf("failed to convert column type %s to snowflake type", genericColumnType), slog.Any("error", err)) - return + continue } createTableSQLArray = append(createTableSQLArray, fmt.Sprintf(`%s %s`, normalizedColName, sfColType)) - }) + } // add a _peerdb_is_deleted column to the normalized table // this is boolean default false, and is used to mark records as deleted @@ -1000,7 +1001,7 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto for _, renameRequest := range req.RenameTableOptions { src := renameRequest.CurrentName dst := renameRequest.NewName - allCols := strings.Join(utils.TableSchemaColumnNames(renameRequest.TableSchema), ",") + allCols := strings.Join(renameRequest.TableSchema.ColumnNames, ",") pkeyCols := strings.Join(renameRequest.TableSchema.PrimaryKeyColumns, ",") c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst)) diff --git a/flow/connectors/utils/columns.go b/flow/connectors/utils/columns.go deleted file mode 100644 index f1e0340f03..0000000000 --- a/flow/connectors/utils/columns.go +++ /dev/null @@ -1,31 +0,0 @@ -package utils - -import ( - "slices" - - "github.com/PeerDB-io/peer-flow/generated/protos" -) - -func TableSchemaColumns(schema *protos.TableSchema) int { - return len(schema.ColumnNames) -} - -func TableSchemaColumnNames(schema *protos.TableSchema) []string { - return slices.Clone(schema.ColumnNames) -} - -func IterColumns(schema *protos.TableSchema, iter func(k, v string)) { - for i, name := range schema.ColumnNames { - iter(name, schema.ColumnTypes[i]) - } -} - -func IterColumnsError(schema *protos.TableSchema, iter func(k, v string) error) error { - for i, name := range schema.ColumnNames { - err := iter(name, schema.ColumnTypes[i]) - if err != nil { - return err - } - } - return nil -} diff --git a/flow/e2e/snowflake/snowflake_schema_delta_test.go b/flow/e2e/snowflake/snowflake_schema_delta_test.go index 52f02b005e..f83d1ac679 100644 --- a/flow/e2e/snowflake/snowflake_schema_delta_test.go +++ b/flow/e2e/snowflake/snowflake_schema_delta_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/e2eshared" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -99,14 +98,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() { }, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "ID" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -154,14 +153,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() { }, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != "ID" { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, @@ -193,14 +192,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() { }, } addedColumns := make([]*protos.DeltaAddedColumn, 0) - utils.IterColumns(expectedTableSchema, func(columnName, columnType string) { + for i, columnName := range expectedTableSchema.ColumnNames { if columnName != " " { addedColumns = append(addedColumns, &protos.DeltaAddedColumn{ ColumnName: columnName, - ColumnType: columnType, + ColumnType: expectedTableSchema.ColumnTypes[i], }) } - }) + } err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 7bee648f4e..54d5c95a99 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -11,7 +11,6 @@ import ( "golang.org/x/exp/maps" "github.com/PeerDB-io/peer-flow/activities" - "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" ) @@ -207,15 +206,16 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( for _, mapping := range flowConnectionConfigs.TableMappings { if mapping.SourceTableIdentifier == srcTableName { if len(mapping.Exclude) != 0 { - columnCount := utils.TableSchemaColumns(tableSchema) + columnCount := len(tableSchema.ColumnNames) columnNames := make([]string, 0, columnCount) columnTypes := make([]string, 0, columnCount) - utils.IterColumns(tableSchema, func(columnName, columnType string) { + for i, columnName := range tableSchema.ColumnNames { + columnType := tableSchema.ColumnTypes[i] if !slices.Contains(mapping.Exclude, columnName) { columnNames = append(columnNames, columnName) columnTypes = append(columnTypes, columnType) } - }) + } tableSchema = &protos.TableSchema{ TableIdentifier: tableSchema.TableIdentifier, PrimaryKeyColumns: tableSchema.PrimaryKeyColumns,