diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 9a86549097..a4a8b21205 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -120,7 +120,6 @@ func (a *FlowableActivity) CreateRawTable( ctx context.Context, config *protos.CreateRawTableInput, ) (*protos.CreateRawTableOutput, error) { - ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogPool) dstConn, err := connectors.GetCDCSyncConnector(ctx, config.PeerConnectionConfig) if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) @@ -215,7 +214,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, input *protos.StartFlowInput) (*model.SyncResponse, error) { activity.RecordHeartbeat(ctx, "starting flow...") conn := input.FlowConnectionConfigs - ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogPool) dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination) if err != nil { return nil, fmt.Errorf("failed to get destination connector: %w", err) @@ -253,7 +251,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, // start a goroutine to pull records from the source errGroup.Go(func() error { - return srcConn.PullRecords(&model.PullRecordsRequest{ + return srcConn.PullRecords(a.CatalogPool, &model.PullRecordsRequest{ FlowJobName: input.FlowConnectionConfigs.FlowJobName, SrcTableIDNameMapping: input.FlowConnectionConfigs.SrcTableIdNameMapping, TableNameMapping: tblNameMapping, diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 2845e371d0..f7a518a6a5 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -14,6 +14,7 @@ import ( connsqlserver "github.com/PeerDB-io/peer-flow/connectors/sqlserver" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/jackc/pgx/v5/pgxpool" ) var ErrUnsupportedFunctionality = errors.New("requested connector does not support functionality") @@ -37,7 +38,7 @@ type CDCPullConnector interface { // PullRecords pulls records from the source, and returns a RecordBatch. // This method should be idempotent, and should be able to be called multiple times with the same request. - PullRecords(req *model.PullRecordsRequest) error + PullRecords(catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error // PullFlowCleanup drops both the Postgres publication and replication slot, as a part of DROP MIRROR PullFlowCleanup(jobName string) error diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 5ac489a6df..cf2a92b5b2 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -193,7 +193,7 @@ func (c *PostgresConnector) GetLastOffset(jobName string) (int64, error) { } // PullRecords pulls records from the source. -func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) error { +func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error { defer func() { req.RecordStream.Close() }() @@ -246,16 +246,13 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) error { return err } - catalogPool, ok := c.ctx.Value(shared.CDCMirrorMonitorKey).(*pgxpool.Pool) - if ok { - latestLSN, err := c.getCurrentLSN() - if err != nil { - return fmt.Errorf("failed to get current LSN: %w", err) - } - err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, catalogPool, req.FlowJobName, latestLSN) - if err != nil { - return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err) - } + latestLSN, err := c.getCurrentLSN() + if err != nil { + return fmt.Errorf("failed to get current LSN: %w", err) + } + err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, catalogPool, req.FlowJobName, latestLSN) + if err != nil { + return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err) } return nil diff --git a/flow/shared/constants.go b/flow/shared/constants.go index e49de60189..8379b6718f 100644 --- a/flow/shared/constants.go +++ b/flow/shared/constants.go @@ -23,10 +23,9 @@ const ( ShutdownSignal PauseSignal - CDCMirrorMonitorKey ContextKey = "cdcMirrorMonitor" - FlowNameKey ContextKey = "flowName" - PartitionIDKey ContextKey = "partitionId" - DeploymentUIDKey ContextKey = "deploymentUid" + FlowNameKey ContextKey = "flowName" + PartitionIDKey ContextKey = "partitionId" + DeploymentUIDKey ContextKey = "deploymentUid" ) type TaskQueueID int64