diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index fa5e156c48..bd3f9fc40e 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -12,6 +12,7 @@ import ( "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -32,6 +33,9 @@ type PostgresCDCSource struct { startLSN pglogrepl.LSN commitLock bool customTypeMapping map[uint32]string + + // for partitioned tables, maps child relid to parent relid + chIdToParRelId map[uint32]uint32 } type PostgresCDCConfig struct { @@ -46,6 +50,11 @@ type PostgresCDCConfig struct { // Create a new PostgresCDCSource func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32]string) (*PostgresCDCSource, error) { + childToParentRelIdMap, err := getChildToParentRelIdMap(cdcConfig.AppContext, cdcConfig.Connection) + if err != nil { + return nil, fmt.Errorf("error getting child to parent relid map: %w", err) + } + return &PostgresCDCSource{ ctx: cdcConfig.AppContext, replPool: cdcConfig.Connection, @@ -55,11 +64,44 @@ func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32 publication: cdcConfig.Publication, relationMessageMapping: cdcConfig.RelationMessageMapping, typeMap: pgtype.NewMap(), + chIdToParRelId: childToParentRelIdMap, commitLock: false, customTypeMapping: customTypeMap, }, nil } +func getChildToParentRelIdMap(ctx context.Context, pool *pgxpool.Pool) (map[uint32]uint32, error) { + query := ` + SELECT + parent.oid AS parentrelid, + child.oid AS childrelid + FROM pg_inherits + JOIN pg_class parent ON pg_inherits.inhparent = parent.oid + JOIN pg_class child ON pg_inherits.inhrelid = child.oid + WHERE parent.relkind='p'; + ` + + rows, err := pool.Query(ctx, query, pgx.QueryExecModeSimpleProtocol) + if err != nil { + return nil, fmt.Errorf("error querying for child to parent relid map: %w", err) + } + + defer rows.Close() + + childToParentRelIdMap := make(map[uint32]uint32) + for rows.Next() { + var parentRelId uint32 + var childRelId uint32 + err := rows.Scan(&parentRelId, &childRelId) + if err != nil { + return nil, fmt.Errorf("error scanning child to parent relid map: %w", err) + } + childToParentRelIdMap[childRelId] = parentRelId + } + + return childToParentRelIdMap, nil +} + // PullRecords pulls records from the cdc stream func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) ( *model.RecordsWithTableSchemaDelta, error) { @@ -348,6 +390,9 @@ func (p *PostgresCDCSource) processMessage(batch *model.RecordBatch, xld pglogre batch.LastCheckPointID = int64(xld.WALStart) p.commitLock = false case *pglogrepl.RelationMessage: + // treat all relation messages as correponding to parent if partitioned. + msg.RelationID = p.getParentRelIdIfPartitioned(msg.RelationID) + // TODO (kaushik): consider persistent state for a mirror job // to be stored somewhere in temporal state. We might need to persist // the state of the relation message somewhere @@ -373,17 +418,19 @@ func (p *PostgresCDCSource) processInsertMessage( lsn pglogrepl.LSN, msg *pglogrepl.InsertMessage, ) (model.Record, error) { - tableName, exists := p.SrcTableIDNameMapping[msg.RelationID] + relId := p.getParentRelIdIfPartitioned(msg.RelationID) + + tableName, exists := p.SrcTableIDNameMapping[relId] if !exists { return nil, nil } // log lsn and relation id for debugging - log.Debugf("InsertMessage => LSN: %d, RelationID: %d, Relation Name: %s", lsn, msg.RelationID, tableName) + log.Debugf("InsertMessage => LSN: %d, RelationID: %d, Relation Name: %s", lsn, relId, tableName) - rel, ok := p.relationMessageMapping[msg.RelationID] + rel, ok := p.relationMessageMapping[relId] if !ok { - return nil, fmt.Errorf("unknown relation id: %d", msg.RelationID) + return nil, fmt.Errorf("unknown relation id: %d", relId) } // create empty map of string to interface{} @@ -405,17 +452,19 @@ func (p *PostgresCDCSource) processUpdateMessage( lsn pglogrepl.LSN, msg *pglogrepl.UpdateMessage, ) (model.Record, error) { - tableName, exists := p.SrcTableIDNameMapping[msg.RelationID] + relID := p.getParentRelIdIfPartitioned(msg.RelationID) + + tableName, exists := p.SrcTableIDNameMapping[relID] if !exists { return nil, nil } // log lsn and relation id for debugging - log.Debugf("UpdateMessage => LSN: %d, RelationID: %d, Relation Name: %s", lsn, msg.RelationID, tableName) + log.Debugf("UpdateMessage => LSN: %d, RelationID: %d, Relation Name: %s", lsn, relID, tableName) - rel, ok := p.relationMessageMapping[msg.RelationID] + rel, ok := p.relationMessageMapping[relID] if !ok { - return nil, fmt.Errorf("unknown relation id: %d", msg.RelationID) + return nil, fmt.Errorf("unknown relation id: %d", relID) } // create empty map of string to interface{} @@ -444,17 +493,19 @@ func (p *PostgresCDCSource) processDeleteMessage( lsn pglogrepl.LSN, msg *pglogrepl.DeleteMessage, ) (model.Record, error) { - tableName, exists := p.SrcTableIDNameMapping[msg.RelationID] + relID := p.getParentRelIdIfPartitioned(msg.RelationID) + + tableName, exists := p.SrcTableIDNameMapping[relID] if !exists { return nil, nil } // log lsn and relation id for debugging - log.Debugf("DeleteMessage => LSN: %d, RelationID: %d, Relation Name: %s", lsn, msg.RelationID, tableName) + log.Debugf("DeleteMessage => LSN: %d, RelationID: %d, Relation Name: %s", lsn, relID, tableName) - rel, ok := p.relationMessageMapping[msg.RelationID] + rel, ok := p.relationMessageMapping[relID] if !ok { - return nil, fmt.Errorf("unknown relation id: %d", msg.RelationID) + return nil, fmt.Errorf("unknown relation id: %d", relID) } // create empty map of string to interface{} @@ -668,3 +719,12 @@ func (p *PostgresCDCSource) compositePKeyToString(req *model.PullRecordsRequest, hasher.Write(pkeyColsMerged) return fmt.Sprintf("%x", hasher.Sum(nil)), nil } + +func (p *PostgresCDCSource) getParentRelIdIfPartitioned(relId uint32) uint32 { + parentRelId, ok := p.chIdToParRelId[relId] + if ok { + return parentRelId + } + + return relId +}