diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index f9830d479d..6872333571 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -251,24 +251,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, hasRecords := !recordBatch.WaitAndCheckEmpty() slog.InfoContext(ctx, fmt.Sprintf("the current sync flow has records: %v", hasRecords)) - if a.CatalogPool != nil && hasRecords { - syncBatchID, err := dstConn.GetLastSyncBatchID(flowName) - if err != nil && conn.Destination.Type != protos.DBType_EVENTHUB { - return nil, err - } - - err = monitoring.AddCDCBatchForFlow(ctx, a.CatalogPool, flowName, - monitoring.CDCBatchInfo{ - BatchID: syncBatchID + 1, - RowsInBatch: 0, - BatchEndlSN: 0, - StartTime: startTime, - }) - if err != nil { - a.Alerter.LogFlowError(ctx, flowName, err) - return nil, err - } - } if !hasRecords { // wait for the pull goroutine to finish @@ -291,8 +273,26 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, }, nil } - syncStartTime := time.Now() + syncBatchID, err := dstConn.GetLastSyncBatchID(flowName) + if err != nil && conn.Destination.Type != protos.DBType_EVENTHUB { + return nil, err + } + syncBatchID += 1 + + err = monitoring.AddCDCBatchForFlow(ctx, a.CatalogPool, flowName, + monitoring.CDCBatchInfo{ + BatchID: syncBatchID, + RowsInBatch: 0, + BatchEndlSN: 0, + StartTime: startTime, + }) + if err != nil { + a.Alerter.LogFlowError(ctx, flowName, err) + return nil, err + } + res, err := dstConn.SyncRecords(&model.SyncRecordsRequest{ + SyncBatchID: syncBatchID, Records: recordBatch, FlowJobName: input.FlowConnectionConfigs.FlowJobName, TableMappings: input.FlowConnectionConfigs.TableMappings, @@ -376,13 +376,13 @@ func (a *FlowableActivity) StartNormalize( if errors.Is(err, connectors.ErrUnsupportedFunctionality) { dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination) if err != nil { - return nil, fmt.Errorf("failed to get connector: %v", err) + return nil, fmt.Errorf("failed to get connector: %w", err) } defer connectors.CloseConnector(dstConn) lastSyncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName) if err != nil { - return nil, fmt.Errorf("failed to get last sync batch ID: %v", err) + return nil, fmt.Errorf("failed to get last sync batch ID: %w", err) } err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName, diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 3d7e27d863..a2987a8c1e 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -479,15 +479,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S c.logger.Info(fmt.Sprintf("pushing records to %s.%s...", c.datasetID, rawTableName)) - // generate a sequential number for last synced batch this sequence will be - // used to keep track of records that are normalized in NormalizeFlowWorkflow - syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) - } - syncBatchID += 1 - - res, err := c.syncRecordsViaAvro(req, rawTableName, syncBatchID) + res, err := c.syncRecordsViaAvro(req, rawTableName, req.SyncBatchID) if err != nil { return nil, err } diff --git a/flow/connectors/eventhub/eventhub.go b/flow/connectors/eventhub/eventhub.go index f60cb547c7..70a580c038 100644 --- a/flow/connectors/eventhub/eventhub.go +++ b/flow/connectors/eventhub/eventhub.go @@ -249,16 +249,10 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return nil, err } - rowsSynced := int64(numRecords) - syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) - if err != nil { - c.logger.Error("failed to get last sync batch id", slog.Any("error", err)) - } - return &model.SyncResponse{ - CurrentSyncBatchID: syncBatchID, + CurrentSyncBatchID: req.SyncBatchID, LastSyncedCheckPointID: lastCheckpoint, - NumRecordsSynced: rowsSynced, + NumRecordsSynced: int64(numRecords), TableNameRowsMapping: make(map[string]uint32), TableSchemaDeltas: req.Records.WaitForSchemaDeltas(req.TableMappings), }, nil diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 98b4382df6..b4f5c0663d 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -272,11 +272,6 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) c.logger.Info(fmt.Sprintf("pushing records to Postgres table %s via COPY", rawTableIdentifier)) - syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err) - } - syncBatchID += 1 records := make([][]interface{}, 0) tableNameRowsMapping := make(map[string]uint32) @@ -298,7 +293,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S itemsJSON, 0, "{}", - syncBatchID, + req.SyncBatchID, "", }) tableNameRowsMapping[typedRecord.DestinationTableName] += 1 @@ -325,7 +320,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S newItemsJSON, 1, oldItemsJSON, - syncBatchID, + req.SyncBatchID, utils.KeysToString(typedRecord.UnchangedToastColumns), }) tableNameRowsMapping[typedRecord.DestinationTableName] += 1 @@ -345,7 +340,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S itemsJSON, 2, itemsJSON, - syncBatchID, + req.SyncBatchID, "", }) tableNameRowsMapping[typedRecord.DestinationTableName] += 1 @@ -355,7 +350,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } tableSchemaDeltas := req.Records.WaitForSchemaDeltas(req.TableMappings) - err = c.ReplayTableSchemaDeltas(req.FlowJobName, tableSchemaDeltas) + err := c.ReplayTableSchemaDeltas(req.FlowJobName, tableSchemaDeltas) if err != nil { return nil, fmt.Errorf("failed to sync schema changes: %w", err) } @@ -401,7 +396,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S } // updating metadata with new offset and syncBatchID - err = c.updateSyncMetadata(req.FlowJobName, lastCP, syncBatchID, syncRecordsTx) + err = c.updateSyncMetadata(req.FlowJobName, lastCP, req.SyncBatchID, syncRecordsTx) if err != nil { return nil, err } @@ -414,7 +409,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return &model.SyncResponse{ LastSyncedCheckPointID: lastCP, NumRecordsSynced: int64(len(records)), - CurrentSyncBatchID: syncBatchID, + CurrentSyncBatchID: req.SyncBatchID, TableNameRowsMapping: tableNameRowsMapping, TableSchemaDeltas: tableSchemaDeltas, }, nil diff --git a/flow/connectors/s3/s3.go b/flow/connectors/s3/s3.go index d24175ed45..992903d24a 100644 --- a/flow/connectors/s3/s3.go +++ b/flow/connectors/s3/s3.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "strconv" "strings" "time" @@ -182,14 +183,8 @@ func (c *S3Connector) SetLastOffset(jobName string, offset int64) error { } func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) { - syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err) - } - syncBatchID += 1 - tableNameRowsMapping := make(map[string]uint32) - streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID) + streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, req.SyncBatchID) streamRes, err := utils.RecordsToRawTableStream(streamReq) if err != nil { return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err) @@ -200,7 +195,7 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes DestinationTableIdentifier: fmt.Sprintf("raw_table_%s", req.FlowJobName), } partition := &protos.QRepPartition{ - PartitionId: fmt.Sprint(syncBatchID), + PartitionId: strconv.FormatInt(req.SyncBatchID, 10), } numRecords, err := c.SyncQRepRecords(qrepConfig, partition, recordStream) if err != nil { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 23b335f8f8..de048f42b7 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -484,13 +484,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. rawTableIdentifier := getRawTableIdentifier(req.FlowJobName) c.logger.Info(fmt.Sprintf("pushing records to Snowflake table %s", rawTableIdentifier)) - syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err) - } - syncBatchID += 1 - - res, err := c.syncRecordsViaAvro(req, rawTableIdentifier, syncBatchID) + res, err := c.syncRecordsViaAvro(req, rawTableIdentifier, req.SyncBatchID) if err != nil { return nil, err } @@ -505,12 +499,12 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. deferErr := syncRecordsTx.Rollback() if deferErr != sql.ErrTxDone && deferErr != nil { c.logger.Error("error while rolling back transaction for SyncRecords: %v", - slog.Any("error", deferErr), slog.Int64("syncBatchID", syncBatchID)) + slog.Any("error", deferErr), slog.Int64("syncBatchID", req.SyncBatchID)) } }() // updating metadata with new offset and syncBatchID - err = c.updateSyncMetadata(req.FlowJobName, res.LastSyncedCheckPointID, syncBatchID, syncRecordsTx) + err = c.updateSyncMetadata(req.FlowJobName, res.LastSyncedCheckPointID, req.SyncBatchID, syncRecordsTx) if err != nil { return nil, err } diff --git a/flow/model/model.go b/flow/model/model.go index 524a1ce83f..9bc60fb52b 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -525,7 +525,8 @@ type SyncAndNormalizeBatchID struct { } type SyncRecordsRequest struct { - Records *CDCRecordStream + SyncBatchID int64 + Records *CDCRecordStream // FlowJobName is the name of the flow job. FlowJobName string // SyncMode to use for pushing raw records