From 1213f4419e6835a2bdfa168d79d15e093891f86b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Sat, 10 Feb 2024 15:47:47 +0000 Subject: [PATCH] add lint: containedctx (#1240) Required updating connector interfaces. Bit annoying `ctx` everywhere, but that's ultimately the correct way. Was running into context complications in #1211 with connector being shared between activities Putting context in struct essentially makes that struct a context, but this is not the context we necessarily want. For more context, see https://zenhorace.dev/blog/context-control-go Some changes were made: 1. GetCatalog takes a context now instead of using `context.Background()` 2. eventhubs processBatch now takes context instead of using `context.Background()` 3. many instances of `Query`/`Exec` in snowflake/clickhouse converted to `QueryContext`/`ExecContext` 4. got rid of cancel context in ssh tunnel, context being passed in is sufficient Followup to #1238 --- flow/.golangci.yml | 1 + flow/activities/flowable.go | 116 +++++------ flow/activities/snapshot_activity.go | 8 +- flow/cmd/api.go | 2 +- flow/cmd/peer_data.go | 4 +- flow/cmd/snapshot_worker.go | 3 +- flow/cmd/validate_mirror.go | 6 +- flow/cmd/validate_peer.go | 6 +- flow/cmd/worker.go | 3 +- flow/connectors/bigquery/bigquery.go | 123 ++++++------ flow/connectors/bigquery/qrep.go | 23 ++- flow/connectors/bigquery/qrep_avro_sync.go | 4 +- flow/connectors/clickhouse/cdc.go | 44 ++--- flow/connectors/clickhouse/clickhouse.go | 10 +- flow/connectors/clickhouse/client.go | 7 +- flow/connectors/clickhouse/normalize.go | 22 ++- flow/connectors/clickhouse/qrep.go | 32 +-- flow/connectors/clickhouse/qrep_avro_sync.go | 13 +- flow/connectors/core.go | 54 ++--- flow/connectors/eventhub/eventhub.go | 58 +++--- flow/connectors/external_metadata/store.go | 7 +- flow/connectors/postgres/cdc.go | 56 +++--- flow/connectors/postgres/client.go | 125 ++++++------ flow/connectors/postgres/postgres.go | 184 +++++++++--------- .../postgres/postgres_schema_delta_test.go | 22 +-- flow/connectors/postgres/qrep.go | 98 ++++++---- flow/connectors/postgres/qrep_bench_test.go | 2 +- .../postgres/qrep_partition_test.go | 3 +- .../postgres/qrep_query_executor.go | 61 +++--- .../postgres/qrep_query_executor_test.go | 4 +- flow/connectors/postgres/ssh_wrapped_pool.go | 13 +- flow/connectors/s3/qrep.go | 9 +- flow/connectors/s3/s3.go | 44 ++--- flow/connectors/snowflake/client.go | 12 +- .../snowflake/get_schema_for_tests.go | 10 +- flow/connectors/snowflake/qrep.go | 44 +++-- .../snowflake/qrep_avro_consolidate.go | 4 +- flow/connectors/snowflake/qrep_avro_sync.go | 2 +- flow/connectors/snowflake/snowflake.go | 133 +++++++------ flow/connectors/sql/query_executor.go | 102 +++++----- flow/connectors/sqlserver/qrep.go | 13 +- flow/connectors/sqlserver/sqlserver.go | 14 +- flow/connectors/utils/catalog/env.go | 6 +- flow/dynamicconf/dynamicconf.go | 2 +- flow/e2e/postgres/peer_flow_pg_test.go | 2 +- flow/e2e/postgres/qrep_flow_pg_test.go | 2 +- flow/e2e/snowflake/peer_flow_sf_test.go | 10 +- flow/e2e/snowflake/snowflake_helper.go | 20 +- .../snowflake/snowflake_schema_delta_test.go | 18 +- .../e2e/sqlserver/qrep_flow_sqlserver_test.go | 1 + flow/e2e/sqlserver/sqlserver_helper.go | 10 +- flow/e2e/test_utils.go | 1 + flow/model/model.go | 2 - 53 files changed, 818 insertions(+), 757 deletions(-) diff --git a/flow/.golangci.yml b/flow/.golangci.yml index a34a27b73a..50582063f1 100644 --- a/flow/.golangci.yml +++ b/flow/.golangci.yml @@ -3,6 +3,7 @@ run: - generated linters: enable: + - containedctx - dogsled - durationcheck - errcheck diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 0d8a796ef5..c7a2bd6147 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -32,7 +32,6 @@ import ( // CheckConnectionResult is the result of a CheckConnection call. type CheckConnectionResult struct { - // True of metadata tables need to be set up. NeedsSetupMetadataTables bool } @@ -47,7 +46,6 @@ type FlowableActivity struct { Alerter *alerting.Alerter } -// CheckConnection implements CheckConnection. func (a *FlowableActivity) CheckConnection( ctx context.Context, config *protos.SetupInput, @@ -58,25 +56,24 @@ func (a *FlowableActivity) CheckConnection( a.Alerter.LogFlowError(ctx, config.FlowName, err) return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) - needsSetup := dstConn.NeedsSetupMetadataTables() + needsSetup := dstConn.NeedsSetupMetadataTables(ctx) return &CheckConnectionResult{ NeedsSetupMetadataTables: needsSetup, }, nil } -// SetupMetadataTables implements SetupMetadataTables. func (a *FlowableActivity) SetupMetadataTables(ctx context.Context, config *protos.SetupInput) error { ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) dstConn, err := connectors.GetCDCSyncConnector(ctx, config.Peer) if err != nil { return fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) - if err := dstConn.SetupMetadataTables(); err != nil { + if err := dstConn.SetupMetadataTables(ctx); err != nil { a.Alerter.LogFlowError(ctx, config.FlowName, err) return fmt.Errorf("failed to setup metadata tables: %w", err) } @@ -84,7 +81,6 @@ func (a *FlowableActivity) SetupMetadataTables(ctx context.Context, config *prot return nil } -// GetLastSyncedID implements GetLastSyncedID. func (a *FlowableActivity) GetLastSyncedID( ctx context.Context, config *protos.GetLastSyncedIDInput, @@ -94,17 +90,16 @@ func (a *FlowableActivity) GetLastSyncedID( if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) var lastOffset int64 - lastOffset, err = dstConn.GetLastOffset(config.FlowJobName) + lastOffset, err = dstConn.GetLastOffset(ctx, config.FlowJobName) if err != nil { return nil, err } return &protos.LastSyncState{Checkpoint: lastOffset}, nil } -// EnsurePullability implements EnsurePullability. func (a *FlowableActivity) EnsurePullability( ctx context.Context, config *protos.EnsurePullabilityBatchInput, @@ -114,9 +109,9 @@ func (a *FlowableActivity) EnsurePullability( if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) - output, err := srcConn.EnsurePullability(config) + output, err := srcConn.EnsurePullability(ctx, config) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, fmt.Errorf("failed to ensure pullability: %w", err) @@ -135,9 +130,9 @@ func (a *FlowableActivity) CreateRawTable( if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) - res, err := dstConn.CreateRawTable(config) + res, err := dstConn.CreateRawTable(ctx, config) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, err @@ -160,9 +155,9 @@ func (a *FlowableActivity) GetTableSchema( if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) - return srcConn.GetTableSchema(config) + return srcConn.GetTableSchema(ctx, config) } // CreateNormalizedTable creates a normalized table in the destination flowable. @@ -175,9 +170,9 @@ func (a *FlowableActivity) CreateNormalizedTable( if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(conn) + defer connectors.CloseConnector(ctx, conn) - setupNormalizedTablesOutput, err := conn.SetupNormalizedTables(config) + setupNormalizedTablesOutput, err := conn.SetupNormalizedTables(ctx, config) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowName, err) return nil, fmt.Errorf("failed to setup normalized tables: %w", err) @@ -197,7 +192,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, if err != nil { return nil, fmt.Errorf("failed to get destination connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) logger.Info("pulling records...") tblNameMapping := make(map[string]model.NameAndExclude, len(input.SyncFlowOptions.TableMappings)) @@ -210,7 +205,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, if err != nil { return nil, fmt.Errorf("failed to get source connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) slotNameForMetrics := fmt.Sprintf("peerflow_slot_%s", input.FlowConnectionConfigs.FlowJobName) if input.FlowConnectionConfigs.ReplicationSlotName != "" { @@ -235,7 +230,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, startTime := time.Now() flowName := input.FlowConnectionConfigs.FlowJobName errGroup.Go(func() error { - return srcConn.PullRecords(a.CatalogPool, &model.PullRecordsRequest{ + return srcConn.PullRecords(ctx, a.CatalogPool, &model.PullRecordsRequest{ FlowJobName: flowName, SrcTableIDNameMapping: input.SrcTableIdNameMapping, TableNameMapping: tblNameMapping, @@ -249,9 +244,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, OverrideReplicationSlotName: input.FlowConnectionConfigs.ReplicationSlotName, RelationMessageMapping: input.RelationMessageMapping, RecordStream: recordBatch, - SetLastOffset: func(lastOffset int64) error { - return dstConn.SetLastOffset(flowName, lastOffset) - }, }) }) @@ -267,7 +259,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } logger.Info("no records to push") - err := dstConn.ReplayTableSchemaDeltas(flowName, recordBatch.SchemaDeltas) + err := dstConn.ReplayTableSchemaDeltas(ctx, flowName, recordBatch.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema: %w", err) } @@ -279,7 +271,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, }, nil } - syncBatchID, err := dstConn.GetLastSyncBatchID(flowName) + syncBatchID, err := dstConn.GetLastSyncBatchID(ctx, flowName) if err != nil && config.Destination.Type != protos.DBType_EVENTHUB { return nil, err } @@ -298,7 +290,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } syncStartTime := time.Now() - res, err := dstConn.SyncRecords(&model.SyncRecordsRequest{ + res, err := dstConn.SyncRecords(ctx, &model.SyncRecordsRequest{ SyncBatchID: syncBatchID, Records: recordBatch, FlowJobName: input.FlowConnectionConfigs.FlowJobName, @@ -385,7 +377,7 @@ func (a *FlowableActivity) StartNormalize( if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName, input.SyncBatchID) @@ -393,14 +385,14 @@ func (a *FlowableActivity) StartNormalize( } else if err != nil { return nil, err } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) shutdown := utils.HeartbeatRoutine(ctx, func() string { return fmt.Sprintf("normalizing records from batch for job - %s", input.FlowConnectionConfigs.FlowJobName) }) defer shutdown() - res, err := dstConn.NormalizeRecords(&model.NormalizeRecordsRequest{ + res, err := dstConn.NormalizeRecords(ctx, &model.NormalizeRecordsRequest{ FlowJobName: input.FlowConnectionConfigs.FlowJobName, SyncBatchID: input.SyncBatchID, SoftDelete: input.FlowConnectionConfigs.SoftDelete, @@ -439,9 +431,9 @@ func (a *FlowableActivity) SetupQRepMetadataTables(ctx context.Context, config * if err != nil { return fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(conn) + defer connectors.CloseConnector(ctx, conn) - err = conn.SetupQRepMetadataTables(config) + err = conn.SetupQRepMetadataTables(ctx, config) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to setup metadata tables: %w", err) @@ -461,14 +453,14 @@ func (a *FlowableActivity) GetQRepPartitions(ctx context.Context, if err != nil { return nil, fmt.Errorf("failed to get qrep pull connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) shutdown := utils.HeartbeatRoutine(ctx, func() string { return fmt.Sprintf("getting partitions for job - %s", config.FlowJobName) }) defer shutdown() - partitions, err := srcConn.GetQRepPartitions(config, last) + partitions, err := srcConn.GetQRepPartitions(ctx, config, last) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, fmt.Errorf("failed to get partitions from source: %w", err) @@ -546,14 +538,14 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to get qrep source connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) dstConn, err := connectors.GetQRepSyncConnector(ctx, config.DestinationPeer) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to get qrep destination connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) logger.Info(fmt.Sprintf("replicating partition %s", partition.PartitionId)) shutdown := utils.HeartbeatRoutine(ctx, func() string { @@ -572,7 +564,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, go func() { pgConn := srcConn.(*connpostgres.PostgresConnector) - tmp, err := pgConn.PullQRepRecordStream(config, partition, stream) + tmp, err := pgConn.PullQRepRecordStream(ctx, config, partition, stream) numRecords := int64(tmp) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) @@ -589,7 +581,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, wg.Done() }() } else { - recordBatch, err := srcConn.PullQRepRecords(config, partition) + recordBatch, err := srcConn.PullQRepRecords(ctx, config, partition) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to pull qrep records: %w", err) @@ -608,7 +600,7 @@ func (a *FlowableActivity) replicateQRepPartition(ctx context.Context, } } - rowsSynced, err := dstConn.SyncQRepRecords(config, partition, stream) + rowsSynced, err := dstConn.SyncQRepRecords(ctx, config, partition, stream) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to sync records: %w", err) @@ -645,14 +637,14 @@ func (a *FlowableActivity) ConsolidateQRepPartitions(ctx context.Context, config } else if err != nil { return err } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) shutdown := utils.HeartbeatRoutine(ctx, func() string { return fmt.Sprintf("consolidating partitions for job - %s", config.FlowJobName) }) defer shutdown() - err = dstConn.ConsolidateQRepPartitions(config) + err = dstConn.ConsolidateQRepPartitions(ctx, config) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return err @@ -670,9 +662,9 @@ func (a *FlowableActivity) CleanupQRepFlow(ctx context.Context, config *protos.Q return err } - defer dst.Close() + defer dst.Close(ctx) - return dst.CleanupQRepFlow(config) + return dst.CleanupQRepFlow(ctx, config) } func (a *FlowableActivity) DropFlowSource(ctx context.Context, config *protos.ShutdownRequest) error { @@ -680,9 +672,9 @@ func (a *FlowableActivity) DropFlowSource(ctx context.Context, config *protos.Sh if err != nil { return fmt.Errorf("failed to get source connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) - return srcConn.PullFlowCleanup(config.FlowJobName) + return srcConn.PullFlowCleanup(ctx, config.FlowJobName) } func (a *FlowableActivity) DropFlowDestination(ctx context.Context, config *protos.ShutdownRequest) error { @@ -690,9 +682,9 @@ func (a *FlowableActivity) DropFlowDestination(ctx context.Context, config *prot if err != nil { return fmt.Errorf("failed to get destination connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) - return dstConn.SyncFlowCleanup(config.FlowJobName) + return dstConn.SyncFlowCleanup(ctx, config.FlowJobName) } func (a *FlowableActivity) getPostgresPeerConfigs(ctx context.Context) ([]*protos.Peer, error) { @@ -802,7 +794,7 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context, a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to get qrep source connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) pgSrcConn := srcConn.(*connpostgres.PostgresConnector) logger.Info(fmt.Sprintf("current last partition value is %v", last)) attemptCount := 1 @@ -824,7 +816,7 @@ func (a *FlowableActivity) QRepWaitUntilNewRows(ctx context.Context, } } - result, err := pgSrcConn.CheckForUpdatedMaxValue(config, last) + result, err := pgSrcConn.CheckForUpdatedMaxValue(ctx, config, last) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return fmt.Errorf("failed to check for new rows: %w", err) @@ -848,7 +840,7 @@ func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.Rena a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) shutdown := utils.HeartbeatRoutine(ctx, func() string { return fmt.Sprintf("renaming tables for job - %s", config.FlowJobName) @@ -861,14 +853,14 @@ func (a *FlowableActivity) RenameTables(ctx context.Context, config *protos.Rena a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, fmt.Errorf("failed to cast connector to snowflake connector") } - return sfConn.RenameTables(config) + return sfConn.RenameTables(ctx, config) } else if config.Peer.Type == protos.DBType_BIGQUERY { bqConn, ok := dstConn.(*connbigquery.BigQueryConnector) if !ok { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return nil, fmt.Errorf("failed to cast connector to bigquery connector") } - return bqConn.RenameTables(config) + return bqConn.RenameTables(ctx, config) } return nil, fmt.Errorf("rename tables is only supported on snowflake and bigquery") } @@ -881,20 +873,20 @@ func (a *FlowableActivity) CreateTablesFromExisting(ctx context.Context, req *pr if err != nil { return nil, fmt.Errorf("failed to get connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) if req.Peer.Type == protos.DBType_SNOWFLAKE { sfConn, ok := dstConn.(*connsnowflake.SnowflakeConnector) if !ok { return nil, fmt.Errorf("failed to cast connector to snowflake connector") } - return sfConn.CreateTablesFromExisting(req) + return sfConn.CreateTablesFromExisting(ctx, req) } else if req.Peer.Type == protos.DBType_BIGQUERY { bqConn, ok := dstConn.(*connbigquery.BigQueryConnector) if !ok { return nil, fmt.Errorf("failed to cast connector to bigquery connector") } - return bqConn.CreateTablesFromExisting(req) + return bqConn.CreateTablesFromExisting(ctx, req) } a.Alerter.LogFlowError(ctx, req.FlowJobName, err) return nil, fmt.Errorf("create tables from existing is only supported on snowflake and bigquery") @@ -914,13 +906,13 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context, if err != nil { return 0, fmt.Errorf("failed to get qrep source connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) dstConn, err := connectors.GetQRepSyncConnector(ctx, config.DestinationPeer) if err != nil { return 0, fmt.Errorf("failed to get qrep destination connector: %w", err) } - defer connectors.CloseConnector(dstConn) + defer connectors.CloseConnector(ctx, dstConn) logger.Info("replicating xmin") @@ -934,7 +926,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context, pgConn := srcConn.(*connpostgres.PostgresConnector) var pullErr error var numRecords int - numRecords, currentSnapshotXmin, pullErr = pgConn.PullXminRecordStream(config, partition, stream) + numRecords, currentSnapshotXmin, pullErr = pgConn.PullXminRecordStream(ctx, config, partition, stream) if pullErr != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) logger.Warn(fmt.Sprintf("[xmin] failed to pull records: %v", err)) @@ -981,7 +973,7 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context, }) defer shutdown() - rowsSynced, err := dstConn.SyncQRepRecords(config, partition, stream) + rowsSynced, err := dstConn.SyncQRepRecords(ctx, config, partition, stream) if err != nil { a.Alerter.LogFlowError(ctx, config.FlowJobName, err) return 0, fmt.Errorf("failed to sync records: %w", err) @@ -1020,9 +1012,9 @@ func (a *FlowableActivity) AddTablesToPublication(ctx context.Context, cfg *prot if err != nil { return fmt.Errorf("failed to get source connector: %w", err) } - defer connectors.CloseConnector(srcConn) + defer connectors.CloseConnector(ctx, srcConn) - err = srcConn.AddTablesToPublication(&protos.AddTablesToPublicationInput{ + err = srcConn.AddTablesToPublication(ctx, &protos.AddTablesToPublicationInput{ FlowJobName: cfg.FlowJobName, PublicationName: cfg.PublicationName, AdditionalTables: additionalTableMappings, diff --git a/flow/activities/snapshot_activity.go b/flow/activities/snapshot_activity.go index 518b868d44..95faba6b54 100644 --- a/flow/activities/snapshot_activity.go +++ b/flow/activities/snapshot_activity.go @@ -20,14 +20,14 @@ type SnapshotActivity struct { } // closes the slot signal -func (a *SnapshotActivity) CloseSlotKeepAlive(flowJobName string) error { +func (a *SnapshotActivity) CloseSlotKeepAlive(ctx context.Context, flowJobName string) error { if a.SnapshotConnections == nil { return nil } if s, ok := a.SnapshotConnections[flowJobName]; ok { close(s.signal.CloneComplete) - s.connector.Close() + s.connector.Close(ctx) } return nil @@ -60,7 +60,7 @@ func (a *SnapshotActivity) SetupReplication( logger.Error("failed to setup replication", slog.Any("error", err)) a.Alerter.LogFlowError(ctx, config.FlowJobName, err) // it is important to close the connection here as it is not closed in CloseSlotKeepAlive - connCloseErr := conn.Close() + connCloseErr := conn.Close(ctx) if connCloseErr != nil { logger.Error("failed to close connection", slog.Any("error", connCloseErr)) } @@ -69,7 +69,7 @@ func (a *SnapshotActivity) SetupReplication( // This now happens in a goroutine go func() { pgConn := conn.(*connpostgres.PostgresConnector) - err = pgConn.SetupReplication(slotSignal, config) + err = pgConn.SetupReplication(ctx, slotSignal, config) if err != nil { closeConnectionForError(err) replicationErr <- err diff --git a/flow/cmd/api.go b/flow/cmd/api.go index 88a435767f..2634aa3bf8 100644 --- a/flow/cmd/api.go +++ b/flow/cmd/api.go @@ -119,7 +119,7 @@ func APIMain(ctx context.Context, args *APIServerParams) error { grpcServer := grpc.NewServer() - catalogConn, err := utils.GetCatalogConnectionPoolFromEnv() + catalogConn, err := utils.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { return fmt.Errorf("unable to get catalog connection pool: %w", err) } diff --git a/flow/cmd/peer_data.go b/flow/cmd/peer_data.go index 4a3bf534ac..1cf854e3b0 100644 --- a/flow/cmd/peer_data.go +++ b/flow/cmd/peer_data.go @@ -233,9 +233,9 @@ func (h *FlowRequestHandler) GetSlotInfo( slog.Error("Failed to create postgres connector", slog.Any("error", err)) return &protos.PeerSlotResponse{SlotData: nil}, err } - defer pgConnector.Close() + defer pgConnector.Close(ctx) - slotInfo, err := pgConnector.GetSlotInfo("") + slotInfo, err := pgConnector.GetSlotInfo(ctx, "") if err != nil { slog.Error("Failed to get slot info", slog.Any("error", err)) return &protos.PeerSlotResponse{SlotData: nil}, err diff --git a/flow/cmd/snapshot_worker.go b/flow/cmd/snapshot_worker.go index e917818ea4..35f2e81039 100644 --- a/flow/cmd/snapshot_worker.go +++ b/flow/cmd/snapshot_worker.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "fmt" "log/slog" @@ -61,7 +62,7 @@ func SnapshotWorkerMain(opts *SnapshotWorkerOptions) error { EnableSessionWorker: true, }) - conn, err := utils.GetCatalogConnectionPoolFromEnv() + conn, err := utils.GetCatalogConnectionPoolFromEnv(context.Background()) if err != nil { return fmt.Errorf("unable to create catalog connection pool: %w", err) } diff --git a/flow/cmd/validate_mirror.go b/flow/cmd/validate_mirror.go index d72b69781c..848ce4138a 100644 --- a/flow/cmd/validate_mirror.go +++ b/flow/cmd/validate_mirror.go @@ -30,10 +30,10 @@ func (h *FlowRequestHandler) ValidateCDCMirror( Ok: false, }, fmt.Errorf("failed to create postgres connector: %v", err) } - defer pgPeer.Close() + defer pgPeer.Close(ctx) // Check permissions of postgres peer - err = pgPeer.CheckReplicationPermissions(sourcePeerConfig.User) + err = pgPeer.CheckReplicationPermissions(ctx, sourcePeerConfig.User) if err != nil { return &protos.ValidateCDCMirrorResponse{ Ok: false, @@ -46,7 +46,7 @@ func (h *FlowRequestHandler) ValidateCDCMirror( sourceTables = append(sourceTables, tableMapping.SourceTableIdentifier) } - err = pgPeer.CheckSourceTables(sourceTables, req.ConnectionConfigs.PublicationName) + err = pgPeer.CheckSourceTables(ctx, sourceTables, req.ConnectionConfigs.PublicationName) if err != nil { return &protos.ValidateCDCMirrorResponse{ Ok: false, diff --git a/flow/cmd/validate_peer.go b/flow/cmd/validate_peer.go index 366f9b24ab..d8ed016bbf 100644 --- a/flow/cmd/validate_peer.go +++ b/flow/cmd/validate_peer.go @@ -37,10 +37,10 @@ func (h *FlowRequestHandler) ValidatePeer( }, nil } - defer conn.Close() + defer conn.Close(ctx) if req.Peer.Type == protos.DBType_POSTGRES { - isValid, version, err := conn.(*connpostgres.PostgresConnector).MajorVersionCheck(connpostgres.POSTGRES_12) + isValid, version, err := conn.(*connpostgres.PostgresConnector).MajorVersionCheck(ctx, connpostgres.POSTGRES_12) if err != nil { slog.Error("/peer/validate: pg version check", slog.Any("error", err)) return nil, err @@ -55,7 +55,7 @@ func (h *FlowRequestHandler) ValidatePeer( } } - connErr := conn.ConnectionActive() + connErr := conn.ConnectionActive(ctx) if connErr != nil { return &protos.ValidatePeerResponse{ Status: protos.ValidatePeerStatus_INVALID, diff --git a/flow/cmd/worker.go b/flow/cmd/worker.go index 7a45f2306b..c63511470d 100644 --- a/flow/cmd/worker.go +++ b/flow/cmd/worker.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "fmt" "log" @@ -109,7 +110,7 @@ func WorkerMain(opts *WorkerOptions) error { clientOptions.ConnectionOptions = connOptions } - conn, err := utils.GetCatalogConnectionPoolFromEnv() + conn, err := utils.GetCatalogConnectionPoolFromEnv(context.Background()) if err != nil { return fmt.Errorf("unable to create catalog connection pool: %w", err) } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 81f1986205..8f15a5f325 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -50,7 +50,6 @@ type BigQueryServiceAccount struct { // BigQueryConnector is a Connector implementation for BigQuery. type BigQueryConnector struct { - ctx context.Context bqConfig *protos.BigqueryConfig client *bigquery.Client storageClient *storage.Client @@ -223,13 +222,12 @@ func NewBigQueryConnector(ctx context.Context, config *protos.BigqueryConfig) (* return nil, fmt.Errorf("failed to create Storage client: %v", err) } - catalogPool, err := cc.GetCatalogConnectionPoolFromEnv() + catalogPool, err := cc.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { return nil, fmt.Errorf("failed to create catalog connection pool: %v", err) } return &BigQueryConnector{ - ctx: ctx, bqConfig: config, client: client, datasetID: datasetID, @@ -242,7 +240,7 @@ func NewBigQueryConnector(ctx context.Context, config *protos.BigqueryConfig) (* } // Close closes the BigQuery driver. -func (c *BigQueryConnector) Close() error { +func (c *BigQueryConnector) Close(_ context.Context) error { if c == nil || c.client == nil { return nil } @@ -250,8 +248,8 @@ func (c *BigQueryConnector) Close() error { } // ConnectionActive returns true if the connection is active. -func (c *BigQueryConnector) ConnectionActive() error { - _, err := c.client.DatasetInProject(c.projectID, c.datasetID).Metadata(c.ctx) +func (c *BigQueryConnector) ConnectionActive(ctx context.Context) error { + _, err := c.client.DatasetInProject(c.projectID, c.datasetID).Metadata(ctx) if err != nil { return fmt.Errorf("failed to get dataset metadata: %v", err) } @@ -262,11 +260,11 @@ func (c *BigQueryConnector) ConnectionActive() error { return nil } -func (c *BigQueryConnector) NeedsSetupMetadataTables() bool { +func (c *BigQueryConnector) NeedsSetupMetadataTables(_ context.Context) bool { return false } -func (c *BigQueryConnector) waitForTableReady(datasetTable *datasetTable) error { +func (c *BigQueryConnector) waitForTableReady(ctx context.Context, datasetTable *datasetTable) error { table := c.client.DatasetInProject(c.projectID, datasetTable.dataset).Table(datasetTable.table) maxDuration := 5 * time.Minute deadline := time.Now().Add(maxDuration) @@ -278,7 +276,7 @@ func (c *BigQueryConnector) waitForTableReady(datasetTable *datasetTable) error return fmt.Errorf("timeout reached while waiting for table %s to be ready", datasetTable) } - _, err := table.Metadata(c.ctx) + _, err := table.Metadata(ctx) if err == nil { return nil } @@ -292,7 +290,9 @@ func (c *BigQueryConnector) waitForTableReady(datasetTable *datasetTable) error // ReplayTableSchemaDeltas changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. -func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, +func (c *BigQueryConnector) ReplayTableSchemaDeltas( + ctx context.Context, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { for _, schemaDelta := range schemaDeltas { @@ -308,7 +308,7 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, qValueKindToBigQueryType(addedColumn.ColumnType))) query.DefaultProjectID = c.projectID query.DefaultDatasetID = dstDatasetTable.dataset - _, err := query.Read(c.ctx) + _, err := query.Read(ctx) if err != nil { return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.ColumnName, schemaDelta.DstTableName, err) @@ -321,27 +321,30 @@ func (c *BigQueryConnector) ReplayTableSchemaDeltas(flowJobName string, return nil } -func (c *BigQueryConnector) SetupMetadataTables() error { +func (c *BigQueryConnector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *BigQueryConnector) GetLastOffset(jobName string) (int64, error) { - return c.pgMetadata.FetchLastOffset(c.ctx, jobName) +func (c *BigQueryConnector) GetLastOffset(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.FetchLastOffset(ctx, jobName) } -func (c *BigQueryConnector) SetLastOffset(jobName string, offset int64) error { - return c.pgMetadata.UpdateLastOffset(c.ctx, jobName, offset) +func (c *BigQueryConnector) SetLastOffset(ctx context.Context, jobName string, offset int64) error { + return c.pgMetadata.UpdateLastOffset(ctx, jobName, offset) } -func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastBatchID(c.ctx, jobName) +func (c *BigQueryConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastBatchID(ctx, jobName) } -func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastNormalizeBatchID(c.ctx, jobName) +func (c *BigQueryConnector) GetLastNormalizeBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastNormalizeBatchID(ctx, jobName) } -func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, +func (c *BigQueryConnector) getDistinctTableNamesInBatch( + ctx context.Context, + flowJobName string, + syncBatchID int64, normalizeBatchID int64, ) ([]string, error) { rawTableName := c.getRawTableName(flowJobName) @@ -354,7 +357,7 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn q := c.client.Query(query) q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID - it, err := q.Read(c.ctx) + it, err := q.Read(ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) return nil, err @@ -380,7 +383,10 @@ func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syn return distinctTableNames, nil } -func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, syncBatchID int64, +func (c *BigQueryConnector) getTableNametoUnchangedCols( + ctx context.Context, + flowJobName string, + syncBatchID int64, normalizeBatchID int64, ) (map[string][]string, error) { rawTableName := c.getRawTableName(flowJobName) @@ -398,7 +404,7 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync q := c.client.Query(query) q.DefaultDatasetID = c.datasetID q.DefaultProjectID = c.projectID - it, err := q.Read(c.ctx) + it, err := q.Read(ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) return nil, err @@ -427,12 +433,12 @@ func (c *BigQueryConnector) getTableNametoUnchangedCols(flowJobName string, sync // SyncRecords pushes records to the destination. // Currently only supports inserts, updates, and deletes. // More record types will be added in the future. -func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { +func (c *BigQueryConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) c.logger.Info(fmt.Sprintf("pushing records to %s.%s...", c.datasetID, rawTableName)) - res, err := c.syncRecordsViaAvro(req, rawTableName, req.SyncBatchID) + res, err := c.syncRecordsViaAvro(ctx, req, rawTableName, req.SyncBatchID) if err != nil { return nil, err } @@ -442,6 +448,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } func (c *BigQueryConnector) syncRecordsViaAvro( + ctx context.Context, req *model.SyncRecordsRequest, rawTableName string, syncBatchID int64, @@ -454,12 +461,12 @@ func (c *BigQueryConnector) syncRecordsViaAvro( } avroSync := NewQRepAvroSyncMethod(c, req.StagingPath, req.FlowJobName) - rawTableMetadata, err := c.client.DatasetInProject(c.projectID, c.datasetID).Table(rawTableName).Metadata(c.ctx) + rawTableMetadata, err := c.client.DatasetInProject(c.projectID, c.datasetID).Table(rawTableName).Metadata(ctx) if err != nil { return nil, fmt.Errorf("failed to get metadata of destination table: %w", err) } - res, err := avroSync.SyncRecords(c.ctx, req, rawTableName, + res, err := avroSync.SyncRecords(ctx, req, rawTableName, rawTableMetadata, syncBatchID, streamRes.Stream, streamReq.TableMapping) if err != nil { return nil, fmt.Errorf("failed to sync records via avro: %w", err) @@ -469,10 +476,10 @@ func (c *BigQueryConnector) syncRecordsViaAvro( } // NormalizeRecords normalizes raw table to destination table. -func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { +func (c *BigQueryConnector) NormalizeRecords(ctx context.Context, req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) - normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(ctx, req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } @@ -487,6 +494,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } distinctTableNames, err := c.getDistinctTableNamesInBatch( + ctx, req.FlowJobName, req.SyncBatchID, normBatchID, @@ -496,6 +504,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols( + ctx, req.FlowJobName, req.SyncBatchID, normBatchID, @@ -544,7 +553,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) q := c.client.Query(mergeStmt) q.DefaultProjectID = c.projectID q.DefaultDatasetID = dstDatasetTable.dataset - _, err := q.Read(c.ctx) + _, err := q.Read(ctx) if err != nil { return fmt.Errorf("failed to execute merge statement %s: %v", mergeStmt, err) } @@ -555,7 +564,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } } - err = c.pgMetadata.UpdateNormalizeBatchID(c.ctx, req.FlowJobName, req.SyncBatchID) + err = c.pgMetadata.UpdateNormalizeBatchID(ctx, req.FlowJobName, req.SyncBatchID) if err != nil { return nil, err } @@ -574,7 +583,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) // _peerdb_data STRING // _peerdb_record_type INT - 0 for insert, 1 for update, 2 for delete // _peerdb_match_data STRING - json of the match data (only for update and delete) -func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { +func (c *BigQueryConnector) CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { rawTableName := c.getRawTableName(req.FlowJobName) schema := bigquery.Schema{ @@ -592,7 +601,7 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr table := c.client.DatasetInProject(c.projectID, c.datasetID).Table(rawTableName) // check if the table exists - tableRef, err := table.Metadata(c.ctx) + tableRef, err := table.Metadata(ctx) if err == nil { // table exists, check if the schema matches if !reflect.DeepEqual(tableRef.Schema, schema) { @@ -629,7 +638,7 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr } // table does not exist, create it - err = table.Create(c.ctx, metadata) + err = table.Create(ctx, metadata) if err != nil { return nil, fmt.Errorf("failed to create table %s.%s: %w", c.datasetID, rawTableName, err) } @@ -642,12 +651,13 @@ func (c *BigQueryConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr // SetupNormalizedTables sets up normalized tables, implementing the Connector interface. // This runs CREATE TABLE IF NOT EXISTS on bigquery, using the schema and table name provided. func (c *BigQueryConnector) SetupNormalizedTables( + ctx context.Context, req *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { numTablesSetup := atomic.Uint32{} totalTables := uint32(len(req.TableNameSchemaMapping)) - shutdown := utils.HeartbeatRoutine(c.ctx, func() string { + shutdown := utils.HeartbeatRoutine(ctx, func() string { return fmt.Sprintf("setting up normalized tables - %d of %d done", numTablesSetup.Load(), totalTables) }) @@ -667,7 +677,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( datasetTable.string()) } dataset := c.client.DatasetInProject(c.projectID, datasetTable.dataset) - _, err = dataset.Metadata(c.ctx) + _, err = dataset.Metadata(ctx) // just assume this means dataset don't exist, and create it if err != nil { // if err message does not contain `notFound`, then other error happened. @@ -676,7 +686,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( datasetTable.dataset, err) } c.logger.Info(fmt.Sprintf("creating dataset %s...", dataset.DatasetID)) - err = dataset.Create(c.ctx, nil) + err = dataset.Create(ctx, nil) if err != nil { return nil, fmt.Errorf("failed to create BigQuery dataset %s: %w", dataset.DatasetID, err) } @@ -684,7 +694,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( table := dataset.Table(datasetTable.table) // check if the table exists - _, err = table.Metadata(c.ctx) + _, err = table.Metadata(ctx) if err == nil { // table exists, go to next table tableExistsMapping[tableIdentifier] = true @@ -755,7 +765,7 @@ func (c *BigQueryConnector) SetupNormalizedTables( Clustering: clustering, } - err = table.Create(c.ctx, metadata) + err = table.Create(ctx, metadata) if err != nil { return nil, fmt.Errorf("failed to create table %s: %w", tableIdentifier, err) } @@ -772,15 +782,15 @@ func (c *BigQueryConnector) SetupNormalizedTables( }, nil } -func (c *BigQueryConnector) SyncFlowCleanup(jobName string) error { - err := c.pgMetadata.DropMetadata(c.ctx, jobName) +func (c *BigQueryConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { + err := c.pgMetadata.DropMetadata(ctx, jobName) if err != nil { return fmt.Errorf("unable to clear metadata for sync flow cleanup: %w", err) } dataset := c.client.DatasetInProject(c.projectID, c.datasetID) // deleting PeerDB specific tables - err = dataset.Table(c.getRawTableName(jobName)).Delete(c.ctx) + err = dataset.Table(c.getRawTableName(jobName)).Delete(ctx) if err != nil { return fmt.Errorf("failed to delete raw table: %w", err) } @@ -795,7 +805,7 @@ func (c *BigQueryConnector) getRawTableName(flowJobName string) string { return fmt.Sprintf("_peerdb_raw_%s", flowJobName) } -func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { +func (c *BigQueryConnector) RenameTables(ctx context.Context, req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { // BigQuery doesn't really do transactions properly anyway so why bother? for _, renameRequest := range req.RenameTableOptions { srcDatasetTable, _ := c.convertToDatasetTable(renameRequest.CurrentName) @@ -803,7 +813,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos c.logger.Info(fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), dstDatasetTable.string())) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), + activity.RecordHeartbeat(ctx, fmt.Sprintf("renaming table '%s' to '%s'...", srcDatasetTable.string(), dstDatasetTable.string())) columnNames := make([]string, 0, len(renameRequest.TableSchema.Columns)) @@ -817,7 +827,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string())) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string())) + activity.RecordHeartbeat(ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dstDatasetTable.string())) c.logger.Info(fmt.Sprintf("INSERT INTO %s(%s) SELECT %s,true AS %s FROM %s WHERE (%s) NOT IN (SELECT %s FROM %s)", srcDatasetTable.string(), fmt.Sprintf("%s,%s", allCols, *req.SoftDeleteColName), @@ -831,7 +841,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos query.DefaultProjectID = c.projectID query.DefaultDatasetID = c.datasetID - _, err := query.Read(c.ctx) + _, err := query.Read(ctx) if err != nil { return nil, fmt.Errorf("unable to handle soft-deletes for table %s: %w", dstDatasetTable.string(), err) } @@ -840,7 +850,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos if req.SyncedAtColName != nil { c.logger.Info(fmt.Sprintf("setting synced at column for table '%s'...", srcDatasetTable.string())) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("setting synced at column for table '%s'...", + activity.RecordHeartbeat(ctx, fmt.Sprintf("setting synced at column for table '%s'...", srcDatasetTable.string())) c.logger.Info( @@ -852,7 +862,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos query.DefaultProjectID = c.projectID query.DefaultDatasetID = c.datasetID - _, err := query.Read(c.ctx) + _, err := query.Read(ctx) if err != nil { return nil, fmt.Errorf("unable to set synced at column for table %s: %w", srcDatasetTable.string(), err) } @@ -865,7 +875,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos dstDatasetTable.string())) dropQuery.DefaultProjectID = c.projectID dropQuery.DefaultDatasetID = c.datasetID - _, err := dropQuery.Read(c.ctx) + _, err := dropQuery.Read(ctx) if err != nil { return nil, fmt.Errorf("unable to drop table %s: %w", dstDatasetTable.string(), err) } @@ -877,7 +887,7 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos srcDatasetTable.string(), dstDatasetTable.table)) query.DefaultProjectID = c.projectID query.DefaultDatasetID = c.datasetID - _, err = query.Read(c.ctx) + _, err = query.Read(ctx) if err != nil { return nil, fmt.Errorf("unable to rename table %s to %s: %w", srcDatasetTable.string(), dstDatasetTable.string(), err) @@ -892,22 +902,23 @@ func (c *BigQueryConnector) RenameTables(req *protos.RenameTablesInput) (*protos }, nil } -func (c *BigQueryConnector) CreateTablesFromExisting(req *protos.CreateTablesFromExistingInput) ( - *protos.CreateTablesFromExistingOutput, error, -) { +func (c *BigQueryConnector) CreateTablesFromExisting( + ctx context.Context, + req *protos.CreateTablesFromExistingInput, +) (*protos.CreateTablesFromExistingOutput, error) { for newTable, existingTable := range req.NewToExistingTableMapping { newDatasetTable, _ := c.convertToDatasetTable(newTable) existingDatasetTable, _ := c.convertToDatasetTable(existingTable) c.logger.Info(fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) + activity.RecordHeartbeat(ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) // rename the src table to dst query := c.client.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` LIKE `%s`", newDatasetTable.string(), existingDatasetTable.string())) query.DefaultProjectID = c.projectID query.DefaultDatasetID = c.datasetID - _, err := query.Read(c.ctx) + _, err := query.Read(ctx) if err != nil { return nil, fmt.Errorf("unable to create table %s: %w", newTable, err) } diff --git a/flow/connectors/bigquery/qrep.go b/flow/connectors/bigquery/qrep.go index df2070e24d..4bc76986fc 100644 --- a/flow/connectors/bigquery/qrep.go +++ b/flow/connectors/bigquery/qrep.go @@ -1,6 +1,7 @@ package connbigquery import ( + "context" "fmt" "log/slog" "strings" @@ -13,6 +14,7 @@ import ( ) func (c *BigQueryConnector) SyncQRepRecords( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -23,12 +25,12 @@ func (c *BigQueryConnector) SyncQRepRecords( if err != nil { return 0, fmt.Errorf("failed to get schema of source table %s: %w", config.WatermarkTable, err) } - tblMetadata, err := c.replayTableSchemaDeltasQRep(config, partition, srcSchema) + tblMetadata, err := c.replayTableSchemaDeltasQRep(ctx, config, partition, srcSchema) if err != nil { return 0, err } - done, err := c.pgMetadata.IsQrepPartitionSynced(c.ctx, config.FlowJobName, partition.PartitionId) + done, err := c.pgMetadata.IsQrepPartitionSynced(ctx, config.FlowJobName, partition.PartitionId) if err != nil { return 0, fmt.Errorf("failed to check if partition %s is synced: %w", partition.PartitionId, err) } @@ -42,16 +44,19 @@ func (c *BigQueryConnector) SyncQRepRecords( partition.PartitionId, destTable)) avroSync := NewQRepAvroSyncMethod(c, config.StagingPath, config.FlowJobName) - return avroSync.SyncQRepRecords(c.ctx, config.FlowJobName, destTable, partition, + return avroSync.SyncQRepRecords(ctx, config.FlowJobName, destTable, partition, tblMetadata, stream, config.SyncedAtColName, config.SoftDeleteColName) } -func (c *BigQueryConnector) replayTableSchemaDeltasQRep(config *protos.QRepConfig, partition *protos.QRepPartition, +func (c *BigQueryConnector) replayTableSchemaDeltasQRep( + ctx context.Context, + config *protos.QRepConfig, + partition *protos.QRepPartition, srcSchema *model.QRecordSchema, ) (*bigquery.TableMetadata, error) { destDatasetTable, _ := c.convertToDatasetTable(config.DestinationTableIdentifier) bqTable := c.client.DatasetInProject(c.projectID, destDatasetTable.dataset).Table(destDatasetTable.table) - dstTableMetadata, err := bqTable.Metadata(c.ctx) + dstTableMetadata, err := bqTable.Metadata(ctx) if err != nil { return nil, fmt.Errorf("failed to get metadata of table %s: %w", destDatasetTable, err) } @@ -82,23 +87,23 @@ func (c *BigQueryConnector) replayTableSchemaDeltasQRep(config *protos.QRepConfi } } - err = c.ReplayTableSchemaDeltas(config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta}) + err = c.ReplayTableSchemaDeltas(ctx, config.FlowJobName, []*protos.TableSchemaDelta{tableSchemaDelta}) if err != nil { return nil, fmt.Errorf("failed to add columns to destination table: %w", err) } - dstTableMetadata, err = bqTable.Metadata(c.ctx) + dstTableMetadata, err = bqTable.Metadata(ctx) if err != nil { return nil, fmt.Errorf("failed to get metadata of table %s: %w", destDatasetTable, err) } return dstTableMetadata, nil } -func (c *BigQueryConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error { +func (c *BigQueryConnector) SetupQRepMetadataTables(ctx context.Context, config *protos.QRepConfig) error { if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE { query := c.client.Query(fmt.Sprintf("TRUNCATE TABLE %s", config.DestinationTableIdentifier)) query.DefaultDatasetID = c.datasetID query.DefaultProjectID = c.projectID - _, err := query.Read(c.ctx) + _, err := query.Read(ctx) if err != nil { return fmt.Errorf("failed to TRUNCATE table before query replication: %w", err) } diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 0910c5c104..43df1b9d8c 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -85,7 +85,7 @@ func (s *QRepAvroSyncMethod) SyncRecords( req.FlowJobName, rawTableName, syncBatchID), ) - err = s.connector.ReplayTableSchemaDeltas(req.FlowJobName, req.Records.SchemaDeltas) + err = s.connector.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -487,7 +487,7 @@ func (s *QRepAvroSyncMethod) writeToStage( } s.connector.logger.Info(fmt.Sprintf("Pushed from %s to BigQuery", avroFile.FilePath), idLog) - err = s.connector.waitForTableReady(stagingTable) + err = s.connector.waitForTableReady(ctx, stagingTable) if err != nil { return 0, fmt.Errorf("failed to wait for table to be ready: %w", err) } diff --git a/flow/connectors/clickhouse/cdc.go b/flow/connectors/clickhouse/cdc.go index 31d46b9a47..3074df7bad 100644 --- a/flow/connectors/clickhouse/cdc.go +++ b/flow/connectors/clickhouse/cdc.go @@ -1,6 +1,7 @@ package connclickhouse import ( + "context" "database/sql" "fmt" "log/slog" @@ -27,9 +28,9 @@ func (c *ClickhouseConnector) getRawTableName(flowJobName string) string { return fmt.Sprintf("_peerdb_raw_%s", flowJobName) } -func (c *ClickhouseConnector) checkIfTableExists(databaseName string, tableIdentifier string) (bool, error) { +func (c *ClickhouseConnector) checkIfTableExists(ctx context.Context, databaseName string, tableIdentifier string) (bool, error) { var result sql.NullInt32 - err := c.database.QueryRowContext(c.ctx, checkIfTableExistsSQL, databaseName, tableIdentifier).Scan(&result) + err := c.database.QueryRowContext(ctx, checkIfTableExistsSQL, databaseName, tableIdentifier).Scan(&result) if err != nil { return false, fmt.Errorf("error while reading result row: %w", err) } @@ -48,7 +49,7 @@ type MirrorJobRow struct { NormalizeBatchID int } -func (c *ClickhouseConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { +func (c *ClickhouseConnector) CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { rawTableName := c.getRawTableName(req.FlowJobName) createRawTableSQL := `CREATE TABLE IF NOT EXISTS %s ( @@ -62,7 +63,7 @@ func (c *ClickhouseConnector) CreateRawTable(req *protos.CreateRawTableInput) (* _peerdb_unchanged_toast_columns String ) ENGINE = ReplacingMergeTree ORDER BY _peerdb_uid;` - _, err := c.database.ExecContext(c.ctx, + _, err := c.database.ExecContext(ctx, fmt.Sprintf(createRawTableSQL, rawTableName)) if err != nil { return nil, fmt.Errorf("unable to create raw table: %w", err) @@ -73,6 +74,7 @@ func (c *ClickhouseConnector) CreateRawTable(req *protos.CreateRawTableInput) (* } func (c *ClickhouseConnector) syncRecordsViaAvro( + ctx context.Context, req *model.SyncRecordsRequest, rawTableIdentifier string, syncBatchID int64, @@ -95,12 +97,12 @@ func (c *ClickhouseConnector) syncRecordsViaAvro( return nil, err } - numRecords, err := avroSyncer.SyncRecords(c.ctx, destinationTableSchema, streamRes.Stream, req.FlowJobName) + numRecords, err := avroSyncer.SyncRecords(ctx, destinationTableSchema, streamRes.Stream, req.FlowJobName) if err != nil { return nil, err } - err = c.ReplayTableSchemaDeltas(req.FlowJobName, req.Records.SchemaDeltas) + err = c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -119,11 +121,11 @@ func (c *ClickhouseConnector) syncRecordsViaAvro( }, nil } -func (c *ClickhouseConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { +func (c *ClickhouseConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) c.logger.Info(fmt.Sprintf("pushing records to Clickhouse table %s", rawTableName)) - res, err := c.syncRecordsViaAvro(req, rawTableName, req.SyncBatchID) + res, err := c.syncRecordsViaAvro(ctx, req, rawTableName, req.SyncBatchID) if err != nil { return nil, err } @@ -133,7 +135,7 @@ func (c *ClickhouseConnector) SyncRecords(req *model.SyncRecordsRequest) (*model return nil, fmt.Errorf("failed to get last checkpoint: %w", err) } - err = c.pgMetadata.FinishBatch(c.ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) + err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) if err != nil { c.logger.Error("failed to increment id", slog.Any("error", err)) return nil, err @@ -142,41 +144,39 @@ func (c *ClickhouseConnector) SyncRecords(req *model.SyncRecordsRequest) (*model return res, nil } -func (c *ClickhouseConnector) SyncFlowCleanup(jobName string) error { - err := c.pgMetadata.DropMetadata(c.ctx, jobName) +func (c *ClickhouseConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { + err := c.pgMetadata.DropMetadata(ctx, jobName) if err != nil { return err } return nil } -// ReplayTableSchemaDeltas changes a destination table to match the schema at source -// This could involve adding or dropping multiple columns. -func (c *ClickhouseConnector) ReplayTableSchemaDeltas(flowJobName string, +func (c *ClickhouseConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { return nil } -func (c *ClickhouseConnector) NeedsSetupMetadataTables() bool { +func (c *ClickhouseConnector) NeedsSetupMetadataTables(_ context.Context) bool { return false } -func (c *ClickhouseConnector) SetupMetadataTables() error { +func (c *ClickhouseConnector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *ClickhouseConnector) GetLastSyncBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastBatchID(c.ctx, jobName) +func (c *ClickhouseConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastBatchID(ctx, jobName) } -func (c *ClickhouseConnector) GetLastOffset(jobName string) (int64, error) { - return c.pgMetadata.FetchLastOffset(c.ctx, jobName) +func (c *ClickhouseConnector) GetLastOffset(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.FetchLastOffset(ctx, jobName) } // update offset for a job -func (c *ClickhouseConnector) SetLastOffset(jobName string, offset int64) error { - err := c.pgMetadata.UpdateLastOffset(c.ctx, jobName, offset) +func (c *ClickhouseConnector) SetLastOffset(ctx context.Context, jobName string, offset int64) error { + err := c.pgMetadata.UpdateLastOffset(ctx, jobName, offset) if err != nil { c.logger.Error("failed to update last offset: ", slog.Any("error", err)) return err diff --git a/flow/connectors/clickhouse/clickhouse.go b/flow/connectors/clickhouse/clickhouse.go index 065a63d3da..80e2e80f5b 100644 --- a/flow/connectors/clickhouse/clickhouse.go +++ b/flow/connectors/clickhouse/clickhouse.go @@ -17,7 +17,6 @@ import ( ) type ClickhouseConnector struct { - ctx context.Context database *sql.DB pgMetadata *metadataStore.PostgresMetadataStore tableSchemaMapping map[string]*protos.TableSchema @@ -51,7 +50,7 @@ func NewClickhouseConnector( return nil, fmt.Errorf("failed to open connection to Clickhouse peer: %w", err) } - pgMetadata, err := metadataStore.NewPostgresMetadataStore(logger) + pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx) if err != nil { logger.Error("failed to create postgres metadata store", "error", err) return nil, err @@ -69,7 +68,6 @@ func NewClickhouseConnector( } return &ClickhouseConnector{ - ctx: ctx, database: database, pgMetadata: pgMetadata, tableSchemaMapping: nil, @@ -101,7 +99,7 @@ func connect(ctx context.Context, config *protos.ClickhouseConfig) (*sql.DB, err return conn, nil } -func (c *ClickhouseConnector) Close() error { +func (c *ClickhouseConnector) Close(_ context.Context) error { if c == nil || c.database == nil { return nil } @@ -113,12 +111,12 @@ func (c *ClickhouseConnector) Close() error { return nil } -func (c *ClickhouseConnector) ConnectionActive() error { +func (c *ClickhouseConnector) ConnectionActive(ctx context.Context) error { if c == nil || c.database == nil { return fmt.Errorf("ClickhouseConnector is nil") } // This also checks if database exists - err := c.database.PingContext(c.ctx) + err := c.database.PingContext(ctx) return err } diff --git a/flow/connectors/clickhouse/client.go b/flow/connectors/clickhouse/client.go index 8bd5a0221e..8bed8f48ed 100644 --- a/flow/connectors/clickhouse/client.go +++ b/flow/connectors/clickhouse/client.go @@ -8,14 +8,12 @@ import ( peersql "github.com/PeerDB-io/peer-flow/connectors/sql" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model/qvalue" ) type ClickhouseClient struct { peersql.GenericSQLQueryExecutor - // ctx is the context. - ctx context.Context - // config is the Snowflake config. Config *protos.ClickhouseConfig } @@ -28,11 +26,10 @@ func NewClickhouseClient(ctx context.Context, config *protos.ClickhouseConfig) ( } genericExecutor := *peersql.NewGenericSQLQueryExecutor( - ctx, database, clickhouseTypeToQValueKindMap, qvalue.QValueKindToSnowflakeTypeMap) + logger.LoggerFromCtx(ctx), database, clickhouseTypeToQValueKindMap, qvalue.QValueKindToSnowflakeTypeMap) return &ClickhouseClient{ GenericSQLQueryExecutor: genericExecutor, - ctx: ctx, Config: config, }, nil } diff --git a/flow/connectors/clickhouse/normalize.go b/flow/connectors/clickhouse/normalize.go index 2c1e5d4ab8..d6e91b06ec 100644 --- a/flow/connectors/clickhouse/normalize.go +++ b/flow/connectors/clickhouse/normalize.go @@ -1,6 +1,7 @@ package connclickhouse import ( + "context" "database/sql" "fmt" "strconv" @@ -19,11 +20,12 @@ const ( ) func (c *ClickhouseConnector) SetupNormalizedTables( + ctx context.Context, req *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { tableExistsMapping := make(map[string]bool) for tableIdentifier, tableSchema := range req.TableNameSchemaMapping { - tableAlreadyExists, err := c.checkIfTableExists(c.config.Database, tableIdentifier) + tableAlreadyExists, err := c.checkIfTableExists(ctx, c.config.Database, tableIdentifier) if err != nil { return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } @@ -42,7 +44,7 @@ func (c *ClickhouseConnector) SetupNormalizedTables( return nil, fmt.Errorf("error while generating create table sql for normalized table: %w", err) } - _, err = c.database.ExecContext(c.ctx, normalizedTableCreateSQL) + _, err = c.database.ExecContext(ctx, normalizedTableCreateSQL) if err != nil { return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err) } @@ -103,8 +105,8 @@ func generateCreateTableSQLForNormalizedTable( return stmtBuilder.String(), nil } -func (c *ClickhouseConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { - normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) +func (c *ClickhouseConnector) NormalizeRecords(ctx context.Context, req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { + normBatchID, err := c.GetLastNormalizeBatchID(ctx, req.FlowJobName) if err != nil { c.logger.Error("[clickhouse] error while getting last sync and normalize batch id", "error", err) return nil, err @@ -120,6 +122,7 @@ func (c *ClickhouseConnector) NormalizeRecords(req *model.NormalizeRecordsReques } destinationTableNames, err := c.getDistinctTableNamesInBatch( + ctx, req.FlowJobName, req.SyncBatchID, normBatchID, @@ -192,14 +195,14 @@ func (c *ClickhouseConnector) NormalizeRecords(req *model.NormalizeRecordsReques q := insertIntoSelectQuery.String() c.logger.Info(fmt.Sprintf("[clickhouse] insert into select query %s", q)) - _, err = c.database.ExecContext(c.ctx, q) + _, err = c.database.ExecContext(ctx, q) if err != nil { return nil, fmt.Errorf("error while inserting into normalized table: %w", err) } } endNormalizeBatchId := normBatchID + 1 - err = c.pgMetadata.UpdateNormalizeBatchID(c.ctx, req.FlowJobName, endNormalizeBatchId) + err = c.pgMetadata.UpdateNormalizeBatchID(ctx, req.FlowJobName, endNormalizeBatchId) if err != nil { c.logger.Error("[clickhouse] error while updating normalize batch id", "error", err) return nil, err @@ -213,6 +216,7 @@ func (c *ClickhouseConnector) NormalizeRecords(req *model.NormalizeRecordsReques } func (c *ClickhouseConnector) getDistinctTableNamesInBatch( + ctx context.Context, flowJobName string, syncBatchID int64, normalizeBatchID int64, @@ -224,7 +228,7 @@ func (c *ClickhouseConnector) getDistinctTableNamesInBatch( `SELECT DISTINCT _peerdb_destination_table_name FROM %s WHERE _peerdb_batch_id > %d AND _peerdb_batch_id <= %d`, rawTbl, normalizeBatchID, syncBatchID) - rows, err := c.database.QueryContext(c.ctx, q) + rows, err := c.database.QueryContext(ctx, q) if err != nil { return nil, fmt.Errorf("error while querying raw table for distinct table names in batch: %w", err) } @@ -252,8 +256,8 @@ func (c *ClickhouseConnector) getDistinctTableNamesInBatch( return tableNames, nil } -func (c *ClickhouseConnector) GetLastNormalizeBatchID(flowJobName string) (int64, error) { - normalizeBatchID, err := c.pgMetadata.GetLastNormalizeBatchID(c.ctx, flowJobName) +func (c *ClickhouseConnector) GetLastNormalizeBatchID(ctx context.Context, flowJobName string) (int64, error) { + normalizeBatchID, err := c.pgMetadata.GetLastNormalizeBatchID(ctx, flowJobName) if err != nil { return 0, fmt.Errorf("error while getting last normalize batch id: %w", err) } diff --git a/flow/connectors/clickhouse/qrep.go b/flow/connectors/clickhouse/qrep.go index 50e1b74a46..0abff100ef 100644 --- a/flow/connectors/clickhouse/qrep.go +++ b/flow/connectors/clickhouse/qrep.go @@ -1,6 +1,7 @@ package connclickhouse import ( + "context" "database/sql" "fmt" "log/slog" @@ -20,6 +21,7 @@ import ( const qRepMetadataTableName = "_peerdb_query_replication_metadata" func (c *ClickhouseConnector) SyncQRepRecords( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -31,7 +33,7 @@ func (c *ClickhouseConnector) SyncQRepRecords( slog.String("destinationTable", destTable), ) - done, err := c.isPartitionSynced(partition.PartitionId) + done, err := c.isPartitionSynced(ctx, partition.PartitionId) if err != nil { return 0, fmt.Errorf("failed to check if partition %s is synced: %w", partition.PartitionId, err) } @@ -49,7 +51,7 @@ func (c *ClickhouseConnector) SyncQRepRecords( avroSync := NewClickhouseAvroSyncMethod(config, c) - return avroSync.SyncQRepRecords(c.ctx, config, partition, tblSchema, stream) + return avroSync.SyncQRepRecords(ctx, config, partition, tblSchema, stream) } func (c *ClickhouseConnector) createMetadataInsertStatement( @@ -94,11 +96,11 @@ func (c *ClickhouseConnector) getTableSchema(tableName string) ([]*sql.ColumnTyp return columnTypes, nil } -func (c *ClickhouseConnector) isPartitionSynced(partitionID string) (bool, error) { +func (c *ClickhouseConnector) isPartitionSynced(ctx context.Context, partitionID string) (bool, error) { //nolint:gosec queryString := fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE partitionID = '%s'`, qRepMetadataTableName, partitionID) - row := c.database.QueryRow(queryString) + row := c.database.QueryRowContext(ctx, queryString) var count int if err := row.Scan(&count); err != nil { @@ -107,14 +109,14 @@ func (c *ClickhouseConnector) isPartitionSynced(partitionID string) (bool, error return count > 0, nil } -func (c *ClickhouseConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error { - err := c.createQRepMetadataTable() +func (c *ClickhouseConnector) SetupQRepMetadataTables(ctx context.Context, config *protos.QRepConfig) error { + err := c.createQRepMetadataTable(ctx) if err != nil { return err } if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE { - _, err = c.database.Exec(fmt.Sprintf("TRUNCATE TABLE %s", config.DestinationTableIdentifier)) + _, err = c.database.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", config.DestinationTableIdentifier)) if err != nil { return fmt.Errorf("failed to TRUNCATE table before query replication: %w", err) } @@ -123,7 +125,7 @@ func (c *ClickhouseConnector) SetupQRepMetadataTables(config *protos.QRepConfig) return nil } -func (c *ClickhouseConnector) createQRepMetadataTable() error { +func (c *ClickhouseConnector) createQRepMetadataTable(ctx context.Context) error { // Define the schema schemaStatement := ` CREATE TABLE IF NOT EXISTS %s ( @@ -136,7 +138,7 @@ func (c *ClickhouseConnector) createQRepMetadataTable() error { ORDER BY partitionID; ` queryString := fmt.Sprintf(schemaStatement, qRepMetadataTableName) - _, err := c.database.Exec(queryString) + _, err := c.database.ExecContext(ctx, queryString) if err != nil { c.logger.Error(fmt.Sprintf("failed to create table %s", qRepMetadataTableName), slog.Any("error", err)) @@ -147,19 +149,19 @@ func (c *ClickhouseConnector) createQRepMetadataTable() error { return nil } -func (c *ClickhouseConnector) ConsolidateQRepPartitions(config *protos.QRepConfig) error { +func (c *ClickhouseConnector) ConsolidateQRepPartitions(_ context.Context, config *protos.QRepConfig) error { c.logger.Info("Consolidating partitions noop") return nil } // CleanupQRepFlow function for clickhouse connector -func (c *ClickhouseConnector) CleanupQRepFlow(config *protos.QRepConfig) error { +func (c *ClickhouseConnector) CleanupQRepFlow(ctx context.Context, config *protos.QRepConfig) error { c.logger.Info("Cleaning up flow job") - return c.dropStage(config.StagingPath, config.FlowJobName) + return c.dropStage(ctx, config.StagingPath, config.FlowJobName) } // dropStage drops the stage for the given job. -func (c *ClickhouseConnector) dropStage(stagingPath string, job string) error { +func (c *ClickhouseConnector) dropStage(ctx context.Context, stagingPath string, job string) error { // if s3 we need to delete the contents of the bucket if strings.HasPrefix(stagingPath, "s3://") { s3o, err := utils.NewS3BucketAndPrefix(stagingPath) @@ -183,14 +185,14 @@ func (c *ClickhouseConnector) dropStage(stagingPath string, job string) error { Prefix: aws.String(fmt.Sprintf("%s/%s", s3o.Prefix, job)), }) for pages.HasMorePages() { - page, err := pages.NextPage(c.ctx) + page, err := pages.NextPage(ctx) if err != nil { c.logger.Error("failed to list objects from bucket", slog.Any("error", err)) return fmt.Errorf("failed to list objects from bucket: %w", err) } for _, object := range page.Contents { - _, err = s3svc.DeleteObject(c.ctx, &s3.DeleteObjectInput{ + _, err = s3svc.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(s3o.Bucket), Key: object.Key, }) diff --git a/flow/connectors/clickhouse/qrep_avro_sync.go b/flow/connectors/clickhouse/qrep_avro_sync.go index c32b23ecf7..94813e84f1 100644 --- a/flow/connectors/clickhouse/qrep_avro_sync.go +++ b/flow/connectors/clickhouse/qrep_avro_sync.go @@ -33,7 +33,7 @@ func NewClickhouseAvroSyncMethod( } } -func (s *ClickhouseAvroSyncMethod) CopyStageToDestination(avroFile *avro.AvroFile) error { +func (s *ClickhouseAvroSyncMethod) CopyStageToDestination(ctx context.Context, avroFile *avro.AvroFile) error { stagingPath := s.config.StagingPath if stagingPath == "" { stagingPath = s.config.DestinationPeer.GetClickhouseConfig().S3Path // "s3://avro-clickhouse" @@ -52,7 +52,7 @@ func (s *ClickhouseAvroSyncMethod) CopyStageToDestination(avroFile *avro.AvroFil query := fmt.Sprintf("INSERT INTO %s SELECT * FROM s3('%s','%s','%s', 'Avro')", s.config.DestinationTableIdentifier, avroFileUrl, awsCreds.AccessKeyID, awsCreds.SecretAccessKey) - _, err = s.connector.database.Exec(query) + _, err = s.connector.database.ExecContext(ctx, query) return err } @@ -85,7 +85,7 @@ func (s *ClickhouseAvroSyncMethod) SyncRecords( } defer avroFile.Cleanup() s.connector.logger.Info(fmt.Sprintf("written %d records to Avro file", avroFile.NumRecords), tableLog) - err = s.CopyStageToDestination(avroFile) + err = s.CopyStageToDestination(ctx, avroFile) if err != nil { return 0, err } @@ -134,13 +134,13 @@ func (s *ClickhouseAvroSyncMethod) SyncQRepRecords( query := fmt.Sprintf("INSERT INTO %s (%s) SELECT * FROM s3('%s','%s','%s', 'Avro')", config.DestinationTableIdentifier, selector, avroFileUrl, awsCreds.AccessKeyID, awsCreds.SecretAccessKey) - _, err = s.connector.database.Exec(query) + _, err = s.connector.database.ExecContext(ctx, query) if err != nil { return 0, err } - err = s.insertMetadata(partition, config.FlowJobName, startTime) + err = s.insertMetadata(ctx, partition, config.FlowJobName, startTime) if err != nil { return -1, err } @@ -189,6 +189,7 @@ func (s *ClickhouseAvroSyncMethod) writeToAvroFile( } func (s *ClickhouseAvroSyncMethod) insertMetadata( + ctx context.Context, partition *protos.QRepPartition, flowJobName string, startTime time.Time, @@ -201,7 +202,7 @@ func (s *ClickhouseAvroSyncMethod) insertMetadata( return fmt.Errorf("failed to create metadata insert statement: %w", err) } - if _, err := s.connector.database.Exec(insertMetadataStmt); err != nil { + if _, err := s.connector.database.ExecContext(ctx, insertMetadataStmt); err != nil { return fmt.Errorf("failed to execute metadata insert statement: %w", err) } diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 78e1a27b7e..d3e4299bd9 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -23,76 +23,76 @@ import ( var ErrUnsupportedFunctionality = errors.New("requested connector does not support functionality") type Connector interface { - Close() error - ConnectionActive() error + Close(context.Context) error + ConnectionActive(context.Context) error } type CDCPullConnector interface { Connector // GetTableSchema returns the schema of a table. - GetTableSchema(req *protos.GetTableSchemaBatchInput) (*protos.GetTableSchemaBatchOutput, error) + GetTableSchema(ctx context.Context, req *protos.GetTableSchemaBatchInput) (*protos.GetTableSchemaBatchOutput, error) // EnsurePullability ensures that the connector is pullable. - EnsurePullability(req *protos.EnsurePullabilityBatchInput) ( + EnsurePullability(ctx context.Context, req *protos.EnsurePullabilityBatchInput) ( *protos.EnsurePullabilityBatchOutput, error) // Methods related to retrieving and pushing records for this connector as a source and destination. // PullRecords pulls records from the source, and returns a RecordBatch. // This method should be idempotent, and should be able to be called multiple times with the same request. - PullRecords(catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error + PullRecords(ctx context.Context, catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error // PullFlowCleanup drops both the Postgres publication and replication slot, as a part of DROP MIRROR - PullFlowCleanup(jobName string) error + PullFlowCleanup(ctx context.Context, jobName string) error // HandleSlotInfo update monitoring info on slot size etc // threadsafe HandleSlotInfo(ctx context.Context, alerter *alerting.Alerter, catalogPool *pgxpool.Pool, slotName string, peerName string) error // GetSlotInfo returns the WAL (or equivalent) info of a slot for the connector. - GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) + GetSlotInfo(ctx context.Context, slotName string) ([]*protos.SlotInfo, error) // AddTablesToPublication adds additional tables added to a mirror to the publication also - AddTablesToPublication(req *protos.AddTablesToPublicationInput) error + AddTablesToPublication(ctx context.Context, req *protos.AddTablesToPublicationInput) error } type CDCSyncConnector interface { Connector // NeedsSetupMetadataTables checks if the metadata table [PEERDB_MIRROR_JOBS] needs to be created. - NeedsSetupMetadataTables() bool + NeedsSetupMetadataTables(ctx context.Context) bool // SetupMetadataTables creates the metadata table [PEERDB_MIRROR_JOBS] if necessary. - SetupMetadataTables() error + SetupMetadataTables(ctx context.Context) error // GetLastOffset gets the last offset from the metadata table on the destination - GetLastOffset(jobName string) (int64, error) + GetLastOffset(ctx context.Context, jobName string) (int64, error) // SetLastOffset updates the last offset on the metadata table on the destination - SetLastOffset(jobName string, lastOffset int64) error + SetLastOffset(ctx context.Context, jobName string, lastOffset int64) error // GetLastSyncBatchID gets the last batch synced to the destination from the metadata table - GetLastSyncBatchID(jobName string) (int64, error) + GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) // CreateRawTable creates a raw table for the connector with a given name and a fixed schema. - CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) + CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. // Connectors which are non-normalizing should implement this as a nop. - ReplayTableSchemaDeltas(flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error + ReplayTableSchemaDeltas(ctx context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error // SetupNormalizedTables sets up the normalized table on the connector. - SetupNormalizedTables(req *protos.SetupNormalizedTableBatchInput) ( + SetupNormalizedTables(ctx context.Context, req *protos.SetupNormalizedTableBatchInput) ( *protos.SetupNormalizedTableBatchOutput, error) // SyncRecords pushes records to the destination peer and stores it in PeerDB specific tables. // This method should be idempotent, and should be able to be called multiple times with the same request. - SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) + SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) // SyncFlowCleanup drops metadata tables on the destination, as a part of DROP MIRROR. - SyncFlowCleanup(jobName string) error + SyncFlowCleanup(ctx context.Context, jobName string) error } type CDCNormalizeConnector interface { @@ -100,28 +100,28 @@ type CDCNormalizeConnector interface { // NormalizeRecords merges records pushed earlier into the destination table. // This method should be idempotent, and should be able to be called multiple times with the same request. - NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) + NormalizeRecords(ctx context.Context, req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) } type QRepPullConnector interface { Connector // GetQRepPartitions returns the partitions for a given table that haven't been synced yet. - GetQRepPartitions(config *protos.QRepConfig, last *protos.QRepPartition) ([]*protos.QRepPartition, error) + GetQRepPartitions(ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition) ([]*protos.QRepPartition, error) // PullQRepRecords returns the records for a given partition. - PullQRepRecords(config *protos.QRepConfig, partition *protos.QRepPartition) (*model.QRecordBatch, error) + PullQRepRecords(ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition) (*model.QRecordBatch, error) } type QRepSyncConnector interface { Connector // SetupQRepMetadataTables sets up the metadata tables for QRep. - SetupQRepMetadataTables(config *protos.QRepConfig) error + SetupQRepMetadataTables(ctx context.Context, config *protos.QRepConfig) error // SyncQRepRecords syncs the records for a given partition. // returns the number of records synced. - SyncQRepRecords(config *protos.QRepConfig, partition *protos.QRepPartition, + SyncQRepRecords(ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream) (int, error) } @@ -129,10 +129,10 @@ type QRepConsolidateConnector interface { Connector // ConsolidateQRepPartitions consolidates the partitions for a given table. - ConsolidateQRepPartitions(config *protos.QRepConfig) error + ConsolidateQRepPartitions(ctx context.Context, config *protos.QRepConfig) error // CleanupQRepFlow cleans up the QRep flow for a given table. - CleanupQRepFlow(config *protos.QRepConfig) error + CleanupQRepFlow(ctx context.Context, config *protos.QRepConfig) error } func GetCDCPullConnector(ctx context.Context, config *protos.Peer) (CDCPullConnector, error) { @@ -278,12 +278,12 @@ func GetQRepConsolidateConnector(ctx context.Context, } } -func CloseConnector(conn Connector) { +func CloseConnector(ctx context.Context, conn Connector) { if conn == nil { return } - err := conn.Close() + err := conn.Close(ctx) if err != nil { slog.Error("error closing connector", slog.Any("error", err)) } diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 25f1fd812d..55d2a4b6c4 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -19,7 +19,6 @@ import ( ) type EventHubConnector struct { - ctx context.Context config *protos.EventHubGroupConfig pgMetadata *metadataStore.PostgresMetadataStore creds *azidentity.DefaultAzureCredential @@ -40,14 +39,13 @@ func NewEventHubConnector( } hubManager := NewEventHubManager(defaultAzureCreds, config) - pgMetadata, err := metadataStore.NewPostgresMetadataStore(logger) + pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx) if err != nil { logger.Error("failed to create postgres metadata store", "error", err) return nil, err } return &EventHubConnector{ - ctx: ctx, config: config, pgMetadata: pgMetadata, creds: defaultAzureCreds, @@ -55,8 +53,8 @@ func NewEventHubConnector( }, nil } -func (c *EventHubConnector) Close() error { - err := c.hubManager.Close(context.Background()) +func (c *EventHubConnector) Close(ctx context.Context) error { + err := c.hubManager.Close(ctx) if err != nil { c.logger.Error("failed to close event hub manager", slog.Any("error", err)) return err @@ -65,28 +63,28 @@ func (c *EventHubConnector) Close() error { return nil } -func (c *EventHubConnector) ConnectionActive() error { +func (c *EventHubConnector) ConnectionActive(_ context.Context) error { return nil } -func (c *EventHubConnector) NeedsSetupMetadataTables() bool { +func (c *EventHubConnector) NeedsSetupMetadataTables(_ context.Context) bool { return false } -func (c *EventHubConnector) SetupMetadataTables() error { +func (c *EventHubConnector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *EventHubConnector) GetLastSyncBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastBatchID(c.ctx, jobName) +func (c *EventHubConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastBatchID(ctx, jobName) } -func (c *EventHubConnector) GetLastOffset(jobName string) (int64, error) { - return c.pgMetadata.FetchLastOffset(c.ctx, jobName) +func (c *EventHubConnector) GetLastOffset(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.FetchLastOffset(ctx, jobName) } -func (c *EventHubConnector) SetLastOffset(jobName string, offset int64) error { - err := c.pgMetadata.UpdateLastOffset(c.ctx, jobName, offset) +func (c *EventHubConnector) SetLastOffset(ctx context.Context, jobName string, offset int64) error { + err := c.pgMetadata.UpdateLastOffset(ctx, jobName, offset) if err != nil { c.logger.Error(fmt.Sprintf("failed to update last offset: %v", err)) return err @@ -97,10 +95,10 @@ func (c *EventHubConnector) SetLastOffset(jobName string, offset int64) error { // returns the number of records synced func (c *EventHubConnector) processBatch( + ctx context.Context, flowJobName string, batch *model.CDCRecordStream, ) (uint32, error) { - ctx := context.Background() batchPerTopic := NewHubBatches(c.hubManager) toJSONOpts := model.NewToJSONOptions(c.config.UnnestColumns, false) @@ -111,7 +109,7 @@ func (c *EventHubConnector) processBatch( lastUpdatedOffset := int64(0) numRecords := atomic.Uint32{} - shutdown := utils.HeartbeatRoutine(c.ctx, func() string { + shutdown := utils.HeartbeatRoutine(ctx, func() string { return fmt.Sprintf("processed %d records for flow %s", numRecords.Load(), flowJobName) }) defer shutdown() @@ -180,8 +178,8 @@ func (c *EventHubConnector) processBatch( c.logger.Info("processBatch", slog.Int("number of records processed for sending", int(curNumRecords))) } - case <-c.ctx.Done(): - return 0, fmt.Errorf("[eventhub] context cancelled %w", c.ctx.Err()) + case <-ctx.Done(): + return 0, fmt.Errorf("[eventhub] context cancelled %w", ctx.Err()) case <-ticker.C: err := batchPerTopic.flushAllBatches(ctx, flowJobName) @@ -190,7 +188,7 @@ func (c *EventHubConnector) processBatch( } if lastSeenLSN > lastUpdatedOffset { - err = c.SetLastOffset(flowJobName, lastSeenLSN) + err = c.SetLastOffset(ctx, flowJobName, lastSeenLSN) lastUpdatedOffset = lastSeenLSN c.logger.Info("processBatch", slog.Int64("updated last offset", lastSeenLSN)) if err != nil { @@ -201,10 +199,10 @@ func (c *EventHubConnector) processBatch( } } -func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { +func (c *EventHubConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { batch := req.Records - numRecords, err := c.processBatch(req.FlowJobName, batch) + numRecords, err := c.processBatch(ctx, req.FlowJobName, batch) if err != nil { c.logger.Error("failed to process batch", slog.Any("error", err)) return nil, err @@ -216,7 +214,7 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return nil, err } - err = c.pgMetadata.FinishBatch(c.ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) + err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) if err != nil { c.logger.Error("failed to increment id", slog.Any("error", err)) return nil, err @@ -231,7 +229,7 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S }, nil } -func (c *EventHubConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { +func (c *EventHubConnector) CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { // create topics for each table // key is the source table and value is the "eh_peer.eh_topic" that ought to be used. tableMap := req.GetTableNameMapping() @@ -245,7 +243,7 @@ func (c *EventHubConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr return nil, err } - err = c.hubManager.EnsureEventHubExists(c.ctx, name) + err = c.hubManager.EnsureEventHubExists(ctx, name) if err != nil { c.logger.Error("failed to ensure eventhub exists", slog.Any("error", err), slog.String("destinationTable", destinationTable)) @@ -258,21 +256,21 @@ func (c *EventHubConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr }, nil } -func (c *EventHubConnector) ReplayTableSchemaDeltas(flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *EventHubConnector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { c.logger.Info("ReplayTableSchemaDeltas for event hub is a no-op") return nil } func (c *EventHubConnector) SetupNormalizedTables( - req *protos.SetupNormalizedTableBatchInput) ( - *protos.SetupNormalizedTableBatchOutput, error, -) { + _ context.Context, + req *protos.SetupNormalizedTableBatchInput, +) (*protos.SetupNormalizedTableBatchOutput, error) { c.logger.Info("normalization for event hub is a no-op") return &protos.SetupNormalizedTableBatchOutput{ TableExistsMapping: nil, }, nil } -func (c *EventHubConnector) SyncFlowCleanup(jobName string) error { - return c.pgMetadata.DropMetadata(c.ctx, jobName) +func (c *EventHubConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { + return c.pgMetadata.DropMetadata(ctx, jobName) } diff --git a/flow/connectors/external_metadata/store.go b/flow/connectors/external_metadata/store.go index cef4eafd05..5fffec0f69 100644 --- a/flow/connectors/external_metadata/store.go +++ b/flow/connectors/external_metadata/store.go @@ -14,6 +14,7 @@ import ( cc "github.com/PeerDB-io/peer-flow/connectors/utils/catalog" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/logger" ) const ( @@ -26,15 +27,15 @@ type PostgresMetadataStore struct { logger log.Logger } -func NewPostgresMetadataStore(logger log.Logger) (*PostgresMetadataStore, error) { - pool, err := cc.GetCatalogConnectionPoolFromEnv() +func NewPostgresMetadataStore(ctx context.Context) (*PostgresMetadataStore, error) { + pool, err := cc.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { return nil, fmt.Errorf("failed to create catalog connection pool: %w", err) } return &PostgresMetadataStore{ pool: pool, - logger: logger, + logger: logger.LoggerFromCtx(ctx), }, nil } diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 69fbca87d4..aef4f81c47 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -26,7 +26,6 @@ import ( ) type PostgresCDCSource struct { - ctx context.Context replConn *pgx.Conn SrcTableIDNameMapping map[uint32]string TableNameMapping map[string]model.NameAndExclude @@ -47,7 +46,6 @@ type PostgresCDCSource struct { } type PostgresCDCConfig struct { - AppContext context.Context Connection *pgx.Conn Slot string Publication string @@ -65,15 +63,14 @@ type startReplicationOpts struct { } // Create a new PostgresCDCSource -func NewPostgresCDCSource(cdcConfig *PostgresCDCConfig, customTypeMap map[uint32]string) (*PostgresCDCSource, error) { - childToParentRelIDMap, err := getChildToParentRelIDMap(cdcConfig.AppContext, cdcConfig.Connection) +func NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig, customTypeMap map[uint32]string) (*PostgresCDCSource, error) { + childToParentRelIDMap, err := getChildToParentRelIDMap(ctx, cdcConfig.Connection) if err != nil { return nil, fmt.Errorf("error getting child to parent relid map: %w", err) } - flowName, _ := cdcConfig.AppContext.Value(shared.FlowNameKey).(string) + flowName, _ := ctx.Value(shared.FlowNameKey).(string) return &PostgresCDCSource{ - ctx: cdcConfig.AppContext, replConn: cdcConfig.Connection, SrcTableIDNameMapping: cdcConfig.SrcTableIDNameMapping, TableNameMapping: cdcConfig.TableNameMapping, @@ -123,7 +120,7 @@ func getChildToParentRelIDMap(ctx context.Context, conn *pgx.Conn) (map[uint32]u } // PullRecords pulls records from the cdc stream -func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) error { +func (p *PostgresCDCSource) PullRecords(ctx context.Context, req *model.PullRecordsRequest) error { replicationOpts, err := p.replicationOptions() if err != nil { return fmt.Errorf("error getting replication options: %w", err) @@ -145,18 +142,18 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) error { replicationOpts: *replicationOpts, } - err = p.startReplication(opts) + err = p.startReplication(ctx, opts) if err != nil { return fmt.Errorf("error starting replication: %w", err) } p.logger.Info(fmt.Sprintf("started replication on slot %s at startLSN: %d", p.slot, startLSN)) - return p.consumeStream(pgConn, req, clientXLogPos, req.RecordStream) + return p.consumeStream(ctx, pgConn, req, clientXLogPos, req.RecordStream) } -func (p *PostgresCDCSource) startReplication(opts startReplicationOpts) error { - err := pglogrepl.StartReplication(p.ctx, opts.conn, p.slot, opts.startLSN, opts.replicationOpts) +func (p *PostgresCDCSource) startReplication(ctx context.Context, opts startReplicationOpts) error { + err := pglogrepl.StartReplication(ctx, opts.conn, p.slot, opts.startLSN, opts.replicationOpts) if err != nil { p.logger.Error("error starting replication", slog.Any("error", err)) return fmt.Errorf("error starting replication at startLsn - %d: %w", opts.startLSN, err) @@ -183,13 +180,14 @@ func (p *PostgresCDCSource) replicationOptions() (*pglogrepl.StartReplicationOpt // start consuming the cdc stream func (p *PostgresCDCSource) consumeStream( + ctx context.Context, conn *pgconn.PgConn, req *model.PullRecordsRequest, clientXLogPos pglogrepl.LSN, records *model.CDCRecordStream, ) error { defer func() { - err := conn.Close(p.ctx) + err := conn.Close(ctx) if err != nil { p.logger.Error("error closing replication connection", slog.Any("error", err)) } @@ -202,7 +200,7 @@ func (p *PostgresCDCSource) consumeStream( if clientXLogPos > 0 { consumedXLogPos = clientXLogPos - err := pglogrepl.SendStandbyStatusUpdate(p.ctx, conn, + err := pglogrepl.SendStandbyStatusUpdate(ctx, conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: consumedXLogPos}) if err != nil { return fmt.Errorf("[initial-flush] SendStandbyStatusUpdate failed: %w", err) @@ -222,7 +220,7 @@ func (p *PostgresCDCSource) consumeStream( } }() - shutdown := utils.HeartbeatRoutine(p.ctx, func() string { + shutdown := utils.HeartbeatRoutine(ctx, func() string { jobName := p.flowJobName currRecords := cdcRecordsStorage.Len() msg := fmt.Sprintf("pulling records for job - %s, currently have %d records", jobName, currRecords) @@ -254,7 +252,7 @@ func (p *PostgresCDCSource) consumeStream( for { if pkmRequiresResponse { - err := pglogrepl.SendStandbyStatusUpdate(p.ctx, conn, + err := pglogrepl.SendStandbyStatusUpdate(ctx, conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: consumedXLogPos}) if err != nil { return fmt.Errorf("SendStandbyStatusUpdate failed: %w", err) @@ -305,18 +303,18 @@ func (p *PostgresCDCSource) consumeStream( nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) } - var ctx context.Context + var receiveCtx context.Context var cancel context.CancelFunc if cdcRecordsStorage.IsEmpty() { - ctx, cancel = context.WithCancel(p.ctx) + receiveCtx, cancel = context.WithCancel(ctx) } else { - ctx, cancel = context.WithDeadline(p.ctx, nextStandbyMessageDeadline) + receiveCtx, cancel = context.WithDeadline(ctx, nextStandbyMessageDeadline) } - rawMsg, err := conn.ReceiveMessage(ctx) + rawMsg, err := conn.ReceiveMessage(receiveCtx) cancel() - ctxErr := p.ctx.Err() + ctxErr := ctx.Err() if ctxErr != nil { return fmt.Errorf("consumeStream preempted: %w", ctxErr) } @@ -368,7 +366,7 @@ func (p *PostgresCDCSource) consumeStream( p.logger.Debug(fmt.Sprintf("XLogData => WALStart %s ServerWALEnd %s ServerTime %s\n", xld.WALStart, xld.ServerWALEnd, xld.ServerTime)) - rec, err := p.processMessage(records, xld, clientXLogPos) + rec, err := p.processMessage(ctx, records, xld, clientXLogPos) if err != nil { return fmt.Errorf("error processing message: %w", err) } @@ -488,7 +486,10 @@ func (p *PostgresCDCSource) consumeStream( } } -func (p *PostgresCDCSource) processMessage(batch *model.CDCRecordStream, xld pglogrepl.XLogData, +func (p *PostgresCDCSource) processMessage( + ctx context.Context, + batch *model.CDCRecordStream, + xld pglogrepl.XLogData, currentClientXlogPos pglogrepl.LSN, ) (model.Record, error) { logicalMsg, err := pglogrepl.Parse(xld.WALData) @@ -528,7 +529,7 @@ func (p *PostgresCDCSource) processMessage(batch *model.CDCRecordStream, xld pgl } else { // RelationMessages don't contain an LSN, so we use current clientXlogPos instead. // https://github.com/postgres/postgres/blob/8b965c549dc8753be8a38c4a1b9fabdb535a4338/src/backend/replication/logical/proto.c#L670 - return p.processRelationMessage(currentClientXlogPos, convertRelationMessageToProto(msg)) + return p.processRelationMessage(ctx, currentClientXlogPos, convertRelationMessageToProto(msg)) } case *pglogrepl.TruncateMessage: @@ -771,12 +772,12 @@ func convertRelationMessageToProto(msg *pglogrepl.RelationMessage) *protos.Relat } } -func (p *PostgresCDCSource) auditSchemaDelta(flowJobName string, rec *model.RelationRecord) error { - activityInfo := activity.GetInfo(p.ctx) +func (p *PostgresCDCSource) auditSchemaDelta(ctx context.Context, flowJobName string, rec *model.RelationRecord) error { + activityInfo := activity.GetInfo(ctx) workflowID := activityInfo.WorkflowExecution.ID runID := activityInfo.WorkflowExecution.RunID - _, err := p.catalogPool.Exec(p.ctx, + _, err := p.catalogPool.Exec(ctx, `INSERT INTO peerdb_stats.schema_deltas_audit_log(flow_job_name,workflow_id,run_id,delta_info) VALUES($1,$2,$3,$4)`, @@ -789,6 +790,7 @@ func (p *PostgresCDCSource) auditSchemaDelta(flowJobName string, rec *model.Rela // processRelationMessage processes a RelationMessage and returns a TableSchemaDelta func (p *PostgresCDCSource) processRelationMessage( + ctx context.Context, lsn pglogrepl.LSN, currRel *protos.RelationMessage, ) (model.Record, error) { @@ -849,7 +851,7 @@ func (p *PostgresCDCSource) processRelationMessage( TableSchemaDelta: schemaDelta, CheckpointID: int64(lsn), } - return rec, p.auditSchemaDelta(p.flowJobName, rec) + return rec, p.auditSchemaDelta(ctx, p.flowJobName, rec) } func (p *PostgresCDCSource) recToTablePKey(req *model.PullRecordsRequest, diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index f45ae2e2bf..fce83d11b3 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -96,9 +96,9 @@ const ( ) // getRelIDForTable returns the relation ID for a table. -func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (uint32, error) { +func (c *PostgresConnector) getRelIDForTable(ctx context.Context, schemaTable *utils.SchemaTable) (uint32, error) { var relID pgtype.Uint32 - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, `SELECT c.oid FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname=$1 AND c.relname=$2`, schemaTable.Schema, schemaTable.Table).Scan(&relID) @@ -110,14 +110,14 @@ func (c *PostgresConnector) getRelIDForTable(schemaTable *utils.SchemaTable) (ui } // getReplicaIdentity returns the replica identity for a table. -func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTable) (ReplicaIdentityType, error) { - relID, relIDErr := c.getRelIDForTable(schemaTable) +func (c *PostgresConnector) getReplicaIdentityType(ctx context.Context, schemaTable *utils.SchemaTable) (ReplicaIdentityType, error) { + relID, relIDErr := c.getRelIDForTable(ctx, schemaTable) if relIDErr != nil { return ReplicaIdentityDefault, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, relIDErr) } var replicaIdentity rune - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, `SELECT relreplident FROM pg_class WHERE oid = $1;`, relID).Scan(&replicaIdentity) if err != nil { @@ -135,21 +135,22 @@ func (c *PostgresConnector) getReplicaIdentityType(schemaTable *utils.SchemaTabl // For replica identity 'i'/index, these are the columns in the selected index (indisreplident set) // For replica identity 'f'/full, if there is a primary key we use that, else we return all columns func (c *PostgresConnector) getUniqueColumns( + ctx context.Context, replicaIdentity ReplicaIdentityType, schemaTable *utils.SchemaTable, ) ([]string, error) { - relID, err := c.getRelIDForTable(schemaTable) + relID, err := c.getRelIDForTable(ctx, schemaTable) if err != nil { return nil, fmt.Errorf("failed to get relation id for table %s: %w", schemaTable, err) } if replicaIdentity == ReplicaIdentityIndex { - return c.getReplicaIdentityIndexColumns(relID, schemaTable) + return c.getReplicaIdentityIndexColumns(ctx, relID, schemaTable) } // Find the primary key index OID, for replica identity 'd'/default or 'f'/full var pkIndexOID oid.Oid - err = c.conn.QueryRow(c.ctx, + err = c.conn.QueryRow(ctx, `SELECT indexrelid FROM pg_index WHERE indrelid = $1 AND indisprimary`, relID).Scan(&pkIndexOID) if err != nil { @@ -160,14 +161,18 @@ func (c *PostgresConnector) getUniqueColumns( return nil, fmt.Errorf("error finding primary key index for table %s: %w", schemaTable, err) } - return c.getColumnNamesForIndex(pkIndexOID) + return c.getColumnNamesForIndex(ctx, pkIndexOID) } // getReplicaIdentityIndexColumns returns the columns used in the replica identity index. -func (c *PostgresConnector) getReplicaIdentityIndexColumns(relID uint32, schemaTable *utils.SchemaTable) ([]string, error) { +func (c *PostgresConnector) getReplicaIdentityIndexColumns( + ctx context.Context, + relID uint32, + schemaTable *utils.SchemaTable, +) ([]string, error) { var indexRelID oid.Oid // Fetch the OID of the index used as the replica identity - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, `SELECT indexrelid FROM pg_index WHERE indrelid=$1 AND indisreplident=true`, relID).Scan(&indexRelID) @@ -175,12 +180,12 @@ func (c *PostgresConnector) getReplicaIdentityIndexColumns(relID uint32, schemaT return nil, fmt.Errorf("error finding replica identity index for table %s: %w", schemaTable, err) } - return c.getColumnNamesForIndex(indexRelID) + return c.getColumnNamesForIndex(ctx, indexRelID) } // getColumnNamesForIndex returns the column names for a given index. -func (c *PostgresConnector) getColumnNamesForIndex(indexOID oid.Oid) ([]string, error) { - rows, err := c.conn.Query(c.ctx, +func (c *PostgresConnector) getColumnNamesForIndex(ctx context.Context, indexOID oid.Oid) ([]string, error) { + rows, err := c.conn.Query(ctx, `SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indexrelid = $1 ORDER BY a.attname ASC`, @@ -196,9 +201,9 @@ func (c *PostgresConnector) getColumnNamesForIndex(indexOID oid.Oid) ([]string, return cols, nil } -func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, error) { +func (c *PostgresConnector) tableExists(ctx context.Context, schemaTable *utils.SchemaTable) (bool, error) { var exists pgtype.Bool - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, `SELECT EXISTS ( SELECT FROM pg_tables WHERE schemaname = $1 @@ -215,13 +220,13 @@ func (c *PostgresConnector) tableExists(schemaTable *utils.SchemaTable) (bool, e } // checkSlotAndPublication checks if the replication slot and publication exist. -func (c *PostgresConnector) checkSlotAndPublication(slot string, publication string) (SlotCheckResult, error) { +func (c *PostgresConnector) checkSlotAndPublication(ctx context.Context, slot string, publication string) (SlotCheckResult, error) { slotExists := false publicationExists := false // Check if the replication slot exists var slotName pgtype.Text - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, "SELECT slot_name FROM pg_replication_slots WHERE slot_name = $1", slot).Scan(&slotName) if err != nil { @@ -235,7 +240,7 @@ func (c *PostgresConnector) checkSlotAndPublication(slot string, publication str // Check if the publication exists var pubName pgtype.Text - err = c.conn.QueryRow(c.ctx, + err = c.conn.QueryRow(ctx, "SELECT pubname FROM pg_publication WHERE pubname = $1", publication).Scan(&pubName) if err != nil { @@ -308,12 +313,13 @@ func getSlotInfo(ctx context.Context, conn *pgx.Conn, slotName string, database // GetSlotInfo gets the information about the replication slot size and LSNs // If slotName input is empty, all slot info rows are returned - this is for UI. // Else, only the row pertaining to that slotName will be returned. -func (c *PostgresConnector) GetSlotInfo(slotName string) ([]*protos.SlotInfo, error) { - return getSlotInfo(c.ctx, c.conn, slotName, c.config.Database) +func (c *PostgresConnector) GetSlotInfo(ctx context.Context, slotName string) ([]*protos.SlotInfo, error) { + return getSlotInfo(ctx, c.conn, slotName, c.config.Database) } // createSlotAndPublication creates the replication slot and publication. func (c *PostgresConnector) createSlotAndPublication( + ctx context.Context, signal SlotSignal, s SlotCheckResult, slot string, @@ -337,7 +343,7 @@ func (c *PostgresConnector) createSlotAndPublication( if !s.PublicationExists { // check and enable publish_via_partition_root - supportsPubViaRoot, _, err := c.MajorVersionCheck(POSTGRES_13) + supportsPubViaRoot, _, err := c.MajorVersionCheck(ctx, POSTGRES_13) if err != nil { return fmt.Errorf("error checking Postgres version: %w", err) } @@ -347,7 +353,7 @@ func (c *PostgresConnector) createSlotAndPublication( } // Create the publication to help filter changes only for the given tables stmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s %s", publication, tableNameString, pubViaRootString) - _, err = c.conn.Exec(c.ctx, stmt) + _, err = c.conn.Exec(ctx, stmt) if err != nil { c.logger.Warn(fmt.Sprintf("Error creating publication '%s': %v", publication, err)) return fmt.Errorf("error creating publication '%s' : %w", publication, err) @@ -356,15 +362,15 @@ func (c *PostgresConnector) createSlotAndPublication( // create slot only after we succeeded in creating publication. if !s.SlotExists { - conn, err := c.CreateReplConn(c.ctx) + conn, err := c.CreateReplConn(ctx) if err != nil { return fmt.Errorf("[slot] error acquiring connection: %w", err) } - defer conn.Close(c.ctx) + defer conn.Close(ctx) c.logger.Warn(fmt.Sprintf("Creating replication slot '%s'", slot)) - _, err = conn.Exec(c.ctx, "SET idle_in_transaction_session_timeout = 0") + _, err = conn.Exec(ctx, "SET idle_in_transaction_session_timeout = 0") if err != nil { return fmt.Errorf("[slot] error setting idle_in_transaction_session_timeout: %w", err) } @@ -373,7 +379,7 @@ func (c *PostgresConnector) createSlotAndPublication( Temporary: false, Mode: pglogrepl.LogicalReplication, } - res, err := pglogrepl.CreateReplicationSlot(c.ctx, conn.PgConn(), slot, "pgoutput", opts) + res, err := pglogrepl.CreateReplicationSlot(ctx, conn.PgConn(), slot, "pgoutput", opts) if err != nil { return fmt.Errorf("[slot] error creating replication slot: %w", err) } @@ -405,8 +411,8 @@ func (c *PostgresConnector) createSlotAndPublication( return nil } -func (c *PostgresConnector) createMetadataSchema() error { - _, err := c.conn.Exec(c.ctx, fmt.Sprintf(createSchemaSQL, c.metadataSchema)) +func (c *PostgresConnector) createMetadataSchema(ctx context.Context) error { + _, err := c.conn.Exec(ctx, fmt.Sprintf(createSchemaSQL, c.metadataSchema)) if err != nil && !utils.IsUniqueError(err) { return fmt.Errorf("error while creating internal schema: %w", err) } @@ -460,9 +466,9 @@ func generateCreateTableSQLForNormalizedTable( return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, strings.Join(createTableSQLArray, ",")) } -func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) { +func (c *PostgresConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { var result pgtype.Int8 - err := c.conn.QueryRow(c.ctx, fmt.Sprintf( + err := c.conn.QueryRow(ctx, fmt.Sprintf( getLastSyncBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier, @@ -477,9 +483,9 @@ func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *PostgresConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { +func (c *PostgresConnector) GetLastNormalizeBatchID(ctx context.Context, jobName string) (int64, error) { var result pgtype.Int8 - err := c.conn.QueryRow(c.ctx, fmt.Sprintf( + err := c.conn.QueryRow(ctx, fmt.Sprintf( getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier, @@ -494,9 +500,9 @@ func (c *PostgresConnector) GetLastNormalizeBatchID(jobName string) (int64, erro return result.Int64, nil } -func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) { +func (c *PostgresConnector) jobMetadataExists(ctx context.Context, jobName string) (bool, error) { var result pgtype.Bool - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, fmt.Sprintf(checkIfJobMetadataExistsSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { return false, fmt.Errorf("error reading result row: %w", err) @@ -514,14 +520,14 @@ func majorVersionCheck(ctx context.Context, conn *pgx.Conn, majorVersion PGVersi return version.Int64 >= int64(majorVersion), version.Int64, nil } -func (c *PostgresConnector) MajorVersionCheck(majorVersion PGVersion) (bool, int64, error) { - return majorVersionCheck(c.ctx, c.conn, majorVersion) +func (c *PostgresConnector) MajorVersionCheck(ctx context.Context, majorVersion PGVersion) (bool, int64, error) { + return majorVersionCheck(ctx, c.conn, majorVersion) } -func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64, syncBatchID int64, +func (c *PostgresConnector) updateSyncMetadata(ctx context.Context, flowJobName string, lastCP int64, syncBatchID int64, syncRecordsTx pgx.Tx, ) error { - _, err := syncRecordsTx.Exec(c.ctx, + _, err := syncRecordsTx.Exec(ctx, fmt.Sprintf(upsertJobMetadataForSyncSQL, c.metadataSchema, mirrorJobsTableIdentifier), flowJobName, lastCP, syncBatchID, 0) if err != nil { @@ -531,10 +537,13 @@ func (c *PostgresConnector) updateSyncMetadata(flowJobName string, lastCP int64, return nil } -func (c *PostgresConnector) updateNormalizeMetadata(flowJobName string, normalizeBatchID int64, +func (c *PostgresConnector) updateNormalizeMetadata( + ctx context.Context, + flowJobName string, + normalizeBatchID int64, normalizeRecordsTx pgx.Tx, ) error { - _, err := normalizeRecordsTx.Exec(c.ctx, + _, err := normalizeRecordsTx.Exec(ctx, fmt.Sprintf(updateMetadataForNormalizeRecordsSQL, c.metadataSchema, mirrorJobsTableIdentifier), normalizeBatchID, flowJobName) if err != nil { @@ -544,12 +553,15 @@ func (c *PostgresConnector) updateNormalizeMetadata(flowJobName string, normaliz return nil } -func (c *PostgresConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, +func (c *PostgresConnector) getDistinctTableNamesInBatch( + ctx context.Context, + flowJobName string, + syncBatchID int64, normalizeBatchID int64, ) ([]string, error) { rawTableIdentifier := getRawTableIdentifier(flowJobName) - rows, err := c.conn.Query(c.ctx, fmt.Sprintf(getDistinctDestinationTableNamesSQL, c.metadataSchema, + rows, err := c.conn.Query(ctx, fmt.Sprintf(getDistinctDestinationTableNamesSQL, c.metadataSchema, rawTableIdentifier), normalizeBatchID, syncBatchID) if err != nil { return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err) @@ -562,12 +574,15 @@ func (c *PostgresConnector) getDistinctTableNamesInBatch(flowJobName string, syn return destinationTableNames, nil } -func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, syncBatchID int64, +func (c *PostgresConnector) getTableNametoUnchangedCols( + ctx context.Context, + flowJobName string, + syncBatchID int64, normalizeBatchID int64, ) (map[string][]string, error) { rawTableIdentifier := getRawTableIdentifier(flowJobName) - rows, err := c.conn.Query(c.ctx, fmt.Sprintf(getTableNameToUnchangedToastColsSQL, c.metadataSchema, + rows, err := c.conn.Query(ctx, fmt.Sprintf(getTableNameToUnchangedToastColsSQL, c.metadataSchema, rawTableIdentifier), normalizeBatchID, syncBatchID) if err != nil { return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err) @@ -592,8 +607,8 @@ func (c *PostgresConnector) getTableNametoUnchangedCols(flowJobName string, sync return resultMap, nil } -func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { - row := c.conn.QueryRow(c.ctx, +func (c *PostgresConnector) getCurrentLSN(ctx context.Context) (pglogrepl.LSN, error) { + row := c.conn.QueryRow(ctx, "SELECT CASE WHEN pg_is_in_recovery() THEN pg_last_wal_receive_lsn() ELSE pg_current_wal_lsn() END") var result pgtype.Text err := row.Scan(&result) @@ -607,7 +622,7 @@ func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { return fmt.Sprintf("peerflow_pub_%s", jobName) } -func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName string) error { +func (c *PostgresConnector) CheckSourceTables(ctx context.Context, tableNames []string, pubName string) error { if c.conn == nil { return fmt.Errorf("check tables: conn is nil") } @@ -622,7 +637,7 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin } tableArr = append(tableArr, fmt.Sprintf(`(%s::text, %s::text)`, QuoteLiteral(schemaName), QuoteLiteral(tableName))) - err := c.conn.QueryRow(c.ctx, + err := c.conn.QueryRow(ctx, fmt.Sprintf("SELECT * FROM %s.%s LIMIT 0;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName))).Scan(&row) if err != nil && err != pgx.ErrNoRows { return err @@ -633,7 +648,7 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin tableStr := strings.Join(tableArr, ",") if pubName != "" { var pubTableCount int - err := c.conn.QueryRow(c.ctx, fmt.Sprintf(` + err := c.conn.QueryRow(ctx, fmt.Sprintf(` with source_table_components (sname, tname) as (values %s) select COUNT(DISTINCT(schemaname,tablename)) from pg_publication_tables INNER JOIN source_table_components stc @@ -650,13 +665,13 @@ func (c *PostgresConnector) CheckSourceTables(tableNames []string, pubName strin return nil } -func (c *PostgresConnector) CheckReplicationPermissions(username string) error { +func (c *PostgresConnector) CheckReplicationPermissions(ctx context.Context, username string) error { if c.conn == nil { return fmt.Errorf("check replication permissions: conn is nil") } var replicationRes bool - err := c.conn.QueryRow(c.ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1;", username).Scan(&replicationRes) + err := c.conn.QueryRow(ctx, "SELECT rolreplication FROM pg_roles WHERE rolname = $1", username).Scan(&replicationRes) if err != nil { return err } @@ -664,7 +679,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(username string) error { if !replicationRes { // RDS case: check pg_settings for rds.logical_replication var setting string - err := c.conn.QueryRow(c.ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication';").Scan(&setting) + err := c.conn.QueryRow(ctx, "SELECT setting FROM pg_settings WHERE name = 'rds.logical_replication'").Scan(&setting) if err != nil || setting != "on" { return fmt.Errorf("postgres user does not have replication role") } @@ -672,7 +687,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(username string) error { // check wal_level var walLevel string - err = c.conn.QueryRow(c.ctx, "SHOW wal_level;").Scan(&walLevel) + err = c.conn.QueryRow(ctx, "SHOW wal_level").Scan(&walLevel) if err != nil { return err } @@ -683,7 +698,7 @@ func (c *PostgresConnector) CheckReplicationPermissions(username string) error { // max_wal_senders must be at least 2 var maxWalSendersRes string - err = c.conn.QueryRow(c.ctx, "SHOW max_wal_senders;").Scan(&maxWalSendersRes) + err = c.conn.QueryRow(ctx, "SHOW max_wal_senders").Scan(&maxWalSendersRes) if err != nil { return err } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 57c31f0574..1b0fff15af 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -29,7 +29,6 @@ import ( // PostgresConnector is a Connector implementation for Postgres. type PostgresConnector struct { connStr string - ctx context.Context config *protos.PostgresConfig ssh *SSHTunnel conn *pgx.Conn @@ -81,7 +80,6 @@ func NewPostgresConnector(ctx context.Context, pgConfig *protos.PostgresConfig) return &PostgresConnector{ connStr: connectionString, - ctx: ctx, config: pgConfig, ssh: tunnel, conn: conn, @@ -103,26 +101,26 @@ func (c *PostgresConnector) CreateReplConn(ctx context.Context) (*pgx.Conn, erro } // Close closes all connections. -func (c *PostgresConnector) Close() error { +func (c *PostgresConnector) Close(ctx context.Context) error { if c != nil { - c.conn.Close(c.ctx) + c.conn.Close(ctx) c.ssh.Close() } return nil } // ConnectionActive returns true if the connection is active. -func (c *PostgresConnector) ConnectionActive() error { +func (c *PostgresConnector) ConnectionActive(ctx context.Context) error { if c.conn == nil { return fmt.Errorf("connection is nil") } - pingErr := c.conn.Ping(c.ctx) + pingErr := c.conn.Ping(ctx) return pingErr } // NeedsSetupMetadataTables returns true if the metadata tables need to be set up. -func (c *PostgresConnector) NeedsSetupMetadataTables() bool { - result, err := c.tableExists(&utils.SchemaTable{ +func (c *PostgresConnector) NeedsSetupMetadataTables(ctx context.Context) bool { + result, err := c.tableExists(ctx, &utils.SchemaTable{ Schema: c.metadataSchema, Table: mirrorJobsTableIdentifier, }) @@ -133,13 +131,13 @@ func (c *PostgresConnector) NeedsSetupMetadataTables() bool { } // SetupMetadataTables sets up the metadata tables. -func (c *PostgresConnector) SetupMetadataTables() error { - err := c.createMetadataSchema() +func (c *PostgresConnector) SetupMetadataTables(ctx context.Context) error { + err := c.createMetadataSchema(ctx) if err != nil { return err } - _, err = c.conn.Exec(c.ctx, fmt.Sprintf(createMirrorJobsTableSQL, + _, err = c.conn.Exec(ctx, fmt.Sprintf(createMirrorJobsTableSQL, c.metadataSchema, mirrorJobsTableIdentifier)) if err != nil && !utils.IsUniqueError(err) { return fmt.Errorf("error creating table %s: %w", mirrorJobsTableIdentifier, err) @@ -149,9 +147,9 @@ func (c *PostgresConnector) SetupMetadataTables() error { } // GetLastOffset returns the last synced offset for a job. -func (c *PostgresConnector) GetLastOffset(jobName string) (int64, error) { +func (c *PostgresConnector) GetLastOffset(ctx context.Context, jobName string) (int64, error) { var result pgtype.Int8 - err := c.conn.QueryRow(c.ctx, fmt.Sprintf(getLastOffsetSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) + err := c.conn.QueryRow(ctx, fmt.Sprintf(getLastOffsetSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName).Scan(&result) if err != nil { if err == pgx.ErrNoRows { c.logger.Info("No row found, returning nil") @@ -167,9 +165,9 @@ func (c *PostgresConnector) GetLastOffset(jobName string) (int64, error) { } // SetLastOffset updates the last synced offset for a job. -func (c *PostgresConnector) SetLastOffset(jobName string, lastOffset int64) error { +func (c *PostgresConnector) SetLastOffset(ctx context.Context, jobName string, lastOffset int64) error { _, err := c.conn. - Exec(c.ctx, fmt.Sprintf(setLastOffsetSQL, c.metadataSchema, mirrorJobsTableIdentifier), lastOffset, jobName) + Exec(ctx, fmt.Sprintf(setLastOffsetSQL, c.metadataSchema, mirrorJobsTableIdentifier), lastOffset, jobName) if err != nil { return fmt.Errorf("error setting last offset for job %s: %w", jobName, err) } @@ -178,7 +176,7 @@ func (c *PostgresConnector) SetLastOffset(jobName string, lastOffset int64) erro } // PullRecords pulls records from the source. -func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error { +func (c *PostgresConnector) PullRecords(ctx context.Context, catalogPool *pgxpool.Pool, req *model.PullRecordsRequest) error { defer func() { req.RecordStream.Close() }() @@ -195,7 +193,7 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu } // Check if the replication slot and publication exist - exists, err := c.checkSlotAndPublication(slotName, publicationName) + exists, err := c.checkSlotAndPublication(ctx, slotName, publicationName) if err != nil { return err } @@ -212,14 +210,13 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu c.logger.Info("PullRecords: performed checks for slot and publication") - replConn, err := c.CreateReplConn(c.ctx) + replConn, err := c.CreateReplConn(ctx) if err != nil { return err } - defer replConn.Close(c.ctx) + defer replConn.Close(ctx) - cdc, err := NewPostgresCDCSource(&PostgresCDCConfig{ - AppContext: c.ctx, + cdc, err := NewPostgresCDCSource(ctx, &PostgresCDCConfig{ Connection: replConn, SrcTableIDNameMapping: req.SrcTableIDNameMapping, Slot: slotName, @@ -233,16 +230,16 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu return fmt.Errorf("failed to create cdc source: %w", err) } - err = cdc.PullRecords(req) + err = cdc.PullRecords(ctx, req) if err != nil { return err } - latestLSN, err := c.getCurrentLSN() + latestLSN, err := c.getCurrentLSN(ctx) if err != nil { return fmt.Errorf("failed to get current LSN: %w", err) } - err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, catalogPool, req.FlowJobName, latestLSN) + err = monitoring.UpdateLatestLSNAtSourceForCDCFlow(ctx, catalogPool, req.FlowJobName, latestLSN) if err != nil { return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err) } @@ -251,7 +248,7 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu } // SyncRecords pushes records to the destination. -func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { +func (c *PostgresConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) c.logger.Info(fmt.Sprintf("pushing records to Postgres table %s via COPY", rawTableIdentifier)) @@ -340,23 +337,23 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } } - err := c.ReplayTableSchemaDeltas(req.FlowJobName, req.Records.SchemaDeltas) + err := c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } - syncRecordsTx, err := c.conn.Begin(c.ctx) + syncRecordsTx, err := c.conn.Begin(ctx) if err != nil { return nil, fmt.Errorf("error starting transaction for syncing records: %w", err) } defer func() { - deferErr := syncRecordsTx.Rollback(c.ctx) + deferErr := syncRecordsTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for syncing records", slog.Any("error", err)) } }() - syncedRecordsCount, err := syncRecordsTx.CopyFrom(c.ctx, pgx.Identifier{c.metadataSchema, rawTableIdentifier}, + syncedRecordsCount, err := syncRecordsTx.CopyFrom(ctx, pgx.Identifier{c.metadataSchema, rawTableIdentifier}, []string{ "_peerdb_uid", "_peerdb_timestamp", "_peerdb_destination_table_name", "_peerdb_data", "_peerdb_record_type", "_peerdb_match_data", "_peerdb_batch_id", "_peerdb_unchanged_toast_columns", @@ -379,12 +376,12 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } // updating metadata with new offset and syncBatchID - err = c.updateSyncMetadata(req.FlowJobName, lastCP, req.SyncBatchID, syncRecordsTx) + err = c.updateSyncMetadata(ctx, req.FlowJobName, lastCP, req.SyncBatchID, syncRecordsTx) if err != nil { return nil, err } // transaction commits - err = syncRecordsTx.Commit(c.ctx) + err = syncRecordsTx.Commit(ctx) if err != nil { return nil, err } @@ -398,10 +395,10 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S }, nil } -func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { +func (c *PostgresConnector) NormalizeRecords(ctx context.Context, req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) - jobMetadataExists, err := c.jobMetadataExists(req.FlowJobName) + jobMetadataExists, err := c.jobMetadataExists(ctx, req.FlowJobName) if err != nil { return nil, err } @@ -413,7 +410,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) }, nil } - normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(ctx, req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } @@ -430,28 +427,28 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } destinationTableNames, err := c.getDistinctTableNamesInBatch( - req.FlowJobName, req.SyncBatchID, normBatchID) + ctx, req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, err } - unchangedToastColsMap, err := c.getTableNametoUnchangedCols(req.FlowJobName, + unchangedToastColsMap, err := c.getTableNametoUnchangedCols(ctx, req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, err } - normalizeRecordsTx, err := c.conn.Begin(c.ctx) + normalizeRecordsTx, err := c.conn.Begin(ctx) if err != nil { return nil, fmt.Errorf("error starting transaction for normalizing records: %w", err) } defer func() { - deferErr := normalizeRecordsTx.Rollback(c.ctx) + deferErr := normalizeRecordsTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for normalizing records", slog.Any("error", err)) } }() - supportsMerge, _, err := c.MajorVersionCheck(POSTGRES_15) + supportsMerge, _, err := c.MajorVersionCheck(ctx, POSTGRES_15) if err != nil { return nil, err } @@ -482,7 +479,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } } if mergeStatementsBatch.Len() > 0 { - mergeResults := normalizeRecordsTx.SendBatch(c.ctx, mergeStatementsBatch) + mergeResults := normalizeRecordsTx.SendBatch(ctx, mergeStatementsBatch) err = mergeResults.Close() if err != nil { return nil, fmt.Errorf("error executing merge statements: %w", err) @@ -491,12 +488,12 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) c.logger.Info(fmt.Sprintf("normalized %d records", totalRowsAffected)) // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID, normalizeRecordsTx) + err = c.updateNormalizeMetadata(ctx, req.FlowJobName, req.SyncBatchID, normalizeRecordsTx) if err != nil { return nil, err } // transaction commits - err = normalizeRecordsTx.Commit(c.ctx) + err = normalizeRecordsTx.Commit(ctx) if err != nil { return nil, err } @@ -514,41 +511,41 @@ type SlotCheckResult struct { } // CreateRawTable creates a raw table, implementing the Connector interface. -func (c *PostgresConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { +func (c *PostgresConnector) CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) - err := c.createMetadataSchema() + err := c.createMetadataSchema(ctx) if err != nil { return nil, fmt.Errorf("error creating internal schema: %w", err) } - createRawTableTx, err := c.conn.Begin(c.ctx) + createRawTableTx, err := c.conn.Begin(ctx) if err != nil { return nil, fmt.Errorf("error starting transaction for creating raw table: %w", err) } defer func() { - deferErr := createRawTableTx.Rollback(c.ctx) + deferErr := createRawTableTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for creating raw table.", slog.Any("error", err)) } }() - _, err = createRawTableTx.Exec(c.ctx, fmt.Sprintf(createRawTableSQL, c.metadataSchema, rawTableIdentifier)) + _, err = createRawTableTx.Exec(ctx, fmt.Sprintf(createRawTableSQL, c.metadataSchema, rawTableIdentifier)) if err != nil { return nil, fmt.Errorf("error creating raw table: %w", err) } - _, err = createRawTableTx.Exec(c.ctx, fmt.Sprintf(createRawTableBatchIDIndexSQL, rawTableIdentifier, + _, err = createRawTableTx.Exec(ctx, fmt.Sprintf(createRawTableBatchIDIndexSQL, rawTableIdentifier, c.metadataSchema, rawTableIdentifier)) if err != nil { return nil, fmt.Errorf("error creating batch ID index on raw table: %w", err) } - _, err = createRawTableTx.Exec(c.ctx, fmt.Sprintf(createRawTableDstTableIndexSQL, rawTableIdentifier, + _, err = createRawTableTx.Exec(ctx, fmt.Sprintf(createRawTableDstTableIndexSQL, rawTableIdentifier, c.metadataSchema, rawTableIdentifier)) if err != nil { return nil, fmt.Errorf("error creating destion table index on raw table: %w", err) } - err = createRawTableTx.Commit(c.ctx) + err = createRawTableTx.Commit(ctx) if err != nil { return nil, fmt.Errorf("error committing transaction for creating raw table: %w", err) } @@ -558,16 +555,17 @@ func (c *PostgresConnector) CreateRawTable(req *protos.CreateRawTableInput) (*pr // GetTableSchema returns the schema for a table, implementing the Connector interface. func (c *PostgresConnector) GetTableSchema( + ctx context.Context, req *protos.GetTableSchemaBatchInput, ) (*protos.GetTableSchemaBatchOutput, error) { res := make(map[string]*protos.TableSchema) for _, tableName := range req.TableIdentifiers { - tableSchema, err := c.getTableSchemaForTable(tableName) + tableSchema, err := c.getTableSchemaForTable(ctx, tableName) if err != nil { return nil, err } res[tableName] = tableSchema - utils.RecordHeartbeat(c.ctx, fmt.Sprintf("fetched schema for table %s", tableName)) + utils.RecordHeartbeat(ctx, fmt.Sprintf("fetched schema for table %s", tableName)) c.logger.Info(fmt.Sprintf("fetched schema for table %s", tableName)) } @@ -577,6 +575,7 @@ func (c *PostgresConnector) GetTableSchema( } func (c *PostgresConnector) getTableSchemaForTable( + ctx context.Context, tableName string, ) (*protos.TableSchema, error) { schemaTable, err := utils.ParseSchemaTable(tableName) @@ -584,17 +583,17 @@ func (c *PostgresConnector) getTableSchemaForTable( return nil, err } - replicaIdentityType, err := c.getReplicaIdentityType(schemaTable) + replicaIdentityType, err := c.getReplicaIdentityType(ctx, schemaTable) if err != nil { return nil, fmt.Errorf("[getTableSchema] error getting replica identity for table %s: %w", schemaTable, err) } - pKeyCols, err := c.getUniqueColumns(replicaIdentityType, schemaTable) + pKeyCols, err := c.getUniqueColumns(ctx, replicaIdentityType, schemaTable) if err != nil { return nil, fmt.Errorf("[getTableSchema] error getting primary key column for table %s: %w", schemaTable, err) } // Get the column names and types - rows, err := c.conn.Query(c.ctx, + rows, err := c.conn.Query(ctx, fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, schemaTable.String()), pgx.QueryExecModeSimpleProtocol) if err != nil { @@ -641,18 +640,19 @@ func (c *PostgresConnector) getTableSchemaForTable( } // SetupNormalizedTable sets up a normalized table, implementing the Connector interface. -func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTableBatchInput) ( - *protos.SetupNormalizedTableBatchOutput, error, -) { +func (c *PostgresConnector) SetupNormalizedTables( + ctx context.Context, + req *protos.SetupNormalizedTableBatchInput, +) (*protos.SetupNormalizedTableBatchOutput, error) { tableExistsMapping := make(map[string]bool) // Postgres is cool and supports transactional DDL. So we use a transaction. - createNormalizedTablesTx, err := c.conn.Begin(c.ctx) + createNormalizedTablesTx, err := c.conn.Begin(ctx) if err != nil { return nil, fmt.Errorf("error starting transaction for creating raw table: %w", err) } defer func() { - deferErr := createNormalizedTablesTx.Rollback(c.ctx) + deferErr := createNormalizedTablesTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for creating raw table", slog.Any("error", err)) } @@ -663,7 +663,7 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } - tableAlreadyExists, err := c.tableExists(parsedNormalizedTable) + tableAlreadyExists, err := c.tableExists(ctx, parsedNormalizedTable) if err != nil { return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } @@ -675,17 +675,17 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab // convert the column names and types to Postgres types normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( parsedNormalizedTable.String(), tableSchema, req.SoftDeleteColName, req.SyncedAtColName) - _, err = createNormalizedTablesTx.Exec(c.ctx, normalizedTableCreateSQL) + _, err = createNormalizedTablesTx.Exec(ctx, normalizedTableCreateSQL) if err != nil { return nil, fmt.Errorf("error while creating normalized table: %w", err) } tableExistsMapping[tableIdentifier] = false c.logger.Info(fmt.Sprintf("created table %s", tableIdentifier)) - utils.RecordHeartbeat(c.ctx, fmt.Sprintf("created table %s", tableIdentifier)) + utils.RecordHeartbeat(ctx, fmt.Sprintf("created table %s", tableIdentifier)) } - err = createNormalizedTablesTx.Commit(c.ctx) + err = createNormalizedTablesTx.Commit(ctx) if err != nil { return nil, fmt.Errorf("error committing transaction for creating normalized tables: %w", err) } @@ -698,6 +698,7 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. func (c *PostgresConnector) ReplayTableSchemaDeltas( + ctx context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { @@ -706,13 +707,13 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( } // Postgres is cool and supports transactional DDL. So we use a transaction. - tableSchemaModifyTx, err := c.conn.Begin(c.ctx) + tableSchemaModifyTx, err := c.conn.Begin(ctx) if err != nil { return fmt.Errorf("error starting transaction for schema modification: %w", err) } defer func() { - deferErr := tableSchemaModifyTx.Rollback(c.ctx) + deferErr := tableSchemaModifyTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for table schema modification", slog.Any("error", err)) } @@ -724,7 +725,7 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( } for _, addedColumn := range schemaDelta.AddedColumns { - _, err = tableSchemaModifyTx.Exec(c.ctx, fmt.Sprintf( + _, err = tableSchemaModifyTx.Exec(ctx, fmt.Sprintf( "ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s %s", schemaDelta.DstTableName, QuoteIdentifier(addedColumn.ColumnName), qValueKindToPostgresType(addedColumn.ColumnType))) @@ -740,7 +741,7 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( } } - err = tableSchemaModifyTx.Commit(c.ctx) + err = tableSchemaModifyTx.Commit(ctx) if err != nil { return fmt.Errorf("failed to commit transaction for table schema modification: %w", err) @@ -751,6 +752,7 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( // EnsurePullability ensures that a table is pullable, implementing the Connector interface. func (c *PostgresConnector) EnsurePullability( + ctx context.Context, req *protos.EnsurePullabilityBatchInput, ) (*protos.EnsurePullabilityBatchOutput, error) { tableIdentifierMapping := make(map[string]*protos.PostgresTableIdentifier) @@ -761,7 +763,7 @@ func (c *PostgresConnector) EnsurePullability( } // check if the table exists by getting the relation ID - relID, err := c.getRelIDForTable(schemaTable) + relID, err := c.getRelIDForTable(ctx, schemaTable) if err != nil { return nil, err } @@ -772,16 +774,16 @@ func (c *PostgresConnector) EnsurePullability( if !req.CheckConstraints { msg := fmt.Sprintf("[no-constraints] ensured pullability table %s", tableName) - utils.RecordHeartbeat(c.ctx, msg) + utils.RecordHeartbeat(ctx, msg) continue } - replicaIdentity, replErr := c.getReplicaIdentityType(schemaTable) + replicaIdentity, replErr := c.getReplicaIdentityType(ctx, schemaTable) if replErr != nil { return nil, fmt.Errorf("error getting replica identity for table %s: %w", schemaTable, replErr) } - pKeyCols, err := c.getUniqueColumns(replicaIdentity, schemaTable) + pKeyCols, err := c.getUniqueColumns(ctx, replicaIdentity, schemaTable) if err != nil { return nil, fmt.Errorf("error getting primary key column for table %s: %w", schemaTable, err) } @@ -792,14 +794,14 @@ func (c *PostgresConnector) EnsurePullability( return nil, fmt.Errorf("table %s has no primary keys and does not have REPLICA IDENTITY FULL", schemaTable) } - utils.RecordHeartbeat(c.ctx, fmt.Sprintf("ensured pullability table %s", tableName)) + utils.RecordHeartbeat(ctx, fmt.Sprintf("ensured pullability table %s", tableName)) } return &protos.EnsurePullabilityBatchOutput{TableIdentifierMapping: tableIdentifierMapping}, nil } // SetupReplication sets up replication for the source connector. -func (c *PostgresConnector) SetupReplication(signal SlotSignal, req *protos.SetupReplicationInput) error { +func (c *PostgresConnector) SetupReplication(ctx context.Context, signal SlotSignal, req *protos.SetupReplicationInput) error { // ensure that the flowjob name is [a-z0-9_] only reg := regexp.MustCompile(`^[a-z0-9_]+$`) if !reg.MatchString(req.FlowJobName) { @@ -818,7 +820,7 @@ func (c *PostgresConnector) SetupReplication(signal SlotSignal, req *protos.Setu } // Check if the replication slot and publication exist - exists, err := c.checkSlotAndPublication(slotName, publicationName) + exists, err := c.checkSlotAndPublication(ctx, slotName, publicationName) if err != nil { return err } @@ -831,7 +833,7 @@ func (c *PostgresConnector) SetupReplication(signal SlotSignal, req *protos.Setu } } // Create the replication slot and publication - err = c.createSlotAndPublication(signal, exists, + err = c.createSlotAndPublication(ctx, signal, exists, slotName, publicationName, tableNameMapping, req.DoInitialSnapshot) if err != nil { return fmt.Errorf("error creating replication slot and publication: %w", err) @@ -840,35 +842,35 @@ func (c *PostgresConnector) SetupReplication(signal SlotSignal, req *protos.Setu return nil } -func (c *PostgresConnector) PullFlowCleanup(jobName string) error { +func (c *PostgresConnector) PullFlowCleanup(ctx context.Context, jobName string) error { // Slotname would be the job name prefixed with "peerflow_slot_" slotName := fmt.Sprintf("peerflow_slot_%s", jobName) publicationName := c.getDefaultPublicationName(jobName) - pullFlowCleanupTx, err := c.conn.Begin(c.ctx) + pullFlowCleanupTx, err := c.conn.Begin(ctx) if err != nil { return fmt.Errorf("error starting transaction for flow cleanup: %w", err) } defer func() { - deferErr := pullFlowCleanupTx.Rollback(c.ctx) + deferErr := pullFlowCleanupTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for flow cleanup", slog.Any("error", err)) } }() - _, err = pullFlowCleanupTx.Exec(c.ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", publicationName)) + _, err = pullFlowCleanupTx.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", publicationName)) if err != nil { return fmt.Errorf("error dropping publication: %w", err) } - _, err = pullFlowCleanupTx.Exec(c.ctx, `SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots + _, err = pullFlowCleanupTx.Exec(ctx, `SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name=$1`, slotName) if err != nil { return fmt.Errorf("error dropping replication slot: %w", err) } - err = pullFlowCleanupTx.Commit(c.ctx) + err = pullFlowCleanupTx.Commit(ctx) if err != nil { return fmt.Errorf("error committing transaction for flow cleanup: %w", err) } @@ -876,29 +878,29 @@ func (c *PostgresConnector) PullFlowCleanup(jobName string) error { return nil } -func (c *PostgresConnector) SyncFlowCleanup(jobName string) error { - syncFlowCleanupTx, err := c.conn.Begin(c.ctx) +func (c *PostgresConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { + syncFlowCleanupTx, err := c.conn.Begin(ctx) if err != nil { return fmt.Errorf("unable to begin transaction for sync flow cleanup: %w", err) } defer func() { - deferErr := syncFlowCleanupTx.Rollback(c.ctx) + deferErr := syncFlowCleanupTx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error while rolling back transaction for flow cleanup", slog.Any("error", deferErr)) } }() - _, err = syncFlowCleanupTx.Exec(c.ctx, fmt.Sprintf(dropTableIfExistsSQL, c.metadataSchema, + _, err = syncFlowCleanupTx.Exec(ctx, fmt.Sprintf(dropTableIfExistsSQL, c.metadataSchema, getRawTableIdentifier(jobName))) if err != nil { return fmt.Errorf("unable to drop raw table: %w", err) } - _, err = syncFlowCleanupTx.Exec(c.ctx, + _, err = syncFlowCleanupTx.Exec(ctx, fmt.Sprintf(deleteJobMetadataSQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) if err != nil { return fmt.Errorf("unable to delete job metadata: %w", err) } - err = syncFlowCleanupTx.Commit(c.ctx) + err = syncFlowCleanupTx.Commit(ctx) if err != nil { return fmt.Errorf("unable to commit transaction for sync flow cleanup: %w", err) } @@ -982,7 +984,7 @@ func getOpenConnectionsForUser(ctx context.Context, conn *pgx.Conn, user string) }, nil } -func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublicationInput) error { +func (c *PostgresConnector) AddTablesToPublication(ctx context.Context, req *protos.AddTablesToPublicationInput) error { // don't modify custom publications if req == nil || len(req.AdditionalTables) == 0 { return nil @@ -995,7 +997,7 @@ func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublic // just check if we have all the tables already in the publication for custom publications if req.PublicationName != "" { - rows, err := c.conn.Query(c.ctx, + rows, err := c.conn.Query(ctx, "SELECT tablename FROM pg_publication_tables WHERE pubname=$1", req.PublicationName) if err != nil { return fmt.Errorf("failed to check tables in publication: %w", err) @@ -1016,7 +1018,7 @@ func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublic if err != nil { return err } - _, err = c.conn.Exec(c.ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s", + _, err = c.conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s", utils.QuoteIdentifier(c.getDefaultPublicationName(req.FlowJobName)), schemaTable.String())) // don't error out if table is already added to our publication diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index bad9a7a7a2..6f40501d74 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -63,7 +63,7 @@ func (s PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY)", tableName)) require.NoError(s.t, err) - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.DeltaAddedColumn{{ @@ -73,7 +73,7 @@ func (s PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -116,14 +116,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { } } - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -151,14 +151,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -186,14 +186,14 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -216,9 +216,9 @@ func TestPostgresSchemaDeltaTestSuite(t *testing.T) { err = teardownTx.Commit(context.Background()) require.NoError(s.t, err) - require.NoError(s.t, s.connector.ConnectionActive()) - err = s.connector.Close() + require.NoError(s.t, s.connector.ConnectionActive(context.Background())) + err = s.connector.Close(context.Background()) require.NoError(s.t, err) - require.Error(s.t, s.connector.ConnectionActive()) + require.Error(s.t, s.connector.ConnectionActive(context.Background())) }) } diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index f53bcbe15f..d07c043c3b 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -2,6 +2,7 @@ package connpostgres import ( "bytes" + "context" "fmt" "log/slog" "strconv" @@ -22,6 +23,7 @@ import ( const qRepMetadataTableName = "_peerdb_query_replication_metadata" func (c *PostgresConnector) GetQRepPartitions( + ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition, ) ([]*protos.QRepPartition, error) { @@ -36,7 +38,7 @@ func (c *PostgresConnector) GetQRepPartitions( } // begin a transaction - tx, err := c.conn.BeginTx(c.ctx, pgx.TxOptions{ + tx, err := c.conn.BeginTx(ctx, pgx.TxOptions{ AccessMode: pgx.ReadOnly, IsoLevel: pgx.RepeatableRead, }) @@ -44,13 +46,13 @@ func (c *PostgresConnector) GetQRepPartitions( return nil, fmt.Errorf("failed to begin transaction: %w", err) } defer func() { - deferErr := tx.Rollback(c.ctx) + deferErr := tx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for get partitions", slog.Any("error", deferErr)) } }() - err = c.setTransactionSnapshot(tx) + err = c.setTransactionSnapshot(ctx, tx) if err != nil { return nil, fmt.Errorf("failed to set transaction snapshot: %w", err) } @@ -63,13 +65,13 @@ func (c *PostgresConnector) GetQRepPartitions( // log.Warnf("failed to lock table %s: %v", config.WatermarkTable, err) // } - return c.getNumRowsPartitions(tx, config, last) + return c.getNumRowsPartitions(ctx, tx, config, last) } -func (c *PostgresConnector) setTransactionSnapshot(tx pgx.Tx) error { +func (c *PostgresConnector) setTransactionSnapshot(ctx context.Context, tx pgx.Tx) error { snapshot := c.config.TransactionSnapshot if snapshot != "" { - if _, err := tx.Exec(c.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(snapshot))); err != nil { + if _, err := tx.Exec(ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(snapshot))); err != nil { return fmt.Errorf("failed to set transaction snapshot: %w", err) } } @@ -78,6 +80,7 @@ func (c *PostgresConnector) setTransactionSnapshot(tx pgx.Tx) error { } func (c *PostgresConnector) getNumRowsPartitions( + ctx context.Context, tx pgx.Tx, config *protos.QRepConfig, last *protos.QRepPartition, @@ -108,9 +111,9 @@ func (c *PostgresConnector) getNumRowsPartitions( minVal = lastRange.TimestampRange.End.AsTime() } - row = tx.QueryRow(c.ctx, countQuery, minVal) + row = tx.QueryRow(ctx, countQuery, minVal) } else { - row = tx.QueryRow(c.ctx, countQuery) + row = tx.QueryRow(ctx, countQuery) } var totalRows pgtype.Int8 @@ -148,7 +151,7 @@ func (c *PostgresConnector) getNumRowsPartitions( parsedWatermarkTable.String(), ) c.logger.Info(fmt.Sprintf("[row_based_next] partitions query: %s", partitionsQuery)) - rows, err = tx.Query(c.ctx, partitionsQuery, minVal) + rows, err = tx.Query(ctx, partitionsQuery, minVal) } else { partitionsQuery := fmt.Sprintf( `SELECT bucket, MIN(%[2]s) AS start, MAX(%[2]s) AS end @@ -163,7 +166,7 @@ func (c *PostgresConnector) getNumRowsPartitions( parsedWatermarkTable.String(), ) c.logger.Info(fmt.Sprintf("[row_based] partitions query: %s", partitionsQuery)) - rows, err = tx.Query(c.ctx, partitionsQuery) + rows, err = tx.Query(ctx, partitionsQuery) } if err != nil { c.logger.Error(fmt.Sprintf("failed to query for partitions: %v", err)) @@ -190,7 +193,7 @@ func (c *PostgresConnector) getNumRowsPartitions( return nil, fmt.Errorf("failed to read rows: %w", err) } - err = tx.Commit(c.ctx) + err = tx.Commit(ctx) if err != nil { return nil, fmt.Errorf("failed to commit transaction: %w", err) } @@ -199,6 +202,7 @@ func (c *PostgresConnector) getNumRowsPartitions( } func (c *PostgresConnector) getMinMaxValues( + ctx context.Context, tx pgx.Tx, config *protos.QRepConfig, last *protos.QRepPartition, @@ -213,7 +217,7 @@ func (c *PostgresConnector) getMinMaxValues( // Get the maximum value from the database maxQuery := fmt.Sprintf("SELECT MAX(%[1]s) FROM %[2]s", quotedWatermarkColumn, parsedWatermarkTable.String()) - row := tx.QueryRow(c.ctx, maxQuery) + row := tx.QueryRow(ctx, maxQuery) if err := row.Scan(&maxValue); err != nil { return nil, nil, fmt.Errorf("failed to query for max value: %w", err) } @@ -241,7 +245,7 @@ func (c *PostgresConnector) getMinMaxValues( } else { // Otherwise get the minimum value from the database minQuery := fmt.Sprintf("SELECT MIN(%[1]s) FROM %[2]s", quotedWatermarkColumn, parsedWatermarkTable.String()) - row := tx.QueryRow(c.ctx, minQuery) + row := tx.QueryRow(ctx, minQuery) if err := row.Scan(&minValue); err != nil { c.logger.Error(fmt.Sprintf("failed to query [%s] for min value: %v", minQuery, err)) return nil, nil, fmt.Errorf("failed to query for min value: %w", err) @@ -266,7 +270,7 @@ func (c *PostgresConnector) getMinMaxValues( } } - err = tx.Commit(c.ctx) + err = tx.Commit(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to commit transaction: %w", err) } @@ -274,21 +278,23 @@ func (c *PostgresConnector) getMinMaxValues( return minValue, maxValue, nil } -func (c *PostgresConnector) CheckForUpdatedMaxValue(config *protos.QRepConfig, +func (c *PostgresConnector) CheckForUpdatedMaxValue( + ctx context.Context, + config *protos.QRepConfig, last *protos.QRepPartition, ) (bool, error) { - tx, err := c.conn.Begin(c.ctx) + tx, err := c.conn.Begin(ctx) if err != nil { return false, fmt.Errorf("unable to begin transaction for getting max value: %w", err) } defer func() { - deferErr := tx.Rollback(c.ctx) + deferErr := tx.Rollback(ctx) if deferErr != pgx.ErrTxClosed && deferErr != nil { c.logger.Error("error rolling back transaction for getting max value", slog.Any("error", err)) } }() - _, maxValue, err := c.getMinMaxValues(tx, config, last) + _, maxValue, err := c.getMinMaxValues(ctx, tx, config, last) if err != nil { return false, fmt.Errorf("error while getting min and max values: %w", err) } @@ -310,20 +316,21 @@ func (c *PostgresConnector) CheckForUpdatedMaxValue(config *protos.QRepConfig, } func (c *PostgresConnector) PullQRepRecords( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, ) (*model.QRecordBatch, error) { partitionIdLog := slog.String(string(shared.PartitionIDKey), partition.PartitionId) if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) - executor, err := NewQRepQueryExecutorSnapshot( - c.conn, c.ctx, c.config.TransactionSnapshot, + executor, err := NewQRepQueryExecutorSnapshot(ctx, + c.conn, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return nil, err } query := config.Query - return executor.ExecuteAndProcessQuery(query) + return executor.ExecuteAndProcessQuery(ctx, query) } var rangeStart interface{} @@ -361,14 +368,13 @@ func (c *PostgresConnector) PullQRepRecords( } executor, err := NewQRepQueryExecutorSnapshot( - c.conn, c.ctx, c.config.TransactionSnapshot, + ctx, c.conn, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return nil, err } - records, err := executor.ExecuteAndProcessQuery(query, - rangeStart, rangeEnd) + records, err := executor.ExecuteAndProcessQuery(ctx, query, rangeStart, rangeEnd) if err != nil { return nil, err } @@ -377,6 +383,7 @@ func (c *PostgresConnector) PullQRepRecords( } func (c *PostgresConnector) PullQRepRecordStream( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -385,14 +392,14 @@ func (c *PostgresConnector) PullQRepRecordStream( if partition.FullTablePartition { c.logger.Info("pulling full table partition", partitionIdLog) executor, err := NewQRepQueryExecutorSnapshot( - c.conn, c.ctx, c.config.TransactionSnapshot, + ctx, c.conn, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return 0, err } query := config.Query - _, err = executor.ExecuteAndProcessQueryStream(stream, query) + _, err = executor.ExecuteAndProcessQueryStream(ctx, stream, query) return 0, err } c.logger.Info("Obtained ranges for partition for PullQRepStream", partitionIdLog) @@ -431,13 +438,13 @@ func (c *PostgresConnector) PullQRepRecordStream( } executor, err := NewQRepQueryExecutorSnapshot( - c.conn, c.ctx, c.config.TransactionSnapshot, + ctx, c.conn, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return 0, err } - numRecords, err := executor.ExecuteAndProcessQueryStream(stream, query, rangeStart, rangeEnd) + numRecords, err := executor.ExecuteAndProcessQueryStream(ctx, stream, query, rangeStart, rangeEnd) if err != nil { return 0, err } @@ -447,6 +454,7 @@ func (c *PostgresConnector) PullQRepRecordStream( } func (c *PostgresConnector) SyncQRepRecords( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -456,7 +464,7 @@ func (c *PostgresConnector) SyncQRepRecords( return 0, fmt.Errorf("failed to parse destination table identifier: %w", err) } - exists, err := c.tableExists(dstTable) + exists, err := c.tableExists(ctx, dstTable) if err != nil { return 0, fmt.Errorf("failed to check if table exists: %w", err) } @@ -465,7 +473,7 @@ func (c *PostgresConnector) SyncQRepRecords( return 0, fmt.Errorf("table %s does not exist, used schema: %s", dstTable.Table, dstTable.Schema) } - done, err := c.isPartitionSynced(partition.PartitionId) + done, err := c.isPartitionSynced(ctx, partition.PartitionId) if err != nil { return 0, fmt.Errorf("failed to check if partition is synced: %w", err) } @@ -477,14 +485,14 @@ func (c *PostgresConnector) SyncQRepRecords( c.logger.Info("SyncRecords called and initial checks complete.") stagingTableSync := &QRepStagingTableSync{connector: c} - return stagingTableSync.SyncQRepRecords(c.ctx, + return stagingTableSync.SyncQRepRecords(ctx, config.FlowJobName, dstTable, partition, stream, config.WriteMode, config.SyncedAtColName) } // SetupQRepMetadataTables function for postgres connector -func (c *PostgresConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error { - err := c.createMetadataSchema() +func (c *PostgresConnector) SetupQRepMetadataTables(ctx context.Context, config *protos.QRepConfig) error { + err := c.createMetadataSchema(ctx) if err != nil { return fmt.Errorf("error creating metadata schema: %w", err) } @@ -498,7 +506,7 @@ func (c *PostgresConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e syncFinishTime TIMESTAMP DEFAULT NOW() )`, metadataTableIdentifier.Sanitize()) // execute create table query - _, err = c.conn.Exec(c.ctx, createQRepMetadataTableSQL) + _, err = c.conn.Exec(ctx, createQRepMetadataTableSQL) if err != nil && !utils.IsUniqueError(err) { return fmt.Errorf("failed to create table %s: %w", qRepMetadataTableName, err) } @@ -506,7 +514,7 @@ func (c *PostgresConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e if config.WriteMode != nil && config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE { - _, err = c.conn.Exec(c.ctx, + _, err = c.conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", config.DestinationTableIdentifier)) if err != nil { return fmt.Errorf("failed to TRUNCATE table before query replication: %w", err) @@ -517,6 +525,7 @@ func (c *PostgresConnector) SetupQRepMetadataTables(config *protos.QRepConfig) e } func (c *PostgresConnector) PullXminRecordStream( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -530,7 +539,7 @@ func (c *PostgresConnector) PullXminRecordStream( } executor, err := NewQRepQueryExecutorSnapshot( - c.conn, c.ctx, c.config.TransactionSnapshot, + ctx, c.conn, c.config.TransactionSnapshot, config.FlowJobName, partition.PartitionId) if err != nil { return 0, currentSnapshotXmin, err @@ -538,9 +547,18 @@ func (c *PostgresConnector) PullXminRecordStream( var numRecords int if partition.Range != nil { - numRecords, currentSnapshotXmin, err = executor.ExecuteAndProcessQueryStreamGettingCurrentSnapshotXmin(stream, query, oldxid) + numRecords, currentSnapshotXmin, err = executor.ExecuteAndProcessQueryStreamGettingCurrentSnapshotXmin( + ctx, + stream, + query, + oldxid, + ) } else { - numRecords, currentSnapshotXmin, err = executor.ExecuteAndProcessQueryStreamGettingCurrentSnapshotXmin(stream, query) + numRecords, currentSnapshotXmin, err = executor.ExecuteAndProcessQueryStreamGettingCurrentSnapshotXmin( + ctx, + stream, + query, + ) } if err != nil { return 0, currentSnapshotXmin, err @@ -574,7 +592,7 @@ func BuildQuery(query string, flowJobName string) (string, error) { } // isPartitionSynced checks whether a specific partition is synced -func (c *PostgresConnector) isPartitionSynced(partitionID string) (bool, error) { +func (c *PostgresConnector) isPartitionSynced(ctx context.Context, partitionID string) (bool, error) { // setup the query string metadataTableIdentifier := pgx.Identifier{c.metadataSchema, qRepMetadataTableName} queryString := fmt.Sprintf( @@ -584,7 +602,7 @@ func (c *PostgresConnector) isPartitionSynced(partitionID string) (bool, error) // prepare and execute the query var result bool - err := c.conn.QueryRow(c.ctx, queryString, partitionID).Scan(&result) + err := c.conn.QueryRow(ctx, queryString, partitionID).Scan(&result) if err != nil { return false, fmt.Errorf("failed to execute query: %w", err) } diff --git a/flow/connectors/postgres/qrep_bench_test.go b/flow/connectors/postgres/qrep_bench_test.go index 3848171e83..e8f514bc38 100644 --- a/flow/connectors/postgres/qrep_bench_test.go +++ b/flow/connectors/postgres/qrep_bench_test.go @@ -30,7 +30,7 @@ func BenchmarkQRepQueryExecutor(b *testing.B) { b.Logf("iteration %d", i) // Execute the query and process the rows - _, err := qe.ExecuteAndProcessQuery(query) + _, err := qe.ExecuteAndProcessQuery(ctx, query) if err != nil { b.Fatalf("failed to execute query: %v", err) } diff --git a/flow/connectors/postgres/qrep_partition_test.go b/flow/connectors/postgres/qrep_partition_test.go index 0ad98c25ae..0567136bb3 100644 --- a/flow/connectors/postgres/qrep_partition_test.go +++ b/flow/connectors/postgres/qrep_partition_test.go @@ -170,13 +170,12 @@ func TestGetQRepPartitions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { c := &PostgresConnector{ connStr: connStr, - ctx: context.Background(), config: &protos.PostgresConfig{}, conn: conn, logger: log.NewStructuredLogger(slog.With(slog.String(string(shared.FlowNameKey), "testGetQRepPartitions"))), } - got, err := c.GetQRepPartitions(tc.config, tc.last) + got, err := c.GetQRepPartitions(context.Background(), tc.config, tc.last) if (err != nil) != tc.wantErr { t.Fatalf("GetQRepPartitions() error = %v, wantErr %v", err, tc.wantErr) } diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index e1cebcc95b..54294f3537 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -22,7 +22,6 @@ import ( type QRepQueryExecutor struct { conn *pgx.Conn - ctx context.Context snapshot string testEnv bool flowJobName string @@ -36,7 +35,6 @@ func NewQRepQueryExecutor(conn *pgx.Conn, ctx context.Context, ) *QRepQueryExecutor { return &QRepQueryExecutor{ conn: conn, - ctx: ctx, snapshot: "", flowJobName: flowJobName, partitionID: partitionID, @@ -47,7 +45,7 @@ func NewQRepQueryExecutor(conn *pgx.Conn, ctx context.Context, } } -func NewQRepQueryExecutorSnapshot(conn *pgx.Conn, ctx context.Context, snapshot string, +func NewQRepQueryExecutorSnapshot(ctx context.Context, conn *pgx.Conn, snapshot string, flowJobName string, partitionID string, ) (*QRepQueryExecutor, error) { CustomTypeMap, err := utils.GetCustomDataTypes(ctx, conn) @@ -56,7 +54,6 @@ func NewQRepQueryExecutorSnapshot(conn *pgx.Conn, ctx context.Context, snapshot } return &QRepQueryExecutor{ conn: conn, - ctx: ctx, snapshot: snapshot, flowJobName: flowJobName, partitionID: partitionID, @@ -72,8 +69,8 @@ func (qe *QRepQueryExecutor) SetTestEnv(testEnv bool) { qe.testEnv = testEnv } -func (qe *QRepQueryExecutor) ExecuteQuery(query string, args ...interface{}) (pgx.Rows, error) { - rows, err := qe.conn.Query(qe.ctx, query, args...) +func (qe *QRepQueryExecutor) ExecuteQuery(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) { + rows, err := qe.conn.Query(ctx, query, args...) if err != nil { qe.logger.Error("[pg_query_executor] failed to execute query", slog.Any("error", err)) return nil, err @@ -81,19 +78,19 @@ func (qe *QRepQueryExecutor) ExecuteQuery(query string, args ...interface{}) (pg return rows, nil } -func (qe *QRepQueryExecutor) executeQueryInTx(tx pgx.Tx, cursorName string, fetchSize int) (pgx.Rows, error) { +func (qe *QRepQueryExecutor) executeQueryInTx(ctx context.Context, tx pgx.Tx, cursorName string, fetchSize int) (pgx.Rows, error) { qe.logger.Info("Executing query in transaction") q := fmt.Sprintf("FETCH %d FROM %s", fetchSize, cursorName) if !qe.testEnv { - shutdown := utils.HeartbeatRoutine(qe.ctx, func() string { + shutdown := utils.HeartbeatRoutine(ctx, func() string { qe.logger.Info(fmt.Sprintf("still running '%s'...", q)) return fmt.Sprintf("running '%s'", q) }) defer shutdown() } - rows, err := tx.Query(qe.ctx, q) + rows, err := tx.Query(ctx, q) if err != nil { qe.logger.Error("[pg_query_executor] failed to execute query in tx", slog.Any("error", err)) return nil, err @@ -172,6 +169,7 @@ func (qe *QRepQueryExecutor) ProcessRows( } func (qe *QRepQueryExecutor) processRowsStream( + ctx context.Context, cursorName string, stream *model.QRecordStream, rows pgx.Rows, @@ -180,8 +178,6 @@ func (qe *QRepQueryExecutor) processRowsStream( numRows := 0 const heartBeatNumRows = 5000 - ctx := qe.ctx - // Iterate over the rows for rows.Next() { select { @@ -205,35 +201,36 @@ func (qe *QRepQueryExecutor) processRowsStream( } if numRows%heartBeatNumRows == 0 { - qe.recordHeartbeat("cursor: %s - fetched %d records", cursorName, numRows) + qe.recordHeartbeat(ctx, "cursor: %s - fetched %d records", cursorName, numRows) } numRows++ } } - qe.recordHeartbeat("cursor %s - fetch completed - %d records", cursorName, numRows) + qe.recordHeartbeat(ctx, "cursor %s - fetch completed - %d records", cursorName, numRows) qe.logger.Info("processed row stream") return numRows, nil } -func (qe *QRepQueryExecutor) recordHeartbeat(x string, args ...interface{}) { +func (qe *QRepQueryExecutor) recordHeartbeat(ctx context.Context, x string, args ...interface{}) { if qe.testEnv { qe.logger.Info(fmt.Sprintf(x, args...)) return } msg := fmt.Sprintf(x, args...) - activity.RecordHeartbeat(qe.ctx, msg) + activity.RecordHeartbeat(ctx, msg) } func (qe *QRepQueryExecutor) processFetchedRows( + ctx context.Context, query string, tx pgx.Tx, cursorName string, fetchSize int, stream *model.QRecordStream, ) (int, error) { - rows, err := qe.executeQueryInTx(tx, cursorName, fetchSize) + rows, err := qe.executeQueryInTx(ctx, tx, cursorName, fetchSize) if err != nil { stream.Records <- model.QRecordOrError{ Err: err, @@ -251,7 +248,7 @@ func (qe *QRepQueryExecutor) processFetchedRows( _ = stream.SetSchema(schema) } - numRows, err := qe.processRowsStream(cursorName, stream, rows, fieldDescriptions) + numRows, err := qe.processRowsStream(ctx, cursorName, stream, rows, fieldDescriptions) if err != nil { qe.logger.Error("[pg_query_executor] failed to process rows", slog.Any("error", err)) return 0, fmt.Errorf("failed to process rows: %w", err) @@ -270,6 +267,7 @@ func (qe *QRepQueryExecutor) processFetchedRows( } func (qe *QRepQueryExecutor) ExecuteAndProcessQuery( + ctx context.Context, query string, args ...interface{}, ) (*model.QRecordBatch, error) { @@ -280,7 +278,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQuery( // must wait on errors to close before returning to maintain qe.conn exclusion go func() { defer close(errors) - _, err := qe.ExecuteAndProcessQueryStream(stream, query, args...) + _, err := qe.ExecuteAndProcessQueryStream(ctx, stream, query, args...) if err != nil { qe.logger.Error("[pg_query_executor] failed to execute and process query stream", slog.Any("error", err)) errors <- err @@ -314,6 +312,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQuery( } func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStream( + ctx context.Context, stream *model.QRecordStream, query string, args ...interface{}, @@ -321,7 +320,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStream( qe.logger.Info("Executing and processing query stream", slog.String("query", query)) defer close(stream.Records) - tx, err := qe.conn.BeginTx(qe.ctx, pgx.TxOptions{ + tx, err := qe.conn.BeginTx(ctx, pgx.TxOptions{ AccessMode: pgx.ReadOnly, IsoLevel: pgx.RepeatableRead, }) @@ -330,11 +329,12 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStream( return 0, fmt.Errorf("[pg_query_executor] failed to begin transaction: %w", err) } - totalRecordsFetched, err := qe.ExecuteAndProcessQueryStreamWithTx(tx, stream, query, args...) + totalRecordsFetched, err := qe.ExecuteAndProcessQueryStreamWithTx(ctx, tx, stream, query, args...) return totalRecordsFetched, err } func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamGettingCurrentSnapshotXmin( + ctx context.Context, stream *model.QRecordStream, query string, args ...interface{}, @@ -343,7 +343,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamGettingCurrentSnapshotX qe.logger.Info("Executing and processing query stream", slog.String("query", query)) defer close(stream.Records) - tx, err := qe.conn.BeginTx(qe.ctx, pgx.TxOptions{ + tx, err := qe.conn.BeginTx(ctx, pgx.TxOptions{ AccessMode: pgx.ReadOnly, IsoLevel: pgx.RepeatableRead, }) @@ -352,17 +352,18 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamGettingCurrentSnapshotX return 0, currentSnapshotXmin.Int64, fmt.Errorf("[pg_query_executor] failed to begin transaction: %w", err) } - err = tx.QueryRow(qe.ctx, "select txid_snapshot_xmin(txid_current_snapshot())").Scan(¤tSnapshotXmin) + err = tx.QueryRow(ctx, "select txid_snapshot_xmin(txid_current_snapshot())").Scan(¤tSnapshotXmin) if err != nil { qe.logger.Error("[pg_query_executor] failed to get current snapshot xmin", slog.Any("error", err)) return 0, currentSnapshotXmin.Int64, err } - totalRecordsFetched, err := qe.ExecuteAndProcessQueryStreamWithTx(tx, stream, query, args...) + totalRecordsFetched, err := qe.ExecuteAndProcessQueryStreamWithTx(ctx, tx, stream, query, args...) return totalRecordsFetched, currentSnapshotXmin.Int64, err } func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( + ctx context.Context, tx pgx.Tx, stream *model.QRecordStream, query string, @@ -371,14 +372,14 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( var err error defer func() { - err := tx.Rollback(qe.ctx) + err := tx.Rollback(ctx) if err != nil && err != pgx.ErrTxClosed { qe.logger.Error("[pg_query_executor] failed to rollback transaction", slog.Any("error", err)) } }() if qe.snapshot != "" { - _, err = tx.Exec(qe.ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(qe.snapshot))) + _, err = tx.Exec(ctx, fmt.Sprintf("SET TRANSACTION SNAPSHOT %s", QuoteLiteral(qe.snapshot))) if err != nil { stream.Records <- model.QRecordOrError{ Err: fmt.Errorf("failed to set snapshot: %w", err), @@ -402,7 +403,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( fetchSize := shared.FetchAndChannelSize cursorQuery := fmt.Sprintf("DECLARE %s CURSOR FOR %s", cursorName, query) qe.logger.Info(fmt.Sprintf("[pg_query_executor] executing cursor declaration for %v with args %v", cursorQuery, args)) - _, err = tx.Exec(qe.ctx, cursorQuery, args...) + _, err = tx.Exec(ctx, cursorQuery, args...) if err != nil { stream.Records <- model.QRecordOrError{ Err: fmt.Errorf("failed to declare cursor: %w", err), @@ -417,7 +418,7 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( totalRecordsFetched := 0 numFetchOpsComplete := 0 for { - numRows, err := qe.processFetchedRows(query, tx, cursorName, fetchSize, stream) + numRows, err := qe.processFetchedRows(ctx, query, tx, cursorName, fetchSize, stream) if err != nil { qe.logger.Error("[pg_query_executor] failed to process fetched rows", slog.Any("error", err)) return 0, err @@ -430,12 +431,12 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQueryStreamWithTx( break } - numFetchOpsComplete++ - qe.recordHeartbeat("#%d fetched %d rows", numFetchOpsComplete, numRows) + numFetchOpsComplete += 1 + qe.recordHeartbeat(ctx, "#%d fetched %d rows", numFetchOpsComplete, numRows) } qe.logger.Info("Committing transaction") - err = tx.Commit(qe.ctx) + err = tx.Commit(ctx) if err != nil { qe.logger.Error("[pg_query_executor] failed to commit transaction", slog.Any("error", err)) stream.Records <- model.QRecordOrError{ diff --git a/flow/connectors/postgres/qrep_query_executor_test.go b/flow/connectors/postgres/qrep_query_executor_test.go index f012dec05a..bffb28214c 100644 --- a/flow/connectors/postgres/qrep_query_executor_test.go +++ b/flow/connectors/postgres/qrep_query_executor_test.go @@ -70,7 +70,7 @@ func TestExecuteAndProcessQuery(t *testing.T) { qe.SetTestEnv(true) query = fmt.Sprintf("SELECT * FROM %s.test;", schemaName) - batch, err := qe.ExecuteAndProcessQuery(query) + batch, err := qe.ExecuteAndProcessQuery(context.Background(), query) if err != nil { t.Fatalf("error while executing and processing query: %v", err) } @@ -173,7 +173,7 @@ func TestAllDataTypes(t *testing.T) { qe := NewQRepQueryExecutor(conn, ctx, "test flow", "test part") // Select the row back out of the table query = fmt.Sprintf("SELECT * FROM %s.test;", schemaName) - rows, err := qe.ExecuteQuery(query) + rows, err := qe.ExecuteQuery(context.Background(), query) if err != nil { t.Fatalf("error while executing query: %v", err) } diff --git a/flow/connectors/postgres/ssh_wrapped_pool.go b/flow/connectors/postgres/ssh_wrapped_pool.go index 8f048d8d42..619a9e3dbe 100644 --- a/flow/connectors/postgres/ssh_wrapped_pool.go +++ b/flow/connectors/postgres/ssh_wrapped_pool.go @@ -20,16 +20,12 @@ type SSHTunnel struct { sshServer string once sync.Once sshClient *ssh.Client - ctx context.Context - cancel context.CancelFunc } func NewSSHTunnel( ctx context.Context, sshConfig *protos.SSHConfig, ) (*SSHTunnel, error) { - swCtx, cancel := context.WithCancel(ctx) - var sshServer string var clientConfig *ssh.ClientConfig @@ -39,7 +35,6 @@ func NewSSHTunnel( clientConfig, err = utils.GetSSHClientConfig(sshConfig) if err != nil { slog.Error("Failed to get SSH client config", slog.Any("error", err)) - cancel() return nil, err } } @@ -47,8 +42,6 @@ func NewSSHTunnel( pool := &SSHTunnel{ sshConfig: clientConfig, sshServer: sshServer, - ctx: swCtx, - cancel: cancel, } err := pool.connect() @@ -85,8 +78,6 @@ func (tunnel *SSHTunnel) setupSSH() error { } func (tunnel *SSHTunnel) Close() { - tunnel.cancel() - if tunnel.sshClient != nil { tunnel.sshClient.Close() } @@ -121,7 +112,7 @@ func (tunnel *SSHTunnel) NewPostgresConnFromConfig( } } - conn, err := pgx.ConnectConfig(tunnel.ctx, connConfig) + conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { slog.Error("Failed to create pool:", slog.Any("error", err)) return nil, err @@ -129,7 +120,7 @@ func (tunnel *SSHTunnel) NewPostgresConnFromConfig( host := connConfig.Host err = retryWithBackoff(func() error { - err = conn.Ping(tunnel.ctx) + err = conn.Ping(ctx) if err != nil { slog.Error("Failed to ping pool", slog.Any("error", err), slog.String("host", host)) return err diff --git a/flow/connectors/s3/qrep.go b/flow/connectors/s3/qrep.go index 99e07c8111..7a9432c1d3 100644 --- a/flow/connectors/s3/qrep.go +++ b/flow/connectors/s3/qrep.go @@ -1,6 +1,7 @@ package conns3 import ( + "context" "fmt" "log/slog" @@ -13,6 +14,7 @@ import ( ) func (c *S3Connector) SyncQRepRecords( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -31,7 +33,7 @@ func (c *S3Connector) SyncQRepRecords( return 0, err } - numRecords, err := c.writeToAvroFile(stream, avroSchema, partition.PartitionId, config.FlowJobName) + numRecords, err := c.writeToAvroFile(ctx, stream, avroSchema, partition.PartitionId, config.FlowJobName) if err != nil { return 0, err } @@ -52,6 +54,7 @@ func getAvroSchema( } func (c *S3Connector) writeToAvroFile( + ctx context.Context, stream *model.QRecordStream, avroSchema *model.QRecordAvroSchemaDefinition, partitionID string, @@ -64,7 +67,7 @@ func (c *S3Connector) writeToAvroFile( s3AvroFileKey := fmt.Sprintf("%s/%s/%s.avro", s3o.Prefix, jobName, partitionID) writer := avro.NewPeerDBOCFWriter(stream, avroSchema, avro.CompressNone, qvalue.QDWHTypeSnowflake) - avroFile, err := writer.WriteRecordsToS3(c.ctx, s3o.Bucket, s3AvroFileKey, c.creds) + avroFile, err := writer.WriteRecordsToS3(ctx, s3o.Bucket, s3AvroFileKey, c.creds) if err != nil { return 0, fmt.Errorf("failed to write records to S3: %w", err) } @@ -74,7 +77,7 @@ func (c *S3Connector) writeToAvroFile( } // S3 just sets up destination, not metadata tables -func (c *S3Connector) SetupQRepMetadataTables(config *protos.QRepConfig) error { +func (c *S3Connector) SetupQRepMetadataTables(_ context.Context, config *protos.QRepConfig) error { c.logger.Info("QRep metadata setup not needed for S3.") return nil } diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index 42e6a52ddc..930f8f2204 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -23,7 +23,6 @@ const ( ) type S3Connector struct { - ctx context.Context url string pgMetadata *metadataStore.PostgresMetadataStore client s3.Client @@ -35,6 +34,7 @@ func NewS3Connector( ctx context.Context, config *protos.S3Config, ) (*S3Connector, error) { + logger := logger.LoggerFromCtx(ctx) keyID := "" if config.AccessKeyId != nil { keyID = *config.AccessKeyId @@ -66,14 +66,12 @@ func NewS3Connector( if err != nil { return nil, fmt.Errorf("failed to create S3 client: %w", err) } - logger := logger.LoggerFromCtx(ctx) - pgMetadata, err := metadataStore.NewPostgresMetadataStore(logger) + pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx) if err != nil { logger.Error("failed to create postgres metadata store", "error", err) return nil, err } return &S3Connector{ - ctx: ctx, url: config.Url, pgMetadata: pgMetadata, client: *s3Client, @@ -82,12 +80,12 @@ func NewS3Connector( }, nil } -func (c *S3Connector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { +func (c *S3Connector) CreateRawTable(_ context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { c.logger.Info("CreateRawTable for S3 is a no-op") return nil, nil } -func (c *S3Connector) Close() error { +func (c *S3Connector) Close(_ context.Context) error { return nil } @@ -135,8 +133,8 @@ func ValidCheck(ctx context.Context, s3Client *s3.Client, bucketURL string, meta return nil } -func (c *S3Connector) ConnectionActive() error { - validErr := ValidCheck(c.ctx, &c.client, c.url, c.pgMetadata) +func (c *S3Connector) ConnectionActive(ctx context.Context) error { + validErr := ValidCheck(ctx, &c.client, c.url, c.pgMetadata) if validErr != nil { c.logger.Error("failed to validate s3 connector:", "error", validErr) return validErr @@ -145,27 +143,27 @@ func (c *S3Connector) ConnectionActive() error { return nil } -func (c *S3Connector) NeedsSetupMetadataTables() bool { +func (c *S3Connector) NeedsSetupMetadataTables(_ context.Context) bool { return false } -func (c *S3Connector) SetupMetadataTables() error { +func (c *S3Connector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *S3Connector) GetLastSyncBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastBatchID(c.ctx, jobName) +func (c *S3Connector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastBatchID(ctx, jobName) } -func (c *S3Connector) GetLastOffset(jobName string) (int64, error) { - return c.pgMetadata.FetchLastOffset(c.ctx, jobName) +func (c *S3Connector) GetLastOffset(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.FetchLastOffset(ctx, jobName) } -func (c *S3Connector) SetLastOffset(jobName string, offset int64) error { - return c.pgMetadata.UpdateLastOffset(c.ctx, jobName, offset) +func (c *S3Connector) SetLastOffset(ctx context.Context, jobName string, offset int64) error { + return c.pgMetadata.UpdateLastOffset(ctx, jobName, offset) } -func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { +func (c *S3Connector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { tableNameRowsMapping := make(map[string]uint32) streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, req.SyncBatchID) streamRes, err := utils.RecordsToRawTableStream(streamReq) @@ -180,7 +178,7 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes partition := &protos.QRepPartition{ PartitionId: strconv.FormatInt(req.SyncBatchID, 10), } - numRecords, err := c.SyncQRepRecords(qrepConfig, partition, recordStream) + numRecords, err := c.SyncQRepRecords(ctx, qrepConfig, partition, recordStream) if err != nil { return nil, err } @@ -191,7 +189,7 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes return nil, fmt.Errorf("failed to get last checkpoint: %w", err) } - err = c.pgMetadata.FinishBatch(c.ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) + err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, lastCheckpoint) if err != nil { c.logger.Error("failed to increment id", "error", err) return nil, err @@ -205,12 +203,12 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes }, nil } -func (c *S3Connector) ReplayTableSchemaDeltas(flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { +func (c *S3Connector) ReplayTableSchemaDeltas(_ context.Context, flowJobName string, schemaDeltas []*protos.TableSchemaDelta) error { c.logger.Info("ReplayTableSchemaDeltas for S3 is a no-op") return nil } -func (c *S3Connector) SetupNormalizedTables(req *protos.SetupNormalizedTableBatchInput) ( +func (c *S3Connector) SetupNormalizedTables(_ context.Context, req *protos.SetupNormalizedTableBatchInput) ( *protos.SetupNormalizedTableBatchOutput, error, ) { @@ -218,6 +216,6 @@ func (c *S3Connector) SetupNormalizedTables(req *protos.SetupNormalizedTableBatc return nil, nil } -func (c *S3Connector) SyncFlowCleanup(jobName string) error { - return c.pgMetadata.DropMetadata(c.ctx, jobName) +func (c *S3Connector) SyncFlowCleanup(ctx context.Context, jobName string) error { + return c.pgMetadata.DropMetadata(ctx, jobName) } diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go index 82a09cb1ed..3ee20362c7 100644 --- a/flow/connectors/snowflake/client.go +++ b/flow/connectors/snowflake/client.go @@ -13,15 +13,13 @@ import ( peersql "github.com/PeerDB-io/peer-flow/connectors/sql" "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/logger" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" ) type SnowflakeClient struct { peersql.GenericSQLQueryExecutor - // ctx is the context. - ctx context.Context - // config is the Snowflake config. Config *protos.SnowflakeConfig } @@ -58,24 +56,24 @@ func NewSnowflakeClient(ctx context.Context, config *protos.SnowflakeConfig) (*S return nil, fmt.Errorf("failed to open connection to Snowflake peer: %w", err) } + logger := logger.LoggerFromCtx(ctx) genericExecutor := *peersql.NewGenericSQLQueryExecutor( - ctx, database, snowflakeTypeToQValueKindMap, qvalue.QValueKindToSnowflakeTypeMap) + logger, database, snowflakeTypeToQValueKindMap, qvalue.QValueKindToSnowflakeTypeMap) return &SnowflakeClient{ GenericSQLQueryExecutor: genericExecutor, - ctx: ctx, Config: config, }, nil } -func (c *SnowflakeConnector) getTableCounts(tables []string) (int64, error) { +func (c *SnowflakeConnector) getTableCounts(ctx context.Context, tables []string) (int64, error) { var totalRecords int64 for _, table := range tables { _, err := utils.ParseSchemaTable(table) if err != nil { return 0, fmt.Errorf("failed to parse table name %s: %w", table, err) } - row := c.database.QueryRowContext(c.ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) + row := c.database.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)) var count pgtype.Int8 err = row.Scan(&count) if err != nil { diff --git a/flow/connectors/snowflake/get_schema_for_tests.go b/flow/connectors/snowflake/get_schema_for_tests.go index 107df93f78..5a4ed9fd4d 100644 --- a/flow/connectors/snowflake/get_schema_for_tests.go +++ b/flow/connectors/snowflake/get_schema_for_tests.go @@ -1,6 +1,7 @@ package connsnowflake import ( + "context" "fmt" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -8,8 +9,8 @@ import ( "github.com/PeerDB-io/peer-flow/model/qvalue" ) -func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.TableSchema, error) { - colNames, colTypes, err := c.getColsFromTable(tableName) +func (c *SnowflakeConnector) getTableSchemaForTable(ctx context.Context, tableName string) (*protos.TableSchema, error) { + colNames, colTypes, err := c.getColsFromTable(ctx, tableName) if err != nil { return nil, err } @@ -37,16 +38,17 @@ func (c *SnowflakeConnector) getTableSchemaForTable(tableName string) (*protos.T // only used for testing atm. doesn't return info about pkey or ReplicaIdentity [which is PG specific anyway]. func (c *SnowflakeConnector) GetTableSchema( + ctx context.Context, req *protos.GetTableSchemaBatchInput, ) (*protos.GetTableSchemaBatchOutput, error) { res := make(map[string]*protos.TableSchema, len(req.TableIdentifiers)) for _, tableName := range req.TableIdentifiers { - tableSchema, err := c.getTableSchemaForTable(tableName) + tableSchema, err := c.getTableSchemaForTable(ctx, tableName) if err != nil { return nil, err } res[tableName] = tableSchema - utils.RecordHeartbeat(c.ctx, fmt.Sprintf("fetched schema for table %s", tableName)) + utils.RecordHeartbeat(ctx, fmt.Sprintf("fetched schema for table %s", tableName)) } return &protos.GetTableSchemaBatchOutput{ diff --git a/flow/connectors/snowflake/qrep.go b/flow/connectors/snowflake/qrep.go index 51c100f19a..0f3de892b2 100644 --- a/flow/connectors/snowflake/qrep.go +++ b/flow/connectors/snowflake/qrep.go @@ -1,6 +1,7 @@ package connsnowflake import ( + "context" "database/sql" "fmt" "log/slog" @@ -17,6 +18,7 @@ import ( ) func (c *SnowflakeConnector) SyncQRepRecords( + ctx context.Context, config *protos.QRepConfig, partition *protos.QRepPartition, stream *model.QRecordStream, @@ -27,13 +29,13 @@ func (c *SnowflakeConnector) SyncQRepRecords( slog.String(string(shared.PartitionIDKey), partition.PartitionId), slog.String("destinationTable", destTable), ) - tblSchema, err := c.getTableSchema(destTable) + tblSchema, err := c.getTableSchema(ctx, destTable) if err != nil { return 0, fmt.Errorf("failed to get schema of table %s: %w", destTable, err) } c.logger.Info("Called QRep sync function and obtained table schema", flowLog) - done, err := c.pgMetadata.IsQrepPartitionSynced(c.ctx, config.FlowJobName, partition.PartitionId) + done, err := c.pgMetadata.IsQrepPartitionSynced(ctx, config.FlowJobName, partition.PartitionId) if err != nil { return 0, fmt.Errorf("failed to check if partition %s is synced: %w", partition.PartitionId, err) } @@ -44,10 +46,10 @@ func (c *SnowflakeConnector) SyncQRepRecords( } avroSync := NewSnowflakeAvroSyncHandler(config, c) - return avroSync.SyncQRepRecords(c.ctx, config, partition, tblSchema, stream) + return avroSync.SyncQRepRecords(ctx, config, partition, tblSchema, stream) } -func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType, error) { +func (c *SnowflakeConnector) getTableSchema(ctx context.Context, tableName string) ([]*sql.ColumnType, error) { schematable, err := utils.ParseSchemaTable(tableName) if err != nil { return nil, fmt.Errorf("failed to parse table '%s'", tableName) @@ -57,7 +59,7 @@ func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType queryString := fmt.Sprintf("SELECT * FROM %s LIMIT 0", snowflakeSchemaTableNormalize(schematable)) //nolint:rowserrcheck - rows, err := c.database.QueryContext(c.ctx, queryString) + rows, err := c.database.QueryContext(ctx, queryString) if err != nil { return nil, fmt.Errorf("failed to execute query: %w", err) } @@ -71,20 +73,20 @@ func (c *SnowflakeConnector) getTableSchema(tableName string) ([]*sql.ColumnType return columnTypes, nil } -func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig) error { - _, err := c.database.ExecContext(c.ctx, fmt.Sprintf(createSchemaSQL, c.rawSchema)) +func (c *SnowflakeConnector) SetupQRepMetadataTables(ctx context.Context, config *protos.QRepConfig) error { + _, err := c.database.ExecContext(ctx, fmt.Sprintf(createSchemaSQL, c.rawSchema)) if err != nil { return err } stageName := c.getStageNameForJob(config.FlowJobName) - err = c.createStage(stageName, config) + err = c.createStage(ctx, stageName, config) if err != nil { return err } if config.WriteMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE { - _, err = c.database.Exec(fmt.Sprintf("TRUNCATE TABLE %s", config.DestinationTableIdentifier)) + _, err = c.database.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", config.DestinationTableIdentifier)) if err != nil { return fmt.Errorf("failed to TRUNCATE table before query replication: %w", err) } @@ -93,7 +95,7 @@ func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig) return nil } -func (c *SnowflakeConnector) createStage(stageName string, config *protos.QRepConfig) error { +func (c *SnowflakeConnector) createStage(ctx context.Context, stageName string, config *protos.QRepConfig) error { var createStageStmt string if strings.HasPrefix(config.StagingPath, "s3://") { stmt, err := c.createExternalStage(stageName, config) @@ -110,7 +112,7 @@ func (c *SnowflakeConnector) createStage(stageName string, config *protos.QRepCo } // Execute the query - _, err := c.database.Exec(createStageStmt) + _, err := c.database.ExecContext(ctx, createStageStmt) if err != nil { c.logger.Error(fmt.Sprintf("failed to create stage %s", stageName), slog.Any("error", err)) return fmt.Errorf("failed to create stage %s: %w", stageName, err) @@ -156,14 +158,14 @@ func (c *SnowflakeConnector) createExternalStage(stageName string, config *proto } } -func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig) error { +func (c *SnowflakeConnector) ConsolidateQRepPartitions(ctx context.Context, config *protos.QRepConfig) error { c.logger.Info("Consolidating partitions") destTable := config.DestinationTableIdentifier stageName := c.getStageNameForJob(config.FlowJobName) writeHandler := NewSnowflakeAvroConsolidateHandler(c, config, destTable, stageName) - err := writeHandler.CopyStageToDestination(c.ctx) + err := writeHandler.CopyStageToDestination(ctx) if err != nil { c.logger.Error("failed to copy stage to destination", slog.Any("error", err)) return fmt.Errorf("failed to copy stage to destination: %w", err) @@ -173,12 +175,12 @@ func (c *SnowflakeConnector) ConsolidateQRepPartitions(config *protos.QRepConfig } // CleanupQRepFlow function for snowflake connector -func (c *SnowflakeConnector) CleanupQRepFlow(config *protos.QRepConfig) error { +func (c *SnowflakeConnector) CleanupQRepFlow(ctx context.Context, config *protos.QRepConfig) error { c.logger.Info("Cleaning up flow job") - return c.dropStage(config.StagingPath, config.FlowJobName) + return c.dropStage(ctx, config.StagingPath, config.FlowJobName) } -func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, []string, error) { +func (c *SnowflakeConnector) getColsFromTable(ctx context.Context, tableName string) ([]string, []string, error) { // parse the table name to get the schema and table name schemaTable, err := utils.ParseSchemaTable(tableName) if err != nil { @@ -186,7 +188,7 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, []str } rows, err := c.database.QueryContext( - c.ctx, + ctx, getTableSchemaSQL, strings.ToUpper(schemaTable.Schema), strings.ToUpper(schemaTable.Table), @@ -220,11 +222,11 @@ func (c *SnowflakeConnector) getColsFromTable(tableName string) ([]string, []str } // dropStage drops the stage for the given job. -func (c *SnowflakeConnector) dropStage(stagingPath string, job string) error { +func (c *SnowflakeConnector) dropStage(ctx context.Context, stagingPath string, job string) error { stageName := c.getStageNameForJob(job) stmt := fmt.Sprintf("DROP STAGE IF EXISTS %s", stageName) - _, err := c.database.Exec(stmt) + _, err := c.database.ExecContext(ctx, stmt) if err != nil { return fmt.Errorf("failed to drop stage %s: %w", stageName, err) } @@ -252,14 +254,14 @@ func (c *SnowflakeConnector) dropStage(stagingPath string, job string) error { Prefix: aws.String(fmt.Sprintf("%s/%s", s3o.Prefix, job)), }) for pages.HasMorePages() { - page, err := pages.NextPage(c.ctx) + page, err := pages.NextPage(ctx) if err != nil { c.logger.Error("failed to list objects from bucket", slog.Any("error", err)) return fmt.Errorf("failed to list objects from bucket: %w", err) } for _, object := range page.Contents { - _, err = s3svc.DeleteObject(c.ctx, &s3.DeleteObjectInput{ + _, err = s3svc.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(s3o.Bucket), Key: object.Key, }) diff --git a/flow/connectors/snowflake/qrep_avro_consolidate.go b/flow/connectors/snowflake/qrep_avro_consolidate.go index 8de09398e6..26ca452bd2 100644 --- a/flow/connectors/snowflake/qrep_avro_consolidate.go +++ b/flow/connectors/snowflake/qrep_avro_consolidate.go @@ -39,7 +39,7 @@ func NewSnowflakeAvroConsolidateHandler( func (s *SnowflakeAvroConsolidateHandler) CopyStageToDestination(ctx context.Context) error { s.connector.logger.Info("Copying stage to destination " + s.dstTableName) - colNames, colTypes, colsErr := s.connector.getColsFromTable(s.dstTableName) + colNames, colTypes, colsErr := s.connector.getColsFromTable(ctx, s.dstTableName) if colsErr != nil { return fmt.Errorf("failed to get columns from destination table: %w", colsErr) } @@ -228,7 +228,7 @@ func (s *SnowflakeAvroConsolidateHandler) handleUpsertMode(ctx context.Context) } rowCount, err := rows.RowsAffected() if err == nil { - totalRowsAtTarget, err := s.connector.getTableCounts([]string{s.dstTableName}) + totalRowsAtTarget, err := s.connector.getTableCounts(ctx, []string{s.dstTableName}) if err != nil { return err } diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 0cdd8f78ba..d8b7921289 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -65,7 +65,7 @@ func (s *SnowflakeAvroSyncHandler) SyncRecords( s.connector.logger.Info(fmt.Sprintf("written %d records to Avro file", avroFile.NumRecords), tableLog) stage := s.connector.getStageNameForJob(s.config.FlowJobName) - err = s.connector.createStage(stage, s.config) + err = s.connector.createStage(ctx, stage, s.config) if err != nil { return 0, err } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 44efa760f0..4a6d47434b 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -75,7 +75,6 @@ const ( ) type SnowflakeConnector struct { - ctx context.Context database *sql.DB pgMetadata *metadataStore.PostgresMetadataStore rawSchema string @@ -207,13 +206,12 @@ func NewSnowflakeConnector( rawSchema = *snowflakeProtoConfig.MetadataSchema } - pgMetadata, err := metadataStore.NewPostgresMetadataStore(logger) + pgMetadata, err := metadataStore.NewPostgresMetadataStore(ctx) if err != nil { return nil, fmt.Errorf("could not connect to metadata store: %w", err) } return &SnowflakeConnector{ - ctx: ctx, database: database, pgMetadata: pgMetadata, rawSchema: rawSchema, @@ -221,7 +219,7 @@ func NewSnowflakeConnector( }, nil } -func (c *SnowflakeConnector) Close() error { +func (c *SnowflakeConnector) Close(_ context.Context) error { if c == nil || c.database == nil { return nil } @@ -233,46 +231,49 @@ func (c *SnowflakeConnector) Close() error { return nil } -func (c *SnowflakeConnector) ConnectionActive() error { +func (c *SnowflakeConnector) ConnectionActive(ctx context.Context) error { if c == nil || c.database == nil { return fmt.Errorf("SnowflakeConnector is nil") } // This also checks if database exists - err := c.database.PingContext(c.ctx) + err := c.database.PingContext(ctx) return err } -func (c *SnowflakeConnector) NeedsSetupMetadataTables() bool { +func (c *SnowflakeConnector) NeedsSetupMetadataTables(_ context.Context) bool { return false } -func (c *SnowflakeConnector) SetupMetadataTables() error { +func (c *SnowflakeConnector) SetupMetadataTables(_ context.Context) error { return nil } -func (c *SnowflakeConnector) GetLastOffset(jobName string) (int64, error) { - return c.pgMetadata.FetchLastOffset(c.ctx, jobName) +func (c *SnowflakeConnector) GetLastOffset(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.FetchLastOffset(ctx, jobName) } -func (c *SnowflakeConnector) SetLastOffset(jobName string, offset int64) error { - return c.pgMetadata.UpdateLastOffset(c.ctx, jobName, offset) +func (c *SnowflakeConnector) SetLastOffset(ctx context.Context, jobName string, offset int64) error { + return c.pgMetadata.UpdateLastOffset(ctx, jobName, offset) } -func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastBatchID(c.ctx, jobName) +func (c *SnowflakeConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastBatchID(ctx, jobName) } -func (c *SnowflakeConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { - return c.pgMetadata.GetLastNormalizeBatchID(c.ctx, jobName) +func (c *SnowflakeConnector) GetLastNormalizeBatchID(ctx context.Context, jobName string) (int64, error) { + return c.pgMetadata.GetLastNormalizeBatchID(ctx, jobName) } -func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, +func (c *SnowflakeConnector) getDistinctTableNamesInBatch( + ctx context.Context, + flowJobName string, + syncBatchID int64, normalizeBatchID int64, ) ([]string, error) { rawTableIdentifier := getRawTableIdentifier(flowJobName) - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getDistinctDestinationTableNames, c.rawSchema, + rows, err := c.database.QueryContext(ctx, fmt.Sprintf(getDistinctDestinationTableNames, c.rawSchema, rawTableIdentifier, normalizeBatchID, syncBatchID)) if err != nil { return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err) @@ -296,12 +297,15 @@ func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, sy return destinationTableNames, nil } -func (c *SnowflakeConnector) getTableNameToUnchangedCols(flowJobName string, syncBatchID int64, +func (c *SnowflakeConnector) getTableNameToUnchangedCols( + ctx context.Context, + flowJobName string, + syncBatchID int64, normalizeBatchID int64, ) (map[string][]string, error) { rawTableIdentifier := getRawTableIdentifier(flowJobName) - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getTableNameToUnchangedColsSQL, c.rawSchema, + rows, err := c.database.QueryContext(ctx, fmt.Sprintf(getTableNameToUnchangedColsSQL, c.rawSchema, rawTableIdentifier, normalizeBatchID, syncBatchID)) if err != nil { return nil, fmt.Errorf("error while retrieving table names for normalization: %w", err) @@ -326,6 +330,7 @@ func (c *SnowflakeConnector) getTableNameToUnchangedCols(flowJobName string, syn } func (c *SnowflakeConnector) SetupNormalizedTables( + ctx context.Context, req *protos.SetupNormalizedTableBatchInput, ) (*protos.SetupNormalizedTableBatchOutput, error) { tableExistsMapping := make(map[string]bool) @@ -334,7 +339,7 @@ func (c *SnowflakeConnector) SetupNormalizedTables( if err != nil { return nil, fmt.Errorf("error while parsing table schema and name: %w", err) } - tableAlreadyExists, err := c.checkIfTableExists(normalizedSchemaTable.Schema, normalizedSchemaTable.Table) + tableAlreadyExists, err := c.checkIfTableExists(ctx, normalizedSchemaTable.Schema, normalizedSchemaTable.Table) if err != nil { return nil, fmt.Errorf("error occurred while checking if normalized table exists: %w", err) } @@ -345,12 +350,12 @@ func (c *SnowflakeConnector) SetupNormalizedTables( normalizedTableCreateSQL := generateCreateTableSQLForNormalizedTable( normalizedSchemaTable, tableSchema, req.SoftDeleteColName, req.SyncedAtColName) - _, err = c.database.ExecContext(c.ctx, normalizedTableCreateSQL) + _, err = c.database.ExecContext(ctx, normalizedTableCreateSQL) if err != nil { return nil, fmt.Errorf("[sf] error while creating normalized table: %w", err) } tableExistsMapping[tableIdentifier] = false - utils.RecordHeartbeat(c.ctx, fmt.Sprintf("created table %s", tableIdentifier)) + utils.RecordHeartbeat(ctx, fmt.Sprintf("created table %s", tableIdentifier)) } return &protos.SetupNormalizedTableBatchOutput{ @@ -360,7 +365,9 @@ func (c *SnowflakeConnector) SetupNormalizedTables( // ReplayTableSchemaDeltas changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. -func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, +func (c *SnowflakeConnector) ReplayTableSchemaDeltas( + ctx context.Context, + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { if len(schemaDeltas) == 0 { @@ -390,7 +397,7 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, return fmt.Errorf("failed to convert column type %s to snowflake type: %w", addedColumn.ColumnType, err) } - _, err = tableSchemaModifyTx.ExecContext(c.ctx, + _, err = tableSchemaModifyTx.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN IF NOT EXISTS \"%s\" %s", schemaDelta.DstTableName, strings.ToUpper(addedColumn.ColumnName), sfColtype)) if err != nil { @@ -413,16 +420,16 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, return nil } -func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { +func (c *SnowflakeConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest) (*model.SyncResponse, error) { rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) c.logger.Info(fmt.Sprintf("pushing records to Snowflake table %s", rawTableIdentifier)) - res, err := c.syncRecordsViaAvro(req, rawTableIdentifier, req.SyncBatchID) + res, err := c.syncRecordsViaAvro(ctx, req, rawTableIdentifier, req.SyncBatchID) if err != nil { return nil, err } - err = c.pgMetadata.FinishBatch(c.ctx, req.FlowJobName, req.SyncBatchID, res.LastSyncedCheckpointID) + err = c.pgMetadata.FinishBatch(ctx, req.FlowJobName, req.SyncBatchID, res.LastSyncedCheckpointID) if err != nil { return nil, err } @@ -431,6 +438,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. } func (c *SnowflakeConnector) syncRecordsViaAvro( + ctx context.Context, req *model.SyncRecordsRequest, rawTableIdentifier string, syncBatchID int64, @@ -449,17 +457,17 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( rawTableIdentifier)), } avroSyncer := NewSnowflakeAvroSyncHandler(qrepConfig, c) - destinationTableSchema, err := c.getTableSchema(qrepConfig.DestinationTableIdentifier) + destinationTableSchema, err := c.getTableSchema(ctx, qrepConfig.DestinationTableIdentifier) if err != nil { return nil, err } - numRecords, err := avroSyncer.SyncRecords(c.ctx, destinationTableSchema, streamRes.Stream, req.FlowJobName) + numRecords, err := avroSyncer.SyncRecords(ctx, destinationTableSchema, streamRes.Stream, req.FlowJobName) if err != nil { return nil, err } - err = c.ReplayTableSchemaDeltas(req.FlowJobName, req.Records.SchemaDeltas) + err = c.ReplayTableSchemaDeltas(ctx, req.FlowJobName, req.Records.SchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -479,8 +487,8 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( } // NormalizeRecords normalizes raw table to destination table. -func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { - normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) +func (c *SnowflakeConnector) NormalizeRecords(ctx context.Context, req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { + normBatchID, err := c.GetLastNormalizeBatchID(ctx, req.FlowJobName) if err != nil { return nil, err } @@ -495,6 +503,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } destinationTableNames, err := c.getDistinctTableNamesInBatch( + ctx, req.FlowJobName, req.SyncBatchID, normBatchID, @@ -503,13 +512,13 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest return nil, err } - tableNameToUnchangedToastCols, err := c.getTableNameToUnchangedCols(req.FlowJobName, req.SyncBatchID, normBatchID) + tableNameToUnchangedToastCols, err := c.getTableNameToUnchangedCols(ctx, req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } var totalRowsAffected int64 = 0 - g, gCtx := errgroup.WithContext(c.ctx) + g, gCtx := errgroup.WithContext(ctx) g.SetLimit(8) // limit parallel merges to 8 for _, destinationTableName := range destinationTableNames { @@ -565,7 +574,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest return nil, fmt.Errorf("error while normalizing records: %w", err) } - err = c.pgMetadata.UpdateNormalizeBatchID(c.ctx, req.FlowJobName, req.SyncBatchID) + err = c.pgMetadata.UpdateNormalizeBatchID(ctx, req.FlowJobName, req.SyncBatchID) if err != nil { return nil, err } @@ -577,20 +586,20 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest }, nil } -func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { - _, err := c.database.ExecContext(c.ctx, fmt.Sprintf(createSchemaSQL, c.rawSchema)) +func (c *SnowflakeConnector) CreateRawTable(ctx context.Context, req *protos.CreateRawTableInput) (*protos.CreateRawTableOutput, error) { + _, err := c.database.ExecContext(ctx, fmt.Sprintf(createSchemaSQL, c.rawSchema)) if err != nil { return nil, err } - createRawTableTx, err := c.database.BeginTx(c.ctx, nil) + createRawTableTx, err := c.database.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("unable to begin transaction for creation of raw table: %w", err) } // there is no easy way to check if a table has the same schema in Snowflake, // so just executing the CREATE TABLE IF NOT EXISTS blindly. rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) - _, err = createRawTableTx.ExecContext(c.ctx, + _, err = createRawTableTx.ExecContext(ctx, fmt.Sprintf(createRawTableSQL, c.rawSchema, rawTableIdentifier)) if err != nil { return nil, fmt.Errorf("unable to create raw table: %w", err) @@ -601,7 +610,7 @@ func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*p } stage := c.getStageNameForJob(req.FlowJobName) - err = c.createStage(stage, &protos.QRepConfig{}) + err = c.createStage(ctx, stage, &protos.QRepConfig{}) if err != nil { return nil, err } @@ -611,13 +620,13 @@ func (c *SnowflakeConnector) CreateRawTable(req *protos.CreateRawTableInput) (*p }, nil } -func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { - err := c.pgMetadata.DropMetadata(c.ctx, jobName) +func (c *SnowflakeConnector) SyncFlowCleanup(ctx context.Context, jobName string) error { + err := c.pgMetadata.DropMetadata(ctx, jobName) if err != nil { return fmt.Errorf("unable to clear metadata for sync flow cleanup: %w", err) } - syncFlowCleanupTx, err := c.database.BeginTx(c.ctx, nil) + syncFlowCleanupTx, err := c.database.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("unable to begin transaction for sync flow cleanup: %w", err) } @@ -628,7 +637,7 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { } }() - err = c.dropStage("", jobName) + err = c.dropStage(ctx, "", jobName) if err != nil { return err } @@ -636,9 +645,13 @@ func (c *SnowflakeConnector) SyncFlowCleanup(jobName string) error { return nil } -func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableIdentifier string) (bool, error) { +func (c *SnowflakeConnector) checkIfTableExists( + ctx context.Context, + schemaIdentifier string, + tableIdentifier string, +) (bool, error) { var result pgtype.Bool - err := c.database.QueryRowContext(c.ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier).Scan(&result) + err := c.database.QueryRowContext(ctx, checkIfTableExistsSQL, schemaIdentifier, tableIdentifier).Scan(&result) if err != nil { return false, fmt.Errorf("error while reading result row: %w", err) } @@ -709,8 +722,8 @@ func getRawTableIdentifier(jobName string) string { return fmt.Sprintf("%s_%s", rawTablePrefix, jobName) } -func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { - renameTablesTx, err := c.database.BeginTx(c.ctx, nil) +func (c *SnowflakeConnector) RenameTables(ctx context.Context, req *protos.RenameTablesInput) (*protos.RenameTablesOutput, error) { + renameTablesTx, err := c.database.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("unable to begin transaction for rename tables: %w", err) } @@ -727,10 +740,10 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto c.logger.Info(fmt.Sprintf("setting synced at column for table '%s'...", resyncTblName)) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("setting synced at column for table '%s'...", + activity.RecordHeartbeat(ctx, fmt.Sprintf("setting synced at column for table '%s'...", resyncTblName)) - _, err = renameTablesTx.ExecContext(c.ctx, + _, err = renameTablesTx.ExecContext(ctx, fmt.Sprintf("UPDATE %s SET %s = CURRENT_TIMESTAMP", resyncTblName, *req.SyncedAtColName)) if err != nil { return nil, fmt.Errorf("unable to set synced at column for table %s: %w", resyncTblName, err) @@ -753,9 +766,9 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto c.logger.Info(fmt.Sprintf("handling soft-deletes for table '%s'...", dst)) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dst)) + activity.RecordHeartbeat(ctx, fmt.Sprintf("handling soft-deletes for table '%s'...", dst)) - _, err = renameTablesTx.ExecContext(c.ctx, + _, err = renameTablesTx.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s(%s) SELECT %s,true AS %s FROM %s WHERE (%s) NOT IN (SELECT %s FROM %s)", src, fmt.Sprintf("%s,%s", allCols, *req.SoftDeleteColName), allCols, *req.SoftDeleteColName, dst, pkeyCols, pkeyCols, src)) @@ -772,16 +785,16 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto c.logger.Info(fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) + activity.RecordHeartbeat(ctx, fmt.Sprintf("renaming table '%s' to '%s'...", src, dst)) // drop the dst table if exists - _, err = renameTablesTx.ExecContext(c.ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", dst)) + _, err = renameTablesTx.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", dst)) if err != nil { return nil, fmt.Errorf("unable to drop table %s: %w", dst, err) } // rename the src table to dst - _, err = renameTablesTx.ExecContext(c.ctx, fmt.Sprintf("ALTER TABLE %s RENAME TO %s", src, dst)) + _, err = renameTablesTx.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME TO %s", src, dst)) if err != nil { return nil, fmt.Errorf("unable to rename table %s to %s: %w", src, dst, err) } @@ -799,10 +812,10 @@ func (c *SnowflakeConnector) RenameTables(req *protos.RenameTablesInput) (*proto }, nil } -func (c *SnowflakeConnector) CreateTablesFromExisting(req *protos.CreateTablesFromExistingInput) ( +func (c *SnowflakeConnector) CreateTablesFromExisting(ctx context.Context, req *protos.CreateTablesFromExistingInput) ( *protos.CreateTablesFromExistingOutput, error, ) { - createTablesFromExistingTx, err := c.database.BeginTx(c.ctx, nil) + createTablesFromExistingTx, err := c.database.BeginTx(ctx, nil) if err != nil { return nil, fmt.Errorf("unable to begin transaction for rename tables: %w", err) } @@ -816,10 +829,10 @@ func (c *SnowflakeConnector) CreateTablesFromExisting(req *protos.CreateTablesFr for newTable, existingTable := range req.NewToExistingTableMapping { c.logger.Info(fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) - activity.RecordHeartbeat(c.ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) + activity.RecordHeartbeat(ctx, fmt.Sprintf("creating table '%s' similar to '%s'", newTable, existingTable)) // rename the src table to dst - _, err = createTablesFromExistingTx.ExecContext(c.ctx, + _, err = createTablesFromExistingTx.ExecContext(ctx, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s LIKE %s", newTable, existingTable)) if err != nil { return nil, fmt.Errorf("unable to create table %s: %w", newTable, err) diff --git a/flow/connectors/sql/query_executor.go b/flow/connectors/sql/query_executor.go index 58e90b6894..14bc6629b7 100644 --- a/flow/connectors/sql/query_executor.go +++ b/flow/connectors/sql/query_executor.go @@ -13,56 +13,53 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/jmoiron/sqlx" "go.temporal.io/sdk/activity" + "go.temporal.io/sdk/log" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" - "github.com/PeerDB-io/peer-flow/shared" ) type SQLQueryExecutor interface { - ConnectionActive() error - Close() error + ConnectionActive(context.Context) error + Close(context.Context) error - CreateSchema(schemaName string) error - DropSchema(schemaName string) error - CheckSchemaExists(schemaName string) (bool, error) - RecreateSchema(schemaName string) error + CreateSchema(ctx context.Context, schemaName string) error + DropSchema(ctx context.Context, schemaName string) error + CheckSchemaExists(ctx context.Context, schemaName string) (bool, error) + RecreateSchema(ctx context.Context, schemaName string) error - CreateTable(schema *model.QRecordSchema, schemaName string, tableName string) error - CountRows(schemaName string, tableName string) (int64, error) + CreateTable(ctx context.Context, schema *model.QRecordSchema, schemaName string, tableName string) error + CountRows(ctx context.Context, schemaName string, tableName string) (int64, error) - ExecuteAndProcessQuery(query string, args ...interface{}) (*model.QRecordBatch, error) - NamedExecuteAndProcessQuery(query string, arg interface{}) (*model.QRecordBatch, error) - ExecuteQuery(query string, args ...interface{}) error - NamedExec(query string, arg interface{}) (sql.Result, error) + ExecuteAndProcessQuery(ctx context.Context, query string, args ...interface{}) (*model.QRecordBatch, error) + NamedExecuteAndProcessQuery(ctx context.Context, query string, arg interface{}) (*model.QRecordBatch, error) + ExecuteQuery(ctx context.Context, query string, args ...interface{}) error + NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) } type GenericSQLQueryExecutor struct { - ctx context.Context db *sqlx.DB dbtypeToQValueKind map[string]qvalue.QValueKind qvalueKindToDBType map[qvalue.QValueKind]string - logger slog.Logger + logger log.Logger } func NewGenericSQLQueryExecutor( - ctx context.Context, + logger log.Logger, db *sqlx.DB, dbtypeToQValueKind map[string]qvalue.QValueKind, qvalueKindToDBType map[qvalue.QValueKind]string, ) *GenericSQLQueryExecutor { - flowName, _ := ctx.Value(shared.FlowNameKey).(string) return &GenericSQLQueryExecutor{ - ctx: ctx, db: db, dbtypeToQValueKind: dbtypeToQValueKind, qvalueKindToDBType: qvalueKindToDBType, - logger: *slog.With(slog.String(string(shared.FlowNameKey), flowName)), + logger: logger, } } -func (g *GenericSQLQueryExecutor) ConnectionActive() bool { - err := g.db.PingContext(g.ctx) +func (g *GenericSQLQueryExecutor) ConnectionActive(ctx context.Context) bool { + err := g.db.PingContext(ctx) return err == nil } @@ -70,32 +67,32 @@ func (g *GenericSQLQueryExecutor) Close() error { return g.db.Close() } -func (g *GenericSQLQueryExecutor) CreateSchema(schemaName string) error { - _, err := g.db.ExecContext(g.ctx, "CREATE SCHEMA "+schemaName) +func (g *GenericSQLQueryExecutor) CreateSchema(ctx context.Context, schemaName string) error { + _, err := g.db.ExecContext(ctx, "CREATE SCHEMA "+schemaName) return err } -func (g *GenericSQLQueryExecutor) DropSchema(schemaName string) error { - _, err := g.db.ExecContext(g.ctx, "DROP SCHEMA IF EXISTS "+schemaName+" CASCADE") +func (g *GenericSQLQueryExecutor) DropSchema(ctx context.Context, schemaName string) error { + _, err := g.db.ExecContext(ctx, "DROP SCHEMA IF EXISTS "+schemaName+" CASCADE") return err } // the SQL query this function executes appears to be MySQL/MariaDB specific. -func (g *GenericSQLQueryExecutor) CheckSchemaExists(schemaName string) (bool, error) { +func (g *GenericSQLQueryExecutor) CheckSchemaExists(ctx context.Context, schemaName string) (bool, error) { var exists pgtype.Bool // use information schemata to check if schema exists - err := g.db.QueryRowxContext(g.ctx, + err := g.db.QueryRowxContext(ctx, "SELECT EXISTS(SELECT 1 FROM information_schema.schemata WHERE schema_name = $1)", schemaName).Scan(&exists) return exists.Bool, err } -func (g *GenericSQLQueryExecutor) RecreateSchema(schemaName string) error { - err := g.DropSchema(schemaName) +func (g *GenericSQLQueryExecutor) RecreateSchema(ctx context.Context, schemaName string) error { + err := g.DropSchema(ctx, schemaName) if err != nil { return fmt.Errorf("failed to drop schema: %w", err) } - err = g.CreateSchema(schemaName) + err = g.CreateSchema(ctx, schemaName) if err != nil { return fmt.Errorf("failed to create schema: %w", err) } @@ -103,7 +100,7 @@ func (g *GenericSQLQueryExecutor) RecreateSchema(schemaName string) error { return nil } -func (g *GenericSQLQueryExecutor) CreateTable(schema *model.QRecordSchema, schemaName string, tableName string) error { +func (g *GenericSQLQueryExecutor) CreateTable(ctx context.Context, schema *model.QRecordSchema, schemaName string, tableName string) error { fields := make([]string, 0, len(schema.Fields)) for _, field := range schema.Fields { dbType, ok := g.qvalueKindToDBType[field.Type] @@ -115,7 +112,7 @@ func (g *GenericSQLQueryExecutor) CreateTable(schema *model.QRecordSchema, schem command := fmt.Sprintf("CREATE TABLE %s.%s (%s)", schemaName, tableName, strings.Join(fields, ", ")) - _, err := g.db.ExecContext(g.ctx, command) + _, err := g.db.ExecContext(ctx, command) if err != nil { return fmt.Errorf("failed to create table: %w", err) } @@ -123,20 +120,21 @@ func (g *GenericSQLQueryExecutor) CreateTable(schema *model.QRecordSchema, schem return nil } -func (g *GenericSQLQueryExecutor) CountRows(schemaName string, tableName string) (int64, error) { +func (g *GenericSQLQueryExecutor) CountRows(ctx context.Context, schemaName string, tableName string) (int64, error) { var count pgtype.Int8 - err := g.db.QueryRowx("SELECT COUNT(*) FROM " + schemaName + "." + tableName).Scan(&count) + err := g.db.QueryRowxContext(ctx, "SELECT COUNT(*) FROM "+schemaName+"."+tableName).Scan(&count) return count.Int64, err } func (g *GenericSQLQueryExecutor) CountNonNullRows( + ctx context.Context, schemaName string, tableName string, columnName string, ) (int64, error) { var count pgtype.Int8 - err := g.db.QueryRowx("SELECT COUNT(CASE WHEN " + columnName + - " IS NOT NULL THEN 1 END) AS non_null_count FROM " + schemaName + "." + tableName).Scan(&count) + err := g.db.QueryRowxContext(ctx, "SELECT COUNT(CASE WHEN "+columnName+ + " IS NOT NULL THEN 1 END) AS non_null_count FROM "+schemaName+"."+tableName).Scan(&count) return count.Int64, err } @@ -155,7 +153,7 @@ func (g *GenericSQLQueryExecutor) columnTypeToQField(ct *sql.ColumnType) (model. }, nil } -func (g *GenericSQLQueryExecutor) processRows(rows *sqlx.Rows) (*model.QRecordBatch, error) { +func (g *GenericSQLQueryExecutor) processRows(ctx context.Context, rows *sqlx.Rows) (*model.QRecordBatch, error) { dbColTypes, err := rows.ColumnTypes() if err != nil { return nil, err @@ -241,7 +239,7 @@ func (g *GenericSQLQueryExecutor) processRows(rows *sqlx.Rows) (*model.QRecordBa totalRowsProcessed += 1 if totalRowsProcessed%heartBeatNumRows == 0 { - activity.RecordHeartbeat(g.ctx, fmt.Sprintf("processed %d rows", totalRowsProcessed)) + activity.RecordHeartbeat(ctx, fmt.Sprintf("processed %d rows", totalRowsProcessed)) } } @@ -258,46 +256,50 @@ func (g *GenericSQLQueryExecutor) processRows(rows *sqlx.Rows) (*model.QRecordBa } func (g *GenericSQLQueryExecutor) ExecuteAndProcessQuery( - query string, args ...interface{}, + ctx context.Context, + query string, + args ...interface{}, ) (*model.QRecordBatch, error) { - rows, err := g.db.QueryxContext(g.ctx, query, args...) + rows, err := g.db.QueryxContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - return g.processRows(rows) + return g.processRows(ctx, rows) } func (g *GenericSQLQueryExecutor) NamedExecuteAndProcessQuery( - query string, arg interface{}, + ctx context.Context, + query string, + arg interface{}, ) (*model.QRecordBatch, error) { - rows, err := g.db.NamedQueryContext(g.ctx, query, arg) + rows, err := g.db.NamedQueryContext(ctx, query, arg) if err != nil { return nil, err } defer rows.Close() - return g.processRows(rows) + return g.processRows(ctx, rows) } -func (g *GenericSQLQueryExecutor) ExecuteQuery(query string, args ...interface{}) error { - _, err := g.db.ExecContext(g.ctx, query, args...) +func (g *GenericSQLQueryExecutor) ExecuteQuery(ctx context.Context, query string, args ...interface{}) error { + _, err := g.db.ExecContext(ctx, query, args...) return err } -func (g *GenericSQLQueryExecutor) NamedExec(query string, arg interface{}) (sql.Result, error) { - return g.db.NamedExecContext(g.ctx, query, arg) +func (g *GenericSQLQueryExecutor) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + return g.db.NamedExecContext(ctx, query, arg) } // returns true if any of the columns are null in value -func (g *GenericSQLQueryExecutor) CheckNull(schema string, tableName string, colNames []string) (bool, error) { +func (g *GenericSQLQueryExecutor) CheckNull(ctx context.Context, schema string, tableName string, colNames []string) (bool, error) { var count pgtype.Int8 joinedString := strings.Join(colNames, " is null or ") + " is null" query := fmt.Sprintf("SELECT COUNT(*) FROM %s.%s WHERE %s", schema, tableName, joinedString) - err := g.db.QueryRowxContext(g.ctx, query).Scan(&count) + err := g.db.QueryRowxContext(ctx, query).Scan(&count) if err != nil { return false, err } diff --git a/flow/connectors/sqlserver/qrep.go b/flow/connectors/sqlserver/qrep.go index 33474ef490..af9386cfd5 100644 --- a/flow/connectors/sqlserver/qrep.go +++ b/flow/connectors/sqlserver/qrep.go @@ -2,6 +2,7 @@ package connsqlserver import ( "bytes" + "context" "fmt" "log/slog" "text/template" @@ -16,7 +17,7 @@ import ( ) func (c *SQLServerConnector) GetQRepPartitions( - config *protos.QRepConfig, last *protos.QRepPartition, + ctx context.Context, config *protos.QRepConfig, last *protos.QRepPartition, ) ([]*protos.QRepPartition, error) { if config.WatermarkTable == "" { c.logger.Info("watermark table is empty, doing full table refresh") @@ -70,7 +71,7 @@ func (c *SQLServerConnector) GetQRepPartitions( } } } else { - row := c.db.QueryRow(countQuery) + row := c.db.QueryRowContext(ctx, countQuery) if err = row.Scan(&totalRows); err != nil { return nil, fmt.Errorf("failed to query for total rows: %w", err) } @@ -151,7 +152,9 @@ func (c *SQLServerConnector) GetQRepPartitions( } func (c *SQLServerConnector) PullQRepRecords( - config *protos.QRepConfig, partition *protos.QRepPartition, + ctx context.Context, + config *protos.QRepConfig, + partition *protos.QRepPartition, ) (*model.QRecordBatch, error) { // Build the query to pull records within the range from the source table // Be sure to order the results by the watermark column to ensure consistency across pulls @@ -162,7 +165,7 @@ func (c *SQLServerConnector) PullQRepRecords( if partition.FullTablePartition { // this is a full table partition, so just run the query - return c.ExecuteAndProcessQuery(query) + return c.ExecuteAndProcessQuery(ctx, query) } var rangeStart interface{} @@ -185,7 +188,7 @@ func (c *SQLServerConnector) PullQRepRecords( "endRange": rangeEnd, } - return c.NamedExecuteAndProcessQuery(query, rangeParams) + return c.NamedExecuteAndProcessQuery(ctx, query, rangeParams) } func BuildQuery(query string) (string, error) { diff --git a/flow/connectors/sqlserver/sqlserver.go b/flow/connectors/sqlserver/sqlserver.go index cf99a98d0e..4fea9e3f68 100644 --- a/flow/connectors/sqlserver/sqlserver.go +++ b/flow/connectors/sqlserver/sqlserver.go @@ -16,7 +16,6 @@ import ( type SQLServerConnector struct { peersql.GenericSQLQueryExecutor - ctx context.Context config *protos.SqlServerConfig db *sqlx.DB logger log.Logger @@ -37,20 +36,21 @@ func NewSQLServerConnector(ctx context.Context, config *protos.SqlServerConfig) return nil, err } + logger := logger.LoggerFromCtx(ctx) + genericExecutor := *peersql.NewGenericSQLQueryExecutor( - ctx, db, sqlServerTypeToQValueKindMap, qValueKindToSQLServerTypeMap) + logger, db, sqlServerTypeToQValueKindMap, qValueKindToSQLServerTypeMap) return &SQLServerConnector{ GenericSQLQueryExecutor: genericExecutor, - ctx: ctx, config: config, db: db, - logger: logger.LoggerFromCtx(ctx), + logger: logger, }, nil } // Close closes the database connection -func (c *SQLServerConnector) Close() error { +func (c *SQLServerConnector) Close(_ context.Context) error { if c.db != nil { return c.db.Close() } @@ -58,8 +58,8 @@ func (c *SQLServerConnector) Close() error { } // ConnectionActive checks if the connection is still active -func (c *SQLServerConnector) ConnectionActive() error { - if err := c.db.Ping(); err != nil { +func (c *SQLServerConnector) ConnectionActive(ctx context.Context) error { + if err := c.db.PingContext(ctx); err != nil { return err } return nil diff --git a/flow/connectors/utils/catalog/env.go b/flow/connectors/utils/catalog/env.go index 5a12172022..cb73b947cf 100644 --- a/flow/connectors/utils/catalog/env.go +++ b/flow/connectors/utils/catalog/env.go @@ -17,20 +17,20 @@ var ( pool *pgxpool.Pool ) -func GetCatalogConnectionPoolFromEnv() (*pgxpool.Pool, error) { +func GetCatalogConnectionPoolFromEnv(ctx context.Context) (*pgxpool.Pool, error) { var err error poolMutex.Lock() defer poolMutex.Unlock() if pool == nil { catalogConnectionString := genCatalogConnectionString() - pool, err = pgxpool.New(context.Background(), catalogConnectionString) + pool, err = pgxpool.New(ctx, catalogConnectionString) if err != nil { return nil, fmt.Errorf("unable to establish connection with catalog: %w", err) } } - err = pool.Ping(context.Background()) + err = pool.Ping(ctx) if err != nil { return pool, fmt.Errorf("unable to establish connection with catalog: %w", err) } diff --git a/flow/dynamicconf/dynamicconf.go b/flow/dynamicconf/dynamicconf.go index d08ece7078..047b614925 100644 --- a/flow/dynamicconf/dynamicconf.go +++ b/flow/dynamicconf/dynamicconf.go @@ -25,7 +25,7 @@ func dynamicConfKeyExists(ctx context.Context, conn *pgxpool.Pool, key string) b } func dynamicConfUint32(ctx context.Context, key string, defaultValue uint32) uint32 { - conn, err := utils.GetCatalogConnectionPoolFromEnv() + conn, err := utils.GetCatalogConnectionPoolFromEnv(ctx) if err != nil { slog.Error("Failed to get catalog connection pool: %v", err) return defaultValue diff --git a/flow/e2e/postgres/peer_flow_pg_test.go b/flow/e2e/postgres/peer_flow_pg_test.go index 7847759d1e..abc23c4a89 100644 --- a/flow/e2e/postgres/peer_flow_pg_test.go +++ b/flow/e2e/postgres/peer_flow_pg_test.go @@ -60,7 +60,7 @@ func (s PeerFlowE2ETestSuitePG) WaitForSchema( s.t.Helper() e2e.EnvWaitFor(s.t, env, 3*time.Minute, reason, func() bool { s.t.Helper() - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{dstTableName}, }) if err != nil { diff --git a/flow/e2e/postgres/qrep_flow_pg_test.go b/flow/e2e/postgres/qrep_flow_pg_test.go index d7f04b7807..1723afca5d 100644 --- a/flow/e2e/postgres/qrep_flow_pg_test.go +++ b/flow/e2e/postgres/qrep_flow_pg_test.go @@ -207,7 +207,7 @@ func (s PeerFlowE2ETestSuitePG) TestSimpleSlotCreation() { setupError := make(chan error) go func() { - setupError <- s.connector.SetupReplication(signal, setupReplicationInput) + setupError <- s.connector.SetupReplication(context.Background(), signal, setupReplicationInput) }() s.t.Log("waiting for slot creation to complete: ", flowJobName) diff --git a/flow/e2e/snowflake/peer_flow_sf_test.go b/flow/e2e/snowflake/peer_flow_sf_test.go index 9fa22f2b2e..f669492d45 100644 --- a/flow/e2e/snowflake/peer_flow_sf_test.go +++ b/flow/e2e/snowflake/peer_flow_sf_test.go @@ -65,7 +65,7 @@ func TestPeerFlowE2ETestSuiteSF(t *testing.T) { } } - err := s.connector.Close() + err := s.connector.Close(context.Background()) if err != nil { s.t.Fatalf("failed to close Snowflake connector: %v", err) } @@ -747,7 +747,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { }, }, } - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{dstTableName}, }) e2e.EnvNoError(s.t, env, err) @@ -790,7 +790,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { }, }, } - output, err = s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err = s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{dstTableName}, }) e2e.EnvNoError(s.t, env, err) @@ -839,7 +839,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { }, }, } - output, err = s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err = s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{dstTableName}, }) e2e.EnvNoError(s.t, env, err) @@ -888,7 +888,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Simple_Schema_Changes_SF() { }, }, } - output, err = s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err = s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{dstTableName}, }) e2e.EnvNoError(s.t, env, err) diff --git a/flow/e2e/snowflake/snowflake_helper.go b/flow/e2e/snowflake/snowflake_helper.go index 1c4b9d2bb9..2d3b16c271 100644 --- a/flow/e2e/snowflake/snowflake_helper.go +++ b/flow/e2e/snowflake/snowflake_helper.go @@ -60,7 +60,7 @@ func NewSnowflakeTestHelper() (*SnowflakeTestHelper, error) { if err != nil { return nil, fmt.Errorf("failed to create Snowflake client: %w", err) } - err = adminClient.ExecuteQuery(fmt.Sprintf("CREATE DATABASE %s", testDatabaseName)) + err = adminClient.ExecuteQuery(context.Background(), fmt.Sprintf("CREATE DATABASE %s", testDatabaseName)) if err != nil { return nil, fmt.Errorf("failed to create Snowflake test database: %w", err) } @@ -99,7 +99,7 @@ func (s *SnowflakeTestHelper) Cleanup() error { if err != nil { return err } - err = s.adminClient.ExecuteQuery(fmt.Sprintf("DROP DATABASE %s", s.testDatabaseName)) + err = s.adminClient.ExecuteQuery(context.Background(), fmt.Sprintf("DROP DATABASE %s", s.testDatabaseName)) if err != nil { return err } @@ -108,12 +108,12 @@ func (s *SnowflakeTestHelper) Cleanup() error { // RunCommand runs the given command. func (s *SnowflakeTestHelper) RunCommand(command string) error { - return s.testClient.ExecuteQuery(command) + return s.testClient.ExecuteQuery(context.Background(), command) } // CountRows(tableName) returns the number of rows in the given table. func (s *SnowflakeTestHelper) CountRows(tableName string) (int, error) { - res, err := s.testClient.CountRows(s.testSchemaName, tableName) + res, err := s.testClient.CountRows(context.Background(), s.testSchemaName, tableName) if err != nil { return 0, err } @@ -123,7 +123,7 @@ func (s *SnowflakeTestHelper) CountRows(tableName string) (int, error) { // CountRows(tableName) returns the non-null number of rows in the given table. func (s *SnowflakeTestHelper) CountNonNullRows(tableName string, columnName string) (int, error) { - res, err := s.testClient.CountNonNullRows(s.testSchemaName, tableName, columnName) + res, err := s.testClient.CountNonNullRows(context.Background(), s.testSchemaName, tableName, columnName) if err != nil { return 0, err } @@ -132,20 +132,20 @@ func (s *SnowflakeTestHelper) CountNonNullRows(tableName string, columnName stri } func (s *SnowflakeTestHelper) CheckNull(tableName string, colNames []string) (bool, error) { - return s.testClient.CheckNull(s.testSchemaName, tableName, colNames) + return s.testClient.CheckNull(context.Background(), s.testSchemaName, tableName, colNames) } func (s *SnowflakeTestHelper) ExecuteAndProcessQuery(query string) (*model.QRecordBatch, error) { - return s.testClient.ExecuteAndProcessQuery(query) + return s.testClient.ExecuteAndProcessQuery(context.Background(), query) } func (s *SnowflakeTestHelper) CreateTable(tableName string, schema *model.QRecordSchema) error { - return s.testClient.CreateTable(schema, s.testSchemaName, tableName) + return s.testClient.CreateTable(context.Background(), schema, s.testSchemaName, tableName) } // runs a query that returns an int result func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) { - rows, err := s.testClient.ExecuteAndProcessQuery(query) + rows, err := s.testClient.ExecuteAndProcessQuery(context.Background(), query) if err != nil { return 0, err } @@ -179,7 +179,7 @@ func (s *SnowflakeTestHelper) RunIntQuery(query string) (int, error) { // runs a query that returns an int result func (s *SnowflakeTestHelper) checkSyncedAt(query string) error { - recordBatch, err := s.testClient.ExecuteAndProcessQuery(query) + recordBatch, err := s.testClient.ExecuteAndProcessQuery(context.Background(), query) if err != nil { return err } diff --git a/flow/e2e/snowflake/snowflake_schema_delta_test.go b/flow/e2e/snowflake/snowflake_schema_delta_test.go index 4201aff590..03246ae617 100644 --- a/flow/e2e/snowflake/snowflake_schema_delta_test.go +++ b/flow/e2e/snowflake/snowflake_schema_delta_test.go @@ -50,7 +50,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestSimpleAddColumn() { err := s.sfTestHelper.RunCommand(fmt.Sprintf("CREATE TABLE %s(ID TEXT PRIMARY KEY)", tableName)) require.NoError(s.t, err) - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: []*protos.DeltaAddedColumn{{ @@ -60,7 +60,7 @@ func (s SnowflakeSchemaDeltaTestSuite) TestSimpleAddColumn() { }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -157,14 +157,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddAllColumnTypes() { } } - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -236,14 +236,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddTrickyColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -290,14 +290,14 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() { } } - err = s.connector.ReplayTableSchemaDeltas("schema_delta_flow", []*protos.TableSchemaDelta{{ + err = s.connector.ReplayTableSchemaDeltas(context.Background(), "schema_delta_flow", []*protos.TableSchemaDelta{{ SrcTableName: tableName, DstTableName: tableName, AddedColumns: addedColumns, }}) require.NoError(s.t, err) - output, err := s.connector.GetTableSchema(&protos.GetTableSchemaBatchInput{ + output, err := s.connector.GetTableSchema(context.Background(), &protos.GetTableSchemaBatchInput{ TableIdentifiers: []string{tableName}, }) require.NoError(s.t, err) @@ -307,6 +307,6 @@ func (s SnowflakeSchemaDeltaTestSuite) TestAddWhitespaceColumnNames() { func TestSnowflakeSchemaDeltaTestSuite(t *testing.T) { e2eshared.RunSuite(t, setupSchemaDeltaSuite, func(s SnowflakeSchemaDeltaTestSuite) { require.NoError(s.t, s.sfTestHelper.Cleanup()) - require.NoError(s.t, s.connector.Close()) + require.NoError(s.t, s.connector.Close(context.Background())) }) } diff --git a/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go b/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go index f16128d9f5..e94641ab92 100644 --- a/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go +++ b/flow/e2e/sqlserver/qrep_flow_sqlserver_test.go @@ -103,6 +103,7 @@ func (s PeerFlowE2ETestSuiteSQLServer) insertRowsIntoSQLServerTable(tableName st params["status"] = 1 _, err := s.sqlsHelper.E.NamedExec( + context.Background(), "INSERT INTO "+schemaQualified+" (id, card_id, v_from, price, status) VALUES (:id, :card_id, :v_from, :price, :status)", params, ) diff --git a/flow/e2e/sqlserver/sqlserver_helper.go b/flow/e2e/sqlserver/sqlserver_helper.go index 749367d994..15ade1796f 100644 --- a/flow/e2e/sqlserver/sqlserver_helper.go +++ b/flow/e2e/sqlserver/sqlserver_helper.go @@ -41,7 +41,7 @@ func NewSQLServerHelper(name string) (*SQLServerHelper, error) { return nil, err } - connErr := connector.ConnectionActive() + connErr := connector.ConnectionActive(context.Background()) if connErr != nil { return nil, fmt.Errorf("invalid connection configs: %v", connErr) } @@ -52,7 +52,7 @@ func NewSQLServerHelper(name string) (*SQLServerHelper, error) { } testSchema := fmt.Sprintf("e2e_test_%d", rndNum) - err = connector.CreateSchema(testSchema) + err = connector.CreateSchema(context.Background(), testSchema) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func NewSQLServerHelper(name string) (*SQLServerHelper, error) { } func (h *SQLServerHelper) CreateTable(schema *model.QRecordSchema, tableName string) error { - err := h.E.CreateTable(schema, h.SchemaName, tableName) + err := h.E.CreateTable(context.Background(), schema, h.SchemaName, tableName) if err != nil { return err } @@ -87,14 +87,14 @@ func (h *SQLServerHelper) GetPeer() *protos.Peer { func (h *SQLServerHelper) CleanUp() error { for _, tbl := range h.tables { - err := h.E.ExecuteQuery(fmt.Sprintf("DROP TABLE %s.%s", h.SchemaName, tbl)) + err := h.E.ExecuteQuery(context.Background(), fmt.Sprintf("DROP TABLE %s.%s", h.SchemaName, tbl)) if err != nil { return err } } if h.SchemaName != "" { - return h.E.ExecuteQuery(fmt.Sprintf("DROP SCHEMA %s", h.SchemaName)) + return h.E.ExecuteQuery(context.Background(), fmt.Sprintf("DROP SCHEMA %s", h.SchemaName)) } return nil diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index e412335a15..2637249ed9 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -98,6 +98,7 @@ func GetPgRows(conn *pgx.Conn, suffix string, table string, cols string) (*model pgQueryExecutor.SetTestEnv(true) return pgQueryExecutor.ExecuteAndProcessQuery( + context.Background(), fmt.Sprintf(`SELECT %s FROM e2e_test_%s."%s" ORDER BY id`, cols, suffix, table), ) } diff --git a/flow/model/model.go b/flow/model/model.go index 3213a72434..14de42a44e 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -51,8 +51,6 @@ type PullRecordsRequest struct { RelationMessageMapping RelationMessageMapping // record batch for pushing changes into RecordStream *CDCRecordStream - // last offset may be forwarded while processing records - SetLastOffset func(int64) error } type Record interface {