diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index eb0b3e7619..2813eb7e59 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -125,9 +125,11 @@ func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTabl return ReplicaIdentityType(replicaIdentity), nil } -// getPrimaryKeyColumns returns the primary key columns for a given table. -// Errors if there is no primary key column or if there is more than one primary key column. -func (c *PostgresConnector) getPrimaryKeyColumns( +// getUniqueColumns returns the unique columns (used to select in MERGE statement) for a given table. +// For replica identity 'd'/default, these are the primary key columns +// For replica identity 'i'/index, these are the columns in the selected index (indisreplident set) +// For replica identity 'f'/full, if there is a primary key we use that +func (c *PostgresConnector) getUniqueColumns( replicaIdentity ReplicaIdentityType, schemaTable *utils.SchemaTable, ) ([]string, error) { @@ -146,6 +148,10 @@ func (c *PostgresConnector) getPrimaryKeyColumns( `SELECT indexrelid FROM pg_index WHERE indrelid = $1 AND indisprimary`, relID).Scan(&pkIndexOID) if err != nil { + // don't error out if no pkey columns, this would happen in EnsurePullability or UI. + if err == pgx.ErrNoRows { + return []string{}, nil + } return nil, fmt.Errorf("error finding primary key index for table %s: %w", schemaTable, err) } @@ -158,7 +164,7 @@ func (c *PostgresConnector) getReplicaIdentityIndexColumns(relID uint32, schemaT // Fetch the OID of the index used as the replica identity err := c.pool.QueryRow(c.ctx, `SELECT indexrelid FROM pg_index - WHERE indrelid = $1 AND indisreplident = true`, + WHERE indrelid=$1 AND indisreplident=true`, relID).Scan(&indexRelID) if err != nil { return nil, fmt.Errorf("error finding replica identity index for table %s: %w", schemaTable, err) diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 7fdd6b032d..b7811c8208 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -579,7 +579,7 @@ func (c *PostgresConnector) GetTableSchema( ) (*protos.GetTableSchemaBatchOutput, error) { res := make(map[string]*protos.TableSchema) for _, tableName := range req.TableIdentifiers { - tableSchema, err := c.getTableSchemaForTable(tableName, req.SkipPkeyAndReplicaCheck) + tableSchema, err := c.getTableSchemaForTable(tableName) if err != nil { return nil, err } @@ -595,27 +595,19 @@ func (c *PostgresConnector) GetTableSchema( func (c *PostgresConnector) getTableSchemaForTable( tableName string, - skipPkeyAndReplicaCheck bool, ) (*protos.TableSchema, error) { schemaTable, err := utils.ParseSchemaTable(tableName) if err != nil { return nil, err } - var pKeyCols []string - var replicaIdentityType ReplicaIdentityType - if !skipPkeyAndReplicaCheck { - var replErr error - replicaIdentityType, replErr = c.getReplicaIdentityType(schemaTable) - if replErr != nil { - return nil, fmt.Errorf("[getTableSchema]:error getting replica identity for table %s: %w", schemaTable, replErr) - } - - var err error - pKeyCols, err = c.getPrimaryKeyColumns(replicaIdentityType, schemaTable) - if err != nil { - return nil, fmt.Errorf("[getTableSchema]:error getting primary key column for table %s: %w", schemaTable, err) - } + replicaIdentityType, err := c.getReplicaIdentityType(schemaTable) + if err != nil { + return nil, fmt.Errorf("[getTableSchema] error getting replica identity for table %s: %w", schemaTable, err) + } + pKeyCols, err := c.getUniqueColumns(replicaIdentityType, schemaTable) + if err != nil { + return nil, fmt.Errorf("[getTableSchema] error getting primary key column for table %s: %w", schemaTable, err) } // Get the column names and types @@ -798,7 +790,7 @@ func (c *PostgresConnector) EnsurePullability( return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr) } - pKeyCols, err := c.getPrimaryKeyColumns(replicaIdentity, schemaTable) + pKeyCols, err := c.getUniqueColumns(replicaIdentity, schemaTable) if err != nil { return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) } diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index 1ae1518e21..e2ddadd2ae 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -109,10 +109,9 @@ func (q *QRepFlowExecution) SetupWatermarkTableOnDestination(ctx workflow.Contex }) tableSchemaInput := &protos.GetTableSchemaBatchInput{ - PeerConnectionConfig: q.config.SourcePeer, - TableIdentifiers: []string{q.config.WatermarkTable}, - FlowName: q.config.FlowJobName, - SkipPkeyAndReplicaCheck: true, + PeerConnectionConfig: q.config.SourcePeer, + TableIdentifiers: []string{q.config.WatermarkTable}, + FlowName: q.config.FlowJobName, } future := workflow.ExecuteActivity(ctx, flowable.GetTableSchema, tableSchemaInput) diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 8959c07edf..7bee648f4e 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -182,10 +182,9 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( sort.Strings(sourceTables) tableSchemaInput := &protos.GetTableSchemaBatchInput{ - PeerConnectionConfig: flowConnectionConfigs.Source, - TableIdentifiers: sourceTables, - FlowName: s.cdcFlowName, - SkipPkeyAndReplicaCheck: flowConnectionConfigs.InitialSnapshotOnly, + PeerConnectionConfig: flowConnectionConfigs.Source, + TableIdentifiers: sourceTables, + FlowName: s.cdcFlowName, } future := workflow.ExecuteActivity(ctx, flowable.GetTableSchema, tableSchemaInput) diff --git a/protos/flow.proto b/protos/flow.proto index cef3f5c7e9..46451674a4 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -183,7 +183,6 @@ message GetTableSchemaBatchInput { peerdb_peers.Peer peer_connection_config = 1; repeated string table_identifiers = 2; string flow_name = 3; - bool skip_pkey_and_replica_check = 4; } message GetTableSchemaBatchOutput {