From 9ffb79c5138664273a507a259e14bba656e150b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Mon, 25 Dec 2023 15:50:37 +0000 Subject: [PATCH] GetLastSyncAndNormalizeBatchID: return struct instead of two int64s --- flow/connectors/bigquery/bigquery.go | 33 ++++++++++++++------------ flow/connectors/snowflake/snowflake.go | 33 ++++++++++++++------------ flow/model/model.go | 5 ++++ 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 725a44f65b..fe82fa1d49 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -382,21 +382,21 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { } } -func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (int64, int64, error) { +func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) q := c.client.Query(query) it, err := q.Read(c.ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) - return -1, -1, err + return model.SyncAndNormalizeBatchID{}, err } var row []bigquery.Value err = it.Next(&row) if err != nil { c.logger.Info("no row found for job") - return 0, 0, nil + return model.SyncAndNormalizeBatchID{}, nil } syncBatchID := int64(0) @@ -407,7 +407,10 @@ func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (int6 if row[1] != nil { normBatchID = row[1].(int64) } - return syncBatchID, normBatchID, nil + return model.SyncAndNormalizeBatchID{ + SyncBatchID: syncBatchID, + NormalizeBatchID: normBatchID, + }, nil } func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -749,7 +752,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro( func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) - syncBatchID, normalizeBatchID, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } @@ -760,20 +763,20 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } // if job is not yet found in the peerdb_mirror_jobs_table // OR sync is lagging end normalize - if !hasJob || normalizeBatchID >= syncBatchID { + if !hasJob || batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { c.logger.Info("waiting for sync to catch up, so finishing") return &model.NormalizeResponse{ Done: false, - StartBatchID: normalizeBatchID, - EndBatchID: syncBatchID, + StartBatchID: batchIDs.NormalizeBatchID, + EndBatchID: batchIDs.SyncBatchID, }, nil } - distinctTableNames, err := c.getDistinctTableNamesInBatch(req.FlowJobName, syncBatchID, normalizeBatchID) + distinctTableNames, err := c.getDistinctTableNamesInBatch(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) if err != nil { return nil, fmt.Errorf("couldn't get distinct table names to normalize: %w", err) } - tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, syncBatchID, normalizeBatchID) + tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) if err != nil { return nil, fmt.Errorf("couldn't get tablename to unchanged cols mapping: %w", err) } @@ -789,8 +792,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) NormalizedTable: tableName, RawTable: rawTableName, NormalizedTableSchema: c.tableNameSchemaMapping[tableName], - SyncBatchID: syncBatchID, - NormalizeBatchID: normalizeBatchID, + SyncBatchID: batchIDs.SyncBatchID, + NormalizeBatchID: batchIDs.NormalizeBatchID, UnchangedToastColumns: tableNametoUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ SoftDeleteColName: req.SoftDeleteColName, @@ -805,7 +808,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) // update metadata to make the last normalized batch id to the recent last sync batch id. updateMetadataStmt := fmt.Sprintf( "UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", - c.datasetID, MirrorJobsTable, syncBatchID, req.FlowJobName) + c.datasetID, MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName) stmts = append(stmts, updateMetadataStmt) query := strings.Join(stmts, "\n") @@ -816,8 +819,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: normalizeBatchID + 1, - EndBatchID: syncBatchID, + StartBatchID: batchIDs.NormalizeBatchID + 1, + EndBatchID: batchIDs.SyncBatchID, }, nil } diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index bbed7d2a55..321b60d0d2 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -345,23 +345,26 @@ func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *SnowflakeConnector) GetLastSyncAndNormalizeBatchID(jobName string) (int64, int64, error) { +func (c *SnowflakeConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastSyncNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) if err != nil { - return 0, 0, fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) + return model.SyncAndNormalizeBatchID{}, fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) } var syncResult, normResult pgtype.Int8 if !rows.Next() { c.logger.Warn("No row found, returning 0") - return 0, 0, nil + return model.SyncAndNormalizeBatchID{}, nil } err = rows.Scan(&syncResult, &normResult) if err != nil { - return 0, 0, fmt.Errorf("error while reading result row: %w", err) + return model.SyncAndNormalizeBatchID{}, fmt.Errorf("error while reading result row: %w", err) } - return syncResult.Int64, normResult.Int64, nil + return model.SyncAndNormalizeBatchID{ + SyncBatchID: syncResult.Int64, + NormalizeBatchID: normResult.Int64, + }, nil } func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -590,16 +593,16 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( // NormalizeRecords normalizes raw table to destination table. func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { - syncBatchID, normalizeBatchID, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) if err != nil { return nil, err } // normalize has caught up with sync, chill until more records are loaded. - if normalizeBatchID >= syncBatchID { + if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { return &model.NormalizeResponse{ Done: false, - StartBatchID: normalizeBatchID, - EndBatchID: syncBatchID, + StartBatchID: batchIDs.NormalizeBatchID, + EndBatchID: batchIDs.SyncBatchID, }, nil } @@ -613,12 +616,12 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest Done: false, }, nil } - destinationTableNames, err := c.getDistinctTableNamesInBatch(req.FlowJobName, syncBatchID, normalizeBatchID) + destinationTableNames, err := c.getDistinctTableNamesInBatch(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) if err != nil { return nil, err } - tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, syncBatchID, normalizeBatchID) + tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) if err != nil { return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } @@ -636,7 +639,7 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest tableName, tableNametoUnchangedToastCols[tableName], getRawTableIdentifier(req.FlowJobName), - syncBatchID, normalizeBatchID, + batchIDs.SyncBatchID, batchIDs.NormalizeBatchID, req) if err != nil { c.logger.Error("[merge] error while normalizing records", slog.Any("error", err)) @@ -653,15 +656,15 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, syncBatchID) + err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID) if err != nil { return nil, err } return &model.NormalizeResponse{ Done: true, - StartBatchID: normalizeBatchID + 1, - EndBatchID: syncBatchID, + StartBatchID: batchIDs.NormalizeBatchID + 1, + EndBatchID: batchIDs.SyncBatchID, }, nil } diff --git a/flow/model/model.go b/flow/model/model.go index 581b57178b..9b8213c8cd 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -416,6 +416,11 @@ func (r *CDCRecordStream) GetRecords() chan Record { return r.records } +type SyncAndNormalizeBatchID struct { + SyncBatchID int64 + NormalizeBatchID int64 +} + type SyncRecordsRequest struct { Records *CDCRecordStream // FlowJobName is the name of the flow job.