Skip to content

Commit

Permalink
Pass sync batch id to SyncRecords (#1131)
Browse files Browse the repository at this point in the history
No need to query it twice

CatalogPool is always present, so remove nil check
  • Loading branch information
serprex authored Jan 23, 2024
1 parent 5b3fe7b commit 87bdc8c
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 66 deletions.
41 changes: 21 additions & 20 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,24 +252,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,

hasRecords := !recordBatch.WaitAndCheckEmpty()
slog.InfoContext(ctx, fmt.Sprintf("the current sync flow has records: %v", hasRecords))
if a.CatalogPool != nil && hasRecords {
syncBatchID, err := dstConn.GetLastSyncBatchID(flowName)
if err != nil && conn.Destination.Type != protos.DBType_EVENTHUB {
return nil, err
}

err = monitoring.AddCDCBatchForFlow(ctx, a.CatalogPool, flowName,
monitoring.CDCBatchInfo{
BatchID: syncBatchID + 1,
RowsInBatch: 0,
BatchEndlSN: 0,
StartTime: startTime,
})
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, err
}
}

if !hasRecords {
// wait for the pull goroutine to finish
Expand All @@ -292,8 +274,27 @@ func (a *FlowableActivity) StartFlow(ctx context.Context,
}, nil
}

syncBatchID, err := dstConn.GetLastSyncBatchID(flowName)
if err != nil && conn.Destination.Type != protos.DBType_EVENTHUB {
return nil, err
}
syncBatchID += 1

err = monitoring.AddCDCBatchForFlow(ctx, a.CatalogPool, flowName,
monitoring.CDCBatchInfo{
BatchID: syncBatchID,
RowsInBatch: 0,
BatchEndlSN: 0,
StartTime: startTime,
})
if err != nil {
a.Alerter.LogFlowError(ctx, flowName, err)
return nil, err
}

syncStartTime := time.Now()
res, err := dstConn.SyncRecords(&model.SyncRecordsRequest{
SyncBatchID: syncBatchID,
Records: recordBatch,
FlowJobName: input.FlowConnectionConfigs.FlowJobName,
TableMappings: input.FlowConnectionConfigs.TableMappings,
Expand Down Expand Up @@ -377,13 +378,13 @@ func (a *FlowableActivity) StartNormalize(
if errors.Is(err, connectors.ErrUnsupportedFunctionality) {
dstConn, err := connectors.GetCDCSyncConnector(ctx, conn.Destination)
if err != nil {
return nil, fmt.Errorf("failed to get connector: %v", err)
return nil, fmt.Errorf("failed to get connector: %w", err)
}
defer connectors.CloseConnector(dstConn)

lastSyncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName)
if err != nil {
return nil, fmt.Errorf("failed to get last sync batch ID: %v", err)
return nil, fmt.Errorf("failed to get last sync batch ID: %w", err)
}

err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName,
Expand Down
10 changes: 1 addition & 9 deletions flow/connectors/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S

c.logger.Info(fmt.Sprintf("pushing records to %s.%s...", c.datasetID, rawTableName))

// generate a sequential number for last synced batch this sequence will be
// used to keep track of records that are normalized in NormalizeFlowWorkflow
syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName)
if err != nil {
return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err)
}
syncBatchID += 1

res, err := c.syncRecordsViaAvro(req, rawTableName, syncBatchID)
res, err := c.syncRecordsViaAvro(req, rawTableName, req.SyncBatchID)
if err != nil {
return nil, err
}
Expand Down
10 changes: 2 additions & 8 deletions flow/connectors/eventhub/eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,10 @@ func (c *EventHubConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
return nil, err
}

rowsSynced := int64(numRecords)
syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName)
if err != nil {
c.logger.Error("failed to get last sync batch id", slog.Any("error", err))
}

return &model.SyncResponse{
CurrentSyncBatchID: syncBatchID,
CurrentSyncBatchID: req.SyncBatchID,
LastSyncedCheckpointID: lastCheckpoint,
NumRecordsSynced: rowsSynced,
NumRecordsSynced: int64(numRecords),
TableNameRowsMapping: make(map[string]uint32),
TableSchemaDeltas: req.Records.WaitForSchemaDeltas(req.TableMappings),
}, nil
Expand Down
17 changes: 6 additions & 11 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,6 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
rawTableIdentifier := getRawTableIdentifier(req.FlowJobName)
c.logger.Info(fmt.Sprintf("pushing records to Postgres table %s via COPY", rawTableIdentifier))

syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName)
if err != nil {
return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err)
}
syncBatchID += 1
records := make([][]interface{}, 0)
tableNameRowsMapping := make(map[string]uint32)

