diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index b34ccf6f48..2ffdc816d3 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -10,7 +10,6 @@ import ( "sync/atomic" "time" - "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" @@ -352,11 +351,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, logger.Info(fmt.Sprintf("pushed %d records in %d seconds", numRecords, int(syncDuration.Seconds()))) - lastCheckpoint, err := recordBatch.GetLastCheckpoint() - if err != nil { - a.Alerter.LogFlowError(ctx, flowName, err) - return nil, fmt.Errorf("failed to get last checkpoint: %w", err) - } + lastCheckpoint := recordBatch.GetLastCheckpoint() err = monitoring.UpdateNumRowsAndEndLSNForCDCBatch( ctx, @@ -364,7 +359,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, input.FlowConnectionConfigs.FlowJobName, res.CurrentSyncBatchID, uint32(numRecords), - pglogrepl.LSN(lastCheckpoint), + lastCheckpoint, ) if err != nil { a.Alerter.LogFlowError(ctx, flowName, err) @@ -375,7 +370,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName, - pglogrepl.LSN(lastCheckpoint), + lastCheckpoint, ) if err != nil { a.Alerter.LogFlowError(ctx, flowName, err) diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 60b6400c75..0237f7aed0 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -74,11 +74,6 @@ func (s *QRepAvroSyncMethod) SyncRecords( insertStmt := fmt.Sprintf("INSERT INTO `%s` SELECT * FROM `%s`;", rawTableName, stagingTable) - lastCP, err := req.Records.GetLastCheckpoint() - if err != nil { - return nil, fmt.Errorf("failed to get last checkpoint: %w", err) - } - activity.RecordHeartbeat(ctx, fmt.Sprintf("Flow job %s: performing insert and update transaction"+ " for destination table %s and sync batch ID %d", @@ -98,6 +93,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( return nil, fmt.Errorf("failed to execute statements in a transaction: %w", err) } + lastCP := req.Records.GetLastCheckpoint() err = s.connector.pgMetadata.FinishBatch(ctx, req.FlowJobName, syncBatchID, lastCP) if err != nil { return nil, fmt.Errorf("failed to update metadata: %w", err) diff --git a/flow/connectors/clickhouse/cdc.go b/flow/connectors/clickhouse/cdc.go index 24c9a8f0f5..2dcd6a9824 100644 --- a/flow/connectors/clickhouse/cdc.go +++ b/flow/connectors/clickhouse/cdc.go @@ -107,13 +107,8 @@ func (c *ClickhouseConnector) syncRecordsViaAvro( return nil, fmt.Errorf("failed to sync schema changes: %w", err) } - lastCheckpoint, err := req.Records.GetLastCheckpoint() - if err != nil { - return nil, err - } - return &model.SyncResponse{ - LastSyncedCheckpointID: lastCheckpoint, + LastSyncedCheckpointID: req.Records.GetLastCheckpoint(), NumRecordsSynced: int64(numRecords), CurrentSyncBatchID: syncBatchID, TableNameRowsMapping: tableNameRowsMapping, @@ -130,12 +125,7 @@ func (c *ClickhouseConnector) SyncRecords(ctx context.Context, req *model.SyncRe return nil, err } - lastCheckpoint, err := req.Records.GetLastCheckpoint() - if err != nil { - return nil, fmt.Errorf("failed to get last checkpoint: %w", err) - } - - err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) + err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, res.LastSyncedCheckpointID) if err != nil { c.logger.Error("failed to increment id", slog.Any("error", err)) return nil, err diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 2641dc5ff2..2da33c1043 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -208,12 +208,7 @@ func (c *EventHubConnector) SyncRecords(ctx context.Context, req *model.SyncReco return nil, err } - lastCheckpoint, err := req.Records.GetLastCheckpoint() - if err != nil { - c.logger.Error("failed to get last checkpoint", slog.Any("error", err)) - return nil, err - } - + lastCheckpoint := req.Records.GetLastCheckpoint() err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) if err != nil { c.logger.Error("failed to increment id", slog.Any("error", err)) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index a9cc01787e..5173f76705 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -52,6 +52,7 @@ type PostgresCDCConfig struct { SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]model.NameAndExclude RelationMessageMapping model.RelationMessageMapping + ChildToParentRelIDMap map[uint32]uint32 CatalogPool *pgxpool.Pool FlowJobName string } @@ -63,12 +64,7 @@ type startReplicationOpts struct { } // Create a new PostgresCDCSource -func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, error) { - childToParentRelIDMap, err := getChildToParentRelIDMap(ctx, cdcConfig.Connection) - if err != nil { - return nil, fmt.Errorf("error getting child to parent relid map: %w", err) - } - +func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig) *PostgresCDCSource { return &PostgresCDCSource{ PostgresConnector: c, replConn: cdcConfig.Connection, @@ -78,21 +74,19 @@ func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig publication: cdcConfig.Publication, relationMessageMapping: cdcConfig.RelationMessageMapping, typeMap: pgtype.NewMap(), - childToParentRelIDMapping: childToParentRelIDMap, + childToParentRelIDMapping: cdcConfig.ChildToParentRelIDMap, commitLock: false, catalogPool: cdcConfig.CatalogPool, flowJobName: cdcConfig.FlowJobName, - }, nil + } } func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]uint32, error) { query := ` - SELECT - parent.oid AS parentrelid, - child.oid AS childrelid + SELECT parent.oid AS parentrelid, child.oid AS childrelid FROM pg_inherits - JOIN pg_class parent ON pg_inherits.inhparent = parent.oid - JOIN pg_class child ON pg_inherits.inhrelid = child.oid + JOIN pg_class parent ON pg_inherits.inhparent = parent.oid + JOIN pg_class child ON pg_inherits.inhrelid = child.oid WHERE parent.relkind='p'; ` diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index ca195f4c02..fa2ebdc201 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -219,19 +219,22 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo } defer replConn.Close(ctx) - cdc, err := c.NewPostgresCDCSource(ctx, &PostgresCDCConfig{ + childToParentRelIDMap, err := getChildToParentRelIDMap(ctx, replConn) + if err != nil { + return fmt.Errorf("error getting child to parent relid map: %w", err) + } + + cdc := c.NewPostgresCDCSource(ctx, &PostgresCDCConfig{ Connection: replConn, SrcTableIDNameMapping: req.SrcTableIDNameMapping, Slot: slotName, Publication: publicationName, TableNameMapping: req.TableNameMapping, RelationMessageMapping: req.RelationMessageMapping, + ChildToParentRelIDMap: childToParentRelIDMap, CatalogPool: catalogPool, FlowJobName: req.FlowJobName, }) - if err != nil { - return fmt.Errorf("failed to create cdc source: %w", err) - } err = cdc.PullRecords(ctx, req) if err != nil { @@ -242,7 +245,7 @@ func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpoo if err != nil { return fmt.Errorf("failed to get current LSN: %w", err) } - err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(ctx, catalogPool, req.FlowJobName, latestLSN) + err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(ctx, catalogPool, req.FlowJobName, int64(latestLSN)) if err != nil { return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err) } @@ -373,12 +376,8 @@ func (c *PostgresConnector) SyncRecords(ctx context.Context, req *model.SyncReco c.logger.Info(fmt.Sprintf("synced %d records to Postgres table %s via COPY", syncedRecordsCount, rawTableIdentifier)) - lastCP, err := req.Records.GetLastCheckpoint() - if err != nil { - return nil, fmt.Errorf("error getting last checkpoint: %w", err) - } - // updating metadata with new offset and syncBatchID + lastCP := req.Records.GetLastCheckpoint() err = c.updateSyncMetadata(ctx, req.FlowJobName, lastCP, req.SyncBatchID, syncRecordsTx) if err != nil { return nil, err diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index 5628a3c3ff..dabfc75ef5 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -184,11 +184,7 @@ func (c *S3Connector) SyncRecords(ctx context.Context, req *model.SyncRecordsReq } c.logger.Info(fmt.Sprintf("Synced %d records", numRecords)) - lastCheckpoint, err := req.Records.GetLastCheckpoint() - if err != nil { - return nil, fmt.Errorf("failed to get last checkpoint: %w", err) - } - + lastCheckpoint := req.Records.GetLastCheckpoint() err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) if err != nil { c.logger.Error("failed to increment id", "error", err) diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index cb86bc4389..2f6c0210e5 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -478,13 +478,8 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( return nil, fmt.Errorf("failed to sync schema changes: %w", err) } - lastCheckpoint, err := req.Records.GetLastCheckpoint() - if err != nil { - return nil, err - } - return &model.SyncResponse{ - LastSyncedCheckpointID: lastCheckpoint, + LastSyncedCheckpointID: req.Records.GetLastCheckpoint(), NumRecordsSynced: int64(numRecords), CurrentSyncBatchID: syncBatchID, TableNameRowsMapping: tableNameRowsMapping, diff --git a/flow/connectors/utils/monitoring/monitoring.go b/flow/connectors/utils/monitoring/monitoring.go index 77ff006510..ebb73a0b60 100644 --- a/flow/connectors/utils/monitoring/monitoring.go +++ b/flow/connectors/utils/monitoring/monitoring.go @@ -7,7 +7,6 @@ import ( "strconv" "time" - "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" @@ -21,7 +20,7 @@ import ( type CDCBatchInfo struct { BatchID int64 RowsInBatch uint32 - BatchEndlSN pglogrepl.LSN + BatchEndlSN int64 StartTime time.Time } @@ -36,7 +35,7 @@ func InitializeCDCFlow(ctx context.Context, pool *pgxpool.Pool, flowJobName stri } func UpdateLatestLSNAtSourceForCDCFlow(ctx context.Context, pool *pgxpool.Pool, flowJobName string, - latestLSNAtSource pglogrepl.LSN, + latestLSNAtSource int64, ) error { _, err := pool.Exec(ctx, "UPDATE peerdb_stats.cdc_flows SET latest_lsn_at_source=$1 WHERE flow_name=$2", @@ -48,7 +47,7 @@ func UpdateLatestLSNAtSourceForCDCFlow(ctx context.Context, pool *pgxpool.Pool, } func UpdateLatestLSNAtTargetForCDCFlow(ctx context.Context, pool *pgxpool.Pool, flowJobName string, - latestLSNAtTarget pglogrepl.LSN, + latestLSNAtTarget int64, ) error { _, err := pool.Exec(ctx, "UPDATE peerdb_stats.cdc_flows SET latest_lsn_at_target=$1 WHERE flow_name=$2", @@ -80,7 +79,7 @@ func UpdateNumRowsAndEndLSNForCDCBatch( flowJobName string, batchID int64, numRows uint32, - batchEndLSN pglogrepl.LSN, + batchEndLSN int64, ) error { _, err := pool.Exec(ctx, "UPDATE peerdb_stats.cdc_batches SET rows_in_batch=$1,batch_end_lsn=$2 WHERE flow_name=$3 AND batch_id=$4", diff --git a/flow/model/cdc_record_stream.go b/flow/model/cdc_record_stream.go index 29833112a9..dcdadfbb67 100644 --- a/flow/model/cdc_record_stream.go +++ b/flow/model/cdc_record_stream.go @@ -1,7 +1,6 @@ package model import ( - "errors" "sync/atomic" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -41,11 +40,11 @@ func (r *CDCRecordStream) UpdateLatestCheckpoint(val int64) { } } -func (r *CDCRecordStream) GetLastCheckpoint() (int64, error) { +func (r *CDCRecordStream) GetLastCheckpoint() int64 { if !r.lastCheckpointSet { - return 0, errors.New("last checkpoint not set, stream is still active") + panic("last checkpoint not set, stream is still active") } - return r.lastCheckpointID.Load(), nil + return r.lastCheckpointID.Load() } func (r *CDCRecordStream) AddRecord(record Record) { @@ -66,9 +65,11 @@ func (r *CDCRecordStream) WaitAndCheckEmpty() bool { } func (r *CDCRecordStream) Close() { - close(r.emptySignal) - close(r.records) - r.lastCheckpointSet = true + if !r.lastCheckpointSet { + close(r.emptySignal) + close(r.records) + r.lastCheckpointSet = true + } } func (r *CDCRecordStream) GetRecords() <-chan Record {