diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 4683adb072..acf3ba7a40 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -739,10 +739,13 @@ func (p *PostgresCDCSource) processRelationMessage( for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. if _, ok := prevRelMap[column.Name]; !ok { - schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{ - ColumnName: column.Name, - ColumnType: string(currRelMap[column.Name]), - }) + // only add to delta if not excluded + if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok { + schemaDelta.AddedColumns = append(schemaDelta.AddedColumns, &protos.DeltaAddedColumn{ + ColumnName: column.Name, + ColumnType: string(currRelMap[column.Name]), + }) + } // present in previous and current relation messages, but data types have changed. // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. } else if prevRelMap[column.Name] != currRelMap[column.Name] { diff --git a/flow/model/cdc_record_stream.go b/flow/model/cdc_record_stream.go index dcdadfbb67..0e2e633d4c 100644 --- a/flow/model/cdc_record_stream.go +++ b/flow/model/cdc_record_stream.go @@ -76,22 +76,8 @@ func (r *CDCRecordStream) GetRecords() <-chan Record { return r.records } -func (r *CDCRecordStream) AddSchemaDelta(tableNameMapping map[string]NameAndExclude, delta *protos.TableSchemaDelta) { - if tm, ok := tableNameMapping[delta.SrcTableName]; ok && len(tm.Exclude) != 0 { - added := make([]*protos.DeltaAddedColumn, 0, len(delta.AddedColumns)) - for _, column := range delta.AddedColumns { - if _, has := tm.Exclude[column.ColumnName]; !has { - added = append(added, column) - } - } - if len(added) != 0 { - r.SchemaDeltas = append(r.SchemaDeltas, &protos.TableSchemaDelta{ - SrcTableName: delta.SrcTableName, - DstTableName: delta.DstTableName, - AddedColumns: added, - }) - } - } else { - r.SchemaDeltas = append(r.SchemaDeltas, delta) - } +func (r *CDCRecordStream) AddSchemaDelta(tableNameMapping map[string]NameAndExclude, + delta *protos.TableSchemaDelta, +) { + r.SchemaDeltas = append(r.SchemaDeltas, delta) } diff --git a/flow/shared/additional_tables.go b/flow/shared/additional_tables.go deleted file mode 100644 index 0eb0b79f35..0000000000 --- a/flow/shared/additional_tables.go +++ /dev/null @@ -1,26 +0,0 @@ -package shared - -import ( - "github.com/PeerDB-io/peer-flow/generated/protos" -) - -func AdditionalTablesHasOverlap(currentTableMappings []*protos.TableMapping, - additionalTableMappings []*protos.TableMapping, -) bool { - currentSrcTables := make([]string, 0, len(currentTableMappings)) - currentDstTables := make([]string, 0, len(currentTableMappings)) - additionalSrcTables := make([]string, 0, len(additionalTableMappings)) - additionalDstTables := make([]string, 0, len(additionalTableMappings)) - - for _, currentTableMapping := range currentTableMappings { - currentSrcTables = append(currentSrcTables, currentTableMapping.SourceTableIdentifier) - currentDstTables = append(currentDstTables, currentTableMapping.DestinationTableIdentifier) - } - for _, additionalTableMapping := range additionalTableMappings { - additionalSrcTables = append(additionalSrcTables, additionalTableMapping.SourceTableIdentifier) - additionalDstTables = append(additionalDstTables, additionalTableMapping.DestinationTableIdentifier) - } - - return ArraysHaveOverlap(currentSrcTables, additionalSrcTables) || - ArraysHaveOverlap(currentDstTables, additionalDstTables) -} diff --git a/flow/shared/schema_helpers.go b/flow/shared/schema_helpers.go new file mode 100644 index 0000000000..2c92195e6f --- /dev/null +++ b/flow/shared/schema_helpers.go @@ -0,0 +1,76 @@ +package shared + +import ( + "log/slog" + "slices" + + "go.temporal.io/sdk/log" + "golang.org/x/exp/maps" + + "github.com/PeerDB-io/peer-flow/generated/protos" +) + +func AdditionalTablesHasOverlap(currentTableMappings []*protos.TableMapping, + additionalTableMappings []*protos.TableMapping, +) bool { + currentSrcTables := make([]string, 0, len(currentTableMappings)) + currentDstTables := make([]string, 0, len(currentTableMappings)) + additionalSrcTables := make([]string, 0, len(additionalTableMappings)) + additionalDstTables := make([]string, 0, len(additionalTableMappings)) + + for _, currentTableMapping := range currentTableMappings { + currentSrcTables = append(currentSrcTables, currentTableMapping.SourceTableIdentifier) + currentDstTables = append(currentDstTables, currentTableMapping.DestinationTableIdentifier) + } + for _, additionalTableMapping := range additionalTableMappings { + additionalSrcTables = append(additionalSrcTables, additionalTableMapping.SourceTableIdentifier) + additionalDstTables = append(additionalDstTables, additionalTableMapping.DestinationTableIdentifier) + } + + return ArraysHaveOverlap(currentSrcTables, additionalSrcTables) || + ArraysHaveOverlap(currentDstTables, additionalDstTables) +} + +// given the output of GetTableSchema, processes it to be used by CDCFlow +// 1) changes the map key to be the destination table name instead of the source table name +// 2) performs column exclusion using protos.TableMapping as input. +func BuildProcessedSchemaMapping(tableMappings []*protos.TableMapping, + tableNameSchemaMapping map[string]*protos.TableSchema, + logger log.Logger, +) map[string]*protos.TableSchema { + processedSchemaMapping := make(map[string]*protos.TableSchema) + sortedSourceTables := maps.Keys(tableNameSchemaMapping) + slices.Sort(sortedSourceTables) + + for _, srcTableName := range sortedSourceTables { + tableSchema := tableNameSchemaMapping[srcTableName] + var dstTableName string + for _, mapping := range tableMappings { + if mapping.SourceTableIdentifier == srcTableName { + dstTableName = mapping.DestinationTableIdentifier + if len(mapping.Exclude) != 0 { + columnCount := len(tableSchema.Columns) + columns := make([]*protos.FieldDescription, 0, columnCount) + for _, column := range tableSchema.Columns { + if !slices.Contains(mapping.Exclude, column.Name) { + columns = append(columns, column) + } + } + tableSchema = &protos.TableSchema{ + TableIdentifier: tableSchema.TableIdentifier, + PrimaryKeyColumns: tableSchema.PrimaryKeyColumns, + IsReplicaIdentityFull: tableSchema.IsReplicaIdentityFull, + Columns: columns, + } + } + break + } + } + processedSchemaMapping[dstTableName] = tableSchema + + logger.Info("normalized table schema", + slog.String("table", dstTableName), + slog.Any("schema", tableSchema)) + } + return processedSchemaMapping +} diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 0574f0d24e..4355bd832c 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -3,7 +3,6 @@ package peerflow import ( "fmt" "log/slog" - "slices" "sort" "time" @@ -201,34 +200,8 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( sort.Strings(sortedSourceTables) s.logger.Info("setting up normalized tables for peer flow") - normalizedTableMapping := make(map[string]*protos.TableSchema) - for _, srcTableName := range sortedSourceTables { - tableSchema := tableNameSchemaMapping[srcTableName] - normalizedTableName := s.tableNameMapping[srcTableName] - for _, mapping := range flowConnectionConfigs.TableMappings { - if mapping.SourceTableIdentifier == srcTableName { - if len(mapping.Exclude) != 0 { - columnCount := len(tableSchema.Columns) - columns := make([]*protos.FieldDescription, 0, columnCount) - for _, column := range tableSchema.Columns { - if !slices.Contains(mapping.Exclude, column.Name) { - columns = append(columns, column) - } - } - tableSchema = &protos.TableSchema{ - TableIdentifier: tableSchema.TableIdentifier, - PrimaryKeyColumns: tableSchema.PrimaryKeyColumns, - IsReplicaIdentityFull: tableSchema.IsReplicaIdentityFull, - Columns: columns, - } - } - break - } - } - normalizedTableMapping[normalizedTableName] = tableSchema - - s.logger.Info("normalized table schema", slog.String("table", normalizedTableName), slog.Any("schema", tableSchema)) - } + normalizedTableMapping := shared.BuildProcessedSchemaMapping(flowConnectionConfigs.TableMappings, + tableNameSchemaMapping, s.logger) // now setup the normalized tables on the destination peer setupConfig := &protos.SetupNormalizedTableBatchInput{ diff --git a/flow/workflows/sync_flow.go b/flow/workflows/sync_flow.go index 9958c1c79c..890da2e0fa 100644 --- a/flow/workflows/sync_flow.go +++ b/flow/workflows/sync_flow.go @@ -6,6 +6,7 @@ import ( "go.temporal.io/sdk/log" "go.temporal.io/sdk/workflow" + "golang.org/x/exp/maps" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -139,12 +140,10 @@ func SyncFlowWorkflow( tableSchemaDeltasCount := len(childSyncFlowRes.TableSchemaDeltas) // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. - if tableSchemaDeltasCount != 0 { + if tableSchemaDeltasCount > 0 { modifiedSrcTables := make([]string, 0, tableSchemaDeltasCount) - modifiedDstTables := make([]string, 0, tableSchemaDeltasCount) for _, tableSchemaDelta := range childSyncFlowRes.TableSchemaDeltas { modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) - modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) } getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ @@ -167,10 +166,9 @@ func SyncFlowWorkflow( nil, ).Get(ctx, nil) } else { - for i, srcTable := range modifiedSrcTables { - dstTable := modifiedDstTables[i] - options.TableNameSchemaMapping[dstTable] = getModifiedSchemaRes.TableNameSchemaMapping[srcTable] - } + processedSchemaMapping := shared.BuildProcessedSchemaMapping(options.TableMappings, + getModifiedSchemaRes.TableNameSchemaMapping, logger) + maps.Copy(options.TableNameSchemaMapping, processedSchemaMapping) } }