diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 73853c83d7..6d209a364e 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -11,7 +11,6 @@ import ( connpostgres "github.com/PeerDB-io/peer-flow/connectors/postgres" connsnowflake "github.com/PeerDB-io/peer-flow/connectors/snowflake" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/connectors/utils/monitoring" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -19,6 +18,7 @@ import ( "github.com/jackc/pglogrepl" log "github.com/sirupsen/logrus" "go.temporal.io/sdk/activity" + "golang.org/x/sync/errgroup" ) // CheckConnectionResult is the result of a CheckConnection call. @@ -165,11 +165,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, ctx = context.WithValue(ctx, shared.EnableMetricsKey, a.EnableMetrics) ctx = context.WithValue(ctx, shared.CDCMirrorMonitorKey, a.CatalogMirrorMonitor) - srcConn, err := connectors.GetCDCPullConnector(ctx, conn.Source) - if err != nil { - return nil, fmt.Errorf("failed to get source connector: %w", err) - } - defer connectors.CloseConnector(srcConn) dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination) if err != nil { return nil, fmt.Errorf("failed to get destination connector: %w", err) @@ -196,28 +191,40 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, idleTimeout := utils.GetEnvInt("PEERDB_CDC_IDLE_TIMEOUT_SECONDS", 10) + recordBatch := model.NewCDCRecordStream() + startTime := time.Now() - recordsWithTableSchemaDelta, err := srcConn.PullRecords(&model.PullRecordsRequest{ - FlowJobName: input.FlowConnectionConfigs.FlowJobName, - SrcTableIDNameMapping: input.FlowConnectionConfigs.SrcTableIdNameMapping, - TableNameMapping: tblNameMapping, - LastSyncState: input.LastSyncState, - MaxBatchSize: uint32(input.SyncFlowOptions.BatchSize), - IdleTimeout: time.Duration(idleTimeout) * time.Second, - TableNameSchemaMapping: input.FlowConnectionConfigs.TableNameSchemaMapping, - OverridePublicationName: input.FlowConnectionConfigs.PublicationName, - OverrideReplicationSlotName: input.FlowConnectionConfigs.ReplicationSlotName, - RelationMessageMapping: input.RelationMessageMapping, - }) + + errGroup, errCtx := errgroup.WithContext(ctx) + srcConn, err := connectors.GetCDCPullConnector(errCtx, conn.Source) if err != nil { - return nil, fmt.Errorf("failed to pull records: %w", err) + return nil, fmt.Errorf("failed to get source connector: %w", err) } - recordBatch := recordsWithTableSchemaDelta.RecordBatch + defer connectors.CloseConnector(srcConn) + + // start a goroutine to pull records from the source + errGroup.Go(func() error { + return srcConn.PullRecords(&model.PullRecordsRequest{ + FlowJobName: input.FlowConnectionConfigs.FlowJobName, + SrcTableIDNameMapping: input.FlowConnectionConfigs.SrcTableIdNameMapping, + TableNameMapping: tblNameMapping, + LastSyncState: input.LastSyncState, + MaxBatchSize: uint32(input.SyncFlowOptions.BatchSize), + IdleTimeout: time.Duration(idleTimeout) * time.Second, + TableNameSchemaMapping: input.FlowConnectionConfigs.TableNameSchemaMapping, + OverridePublicationName: input.FlowConnectionConfigs.PublicationName, + OverrideReplicationSlotName: input.FlowConnectionConfigs.ReplicationSlotName, + RelationMessageMapping: input.RelationMessageMapping, + RecordStream: recordBatch, + }) + }) - pullRecordWithCount := fmt.Sprintf("pulled %d records", len(recordBatch.Records)) - activity.RecordHeartbeat(ctx, pullRecordWithCount) + hasRecords := !recordBatch.WaitAndCheckEmpty() + log.WithFields(log.Fields{ + "flowName": input.FlowConnectionConfigs.FlowJobName, + }).Infof("the current sync flow has records: %v", hasRecords) - if a.CatalogMirrorMonitor.IsActive() && len(recordBatch.Records) > 0 { + if a.CatalogMirrorMonitor.IsActive() && hasRecords { syncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName) if err != nil && conn.Destination.Type != protos.DBType_EVENTHUB { return nil, err @@ -226,9 +233,9 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, err = a.CatalogMirrorMonitor.AddCDCBatchForFlow(ctx, input.FlowConnectionConfigs.FlowJobName, monitoring.CDCBatchInfo{ BatchID: syncBatchID + 1, - RowsInBatch: uint32(len(recordBatch.Records)), - BatchStartLSN: pglogrepl.LSN(recordBatch.FirstCheckPointID), - BatchEndlSN: pglogrepl.LSN(recordBatch.LastCheckPointID), + RowsInBatch: 0, + BatchStartLSN: pglogrepl.LSN(recordBatch.GetFirstCheckpoint()), + BatchEndlSN: 0, StartTime: startTime, }) if err != nil { @@ -236,24 +243,18 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } } - pullDuration := time.Since(startTime) - numRecords := len(recordBatch.Records) - log.WithFields(log.Fields{ - "flowName": input.FlowConnectionConfigs.FlowJobName, - }).Infof("pulled %d records in %d seconds\n", numRecords, int(pullDuration.Seconds())) - activity.RecordHeartbeat(ctx, fmt.Sprintf("pulled %d records", numRecords)) + if !hasRecords { + // wait for the pull goroutine to finish + err = errGroup.Wait() + if err != nil { + return nil, fmt.Errorf("failed to pull records: %w", err) + } - if numRecords == 0 { - log.WithFields(log.Fields{ - "flowName": input.FlowConnectionConfigs.FlowJobName, - }).Info("no records to push") - metrics.LogSyncMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, 0, 1) - metrics.LogNormalizeMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, 0, 1, 0) - metrics.LogCDCRawThroughputMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, 0) - return &model.SyncResponse{ - RelationMessageMapping: recordsWithTableSchemaDelta.RelationMessageMapping, - TableSchemaDeltas: recordsWithTableSchemaDelta.TableSchemaDeltas, - }, nil + log.WithFields(log.Fields{"flowName": input.FlowConnectionConfigs.FlowJobName}).Info("no records to push") + syncResponse := &model.SyncResponse{} + syncResponse.RelationMessageMapping = <-recordBatch.RelationMessageMapping + syncResponse.TableSchemaDeltas = recordBatch.WaitForSchemaDeltas() + return syncResponse, nil } shutdown := utils.HeartbeatRoutine(ctx, 10*time.Second, func() string { @@ -279,14 +280,35 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, return nil, fmt.Errorf("failed to push records: %w", err) } + err = errGroup.Wait() + if err != nil { + return nil, fmt.Errorf("failed to pull records: %w", err) + } + + numRecords := res.NumRecordsSynced syncDuration := time.Since(syncStartTime) log.WithFields(log.Fields{ "flowName": input.FlowConnectionConfigs.FlowJobName, }).Infof("pushed %d records in %d seconds\n", numRecords, int(syncDuration.Seconds())) + lastCheckpoint, err := recordBatch.GetLastCheckpoint() + if err != nil { + return nil, fmt.Errorf("failed to get last checkpoint: %w", err) + } + + err = a.CatalogMirrorMonitor.UpdateNumRowsAndEndLSNForCDCBatch( + ctx, + input.FlowConnectionConfigs.FlowJobName, + res.CurrentSyncBatchID, + uint32(numRecords), + pglogrepl.LSN(lastCheckpoint), + ) + if err != nil { + return nil, err + } + err = a.CatalogMirrorMonitor. - UpdateLatestLSNAtTargetForCDCFlow(ctx, input.FlowConnectionConfigs.FlowJobName, - pglogrepl.LSN(recordBatch.LastCheckPointID)) + UpdateLatestLSNAtTargetForCDCFlow(ctx, input.FlowConnectionConfigs.FlowJobName, pglogrepl.LSN(lastCheckpoint)) if err != nil { return nil, err } @@ -300,15 +322,12 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, if err != nil { return nil, err } - res.TableSchemaDeltas = recordsWithTableSchemaDelta.TableSchemaDeltas - res.RelationMessageMapping = recordsWithTableSchemaDelta.RelationMessageMapping + res.TableSchemaDeltas = recordBatch.WaitForSchemaDeltas() + res.RelationMessageMapping = <-recordBatch.RelationMessageMapping pushedRecordsWithCount := fmt.Sprintf("pushed %d records", numRecords) activity.RecordHeartbeat(ctx, pushedRecordsWithCount) - metrics.LogCDCRawThroughputMetrics(ctx, input.FlowConnectionConfigs.FlowJobName, - float64(numRecords)/(pullDuration.Seconds()+syncDuration.Seconds())) - return res, nil } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index a00759ee09..e4594688ef 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -13,7 +13,6 @@ import ( "cloud.google.com/go/bigquery" "cloud.google.com/go/storage" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -432,17 +431,9 @@ func (r StagingBQRecord) Save() (map[string]bigquery.Value, string, error) { // currently only supports inserts,updates and deletes // more record types will be added in the future. func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { - if len(req.Records.Records) == 0 { - return &model.SyncResponse{ - FirstSyncedCheckPointID: 0, - LastSyncedCheckPointID: 0, - NumRecordsSynced: 0, - }, nil - } - rawTableName := c.getRawTableName(req.FlowJobName) - log.Printf("pushing %d records to %s.%s", len(req.Records.Records), c.datasetID, rawTableName) + log.Printf("pushing records to %s.%s...", c.datasetID, rawTableName) // generate a sequential number for the last synced batch // this sequence will be used to keep track of records that are normalized @@ -486,9 +477,9 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, tableNameRowsMapping := make(map[string]uint32) first := true var firstCP int64 = 0 - lastCP := req.Records.LastCheckPointID + // loop over req.Records - for _, record := range req.Records.Records { + for record := range req.Records.GetRecords() { switch r := record.(type) { case *model.InsertRecord: // create the 3 required fields @@ -586,14 +577,6 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, } numRecords := len(records) - if numRecords == 0 { - return &model.SyncResponse{ - FirstSyncedCheckPointID: 0, - LastSyncedCheckPointID: 0, - NumRecordsSynced: 0, - }, nil - } - // insert the records into the staging table stagingInserter := stagingTable.Inserter() stagingInserter.IgnoreUnknownValues = true @@ -613,6 +596,11 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, } } + lastCP, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, fmt.Errorf("failed to get last checkpoint: %v", err) + } + // we have to do the following things in a transaction // 1. append the records in the staging table to the raw table. // 2. execute the update metadata query to store the last committed watermark. @@ -629,13 +617,11 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, stmts = append(stmts, appendStmt) stmts = append(stmts, updateMetadataStmt) stmts = append(stmts, "COMMIT TRANSACTION;") - startTime := time.Now() _, err = c.client.Query(strings.Join(stmts, "\n")).Read(c.ctx) if err != nil { return nil, fmt.Errorf("failed to execute statements in a transaction: %v", err) } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime)) log.Printf("pushed %d records to %s.%s", numRecords, c.datasetID, rawTableName) return &model.SyncResponse{ @@ -647,13 +633,15 @@ func (c *BigQueryConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, }, nil } -func (c *BigQueryConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest, - rawTableName string, syncBatchID int64) (*model.SyncResponse, error) { +func (c *BigQueryConnector) syncRecordsViaAvro( + req *model.SyncRecordsRequest, + rawTableName string, + syncBatchID int64, +) (*model.SyncResponse, error) { tableNameRowsMapping := make(map[string]uint32) first := true var firstCP int64 = 0 - lastCP := req.Records.LastCheckPointID - recordStream := model.NewQRecordStream(len(req.Records.Records)) + recordStream := model.NewQRecordStream(1 << 20) err := recordStream.SetSchema(&model.QRecordSchema{ Fields: []*model.QField{ { @@ -713,7 +701,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest, } // loop over req.Records - for _, record := range req.Records.Records { + for record := range req.Records.GetRecords() { var entries [10]qvalue.QValue switch r := record.(type) { case *model.InsertRecord: @@ -843,21 +831,23 @@ func (c *BigQueryConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest, } } - startTime := time.Now() - close(recordStream.Records) avroSync := NewQRepAvroSyncMethod(c, req.StagingPath) rawTableMetadata, err := c.client.Dataset(c.datasetID).Table(rawTableName).Metadata(c.ctx) if err != nil { return nil, fmt.Errorf("failed to get metadata of destination table: %v", err) } + lastCP, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, fmt.Errorf("failed to get last checkpoint: %v", err) + } + numRecords, err := avroSync.SyncRecords(rawTableName, req.FlowJobName, lastCP, rawTableMetadata, syncBatchID, recordStream) if err != nil { return nil, fmt.Errorf("failed to sync records via avro : %v", err) } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime)) log.Printf("pushed %d records to %s.%s", numRecords, c.datasetID, rawTableName) return &model.SyncResponse{ diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index db9c6ef6c0..e01eea95c4 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -10,7 +10,6 @@ import ( "cloud.google.com/go/bigquery" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -150,13 +149,10 @@ func (s *QRepAvroSyncMethod) SyncQRepRecords( stmts = append(stmts, insertMetadataStmt) stmts = append(stmts, "COMMIT TRANSACTION;") // Execute the statements in a transaction - syncRecordsStartTime := time.Now() _, err = bqClient.Query(strings.Join(stmts, "\n")).Read(s.connector.ctx) if err != nil { return -1, fmt.Errorf("failed to execute statements in a transaction: %v", err) } - metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, - int64(numRecords), time.Since(syncRecordsStartTime)) // drop the staging table if err := bqClient.Dataset(datasetID).Table(stagingTable).Delete(s.connector.ctx); err != nil { diff --git a/flow/connectors/bigquery/qrep_sync_method.go b/flow/connectors/bigquery/qrep_sync_method.go index 2a8eb0f5cd..8a9e42a3b5 100644 --- a/flow/connectors/bigquery/qrep_sync_method.go +++ b/flow/connectors/bigquery/qrep_sync_method.go @@ -7,7 +7,6 @@ import ( "time" "cloud.google.com/go/bigquery" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" log "github.com/sirupsen/logrus" @@ -106,8 +105,6 @@ func (s *QRepStagingTableSync) SyncQRepRecords( if err != nil { return -1, fmt.Errorf("failed to insert records into staging table: %v", err) } - metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, int64(len(valueSaverRecords)), - time.Since(startTime)) // Copy the records into the destination table in a transaction. // append all the statements to one list diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 3f1d57495f..0612514dd7 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + log "github.com/sirupsen/logrus" connbigquery "github.com/PeerDB-io/peer-flow/connectors/bigquery" @@ -37,7 +38,7 @@ type CDCPullConnector interface { // PullRecords pulls records from the source, and returns a RecordBatch. // This method should be idempotent, and should be able to be called multiple times with the same request. - PullRecords(req *model.PullRecordsRequest) (*model.RecordsWithTableSchemaDelta, error) + PullRecords(req *model.PullRecordsRequest) error // PullFlowCleanup drops both the Postgres publication and replication slot, as a part of DROP MIRROR PullFlowCleanup(jobName string) error diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index 953f44f197..30be24eba6 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -12,7 +12,6 @@ import ( azeventhubs "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs" metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" cmap "github.com/orcaman/concurrent-map/v2" @@ -121,25 +120,28 @@ func (c *EventHubConnector) updateLastOffset(jobName string, offset int64) error return nil } +// returns the number of records synced func (c *EventHubConnector) processBatch( flowJobName string, - batch *model.RecordBatch, + batch *model.CDCRecordStream, eventsPerBatch int, maxParallelism int64, -) error { +) (uint32, error) { ctx := context.Background() tableNameRowsMapping := cmap.New[uint32]() batchPerTopic := NewHubBatches(c.hubManager) toJSONOpts := model.NewToJSONOptions(c.config.UnnestColumns) - for i, record := range batch.Records { + numRecords := 0 + for record := range batch.GetRecords() { + numRecords++ json, err := record.GetItems().ToJSONWithOpts(toJSONOpts) if err != nil { log.WithFields(log.Fields{ "flowName": flowJobName, }).Infof("failed to convert record to json: %v", err) - return err + return 0, err } flushBatch := func() error { @@ -159,7 +161,7 @@ func (c *EventHubConnector) processBatch( log.WithFields(log.Fields{ "flowName": flowJobName, }).Infof("failed to get topic name: %v", err) - return err + return 0, err } err = batchPerTopic.AddEvent(ctx, topicName, json) @@ -167,13 +169,13 @@ func (c *EventHubConnector) processBatch( log.WithFields(log.Fields{ "flowName": flowJobName, }).Infof("failed to add event to batch: %v", err) - return err + return 0, err } - if (i+1)%eventsPerBatch == 0 { + if (numRecords)%eventsPerBatch == 0 { err := flushBatch() if err != nil { - return err + return 0, err } } } @@ -181,15 +183,14 @@ func (c *EventHubConnector) processBatch( if batchPerTopic.Len() > 0 { err := c.sendEventBatch(ctx, batchPerTopic, maxParallelism, flowJobName, tableNameRowsMapping) if err != nil { - return err + return 0, err } } - rowsSynced := len(batch.Records) log.WithFields(log.Fields{ "flowName": flowJobName, - }).Infof("[total] successfully sent %d records to event hub", rowsSynced) - return nil + }).Infof("[total] successfully sent %d records to event hub", numRecords) + return uint32(numRecords), nil } func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { @@ -212,29 +213,34 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } var err error - startTime := time.Now() - batch := req.Records + var numRecords uint32 // if env var PEERDB_BETA_EVENTHUB_PUSH_ASYNC=true // we kick off processBatch in a goroutine and return immediately. // otherwise, we block until processBatch is done. if utils.GetEnvBool("PEERDB_BETA_EVENTHUB_PUSH_ASYNC", false) { go func() { - err = c.processBatch(req.FlowJobName, batch, eventsPerBatch, maxParallelism) + numRecords, err = c.processBatch(req.FlowJobName, batch, eventsPerBatch, maxParallelism) if err != nil { log.Errorf("[async] failed to process batch: %v", err) } }() } else { - err = c.processBatch(req.FlowJobName, batch, eventsPerBatch, maxParallelism) + numRecords, err = c.processBatch(req.FlowJobName, batch, eventsPerBatch, maxParallelism) if err != nil { log.Errorf("failed to process batch: %v", err) return nil, err } } - err = c.updateLastOffset(req.FlowJobName, batch.LastCheckPointID) + lastCheckpoint, err := req.Records.GetLastCheckpoint() + if err != nil { + log.Errorf("failed to get last checkpoint: %v", err) + return nil, err + } + + err = c.updateLastOffset(req.FlowJobName, lastCheckpoint) if err != nil { log.Errorf("failed to update last offset: %v", err) return nil, err @@ -245,12 +251,10 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return nil, err } - rowsSynced := int64(len(batch.Records)) - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, rowsSynced, time.Since(startTime)) - metrics.LogNormalizeMetrics(c.ctx, req.FlowJobName, rowsSynced, time.Since(startTime), rowsSynced) + rowsSynced := int64(numRecords) return &model.SyncResponse{ - FirstSyncedCheckPointID: batch.FirstCheckPointID, - LastSyncedCheckPointID: batch.LastCheckPointID, + FirstSyncedCheckPointID: batch.GetFirstCheckpoint(), + LastSyncedCheckPointID: lastCheckpoint, NumRecordsSynced: rowsSynced, TableNameRowsMapping: make(map[string]uint32), }, nil diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index b087e9d0f3..4c53717f02 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -103,8 +103,7 @@ func getChildToParentRelIdMap(ctx context.Context, pool *pgxpool.Pool) (map[uint } // PullRecords pulls records from the cdc stream -func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) ( - *model.RecordsWithTableSchemaDelta, error) { +func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) error { // setup options pluginArguments := []string{ "proto_version '1'", @@ -114,7 +113,7 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) ( pubOpt := fmt.Sprintf("publication_names '%s'", p.publication) pluginArguments = append(pluginArguments, pubOpt) } else { - return nil, fmt.Errorf("publication name is not set") + return fmt.Errorf("publication name is not set") } replicationOpts := pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments} @@ -123,7 +122,7 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) ( // create replication connection replicationConn, err := p.replPool.Acquire(p.ctx) if err != nil { - return nil, fmt.Errorf("error acquiring connection for replication: %w", err) + return fmt.Errorf("error acquiring connection for replication: %w", err) } defer replicationConn.Release() @@ -135,7 +134,7 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) ( sysident, err := pglogrepl.IdentifySystem(p.ctx, pgConn) if err != nil { - return nil, fmt.Errorf("IdentifySystem failed: %w", err) + return fmt.Errorf("IdentifySystem failed: %w", err) } log.Debugf("SystemID: %s, Timeline: %d, XLogPos: %d, DBName: %s", sysident.SystemID, sysident.Timeline, sysident.XLogPos, sysident.DBName) @@ -149,13 +148,13 @@ func (p *PostgresCDCSource) PullRecords(req *model.PullRecordsRequest) ( err = pglogrepl.StartReplication(p.ctx, pgConn, replicationSlot, p.startLSN, replicationOpts) if err != nil { - return nil, fmt.Errorf("error starting replication at startLsn - %d: %w", p.startLSN, err) + return fmt.Errorf("error starting replication at startLsn - %d: %w", p.startLSN, err) } log.WithFields(log.Fields{ "flowName": req.FlowJobName, }).Infof("started replication on slot %s at startLSN: %d", p.slot, p.startLSN) - return p.consumeStream(pgConn, req, p.startLSN) + return p.consumeStream(pgConn, req, p.startLSN, req.RecordStream) } // start consuming the cdc stream @@ -163,19 +162,8 @@ func (p *PostgresCDCSource) consumeStream( conn *pgconn.PgConn, req *model.PullRecordsRequest, clientXLogPos pglogrepl.LSN, -) (*model.RecordsWithTableSchemaDelta, error) { - // TODO (kaushik): take into consideration the MaxBatchSize - // parameters in the original request. - records := &model.RecordBatch{ - Records: make([]model.Record, 0), - TablePKeyLastSeen: make(map[model.TableWithPkey]int), - } - result := &model.RecordsWithTableSchemaDelta{ - RecordBatch: records, - TableSchemaDeltas: nil, - RelationMessageMapping: p.relationMessageMapping, - } - + records *model.CDCRecordStream, +) error { standbyMessageTimeout := req.IdleTimeout nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) @@ -197,10 +185,17 @@ func (p *PostgresCDCSource) consumeStream( } var standByLastLogged time.Time + localRecords := make([]model.Record, 0) + defer func() { + if len(localRecords) == 0 { + records.SignalAsEmpty() + } + records.RelationMessageMapping <- &p.relationMessageMapping + }() shutdown := utils.HeartbeatRoutine(p.ctx, 10*time.Second, func() string { jobName := req.FlowJobName - currRecords := len(records.Records) + currRecords := len(localRecords) return fmt.Sprintf("pulling records for job - %s, currently have %d records", jobName, currRecords) }) @@ -208,20 +203,29 @@ func (p *PostgresCDCSource) consumeStream( shutdown <- true }() - firstProcessed := false + tablePKeyLastSeen := make(map[model.TableWithPkey]int) + + addRecord := func(rec model.Record) { + records.AddRecord(rec) + localRecords = append(localRecords, rec) + + if len(localRecords) == 1 { + records.SignalAsNotEmpty() + } + } for { if time.Now().After(nextStandbyMessageDeadline) || - (len(records.Records) >= int(req.MaxBatchSize)) { + (len(localRecords) >= int(req.MaxBatchSize)) { // Update XLogPos to the last processed position, we can only confirm // that this is the last row committed on the destination. err := pglogrepl.SendStandbyStatusUpdate(p.ctx, conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: consumedXLogPos}) if err != nil { - return nil, fmt.Errorf("SendStandbyStatusUpdate failed: %w", err) + return fmt.Errorf("SendStandbyStatusUpdate failed: %w", err) } - numRowsProcessedMessage := fmt.Sprintf("processed %d rows", len(records.Records)) + numRowsProcessedMessage := fmt.Sprintf("processed %d rows", len(localRecords)) if time.Since(standByLastLogged) > 10*time.Second { log.Infof("Sent Standby status message. %s", numRowsProcessedMessage) @@ -230,8 +234,8 @@ func (p *PostgresCDCSource) consumeStream( nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) - if !p.commitLock && (len(records.Records) >= int(req.MaxBatchSize)) { - return result, nil + if !p.commitLock && (len(localRecords) >= int(req.MaxBatchSize)) { + return nil } } @@ -240,15 +244,15 @@ func (p *PostgresCDCSource) consumeStream( cancel() if err != nil && !p.commitLock { if pgconn.Timeout(err) { - log.Infof("Idle timeout reached, returning currently accumulated records") - return result, nil + log.Infof("Idle timeout reached, returning currently accumulated records - %d", len(localRecords)) + return nil } else { - return nil, fmt.Errorf("ReceiveMessage failed: %w", err) + return fmt.Errorf("ReceiveMessage failed: %w", err) } } if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { - return nil, fmt.Errorf("received Postgres WAL error: %+v", errMsg) + return fmt.Errorf("received Postgres WAL error: %+v", errMsg) } msg, ok := rawMsg.(*pgproto3.CopyData) @@ -261,7 +265,7 @@ func (p *PostgresCDCSource) consumeStream( case pglogrepl.PrimaryKeepaliveMessageByteID: pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { - return nil, fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err) + return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err) } log.Debugf("Primary Keepalive Message => ServerWALEnd: %s ServerTime: %s ReplyRequested: %t", @@ -277,7 +281,7 @@ func (p *PostgresCDCSource) consumeStream( case pglogrepl.XLogDataByteID: xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) if err != nil { - return nil, fmt.Errorf("ParseXLogData failed: %w", err) + return fmt.Errorf("ParseXLogData failed: %w", err) } log.Debugf("XLogData => WALStart %s ServerWALEnd %s ServerTime %s\n", @@ -285,13 +289,9 @@ func (p *PostgresCDCSource) consumeStream( rec, err := p.processMessage(records, xld) if err != nil { - return nil, fmt.Errorf("error processing message: %w", err) + return fmt.Errorf("error processing message: %w", err) } - if !firstProcessed { - firstProcessed = true - records.FirstCheckPointID = int64(xld.WALStart) - } if rec != nil { tableName := rec.GetTableName() switch r := rec.(type) { @@ -301,71 +301,71 @@ func (p *PostgresCDCSource) consumeStream( // will change in future isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull if isFullReplica { - records.Records = append(records.Records, rec) + addRecord(rec) } else { compositePKeyString, err := p.compositePKeyToString(req, rec) if err != nil { - return nil, err + return err } tablePkeyVal := model.TableWithPkey{ TableName: tableName, PkeyColVal: compositePKeyString, } - _, ok := records.TablePKeyLastSeen[tablePkeyVal] + _, ok := tablePKeyLastSeen[tablePkeyVal] if !ok { - records.Records = append(records.Records, rec) - records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1 + addRecord(rec) + tablePKeyLastSeen[tablePkeyVal] = len(localRecords) - 1 } else { - oldRec := records.Records[records.TablePKeyLastSeen[tablePkeyVal]] + oldRec := localRecords[tablePKeyLastSeen[tablePkeyVal]] // iterate through unchanged toast cols and set them in new record updatedCols := r.NewItems.UpdateIfNotExists(oldRec.GetItems()) for _, col := range updatedCols { delete(r.UnchangedToastColumns, col) } - records.Records = append(records.Records, rec) - records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1 + addRecord(rec) + tablePKeyLastSeen[tablePkeyVal] = len(localRecords) - 1 } } case *model.InsertRecord: isFullReplica := req.TableNameSchemaMapping[tableName].IsReplicaIdentityFull if isFullReplica { - records.Records = append(records.Records, rec) + addRecord(rec) } else { compositePKeyString, err := p.compositePKeyToString(req, rec) if err != nil { - return nil, err + return err } tablePkeyVal := model.TableWithPkey{ TableName: tableName, PkeyColVal: compositePKeyString, } - records.Records = append(records.Records, rec) + addRecord(rec) // all columns will be set in insert record, so add it to the map - records.TablePKeyLastSeen[tablePkeyVal] = len(records.Records) - 1 + tablePKeyLastSeen[tablePkeyVal] = len(localRecords) - 1 } case *model.DeleteRecord: - records.Records = append(records.Records, rec) + addRecord(rec) case *model.RelationRecord: tableSchemaDelta := r.TableSchemaDelta if len(tableSchemaDelta.AddedColumns) > 0 { log.Infof("Detected schema change for table %s, addedColumns: %v", tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedColumns) - result.TableSchemaDeltas = append(result.TableSchemaDeltas, tableSchemaDelta) + records.SchemaDeltas <- tableSchemaDelta } } } if xld.WALStart > clientXLogPos { clientXLogPos = xld.WALStart - records.LastCheckPointID = int64(clientXLogPos) + records.UpdateLatestCheckpoint(int64(clientXLogPos)) } } } } -func (p *PostgresCDCSource) processMessage(batch *model.RecordBatch, xld pglogrepl.XLogData) (model.Record, error) { +func (p *PostgresCDCSource) processMessage(batch *model.CDCRecordStream, xld pglogrepl.XLogData) (model.Record, error) { logicalMsg, err := pglogrepl.Parse(xld.WALData) if err != nil { return nil, fmt.Errorf("error parsing logical message: %w", err) @@ -386,7 +386,7 @@ func (p *PostgresCDCSource) processMessage(batch *model.RecordBatch, xld pglogre // for a commit message, update the last checkpoint id for the record batch. log.Debugf("CommitMessage => CommitLSN: %v, TransactionEndLSN: %v", msg.CommitLSN, msg.TransactionEndLSN) - batch.LastCheckPointID = int64(xld.WALStart) + batch.UpdateLatestCheckpoint(int64(xld.WALStart)) p.commitLock = false case *pglogrepl.RelationMessage: // treat all relation messages as correponding to parent if partitioned. diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index b312115536..6c06cb60c4 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -8,7 +8,6 @@ import ( "time" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/connectors/utils/monitoring" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -19,7 +18,6 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) // PostgresConnector is a Connector implementation for Postgres. @@ -176,7 +174,11 @@ func (c *PostgresConnector) GetLastOffset(jobName string) (*protos.LastSyncState } // PullRecords pulls records from the source. -func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.RecordsWithTableSchemaDelta, error) { +func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) error { + defer func() { + req.RecordStream.Close() + }() + // Slotname would be the job name prefixed with "peerflow_slot_" slotName := fmt.Sprintf("peerflow_slot_%s", req.FlowJobName) if req.OverrideReplicationSlotName != "" { @@ -192,7 +194,7 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.R // Check if the replication slot and publication exist exists, err := c.checkSlotAndPublication(slotName, publicationName) if err != nil { - return nil, fmt.Errorf("error checking for replication slot and publication: %w", err) + return fmt.Errorf("error checking for replication slot and publication: %w", err) } if !exists.PublicationExists { @@ -206,7 +208,7 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.R log.WithFields(log.Fields{ "flowName": req.FlowJobName, }).Warnf("slot %s does not exist", slotName) - return nil, fmt.Errorf("replication slot %s does not exist", slotName) + return fmt.Errorf("replication slot %s does not exist", slotName) } log.WithFields(log.Fields{ @@ -223,36 +225,27 @@ func (c *PostgresConnector) PullRecords(req *model.PullRecordsRequest) (*model.R RelationMessageMapping: req.RelationMessageMapping, }, c.customTypesMapping) if err != nil { - return nil, fmt.Errorf("failed to create cdc source: %w", err) + return fmt.Errorf("failed to create cdc source: %w", err) } - startTime := time.Now() - recordsWithSchemaDelta, err := cdc.PullRecords(req) + err = cdc.PullRecords(req) if err != nil { - return nil, err + return err } - totalRecordsAtSource, err := c.getApproxTableCounts(maps.Keys(req.TableNameMapping)) - if err != nil { - return nil, err - } - metrics.LogPullMetrics(c.ctx, req.FlowJobName, recordsWithSchemaDelta.RecordBatch, - totalRecordsAtSource, time.Since(startTime)) - if len(recordsWithSchemaDelta.RecordBatch.Records) > 0 { - cdcMirrorMonitor, ok := c.ctx.Value(shared.CDCMirrorMonitorKey).(*monitoring.CatalogMirrorMonitor) - if ok { - latestLSN, err := c.getCurrentLSN() - if err != nil { - return nil, err - } - err = cdcMirrorMonitor.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, req.FlowJobName, latestLSN) - if err != nil { - return nil, err - } + cdcMirrorMonitor, ok := c.ctx.Value(shared.CDCMirrorMonitorKey).(*monitoring.CatalogMirrorMonitor) + if ok { + latestLSN, err := c.getCurrentLSN() + if err != nil { + return fmt.Errorf("failed to get current LSN: %w", err) + } + err = cdcMirrorMonitor.UpdateLatestLSNAtSourceForCDCFlow(c.ctx, req.FlowJobName, latestLSN) + if err != nil { + return fmt.Errorf("failed to update latest LSN at source for CDC flow: %w", err) } } - return recordsWithSchemaDelta, nil + return nil } // SyncRecords pushes records to the destination. @@ -260,7 +253,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) log.WithFields(log.Fields{ "flowName": req.FlowJobName, - }).Printf("pushing %d records to Postgres table %s via COPY", len(req.Records.Records), rawTableIdentifier) + }).Printf("pushing records to Postgres table %s via COPY", rawTableIdentifier) syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) if err != nil { @@ -272,9 +265,8 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S first := true var firstCP int64 = 0 - lastCP := req.Records.LastCheckPointID - for _, record := range req.Records.Records { + for record := range req.Records.GetRecords() { switch typedRecord := record.(type) { case *model.InsertRecord: itemsJSON, err := typedRecord.Items.ToJSON() @@ -362,7 +354,6 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } }() - startTime := time.Now() syncedRecordsCount, err := syncRecordsTx.CopyFrom(c.ctx, pgx.Identifier{internalSchema, rawTableIdentifier}, []string{"_peerdb_uid", "_peerdb_timestamp", "_peerdb_destination_table_name", "_peerdb_data", "_peerdb_record_type", "_peerdb_match_data", "_peerdb_batch_id", "_peerdb_unchanged_toast_columns"}, @@ -374,12 +365,16 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return nil, fmt.Errorf("error syncing records: expected %d records to be synced, but %d were synced", len(records), syncedRecordsCount) } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, syncedRecordsCount, time.Since(startTime)) log.WithFields(log.Fields{ "flowName": req.FlowJobName, }).Printf("synced %d records to Postgres table %s via COPY", syncedRecordsCount, rawTableIdentifier) + lastCP, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, fmt.Errorf("error getting last checkpoint: %w", err) + } + // updating metadata with new offset and syncBatchID err = c.updateSyncMetadata(req.FlowJobName, lastCP, syncBatchID, syncRecordsTx) if err != nil { @@ -461,7 +456,6 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) }) } } - startTime := time.Now() if mergeStatementsBatch.Len() > 0 { mergeResults := normalizeRecordsTx.SendBatch(c.ctx, mergeStatementsBatch) err = mergeResults.Close() @@ -472,14 +466,6 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) log.WithFields(log.Fields{ "flowName": req.FlowJobName, }).Infof("normalized %d records", totalRowsAffected) - if totalRowsAffected > 0 { - totalRowsAtTarget, err := c.getApproxTableCounts(maps.Keys(unchangedToastColsMap)) - if err != nil { - return nil, err - } - metrics.LogNormalizeMetrics(c.ctx, req.FlowJobName, int64(totalRowsAffected), - time.Since(startTime), totalRowsAtTarget) - } // updating metadata with new normalizeBatchID err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID, normalizeRecordsTx) diff --git a/flow/connectors/postgres/postgres_cdc_test.go b/flow/connectors/postgres/postgres_cdc_test.go deleted file mode 100644 index 439cb8d2c6..0000000000 --- a/flow/connectors/postgres/postgres_cdc_test.go +++ /dev/null @@ -1,832 +0,0 @@ -package connpostgres - -import ( - "context" - "fmt" - "math/rand" - "testing" - "time" - - "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/generated/protos" - "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/model/qvalue" - "github.com/jackc/pgx/v5" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -type PostgresCDCTestSuite struct { - suite.Suite - connector *PostgresConnector -} - -func (suite *PostgresCDCTestSuite) failTestError(err error) { - if err != nil { - suite.FailNow(err.Error()) - } -} - -func (suite *PostgresCDCTestSuite) dropTable(tableName string) { - _, err := suite.connector.pool.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) insertSimpleRecords(srcTableName string) { - _, err := suite.connector.pool.Exec(context.Background(), - fmt.Sprintf("INSERT INTO %s(id, name) VALUES (2, 'quick'), (4, 'brown'), (8, 'fox')", srcTableName)) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) validateInsertedSimpleRecords(records []model.Record, srcTableName string, - dstTableName string) { - suite.Equal(3, len(records)) - model.NewRecordItemWithData([]string{"id", "name"}, - []*qvalue.QValue{ - {Kind: qvalue.QValueKindInt32, Value: int32(2)}, - {Kind: qvalue.QValueKindString, Value: "quick"}}) - matchData := []*model.RecordItems{ - model.NewRecordItemWithData([]string{"id", "name"}, - []*qvalue.QValue{ - {Kind: qvalue.QValueKindInt32, Value: int32(2)}, - {Kind: qvalue.QValueKindString, Value: "quick"}}), - model.NewRecordItemWithData([]string{"id", "name"}, - []*qvalue.QValue{ - {Kind: qvalue.QValueKindInt32, Value: int32(4)}, - {Kind: qvalue.QValueKindString, Value: "brown"}}), - model.NewRecordItemWithData([]string{"id", "name"}, - []*qvalue.QValue{ - {Kind: qvalue.QValueKindInt32, Value: int32(8)}, - {Kind: qvalue.QValueKindString, Value: "fox"}}), - } - for idx, record := range records { - suite.IsType(&model.InsertRecord{}, record) - insertRecord := record.(*model.InsertRecord) - suite.Equal(srcTableName, insertRecord.SourceTableName) - suite.Equal(dstTableName, insertRecord.DestinationTableName) - suite.Equal(matchData[idx], insertRecord.Items) - } -} - -func (suite *PostgresCDCTestSuite) mutateSimpleRecords(srcTableName string) { - mutateRecordsTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) - defer func() { - err := mutateRecordsTx.Rollback(context.Background()) - if err != pgx.ErrTxClosed { - suite.failTestError(err) - } - }() - - _, err = mutateRecordsTx.Exec(context.Background(), - fmt.Sprintf("UPDATE %s SET name = 'slow' WHERE id = 2", srcTableName)) - suite.failTestError(err) - _, err = mutateRecordsTx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE id = 8", srcTableName)) - suite.failTestError(err) - err = mutateRecordsTx.Commit(context.Background()) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) validateSimpleMutatedRecords(records []model.Record, srcTableName string, - dstTableName string) { - suite.Equal(2, len(records)) - - suite.IsType(&model.UpdateRecord{}, records[0]) - updateRecord := records[0].(*model.UpdateRecord) - suite.Equal(srcTableName, updateRecord.SourceTableName) - suite.Equal(dstTableName, updateRecord.DestinationTableName) - suite.Equal(model.NewRecordItemWithData([]string{}, []*qvalue.QValue{}), updateRecord.OldItems) - - items := model.NewRecordItemWithData([]string{"id", "name"}, - []*qvalue.QValue{ - {Kind: qvalue.QValueKindInt32, Value: int32(2)}, - {Kind: qvalue.QValueKindString, Value: "slow"}}) - suite.Equal(items, updateRecord.NewItems) - - suite.IsType(&model.DeleteRecord{}, records[1]) - deleteRecord := records[1].(*model.DeleteRecord) - suite.Equal(srcTableName, deleteRecord.SourceTableName) - suite.Equal(dstTableName, deleteRecord.DestinationTableName) - items = model.NewRecordItemWithData([]string{"id", "name"}, - []*qvalue.QValue{ - {Kind: qvalue.QValueKindInt32, Value: int32(8)}, - {Kind: qvalue.QValueKindInvalid, Value: nil}}) - suite.Equal(items, deleteRecord.Items) -} - -func (suite *PostgresCDCTestSuite) randBytea(n int) []byte { - b := make([]byte, n) - //nolint:gosec - _, err := rand.Read(b) - suite.failTestError(err) - return b -} - -func (suite *PostgresCDCTestSuite) randString(n int) string { - const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - - b := make([]byte, n) - for i := range b { - //nolint:gosec - b[i] = letterBytes[rand.Intn(len(letterBytes))] - } - return string(b) -} - -func (suite *PostgresCDCTestSuite) insertToastRecords(srcTableName string) { - insertRecordsTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) - defer func() { - err := insertRecordsTx.Rollback(context.Background()) - if err != pgx.ErrTxClosed { - suite.failTestError(err) - } - }() - - for i := 0; i < 4; i++ { - _, err := insertRecordsTx.Exec(context.Background(), - fmt.Sprintf("INSERT INTO %s(n_t, lz4_t, n_b, lz4_b) VALUES ($1, $2, $3, $4)", srcTableName), - suite.randString(32768), suite.randString(32768), suite.randBytea(32768), suite.randBytea(32768)) - suite.failTestError(err) - } - - err = insertRecordsTx.Commit(context.Background()) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) validateInsertedToastRecords(records []model.Record, srcTableName string, - dstTableName string) { - suite.Equal(4, len(records)) - for idx, record := range records { - suite.IsType(&model.InsertRecord{}, record) - insertRecord := record.(*model.InsertRecord) - suite.Equal(srcTableName, insertRecord.SourceTableName) - suite.Equal(dstTableName, insertRecord.DestinationTableName) - suite.Equal(5, insertRecord.Items.Len()) - - idVal, err := insertRecord.Items.GetValueByColName("id") - suite.NoError(err, "Error fetching id") - - n_tVal, err := insertRecord.Items.GetValueByColName("n_t") - suite.NoError(err, "Error fetching n_t") - - lz4_tVal, err := insertRecord.Items.GetValueByColName("lz4_t") - suite.NoError(err, "Error fetching lz4_t") - - n_bVal, err := insertRecord.Items.GetValueByColName("n_b") - suite.NoError(err, "Error fetching n_b") - - lz4_bVal, err := insertRecord.Items.GetValueByColName("lz4_b") - suite.NoError(err, "Error fetching lz4_b") - - // Perform the actual value checks - suite.Equal(int32(idx+1), idVal.Value.(int32)) - suite.Equal(32768, len(n_tVal.Value.(string))) - suite.Equal(32768, len(lz4_tVal.Value.(string))) - suite.Equal(32768, len(n_bVal.Value.([]byte))) - suite.Equal(32768, len(lz4_bVal.Value.([]byte))) - } -} - -func (suite *PostgresCDCTestSuite) mutateToastRecords(srcTableName string) { - mutateRecordsTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) - defer func() { - err := mutateRecordsTx.Rollback(context.Background()) - if err != pgx.ErrTxClosed { - suite.failTestError(err) - } - }() - - _, err = mutateRecordsTx.Exec(context.Background(), fmt.Sprintf("UPDATE %s SET n_t = $1 WHERE id = 1", - srcTableName), - suite.randString(65536)) - suite.failTestError(err) - _, err = mutateRecordsTx.Exec(context.Background(), - fmt.Sprintf("UPDATE %s SET lz4_t = $1, n_b = $2, lz4_b = $3 WHERE id = 3", srcTableName), - suite.randString(65536), suite.randBytea(65536), suite.randBytea(65536)) - suite.failTestError(err) - _, err = mutateRecordsTx.Exec(context.Background(), - fmt.Sprintf("UPDATE %s SET n_t = $1, lz4_t = $2, n_b = $3, lz4_b = $4 WHERE id = 4", srcTableName), - suite.randString(65536), suite.randString(65536), suite.randBytea(65536), suite.randBytea(65536)) - suite.failTestError(err) - _, err = mutateRecordsTx.Exec(context.Background(), - fmt.Sprintf("DELETE FROM %s WHERE id = 3", srcTableName)) - suite.failTestError(err) - - err = mutateRecordsTx.Commit(context.Background()) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.Record, srcTableName string, - dstTableName string) { - suite.Equal(4, len(records)) - - suite.IsType(&model.UpdateRecord{}, records[0]) - updateRecord := records[0].(*model.UpdateRecord) - suite.Equal(srcTableName, updateRecord.SourceTableName) - suite.Equal(dstTableName, updateRecord.DestinationTableName) - items := updateRecord.NewItems - suite.Equal(2, items.Len()) - v, err := items.GetValueByColName("id") - suite.NoError(err, "Error fetching id") - suite.Equal(int32(1), v.Value.(int32)) - v, err = items.GetValueByColName("n_t") - suite.NoError(err, "Error fetching n_t") - suite.Equal(qvalue.QValueKindString, v.Kind) - suite.Equal(65536, len(v.Value.(string))) - suite.Equal(3, len(updateRecord.UnchangedToastColumns)) - suite.Contains(updateRecord.UnchangedToastColumns, "lz4_t") - suite.Contains(updateRecord.UnchangedToastColumns, "n_b") - suite.Contains(updateRecord.UnchangedToastColumns, "lz4_b") - suite.IsType(&model.UpdateRecord{}, records[1]) - updateRecord = records[1].(*model.UpdateRecord) - suite.Equal(srcTableName, updateRecord.SourceTableName) - suite.Equal(dstTableName, updateRecord.DestinationTableName) - - items = updateRecord.NewItems - suite.Equal(4, items.Len()) - v = items.GetColumnValue("id") - suite.Equal(qvalue.QValueKindInt32, v.Kind) - suite.Equal(int32(3), v.Value.(int32)) - v = items.GetColumnValue("lz4_t") - suite.Equal(qvalue.QValueKindString, v.Kind) - suite.Equal(65536, len(v.Value.(string))) - v = items.GetColumnValue("n_b") - suite.Equal(qvalue.QValueKindBytes, v.Kind) - suite.Equal(65536, len(v.Value.([]byte))) - v = items.GetColumnValue("lz4_b") - suite.Equal(qvalue.QValueKindBytes, v.Kind) - suite.Equal(65536, len(v.Value.([]byte))) - suite.Equal(1, len(updateRecord.UnchangedToastColumns)) - suite.Contains(updateRecord.UnchangedToastColumns, "n_t") - // Test case for records[2] - suite.IsType(&model.UpdateRecord{}, records[2]) - updateRecord = records[2].(*model.UpdateRecord) - suite.Equal(srcTableName, updateRecord.SourceTableName) - suite.Equal(dstTableName, updateRecord.DestinationTableName) - - items = updateRecord.NewItems - suite.Equal(5, items.Len()) - v = items.GetColumnValue("id") - suite.Equal(int32(4), v.Value.(int32)) - suite.Equal(qvalue.QValueKindString, items.GetColumnValue("n_t").Kind) - suite.Equal(65536, len(items.GetColumnValue("n_t").Value.(string))) - suite.Equal(qvalue.QValueKindString, items.GetColumnValue("lz4_t").Kind) - suite.Equal(65536, len(items.GetColumnValue("lz4_t").Value.(string))) - suite.Equal(qvalue.QValueKindBytes, items.GetColumnValue("n_b").Kind) - suite.Equal(65536, len(items.GetColumnValue("n_b").Value.([]byte))) - suite.Equal(qvalue.QValueKindBytes, items.GetColumnValue("lz4_b").Kind) - suite.Equal(65536, len(items.GetColumnValue("lz4_b").Value.([]byte))) - suite.Equal(0, len(updateRecord.UnchangedToastColumns)) - - // Test case for records[3] - suite.IsType(&model.DeleteRecord{}, records[3]) - deleteRecord := records[3].(*model.DeleteRecord) - suite.Equal(srcTableName, deleteRecord.SourceTableName) - suite.Equal(dstTableName, deleteRecord.DestinationTableName) - items = deleteRecord.Items - suite.Equal(5, items.Len()) - suite.Equal(int32(3), items.GetColumnValue("id").Value.(int32)) - suite.Equal(qvalue.QValueKindInvalid, items.GetColumnValue("n_t").Kind) - suite.Nil(items.GetColumnValue("n_t").Value) - suite.Equal(qvalue.QValueKindInvalid, items.GetColumnValue("lz4_t").Kind) - suite.Nil(items.GetColumnValue("lz4_t").Value) - suite.Equal(qvalue.QValueKindInvalid, items.GetColumnValue("n_b").Kind) - suite.Nil(items.GetColumnValue("n_b").Value) - suite.Equal(qvalue.QValueKindInvalid, items.GetColumnValue("lz4_b").Kind) - suite.Nil(items.GetColumnValue("lz4_b").Value) -} - -func (suite *PostgresCDCTestSuite) SetupSuite() { - var err error - suite.connector, err = NewPostgresConnector(context.Background(), &protos.PostgresConfig{ - Host: "localhost", - Port: 7132, - User: "postgres", - Password: "postgres", - Database: "postgres", - }) - suite.failTestError(err) - - setupTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) - defer func() { - err := setupTx.Rollback(context.Background()) - if err != pgx.ErrTxClosed { - suite.failTestError(err) - } - }() - _, err = setupTx.Exec(context.Background(), "DROP SCHEMA IF EXISTS pgpeer_test CASCADE") - suite.failTestError(err) - _, err = setupTx.Exec(context.Background(), "CREATE SCHEMA pgpeer_test") - suite.failTestError(err) - err = setupTx.Commit(context.Background()) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) TearDownSuite() { - teardownTx, err := suite.connector.pool.Begin(context.Background()) - suite.failTestError(err) - defer func() { - err := teardownTx.Rollback(context.Background()) - if err != pgx.ErrTxClosed { - suite.failTestError(err) - } - }() - _, err = teardownTx.Exec(context.Background(), "DROP SCHEMA IF EXISTS pgpeer_test CASCADE") - suite.failTestError(err) - err = teardownTx.Commit(context.Background()) - suite.failTestError(err) - - suite.True(suite.connector.ConnectionActive()) - err = suite.connector.Close() - suite.failTestError(err) - suite.False(suite.connector.ConnectionActive()) -} - -func (suite *PostgresCDCTestSuite) TestParseSchemaTable() { - schemaTest1, err := utils.ParseSchemaTable("schema") - suite.Nil(schemaTest1) - suite.NotNil(err) - - schemaTest2, err := utils.ParseSchemaTable("schema.table") - suite.Equal(&utils.SchemaTable{ - Schema: "schema", - Table: "table", - }, schemaTest2) - suite.Equal("\"schema\".\"table\"", schemaTest2.String()) - suite.Nil(err) - - schemaTest3, err := utils.ParseSchemaTable("database.schema.table") - suite.Nil(schemaTest3) - suite.NotNil(err) -} - -func (suite *PostgresCDCTestSuite) TestErrorForInvalidConfig() { - connector, err := NewPostgresConnector(context.Background(), &protos.PostgresConfig{ - Host: "fakehost", - Port: 0, - User: "fakeuser", - Password: "fakepassword", - Database: "fakedatabase", - }) - suite.Nil(connector) - suite.NotNil(err) -} - -// intended to test how activities react to a table that does not exist. -func (suite *PostgresCDCTestSuite) TestErrorForTableNotExist() { - nonExistentFlowName := "non_existent_table_testing" - nonExistentFlowSrcTableName := "pgpeer_test.non_existent_table" - nonExistentFlowDstTableName := "non_existent_table_dst" - - ensurePullabilityOutput, err := suite.connector.EnsurePullability(&protos.EnsurePullabilityBatchInput{ - FlowJobName: nonExistentFlowName, - SourceTableIdentifiers: []string{nonExistentFlowSrcTableName}, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.Nil(ensurePullabilityOutput) - suite.Errorf(err, "error getting relation ID for table %s: no rows in result set", nonExistentFlowSrcTableName) - - tableNameMapping := map[string]string{ - nonExistentFlowSrcTableName: nonExistentFlowDstTableName, - } - relationMessageMapping := make(model.RelationMessageMapping) - - getTblSchemaInput := &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{nonExistentFlowSrcTableName}, - PeerConnectionConfig: nil, - } - - tableSchema, err := suite.connector.GetTableSchema(getTblSchemaInput) - suite.Errorf(err, "error getting relation ID for table %s: no rows in result set", nonExistentFlowSrcTableName) - suite.Nil(tableSchema) - tableNameSchemaMapping := make(map[string]*protos.TableSchema) - tableNameSchemaMapping[nonExistentFlowDstTableName] = &protos.TableSchema{ - TableIdentifier: nonExistentFlowSrcTableName, - Columns: map[string]string{ - "id": string(qvalue.QValueKindInt32), - "name": string(qvalue.QValueKindString), - }, - PrimaryKeyColumns: []string{"id"}, - } - - err = suite.connector.PullFlowCleanup(nonExistentFlowName) - suite.Nil(err) - - // creating table and the replication slots for it, and dropping before pull records. - _, err = suite.connector.pool.Exec(context.Background(), - fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s(id INT PRIMARY KEY, name TEXT)", nonExistentFlowSrcTableName)) - suite.failTestError(err) - ensurePullabilityOutput, err = suite.connector.EnsurePullability(&protos.EnsurePullabilityBatchInput{ - FlowJobName: nonExistentFlowName, - SourceTableIdentifiers: []string{nonExistentFlowSrcTableName}, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - tableRelID := ensurePullabilityOutput.TableIdentifierMapping[nonExistentFlowSrcTableName]. - GetPostgresTableIdentifier().RelId - relIDTableNameMapping := map[uint32]string{ - tableRelID: nonExistentFlowSrcTableName, - } - err = suite.connector.SetupReplication(nil, &protos.SetupReplicationInput{ - FlowJobName: nonExistentFlowName, - TableNameMapping: tableNameMapping, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - suite.dropTable(nonExistentFlowSrcTableName) - recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: nonExistentFlowName, - LastSyncState: nil, - IdleTimeout: 5 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.Nil(recordsWithSchemaDelta) - suite.Errorf( - err, - "error while closing statement batch: ERROR: relation \"%s\" does not exist (SQLSTATE 42P01)", - nonExistentFlowSrcTableName) - - err = suite.connector.PullFlowCleanup(nonExistentFlowName) - suite.failTestError(err) -} - -func (suite *PostgresCDCTestSuite) TestSimpleHappyFlow() { - simpleHappyFlowName := "simple_happy_flow_testing_flow" - simpleHappyFlowSrcTableName := "pgpeer_test.simple_table" - simpleHappyFlowDstTableName := "simple_table_dst" - - _, err := suite.connector.pool.Exec(context.Background(), - fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s(id INT PRIMARY KEY, name TEXT)", simpleHappyFlowSrcTableName)) - suite.failTestError(err) - - ensurePullabilityOutput, err := suite.connector.EnsurePullability(&protos.EnsurePullabilityBatchInput{ - FlowJobName: simpleHappyFlowName, - SourceTableIdentifiers: []string{simpleHappyFlowSrcTableName}, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - tableRelID := ensurePullabilityOutput.TableIdentifierMapping[simpleHappyFlowSrcTableName]. - GetPostgresTableIdentifier().RelId - - relIDTableNameMapping := map[uint32]string{ - tableRelID: simpleHappyFlowSrcTableName, - } - tableNameMapping := map[string]string{ - simpleHappyFlowSrcTableName: simpleHappyFlowDstTableName, - } - relationMessageMapping := make(model.RelationMessageMapping) - - err = suite.connector.SetupReplication(nil, &protos.SetupReplicationInput{ - FlowJobName: simpleHappyFlowName, - TableNameMapping: tableNameMapping, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - - tableNameSchemaMapping := make(map[string]*protos.TableSchema) - - getTblSchemaInput := &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{simpleHappyFlowSrcTableName}, - PeerConnectionConfig: nil, - } - tableNameSchema, err := suite.connector.GetTableSchema(getTblSchemaInput) - suite.failTestError(err) - suite.Equal(&protos.GetTableSchemaBatchOutput{ - TableNameSchemaMapping: map[string]*protos.TableSchema{ - simpleHappyFlowSrcTableName: { - TableIdentifier: simpleHappyFlowSrcTableName, - Columns: map[string]string{ - "id": string(qvalue.QValueKindInt32), - "name": string(qvalue.QValueKindString), - }, - PrimaryKeyColumns: []string{"id"}, - }, - }}, tableNameSchema) - tableNameSchemaMapping[simpleHappyFlowDstTableName] = - tableNameSchema.TableNameSchemaMapping[simpleHappyFlowSrcTableName] - - // pulling with no recordsWithSchemaDelta. - recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: simpleHappyFlowName, - LastSyncState: nil, - IdleTimeout: 5 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - suite.Equal(0, len(recordsWithSchemaDelta.RecordBatch.Records)) - suite.Nil(recordsWithSchemaDelta.TableSchemaDeltas) - suite.Equal(int64(0), recordsWithSchemaDelta.RecordBatch.FirstCheckPointID) - suite.Equal(int64(0), recordsWithSchemaDelta.RecordBatch.LastCheckPointID) - relationMessageMapping = recordsWithSchemaDelta.RelationMessageMapping - - // pulling after inserting records. - suite.insertSimpleRecords(simpleHappyFlowSrcTableName) - recordsWithSchemaDelta, err = suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: simpleHappyFlowName, - LastSyncState: nil, - IdleTimeout: 5 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - suite.Nil(recordsWithSchemaDelta.TableSchemaDeltas) - suite.validateInsertedSimpleRecords(recordsWithSchemaDelta.RecordBatch.Records, - simpleHappyFlowSrcTableName, simpleHappyFlowDstTableName) - suite.Greater(recordsWithSchemaDelta.RecordBatch.FirstCheckPointID, int64(0)) - suite.GreaterOrEqual(recordsWithSchemaDelta.RecordBatch.LastCheckPointID, - recordsWithSchemaDelta.RecordBatch.FirstCheckPointID) - currentCheckPointID := recordsWithSchemaDelta.RecordBatch.LastCheckPointID - relationMessageMapping = recordsWithSchemaDelta.RelationMessageMapping - - // pulling after mutating records. - suite.mutateSimpleRecords(simpleHappyFlowSrcTableName) - recordsWithSchemaDelta, err = suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: simpleHappyFlowName, - LastSyncState: &protos.LastSyncState{ - Checkpoint: recordsWithSchemaDelta.RecordBatch.LastCheckPointID, - LastSyncedAt: nil, - }, - IdleTimeout: 5 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - suite.Nil(recordsWithSchemaDelta.TableSchemaDeltas) - suite.validateSimpleMutatedRecords(recordsWithSchemaDelta.RecordBatch.Records, - simpleHappyFlowSrcTableName, simpleHappyFlowDstTableName) - suite.GreaterOrEqual(recordsWithSchemaDelta.RecordBatch.FirstCheckPointID, currentCheckPointID) - suite.GreaterOrEqual(recordsWithSchemaDelta.RecordBatch.LastCheckPointID, - recordsWithSchemaDelta.RecordBatch.FirstCheckPointID) - - err = suite.connector.PullFlowCleanup(simpleHappyFlowName) - suite.failTestError(err) - - suite.dropTable(simpleHappyFlowSrcTableName) -} - -func (suite *PostgresCDCTestSuite) TestAllTypesHappyFlow() { - allTypesHappyFlowName := "all_types_happy_flow_testing" - allTypesHappyFlowSrcTableName := "pgpeer_test.all_types_table" - allTypesHappyFlowDstTableName := "all_types_table_dst" - - _, err := suite.connector.pool.Exec(context.Background(), - fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s(id BIGINT PRIMARY KEY, - c1 BIGINT, c2 BIT, c3 VARBIT, c4 BOOLEAN, c6 BYTEA, c7 CHARACTER, c8 VARCHAR, - c9 CIDR, c11 DATE, c12 FLOAT, c13 DOUBLE PRECISION, c14 INET, c15 INTEGER, - c16 INTERVAL, c17 JSON, c18 JSONB, c21 MACADDR, c22 MONEY, c23 DECIMAL, c24 OID, c28 REAL, - c29 SMALLINT, c30 SMALLSERIAL, c31 SERIAL, c32 TEXT, c33 TIMESTAMP, c34 TIMESTAMPTZ, - c35 TIME, c36 TIMETZ, c37 TSQUERY, c38 TSVECTOR, c39 TXID_SNAPSHOT, c40 UUID, c41 XML)`, - allTypesHappyFlowSrcTableName)) - suite.failTestError(err) - - ensurePullabilityOutput, err := suite.connector.EnsurePullability(&protos.EnsurePullabilityBatchInput{ - FlowJobName: allTypesHappyFlowName, - SourceTableIdentifiers: []string{allTypesHappyFlowSrcTableName}, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - tableRelID := ensurePullabilityOutput.TableIdentifierMapping[allTypesHappyFlowSrcTableName]. - GetPostgresTableIdentifier().RelId - relationMessageMapping := make(model.RelationMessageMapping) - - relIDTableNameMapping := map[uint32]string{ - tableRelID: allTypesHappyFlowSrcTableName, - } - tableNameMapping := map[string]string{ - allTypesHappyFlowSrcTableName: allTypesHappyFlowDstTableName, - } - err = suite.connector.SetupReplication(nil, &protos.SetupReplicationInput{ - FlowJobName: allTypesHappyFlowName, - TableNameMapping: tableNameMapping, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - - tableNameSchemaMapping := make(map[string]*protos.TableSchema) - getTblSchemaInput := &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{allTypesHappyFlowSrcTableName}, - PeerConnectionConfig: nil, - } - tableNameSchema, err := suite.connector.GetTableSchema(getTblSchemaInput) - suite.failTestError(err) - suite.Equal(&protos.GetTableSchemaBatchOutput{ - TableNameSchemaMapping: map[string]*protos.TableSchema{ - allTypesHappyFlowSrcTableName: { - TableIdentifier: allTypesHappyFlowSrcTableName, - Columns: map[string]string{ - "id": string(qvalue.QValueKindInt64), - "c1": string(qvalue.QValueKindInt64), - "c2": string(qvalue.QValueKindBit), - "c3": string(qvalue.QValueKindBit), - "c4": string(qvalue.QValueKindBoolean), - "c6": string(qvalue.QValueKindBytes), - "c7": string(qvalue.QValueKindString), - "c8": string(qvalue.QValueKindString), - "c9": string(qvalue.QValueKindString), - "c11": string(qvalue.QValueKindDate), - "c12": string(qvalue.QValueKindFloat64), - "c13": string(qvalue.QValueKindFloat64), - "c14": string(qvalue.QValueKindString), - "c15": string(qvalue.QValueKindInt32), - "c16": string(qvalue.QValueKindString), - "c17": string(qvalue.QValueKindJSON), - "c18": string(qvalue.QValueKindJSON), - "c21": string(qvalue.QValueKindString), - "c22": string(qvalue.QValueKindString), - "c23": string(qvalue.QValueKindNumeric), - "c24": string(qvalue.QValueKindString), - "c28": string(qvalue.QValueKindFloat32), - "c29": string(qvalue.QValueKindInt16), - "c30": string(qvalue.QValueKindInt16), - "c31": string(qvalue.QValueKindInt32), - "c32": string(qvalue.QValueKindString), - "c33": string(qvalue.QValueKindTimestamp), - "c34": string(qvalue.QValueKindTimestampTZ), - "c35": string(qvalue.QValueKindTime), - "c36": string(qvalue.QValueKindTimeTZ), - "c37": string(qvalue.QValueKindString), - "c38": string(qvalue.QValueKindString), - "c39": string(qvalue.QValueKindString), - "c40": string(qvalue.QValueKindUUID), - "c41": string(qvalue.QValueKindString), - }, - PrimaryKeyColumns: []string{"id"}, - }, - }, - }, tableNameSchema) - tableNameSchemaMapping[allTypesHappyFlowDstTableName] = - tableNameSchema.TableNameSchemaMapping[allTypesHappyFlowSrcTableName] - - _, err = suite.connector.pool.Exec(context.Background(), - fmt.Sprintf(`INSERT INTO %s SELECT 2, 2, b'1', b'101', - true, $1, 's', 'test', '1.1.10.2'::cidr, - CURRENT_DATE, 1.23, 1.234, '192.168.1.5'::inet, 1, - '5 years 2 months 29 days 1 minute 2 seconds 200 milliseconds 20000 microseconds'::interval, - '{"sai":1}'::json, '{"sai":1}'::jsonb, '08:00:2b:01:02:03'::macaddr, - 1.2, 1.23, 4::oid, 1.23, 1, 1, 1, 'test', now(), now(), now()::time, now()::timetz, - 'fat & rat'::tsquery, 'a fat cat sat on a mat and ate a fat rat'::tsvector, - txid_current_snapshot(), '66073c38-b8df-4bdb-bbca-1c97596b8940'::uuid, xmlcomment('hello')`, - allTypesHappyFlowSrcTableName), - suite.randBytea(32)) - suite.failTestError(err) - records, err := suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: allTypesHappyFlowName, - LastSyncState: nil, - IdleTimeout: 5 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - require.Equal(suite.T(), 1, len(records.RecordBatch.Records)) - - items := records.RecordBatch.Records[0].GetItems() - numCols := items.Len() - if numCols != 35 { - jsonStr, err := items.ToJSON() - suite.failTestError(err) - fmt.Printf("record batch json: %s\n", jsonStr) - suite.FailNow("expected 35 columns, got %d", numCols) - } - - err = suite.connector.PullFlowCleanup(allTypesHappyFlowName) - suite.failTestError(err) - - suite.dropTable(allTypesHappyFlowSrcTableName) -} - -func (suite *PostgresCDCTestSuite) TestToastHappyFlow() { - toastHappyFlowName := "toast_happy_flow_testing" - toastHappyFlowSrcTableName := "pgpeer_test.toast_table" - toastHappyFlowDstTableName := "toast_table_dst" - - _, err := suite.connector.pool.Exec(context.Background(), - fmt.Sprintf(`CREATE TABLE %s(id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - n_t TEXT, lz4_t TEXT COMPRESSION LZ4, n_b BYTEA, lz4_b BYTEA COMPRESSION LZ4)`, toastHappyFlowSrcTableName)) - suite.failTestError(err) - - ensurePullabilityOutput, err := suite.connector.EnsurePullability(&protos.EnsurePullabilityBatchInput{ - FlowJobName: toastHappyFlowName, - SourceTableIdentifiers: []string{toastHappyFlowSrcTableName}, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - tableRelID := ensurePullabilityOutput.TableIdentifierMapping[toastHappyFlowSrcTableName]. - GetPostgresTableIdentifier().RelId - - relIDTableNameMapping := map[uint32]string{ - tableRelID: toastHappyFlowSrcTableName, - } - tableNameMapping := map[string]string{ - toastHappyFlowSrcTableName: toastHappyFlowDstTableName, - } - relationMessageMapping := make(model.RelationMessageMapping) - - err = suite.connector.SetupReplication(nil, &protos.SetupReplicationInput{ - FlowJobName: toastHappyFlowName, - TableNameMapping: tableNameMapping, - PeerConnectionConfig: nil, // not used by the connector itself. - }) - suite.failTestError(err) - - tableNameSchemaMapping := make(map[string]*protos.TableSchema) - getTblSchemaInput := &protos.GetTableSchemaBatchInput{ - TableIdentifiers: []string{toastHappyFlowSrcTableName}, - PeerConnectionConfig: nil, - } - tableNameSchema, err := suite.connector.GetTableSchema(getTblSchemaInput) - suite.failTestError(err) - suite.Equal(&protos.GetTableSchemaBatchOutput{ - TableNameSchemaMapping: map[string]*protos.TableSchema{ - toastHappyFlowSrcTableName: { - TableIdentifier: toastHappyFlowSrcTableName, - Columns: map[string]string{ - "id": string(qvalue.QValueKindInt32), - "n_t": string(qvalue.QValueKindString), - "lz4_t": string(qvalue.QValueKindString), - "n_b": string(qvalue.QValueKindBytes), - "lz4_b": string(qvalue.QValueKindBytes), - }, - PrimaryKeyColumns: []string{"id"}, - }, - }}, tableNameSchema) - tableNameSchemaMapping[toastHappyFlowDstTableName] = - tableNameSchema.TableNameSchemaMapping[toastHappyFlowSrcTableName] - - suite.insertToastRecords(toastHappyFlowSrcTableName) - _, err = suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: toastHappyFlowName, - LastSyncState: nil, - IdleTimeout: 10 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - recordsWithSchemaDelta, err := suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: toastHappyFlowName, - LastSyncState: nil, - IdleTimeout: 10 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - suite.Nil(recordsWithSchemaDelta.TableSchemaDeltas) - suite.validateInsertedToastRecords(recordsWithSchemaDelta.RecordBatch.Records, - toastHappyFlowSrcTableName, toastHappyFlowDstTableName) - suite.Greater(recordsWithSchemaDelta.RecordBatch.FirstCheckPointID, int64(0)) - suite.GreaterOrEqual(recordsWithSchemaDelta.RecordBatch.LastCheckPointID, - recordsWithSchemaDelta.RecordBatch.FirstCheckPointID) - relationMessageMapping = recordsWithSchemaDelta.RelationMessageMapping - - suite.mutateToastRecords(toastHappyFlowSrcTableName) - recordsWithSchemaDelta, err = suite.connector.PullRecords(&model.PullRecordsRequest{ - FlowJobName: toastHappyFlowName, - LastSyncState: &protos.LastSyncState{ - Checkpoint: recordsWithSchemaDelta.RecordBatch.LastCheckPointID, - LastSyncedAt: nil, - }, - IdleTimeout: 10 * time.Second, - MaxBatchSize: 100, - SrcTableIDNameMapping: relIDTableNameMapping, - TableNameMapping: tableNameMapping, - TableNameSchemaMapping: tableNameSchemaMapping, - RelationMessageMapping: relationMessageMapping, - }) - suite.failTestError(err) - suite.validateMutatedToastRecords(recordsWithSchemaDelta.RecordBatch.Records, toastHappyFlowSrcTableName, - toastHappyFlowDstTableName) - - err = suite.connector.PullFlowCleanup(toastHappyFlowName) - suite.failTestError(err) - - suite.dropTable(toastHappyFlowSrcTableName) -} - -func TestPostgresCDCTestSuite(t *testing.T) { - suite.Run(t, new(PostgresCDCTestSuite)) -} diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index e1dd47c314..2287b81158 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -7,7 +7,6 @@ import ( "time" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" partition_utils "github.com/PeerDB-io/peer-flow/connectors/utils/partition" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -426,11 +425,6 @@ func (c *PostgresConnector) PullQRepRecords( return nil, err } - totalRecordsAtSource, err := c.getApproxTableCounts([]string{config.WatermarkTable}) - if err != nil { - return nil, err - } - metrics.LogQRepPullMetrics(c.ctx, config.FlowJobName, int(records.NumRecords), totalRecordsAtSource) return records, nil } @@ -503,11 +497,6 @@ func (c *PostgresConnector) PullQRepRecordStream( return 0, err } - totalRecordsAtSource, err := c.getApproxTableCounts([]string{config.WatermarkTable}) - if err != nil { - return 0, err - } - metrics.LogQRepPullMetrics(c.ctx, config.FlowJobName, numRecords, totalRecordsAtSource) log.WithFields(log.Fields{ "partition": partition.PartitionId, }).Infof("pulled %d records for flow job %s", numRecords, config.FlowJobName) diff --git a/flow/connectors/postgres/qrep_sync_method.go b/flow/connectors/postgres/qrep_sync_method.go index 84b7bb8949..f1e818f397 100644 --- a/flow/connectors/postgres/qrep_sync_method.go +++ b/flow/connectors/postgres/qrep_sync_method.go @@ -7,7 +7,6 @@ import ( "time" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" util "github.com/PeerDB-io/peer-flow/utils" @@ -70,7 +69,6 @@ func (s *QRepStagingTableSync) SyncQRepRecords( // Step 2: Insert records into the destination table. copySource := model.NewQRecordBatchCopyFromSource(stream) - syncRecordsStartTime := time.Now() var numRowsSynced int64 if writeMode == nil || @@ -163,7 +161,11 @@ func (s *QRepStagingTableSync) SyncQRepRecords( } } - metrics.LogQRepSyncMetrics(s.connector.ctx, flowJobName, numRowsSynced, time.Since(syncRecordsStartTime)) + log.WithFields(log.Fields{ + "flowName": flowJobName, + "partitionID": partitionID, + "destinationTable": dstTableName, + }).Infof("pushed %d records to %s", numRowsSynced, dstTableName) // marshal the partition to json using protojson pbytes, err := protojson.Marshal(partition) @@ -171,7 +173,6 @@ func (s *QRepStagingTableSync) SyncQRepRecords( return -1, fmt.Errorf("failed to marshal partition to json: %v", err) } - normalizeRecordsStartTime := time.Now() insertMetadataStmt := fmt.Sprintf( "INSERT INTO %s VALUES ($1, $2, $3, $4, $5);", qRepMetadataTableName, @@ -181,7 +182,7 @@ func (s *QRepStagingTableSync) SyncQRepRecords( "partitionID": partitionID, "destinationTable": dstTableName, }).Infof("Executing transaction inside Qrep sync") - rows, err := tx.Exec( + _, err = tx.Exec( context.Background(), insertMetadataStmt, flowJobName, @@ -199,13 +200,6 @@ func (s *QRepStagingTableSync) SyncQRepRecords( return -1, fmt.Errorf("failed to commit transaction: %v", err) } - totalRecordsAtTarget, err := s.connector.getApproxTableCounts([]string{dstTableName.String()}) - if err != nil { - return -1, fmt.Errorf("failed to get total records at target: %v", err) - } - metrics.LogQRepNormalizeMetrics(s.connector.ctx, flowJobName, rows.RowsAffected(), - time.Since(normalizeRecordsStartTime), totalRecordsAtTarget) - numRowsInserted := copySource.NumRecords() log.WithFields(log.Fields{ "flowName": flowJobName, diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index 51e163f13e..fb6ed77d4b 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -3,11 +3,9 @@ package conns3 import ( "context" "fmt" - "time" metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/aws/aws-sdk-go/service/s3" @@ -136,32 +134,18 @@ func (c *S3Connector) updateLastOffset(jobName string, offset int64) error { } func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { - if len(req.Records.Records) == 0 { - return &model.SyncResponse{ - FirstSyncedCheckPointID: 0, - LastSyncedCheckPointID: 0, - NumRecordsSynced: 0, - }, nil - } - syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err) } syncBatchID = syncBatchID + 1 - lastCP := req.Records.LastCheckPointID tableNameRowsMapping := make(map[string]uint32) - streamRes, err := utils.RecordsToRawTableStream(model.RecordsToStreamRequest{ - Records: req.Records.Records, - TableMapping: tableNameRowsMapping, - CP: 0, - BatchID: syncBatchID, - }) + streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) + streamRes, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) } - firstCP := streamRes.CP recordStream := streamRes.Stream qrepConfig := &protos.QRepConfig{ FlowJobName: req.FlowJobName, @@ -170,14 +154,18 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes partition := &protos.QRepPartition{ PartitionId: fmt.Sprint(syncBatchID), } - startTime := time.Now() - close(recordStream.Records) numRecords, err := c.SyncQRepRecords(qrepConfig, partition, recordStream) if err != nil { return nil, err } + log.Infof("Synced %d records", numRecords) + + lastCheckpoint, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, fmt.Errorf("failed to get last checkpoint: %w", err) + } - err = c.updateLastOffset(req.FlowJobName, lastCP) + err = c.updateLastOffset(req.FlowJobName, lastCheckpoint) if err != nil { log.Errorf("failed to update last offset for s3 cdc: %v", err) return nil, err @@ -187,10 +175,10 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes log.Errorf("%v", err) return nil, err } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime)) + return &model.SyncResponse{ - FirstSyncedCheckPointID: firstCP, - LastSyncedCheckPointID: lastCP, + FirstSyncedCheckPointID: req.Records.GetFirstCheckpoint(), + LastSyncedCheckPointID: lastCheckpoint, NumRecordsSynced: int64(numRecords), TableNameRowsMapping: tableNameRowsMapping, }, nil diff --git a/flow/connectors/snowflake/qrep_avro_sync.go b/flow/connectors/snowflake/qrep_avro_sync.go index 8202d769fc..eac4dda9e9 100644 --- a/flow/connectors/snowflake/qrep_avro_sync.go +++ b/flow/connectors/snowflake/qrep_avro_sync.go @@ -9,7 +9,6 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" avro "github.com/PeerDB-io/peer-flow/connectors/utils/avro" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" util "github.com/PeerDB-io/peer-flow/utils" @@ -148,7 +147,6 @@ func (s *SnowflakeAvroSyncMethod) SyncQRepRecords( stage := s.connector.getStageNameForJob(config.FlowJobName) - putFileStartTime := time.Now() err = s.putFileToStage(localFilePath, stage) if err != nil { return 0, err @@ -157,8 +155,6 @@ func (s *SnowflakeAvroSyncMethod) SyncQRepRecords( "flowName": config.FlowJobName, "partitionID": partition.PartitionId, }).Infof("Put file to stage in Avro sync for snowflake") - metrics.LogQRepSyncMetrics(s.connector.ctx, config.FlowJobName, int64(numRecords), - time.Since(putFileStartTime)) err = s.insertMetadata(partition, config.FlowJobName, startTime) if err != nil { @@ -515,8 +511,10 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( if err != nil { return err } - metrics.LogQRepNormalizeMetrics(s.connector.ctx, flowJobName, rowCount, time.Since(startTime), - totalRowsAtTarget) + log.WithFields(log.Fields{ + "flowName": flowJobName, + }).Infof("merged %d rows into destination table %s, total rows at target: %d", + rowCount, s.dstTableName, totalRowsAtTarget) } else { log.WithFields(log.Fields{ "flowName": flowJobName, @@ -525,7 +523,7 @@ func (s *SnowflakeAvroWriteHandler) HandleUpsertMode( log.WithFields(log.Fields{ "flowName": flowJobName, - }).Infof("merged data from temp table %s into destination table %s", - tempTableName, s.dstTableName) + }).Infof("merged data from temp table %s into destination table %s, time taken %v", + tempTableName, s.dstTableName, time.Since(startTime)) return nil } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 2d96cf314e..3d89d5b5f6 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -11,7 +11,6 @@ import ( "time" "github.com/PeerDB-io/peer-flow/connectors/utils" - "github.com/PeerDB-io/peer-flow/connectors/utils/metrics" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/model/qvalue" @@ -478,16 +477,8 @@ func (c *SnowflakeConnector) ReplayTableSchemaDeltas(flowJobName string, } func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { - if len(req.Records.Records) == 0 { - return &model.SyncResponse{ - FirstSyncedCheckPointID: 0, - LastSyncedCheckPointID: 0, - NumRecordsSynced: 0, - }, nil - } - rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) - log.Printf("pushing %d records to Snowflake table %s", len(req.Records.Records), rawTableIdentifier) + log.Infof("pushing records to Snowflake table %s", rawTableIdentifier) syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) if err != nil { @@ -497,6 +488,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. var res *model.SyncResponse if req.SyncMode == protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO { + log.Infof("sync mode is for flow %s is AVRO", req.FlowJobName) res, err = c.syncRecordsViaAvro(req, rawTableIdentifier, syncBatchID) if err != nil { return nil, err @@ -520,6 +512,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. }() if req.SyncMode == protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT { + log.Infof("sync mode is for flow %s is MULTI_INSERT", req.FlowJobName) res, err = c.syncRecordsViaSQL(req, rawTableIdentifier, syncBatchID, syncRecordsTx) if err != nil { return nil, err @@ -547,9 +540,8 @@ func (c *SnowflakeConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, ra first := true var firstCP int64 = 0 - lastCP := req.Records.LastCheckPointID - for _, record := range req.Records.Records { + for record := range req.Records.GetRecords() { switch typedRecord := record.(type) { case *model.InsertRecord: // json.Marshal converts bytes in Hex automatically to BASE64 string. @@ -622,10 +614,8 @@ func (c *SnowflakeConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, ra // inserting records into raw table. numRecords := len(records) - startTime := time.Now() for begin := 0; begin < numRecords; begin += syncRecordsChunkSize { end := begin + syncRecordsChunkSize - if end > numRecords { end = numRecords } @@ -634,33 +624,33 @@ func (c *SnowflakeConnector) syncRecordsViaSQL(req *model.SyncRecordsRequest, ra return nil, err } } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime)) + + lastCheckpoint, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, err + } return &model.SyncResponse{ FirstSyncedCheckPointID: firstCP, - LastSyncedCheckPointID: lastCP, + LastSyncedCheckPointID: lastCheckpoint, NumRecordsSynced: int64(len(records)), CurrentSyncBatchID: syncBatchID, TableNameRowsMapping: tableNameRowsMapping, }, nil } -func (c *SnowflakeConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest, rawTableIdentifier string, - syncBatchID int64) (*model.SyncResponse, error) { - - lastCP := req.Records.LastCheckPointID +func (c *SnowflakeConnector) syncRecordsViaAvro( + req *model.SyncRecordsRequest, + rawTableIdentifier string, + syncBatchID int64, +) (*model.SyncResponse, error) { tableNameRowsMapping := make(map[string]uint32) - streamRes, err := utils.RecordsToRawTableStream(model.RecordsToStreamRequest{ - Records: req.Records.Records, - TableMapping: tableNameRowsMapping, - CP: 0, - BatchID: syncBatchID, - }) + streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) + streamRes, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) } - firstCP := streamRes.CP - recordStream := streamRes.Stream + qrepConfig := &protos.QRepConfig{ StagingPath: "", FlowJobName: req.FlowJobName, @@ -673,17 +663,20 @@ func (c *SnowflakeConnector) syncRecordsViaAvro(req *model.SyncRecordsRequest, r return nil, err } - startTime := time.Now() - close(recordStream.Records) - numRecords, err := avroSyncer.SyncRecords(destinationTableSchema, recordStream, req.FlowJobName) + numRecords, err := avroSyncer.SyncRecords(destinationTableSchema, streamRes.Stream, req.FlowJobName) if err != nil { return nil, err } - metrics.LogSyncMetrics(c.ctx, req.FlowJobName, int64(numRecords), time.Since(startTime)) + + lastCheckpoint, err := req.Records.GetLastCheckpoint() + if err != nil { + return nil, err + } + return &model.SyncResponse{ - FirstSyncedCheckPointID: firstCP, - LastSyncedCheckPointID: lastCP, - NumRecordsSynced: int64(len(req.Records.Records)), + FirstSyncedCheckPointID: req.Records.GetFirstCheckpoint(), + LastSyncedCheckPointID: lastCheckpoint, + NumRecordsSynced: int64(numRecords), CurrentSyncBatchID: syncBatchID, TableNameRowsMapping: tableNameRowsMapping, }, nil @@ -744,7 +737,6 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest }() var totalRowsAffected int64 = 0 - startTime := time.Now() // execute merge statements per table that uses CTEs to merge data into the normalized table for _, destinationTableName := range destinationTableNames { rowsAffected, err := c.generateAndExecuteMergeStatement( @@ -759,14 +751,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } totalRowsAffected += rowsAffected } - if totalRowsAffected > 0 { - totalRowsAtTarget, err := c.getTableCounts(destinationTableNames) - if err != nil { - return nil, err - } - metrics.LogNormalizeMetrics(c.ctx, req.FlowJobName, totalRowsAffected, time.Since(startTime), - totalRowsAtTarget) - } + // updating metadata with new normalizeBatchID err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID, normalizeRecordsTx) if err != nil { diff --git a/flow/connectors/utils/metrics/metrics.go b/flow/connectors/utils/metrics/metrics.go deleted file mode 100644 index 02f30fd366..0000000000 --- a/flow/connectors/utils/metrics/metrics.go +++ /dev/null @@ -1,132 +0,0 @@ -package metrics - -import ( - "context" - "fmt" - "time" - - "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/shared" - "go.temporal.io/sdk/activity" -) - -func LogPullMetrics( - ctx context.Context, - flowJobName string, - recordBatch *model.RecordBatch, - totalRecordsAtSource int64, - duration time.Duration, -) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - insertRecordsPulledGauge := metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.insert_records_pulled", flowJobName)) - updateRecordsPulledGauge := metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.update_records_pulled", flowJobName)) - deleteRecordsPulledGauge := metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.delete_records_pulled", flowJobName)) - totalRecordsPulledGauge := metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.total_records_pulled", flowJobName)) - totalRecordsAtSourceGauge := metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.records_at_source", flowJobName)) - - insertRecords, updateRecords, deleteRecords := 0, 0, 0 - for _, record := range recordBatch.Records { - switch record.(type) { - case *model.InsertRecord: - insertRecords++ - case *model.UpdateRecord: - updateRecords++ - case *model.DeleteRecord: - deleteRecords++ - } - } - - insertRecordsPulledGauge.Update(float64(insertRecords) / duration.Seconds()) - updateRecordsPulledGauge.Update(float64(updateRecords) / duration.Seconds()) - deleteRecordsPulledGauge.Update(float64(deleteRecords) / duration.Seconds()) - totalRecordsPulledGauge.Update(float64(len(recordBatch.Records)) / duration.Seconds()) - totalRecordsAtSourceGauge.Update(float64(totalRecordsAtSource)) -} - -func LogSyncMetrics(ctx context.Context, flowJobName string, recordsCount int64, duration time.Duration) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - recordsSyncedPerSecondGauge := - metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.records_synced_per_second", flowJobName)) - recordsSyncedPerSecondGauge.Update(float64(recordsCount) / duration.Seconds()) -} - -func LogNormalizeMetrics( - ctx context.Context, - flowJobName string, - recordsCount int64, - duration time.Duration, - totalRecordsAtTarget int64, -) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - recordsNormalizedPerSecondGauge := - metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.records_normalized_per_second", flowJobName)) - totalRecordsAtTargetGauge := - metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.records_at_target", flowJobName)) - - recordsNormalizedPerSecondGauge.Update(float64(recordsCount) / duration.Seconds()) - totalRecordsAtTargetGauge.Update(float64(totalRecordsAtTarget)) -} - -func LogQRepPullMetrics(ctx context.Context, flowJobName string, - numRecords int, totalRecordsAtSource int64) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - totalRecordsPulledGauge := metricsHandler.Gauge(fmt.Sprintf("qrepflow.%s.total_records_pulled", flowJobName)) - totalRecordsAtSourceGauge := metricsHandler.Gauge(fmt.Sprintf("qrepflow.%s.records_at_source", flowJobName)) - - totalRecordsPulledGauge.Update(float64(numRecords)) - totalRecordsAtSourceGauge.Update(float64(totalRecordsAtSource)) -} - -func LogQRepSyncMetrics(ctx context.Context, flowJobName string, recordsCount int64, duration time.Duration) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - recordsSyncedPerSecondGauge := - metricsHandler.Gauge(fmt.Sprintf("qrepflow.%s.records_synced_per_second", flowJobName)) - recordsSyncedPerSecondGauge.Update(float64(recordsCount) / duration.Seconds()) -} - -func LogQRepNormalizeMetrics(ctx context.Context, flowJobName string, - normalizedRecordsCount int64, duration time.Duration, totalRecordsAtTarget int64) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - recordsSyncedPerSecondGauge := - metricsHandler.Gauge(fmt.Sprintf("qrepflow.%s.records_normalized_per_second", flowJobName)) - totalRecordsAtTargetGauge := - metricsHandler.Gauge(fmt.Sprintf("qrepflow.%s.records_at_target", flowJobName)) - - recordsSyncedPerSecondGauge.Update(float64(normalizedRecordsCount) / duration.Seconds()) - totalRecordsAtTargetGauge.Update(float64(totalRecordsAtTarget)) -} - -func LogCDCRawThroughputMetrics(ctx context.Context, flowJobName string, throughput float64) { - if ctx.Value(shared.EnableMetricsKey) != true { - return - } - - metricsHandler := activity.GetMetricsHandler(ctx) - totalThroughputGauge := - metricsHandler.Gauge(fmt.Sprintf("cdcflow.%s.records_throughput", flowJobName)) - totalThroughputGauge.Update(throughput) -} diff --git a/flow/connectors/utils/monitoring/monitoring.go b/flow/connectors/utils/monitoring/monitoring.go index 8c5e9a22e7..ed63d30f74 100644 --- a/flow/connectors/utils/monitoring/monitoring.go +++ b/flow/connectors/utils/monitoring/monitoring.go @@ -67,7 +67,7 @@ func (c *CatalogMirrorMonitor) UpdateLatestLSNAtSourceForCDCFlow(ctx context.Con "UPDATE peerdb_stats.cdc_flows SET latest_lsn_at_source=$1 WHERE flow_name=$2", uint64(latestLSNAtSource), flowJobName) if err != nil { - return fmt.Errorf("error while updating flow in cdc_flows: %w", err) + return fmt.Errorf("[source] error while updating flow in cdc_flows: %w", err) } return nil } @@ -82,7 +82,7 @@ func (c *CatalogMirrorMonitor) UpdateLatestLSNAtTargetForCDCFlow(ctx context.Con "UPDATE peerdb_stats.cdc_flows SET latest_lsn_at_target=$1 WHERE flow_name=$2", uint64(latestLSNAtTarget), flowJobName) if err != nil { - return fmt.Errorf("error while updating flow in cdc_flows: %w", err) + return fmt.Errorf("[target] error while updating flow in cdc_flows: %w", err) } return nil } @@ -104,8 +104,32 @@ func (c *CatalogMirrorMonitor) AddCDCBatchForFlow(ctx context.Context, flowJobNa return nil } -func (c *CatalogMirrorMonitor) UpdateEndTimeForCDCBatch(ctx context.Context, flowJobName string, - batchID int64) error { +// update num records and end-lsn for a cdc batch +func (c *CatalogMirrorMonitor) UpdateNumRowsAndEndLSNForCDCBatch( + ctx context.Context, + flowJobName string, + batchID int64, + numRows uint32, + batchEndLSN pglogrepl.LSN, +) error { + if c == nil || c.catalogConn == nil { + return nil + } + + _, err := c.catalogConn.Exec(ctx, + "UPDATE peerdb_stats.cdc_batches SET rows_in_batch=$1,batch_end_lsn=$2 WHERE flow_name=$3 AND batch_id=$4", + numRows, uint64(batchEndLSN), flowJobName, batchID) + if err != nil { + return fmt.Errorf("error while updating batch in cdc_batch: %w", err) + } + return nil +} + +func (c *CatalogMirrorMonitor) UpdateEndTimeForCDCBatch( + ctx context.Context, + flowJobName string, + batchID int64, +) error { if c == nil || c.catalogConn == nil { return nil } diff --git a/flow/connectors/utils/stream.go b/flow/connectors/utils/stream.go index c16020578a..9359c1565f 100644 --- a/flow/connectors/utils/stream.go +++ b/flow/connectors/utils/stream.go @@ -9,8 +9,8 @@ import ( "github.com/google/uuid" ) -func RecordsToRawTableStream(req model.RecordsToStreamRequest) (*model.RecordsToStreamResponse, error) { - recordStream := model.NewQRecordStream(len(req.Records)) +func RecordsToRawTableStream(req *model.RecordsToStreamRequest) (*model.RecordsToStreamResponse, error) { + recordStream := model.NewQRecordStream(1 << 16) err := recordStream.SetSchema(&model.QRecordSchema{ Fields: []*model.QField{ { @@ -59,131 +59,142 @@ func RecordsToRawTableStream(req model.RecordsToStreamRequest) (*model.RecordsTo return nil, err } - first := true - firstCP := req.CP - for _, record := range req.Records { - var entries [8]qvalue.QValue - switch typedRecord := record.(type) { - case *model.InsertRecord: - // json.Marshal converts bytes in Hex automatically to BASE64 string. - itemsJSON, err := typedRecord.Items.ToJSON() - if err != nil { - return nil, fmt.Errorf("failed to serialize insert record items to JSON: %w", err) - } + go func() { + for record := range req.GetRecords() { + qRecordOrError := recordToQRecordOrError(req.TableMapping, req.BatchID, record) + recordStream.Records <- qRecordOrError + } - // add insert record to the raw table - entries[2] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: typedRecord.DestinationTableName, - } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: itemsJSON, - } - entries[4] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: 0, - } - entries[5] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: "", - } - entries[7] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: "", - } - req.TableMapping[typedRecord.DestinationTableName] += 1 - case *model.UpdateRecord: - newItemsJSON, err := typedRecord.NewItems.ToJSON() - if err != nil { - return nil, fmt.Errorf("failed to serialize update record new items to JSON: %w", err) - } - oldItemsJSON, err := typedRecord.OldItems.ToJSON() - if err != nil { - return nil, fmt.Errorf("failed to serialize update record old items to JSON: %w", err) - } + close(recordStream.Records) + }() - entries[2] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: typedRecord.DestinationTableName, - } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: newItemsJSON, - } - entries[4] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: 1, - } - entries[5] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: oldItemsJSON, - } - entries[7] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: KeysToString(typedRecord.UnchangedToastColumns), - } - req.TableMapping[typedRecord.DestinationTableName] += 1 - case *model.DeleteRecord: - itemsJSON, err := typedRecord.Items.ToJSON() - if err != nil { - return nil, fmt.Errorf("failed to serialize delete record items to JSON: %w", err) - } + return &model.RecordsToStreamResponse{ + Stream: recordStream, + }, nil +} - // append delete record to the raw table - entries[2] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: typedRecord.DestinationTableName, - } - entries[3] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: itemsJSON, +func recordToQRecordOrError(tableMapping map[string]uint32, batchID int64, record model.Record) *model.QRecordOrError { + var entries [8]qvalue.QValue + switch typedRecord := record.(type) { + case *model.InsertRecord: + // json.Marshal converts bytes in Hex automatically to BASE64 string. + itemsJSON, err := typedRecord.Items.ToJSON() + if err != nil { + return &model.QRecordOrError{ + Err: fmt.Errorf("failed to serialize insert record items to JSON: %w", err), } - entries[4] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: 2, - } - entries[5] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: itemsJSON, + } + + // add insert record to the raw table + entries[2] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: typedRecord.DestinationTableName, + } + entries[3] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: itemsJSON, + } + entries[4] = qvalue.QValue{ + Kind: qvalue.QValueKindInt64, + Value: 0, + } + entries[5] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: "", + } + entries[7] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: "", + } + tableMapping[typedRecord.DestinationTableName] += 1 + case *model.UpdateRecord: + newItemsJSON, err := typedRecord.NewItems.ToJSON() + if err != nil { + return &model.QRecordOrError{ + Err: fmt.Errorf("failed to serialize update record new items to JSON: %w", err), } - entries[7] = qvalue.QValue{ - Kind: qvalue.QValueKindString, - Value: "", + } + oldItemsJSON, err := typedRecord.OldItems.ToJSON() + if err != nil { + return &model.QRecordOrError{ + Err: fmt.Errorf("failed to serialize update record old items to JSON: %w", err), } - req.TableMapping[typedRecord.DestinationTableName] += 1 - default: - return nil, fmt.Errorf("record type %T not supported", typedRecord) } - if first { - firstCP = record.GetCheckPointID() - first = false + entries[2] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: typedRecord.DestinationTableName, } - - entries[0] = qvalue.QValue{ + entries[3] = qvalue.QValue{ Kind: qvalue.QValueKindString, - Value: uuid.New().String(), + Value: newItemsJSON, } - entries[1] = qvalue.QValue{ + entries[4] = qvalue.QValue{ Kind: qvalue.QValueKindInt64, - Value: time.Now().UnixNano(), + Value: 1, } - entries[6] = qvalue.QValue{ - Kind: qvalue.QValueKindInt64, - Value: req.BatchID, + entries[5] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: oldItemsJSON, + } + entries[7] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: KeysToString(typedRecord.UnchangedToastColumns), + } + tableMapping[typedRecord.DestinationTableName] += 1 + case *model.DeleteRecord: + itemsJSON, err := typedRecord.Items.ToJSON() + if err != nil { + return &model.QRecordOrError{ + Err: fmt.Errorf("failed to serialize delete record items to JSON: %w", err), + } } - recordStream.Records <- &model.QRecordOrError{ - Record: &model.QRecord{ - NumEntries: 8, - Entries: entries[:], - }, + // append delete record to the raw table + entries[2] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: typedRecord.DestinationTableName, + } + entries[3] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: itemsJSON, + } + entries[4] = qvalue.QValue{ + Kind: qvalue.QValueKindInt64, + Value: 2, + } + entries[5] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: itemsJSON, + } + entries[7] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: "", + } + tableMapping[typedRecord.DestinationTableName] += 1 + default: + return &model.QRecordOrError{ + Err: fmt.Errorf("unknown record type: %T", typedRecord), } } - return &model.RecordsToStreamResponse{ - Stream: recordStream, - CP: firstCP, - }, nil + entries[0] = qvalue.QValue{ + Kind: qvalue.QValueKindString, + Value: uuid.New().String(), + } + entries[1] = qvalue.QValue{ + Kind: qvalue.QValueKindInt64, + Value: time.Now().UnixNano(), + } + entries[6] = qvalue.QValue{ + Kind: qvalue.QValueKindInt64, + Value: batchID, + } + + return &model.QRecordOrError{ + Record: &model.QRecord{ + NumEntries: 8, + Entries: entries[:], + }, + } } diff --git a/flow/model/model.go b/flow/model/model.go index 08aa786f16..754277ea6f 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math/big" + "sync" "time" "github.com/PeerDB-io/peer-flow/generated/protos" @@ -32,6 +33,8 @@ type PullRecordsRequest struct { OverrideReplicationSlotName string // for supporting schema changes RelationMessageMapping RelationMessageMapping + // record batch for pushing changes into + RecordStream *CDCRecordStream } type Record interface { @@ -281,19 +284,107 @@ type TableWithPkey struct { PkeyColVal string } -type RecordBatch struct { +type CDCRecordStream struct { // Records are a list of json objects. - Records []Record - // FirstCheckPointID is the first ID that was pulled. - FirstCheckPointID int64 - // LastCheckPointID is the last ID of the commit that corresponds to this batch. - LastCheckPointID int64 - //TablePkey to record index mapping - TablePKeyLastSeen map[TableWithPkey]int + records chan Record + // Schema changes from the slot + SchemaDeltas chan *protos.TableSchemaDelta + // Relation message mapping + RelationMessageMapping chan *RelationMessageMapping + // Mutex for synchronizing access to the checkpoint fields + checkpointMutex sync.Mutex + // firstCheckPointID is the first ID of the commit that corresponds to this batch. + firstCheckPointID int64 + // Indicates if the last checkpoint has been set. + lastCheckpointSet bool + // lastCheckPointID is the last ID of the commit that corresponds to this batch. + lastCheckPointID int64 + // empty signal to indicate if the records are going to be empty or not. + emptySignal chan bool +} + +func NewCDCRecordStream() *CDCRecordStream { + return &CDCRecordStream{ + records: make(chan Record, 1<<18), + // TODO (kaushik): more than 1024 schema deltas can cause problems! + SchemaDeltas: make(chan *protos.TableSchemaDelta, 1<<10), + emptySignal: make(chan bool, 1), + RelationMessageMapping: make(chan *RelationMessageMapping, 1), + lastCheckpointSet: false, + lastCheckPointID: 0, + firstCheckPointID: 0, + } +} + +func (r *CDCRecordStream) UpdateLatestCheckpoint(val int64) { + r.checkpointMutex.Lock() + defer r.checkpointMutex.Unlock() + + if r.firstCheckPointID == 0 { + r.firstCheckPointID = val + } + + if val > r.lastCheckPointID { + r.lastCheckPointID = val + } +} + +func (r *CDCRecordStream) GetFirstCheckpoint() int64 { + r.checkpointMutex.Lock() + defer r.checkpointMutex.Unlock() + + return r.firstCheckPointID +} + +func (r *CDCRecordStream) GetLastCheckpoint() (int64, error) { + r.checkpointMutex.Lock() + defer r.checkpointMutex.Unlock() + + if !r.lastCheckpointSet { + return 0, errors.New("last checkpoint not set, stream is still active") + } + return r.lastCheckPointID, nil +} + +func (r *CDCRecordStream) AddRecord(record Record) { + r.records <- record +} + +func (r *CDCRecordStream) SignalAsEmpty() { + r.emptySignal <- true +} + +func (r *CDCRecordStream) SignalAsNotEmpty() { + r.emptySignal <- false +} + +func (r *CDCRecordStream) WaitAndCheckEmpty() bool { + isEmpty := <-r.emptySignal + return isEmpty +} + +func (r *CDCRecordStream) WaitForSchemaDeltas() []*protos.TableSchemaDelta { + schemaDeltas := make([]*protos.TableSchemaDelta, 0) + for delta := range r.SchemaDeltas { + schemaDeltas = append(schemaDeltas, delta) + } + return schemaDeltas +} + +func (r *CDCRecordStream) Close() { + close(r.emptySignal) + close(r.records) + close(r.SchemaDeltas) + close(r.RelationMessageMapping) + r.lastCheckpointSet = true +} + +func (r *CDCRecordStream) GetRecords() chan Record { + return r.records } type SyncRecordsRequest struct { - Records *RecordBatch + Records *CDCRecordStream // FlowJobName is the name of the flow job. FlowJobName string // SyncMode to use for pushing raw records @@ -325,7 +416,7 @@ type SyncResponse struct { // to be carried to parent WorkFlow TableSchemaDeltas []*protos.TableSchemaDelta // to be stored in state for future PullFlows - RelationMessageMapping RelationMessageMapping + RelationMessageMapping *RelationMessageMapping } type NormalizeResponse struct { @@ -335,13 +426,6 @@ type NormalizeResponse struct { EndBatchID int64 } -// sync all the records normally, then apply the schema delta after NormalizeFlow. -type RecordsWithTableSchemaDelta struct { - RecordBatch *RecordBatch - TableSchemaDeltas []*protos.TableSchemaDelta - RelationMessageMapping RelationMessageMapping -} - // being clever and passing the delta back as a regular record instead of heavy CDC refactoring. type RelationRecord struct { CheckPointID int64 diff --git a/flow/model/qrecord_stream.go b/flow/model/qrecord_stream.go index a338ace4f4..721ab58c6c 100644 --- a/flow/model/qrecord_stream.go +++ b/flow/model/qrecord_stream.go @@ -20,15 +20,29 @@ type QRecordStream struct { } type RecordsToStreamRequest struct { - Records []Record + records chan Record TableMapping map[string]uint32 - CP int64 BatchID int64 } +func NewRecordsToStreamRequest( + records chan Record, + tableMapping map[string]uint32, + batchID int64, +) *RecordsToStreamRequest { + return &RecordsToStreamRequest{ + records: records, + TableMapping: tableMapping, + BatchID: batchID, + } +} + +func (r *RecordsToStreamRequest) GetRecords() chan Record { + return r.records +} + type RecordsToStreamResponse struct { Stream *QRecordStream - CP int64 } func NewQRecordStream(buffer int) *QRecordStream { diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index 2eef9a706b..9ac0ade826 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -69,7 +69,7 @@ type CDCFlowState struct { NormalizeFlowErrors error // Global mapping of relation IDs to RelationMessages sent as a part of logical replication. // Needed to support schema changes. - RelationMessageMapping model.RelationMessageMapping + RelationMessageMapping *model.RelationMessageMapping } // returns a new empty PeerFlowState @@ -83,7 +83,7 @@ func NewCDCFlowState() *CDCFlowState { SyncFlowErrors: nil, NormalizeFlowErrors: nil, // WORKAROUND: empty maps are protobufed into nil maps for reasons beyond me - RelationMessageMapping: model.RelationMessageMapping{ + RelationMessageMapping: &model.RelationMessageMapping{ 0: &protos.RelationMessage{ RelationId: 0, RelationName: "protobuf_workaround", @@ -315,7 +315,7 @@ func CDCFlowWorkflowWithConfig( }, } ctx = workflow.WithChildOptions(ctx, childSyncFlowOpts) - syncFlowOptions.RelationMessageMapping = state.RelationMessageMapping + syncFlowOptions.RelationMessageMapping = *state.RelationMessageMapping childSyncFlowFuture := workflow.ExecuteChildWorkflow( ctx, SyncFlowWorkflow,