Expand All @@ -299,7 +294,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
itemsJSON,
0,
"{}",
syncBatchID,
req.SyncBatchID,
"",
})
tableNameRowsMapping[typedRecord.DestinationTableName] += 1
Expand All @@ -326,7 +321,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
newItemsJSON,
1,
oldItemsJSON,
syncBatchID,
req.SyncBatchID,
utils.KeysToString(typedRecord.UnchangedToastColumns),
})
tableNameRowsMapping[typedRecord.DestinationTableName] += 1
Expand All @@ -346,7 +341,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
itemsJSON,
2,
itemsJSON,
syncBatchID,
req.SyncBatchID,
"",
})
tableNameRowsMapping[typedRecord.DestinationTableName] += 1
Expand All @@ -356,7 +351,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
}

tableSchemaDeltas := req.Records.WaitForSchemaDeltas(req.TableMappings)
err = c.ReplayTableSchemaDeltas(req.FlowJobName, tableSchemaDeltas)
err := c.ReplayTableSchemaDeltas(req.FlowJobName, tableSchemaDeltas)
if err != nil {
return nil, fmt.Errorf("failed to sync schema changes: %w", err)
}
Expand Down Expand Up @@ -402,7 +397,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
}

// updating metadata with new offset and syncBatchID
err = c.updateSyncMetadata(req.FlowJobName, lastCP, syncBatchID, syncRecordsTx)
err = c.updateSyncMetadata(req.FlowJobName, lastCP, req.SyncBatchID, syncRecordsTx)
if err != nil {
return nil, err
}
Expand All @@ -415,7 +410,7 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
return &model.SyncResponse{
LastSyncedCheckpointID: lastCP,
NumRecordsSynced: int64(len(records)),
CurrentSyncBatchID: syncBatchID,
CurrentSyncBatchID: req.SyncBatchID,
TableNameRowsMapping: tableNameRowsMapping,
TableSchemaDeltas: tableSchemaDeltas,
}, nil
Expand Down
10 changes: 2 additions & 8 deletions flow/connectors/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,8 @@ func (c *S3Connector) SetLastOffset(jobName string, offset int64) error {
}

func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncResponse, error) {
syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName)
if err != nil {
return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err)
}
syncBatchID += 1

tableNameRowsMapping := make(map[string]uint32)
streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, syncBatchID)
streamReq := model.NewRecordsToStreamRequest(req.Records.GetRecords(), tableNameRowsMapping, req.SyncBatchID)
streamRes, err := utils.RecordsToRawTableStream(streamReq)
if err != nil {
return nil, fmt.Errorf("failed to convert records to raw table stream: %w", err)
Expand All @@ -202,7 +196,7 @@ func (c *S3Connector) SyncRecords(req *model.SyncRecordsRequest) (*model.SyncRes
DestinationTableIdentifier: fmt.Sprintf("raw_table_%s", req.FlowJobName),
}
partition := &protos.QRepPartition{
PartitionId: strconv.FormatInt(syncBatchID, 10),
PartitionId: strconv.FormatInt(req.SyncBatchID, 10),
}
numRecords, err := c.SyncQRepRecords(qrepConfig, partition, recordStream)
if err != nil {
Expand Down
12 changes: 3 additions & 9 deletions flow/connectors/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.
rawTableIdentifier := getRawTableIdentifier(req.FlowJobName)
c.logger.Info(fmt.Sprintf("pushing records to Snowflake table %s", rawTableIdentifier))

syncBatchID, err := c.GetLastSyncBatchID(req.FlowJobName)
if err != nil {
return nil, fmt.Errorf("failed to get previous syncBatchID: %w", err)
}
syncBatchID += 1

res, err := c.syncRecordsViaAvro(req, rawTableIdentifier, syncBatchID)
res, err := c.syncRecordsViaAvro(req, rawTableIdentifier, req.SyncBatchID)
if err != nil {
return nil, err
}
Expand All @@ -506,12 +500,12 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.
deferErr := syncRecordsTx.Rollback()
if deferErr != sql.ErrTxDone && deferErr != nil {
c.logger.Error("error while rolling back transaction for SyncRecords: %v",
slog.Any("error", deferErr), slog.Int64("syncBatchID", syncBatchID))
slog.Any("error", deferErr), slog.Int64("syncBatchID", req.SyncBatchID))
}
}()

// updating metadata with new offset and syncBatchID
err = c.updateSyncMetadata(req.FlowJobName, res.LastSyncedCheckpointID, syncBatchID, syncRecordsTx)
err = c.updateSyncMetadata(req.FlowJobName, res.LastSyncedCheckpointID, req.SyncBatchID, syncRecordsTx)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion flow/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ type SyncAndNormalizeBatchID struct {
}

type SyncRecordsRequest struct {
Records *CDCRecordStream
SyncBatchID int64
Records *CDCRecordStream
// FlowJobName is the name of the flow job.
FlowJobName string
// SyncMode to use for pushing raw records
Expand Down

0 comments on commit 87bdc8c

Please sign in to comment.