From 8adab3f26b8dbc7a9222ba16448be1c4a757974a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Thu, 25 Jan 2024 16:46:56 +0000 Subject: [PATCH] Normalize concurrently with sync flows (#893) Previously after each sync we'd pause reading slot to process table schema deltas & normalize This has two problems: 1. we want to always be reading slot, we aren't reading slot during normalize 2. merging multiple batches at once can be less expensive Now NormalizeFlow is created as a child workflow at start of cdc flow & a signal is sent after each sync flow with schema updates Normalize consumes all signals since it last checked, merging their processing in parallel with sync flows NormalizeFlow only reads up to the signal's batch id to avoid potentially syncing a batch without its schema This creates a range `(normid..syncid]` in which normid is always catching up to syncid as we normalize `normid+1` to `syncid` Normalize logic already handled this, so it goes untouched in this change `PEERDB_ENABLE_PARALLEL_SYNC_NORMALIZE` needs to be set to true, for now keep this change behind feature flag to avoid potentially increasing data warehouse costs --- flow/activities/flowable.go | 16 +-- flow/connectors/bigquery/bigquery.go | 48 +++---- flow/connectors/postgres/client.go | 27 ++-- flow/connectors/postgres/postgres.go | 26 ++-- flow/connectors/snowflake/snowflake.go | 63 +++++---- flow/model/model.go | 17 ++- flow/peerdbenv/config.go | 5 + flow/shared/constants.go | 2 + flow/workflows/cdc_flow.go | 176 +++++++++++++++---------- flow/workflows/normalize_flow.go | 119 ++++++++++------- flow/workflows/qrep_flow.go | 8 +- flow/workflows/xmin_flow.go | 4 +- protos/flow.proto | 1 + 13 files changed, 290 insertions(+), 222 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 2eaf9f1d3f..c5680ebe6c 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -198,7 +198,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } defer connectors.CloseConnector(dstConn) - activity.RecordHeartbeat(ctx, "initialized table schema") slog.InfoContext(ctx, "pulling records...") tblNameMapping := make(map[string]model.NameAndExclude) for _, v := range input.FlowConnectionConfigs.TableMappings { @@ -268,6 +267,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } return &model.SyncResponse{ + CurrentSyncBatchID: -1, TableSchemaDeltas: recordBatch.SchemaDeltas, RelationMessageMapping: input.RelationMessageMapping, }, nil @@ -379,13 +379,8 @@ func (a *FlowableActivity) StartNormalize( } defer connectors.CloseConnector(dstConn) - lastSyncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get last sync batch ID: %w", err) - } - err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName, - lastSyncBatchID) + input.SyncBatchID) return nil, err } else if err != nil { return nil, err @@ -399,6 +394,7 @@ func (a *FlowableActivity) StartNormalize( res, err := dstConn.NormalizeRecords(&model.NormalizeRecordsRequest{ FlowJobName: input.FlowConnectionConfigs.FlowJobName, + SyncBatchID: input.SyncBatchID, SoftDelete: input.FlowConnectionConfigs.SoftDelete, SoftDeleteColName: input.FlowConnectionConfigs.SoftDeleteColName, SyncedAtColName: input.FlowConnectionConfigs.SyncedAtColName, @@ -423,10 +419,8 @@ func (a *FlowableActivity) StartNormalize( } // log the number of batches normalized - if res != nil { - slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d", - res.StartBatchID, res.EndBatchID)) - } + slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d", + res.StartBatchID, res.EndBatchID)) return res, nil } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 9f501d9fcc..3443993512 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -359,6 +359,7 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) q := c.client.Query(query) + q.DisableQueryCache = true q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID it, err := q.Read(c.ctx) @@ -382,37 +383,30 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { } } -func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s WHERE mirror_job_name = '%s'", +func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + query := fmt.Sprintf("SELECT normalize_batch_id FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) q := c.client.Query(query) + q.DisableQueryCache = true q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID it, err := q.Read(c.ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) - return model.SyncAndNormalizeBatchID{}, err + return 0, err } var row []bigquery.Value err = it.Next(&row) if err != nil { c.logger.Info("no row found for job") - return model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - syncBatchID := int64(0) - normBatchID := int64(0) if row[0] != nil { - syncBatchID = row[0].(int64) - } - if row[1] != nil { - normBatchID = row[1].(int64) + return row[0].(int64), nil } - return model.SyncAndNormalizeBatchID{ - SyncBatchID: syncBatchID, - NormalizeBatchID: normBatchID, - }, nil + return 0, nil } func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -546,7 +540,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro( func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } @@ -557,18 +551,18 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } // if job is not yet found in the peerdb_mirror_jobs_table // OR sync is lagging end normalize - if !hasJob || batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if !hasJob || normBatchID >= req.SyncBatchID { c.logger.Info("waiting for sync to catch up, so finishing") return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } distinctTableNames, err := c.getDistinctTableNamesInBatch( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, fmt.Errorf("couldn't get distinct table names to normalize: %w", err) @@ -576,8 +570,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, fmt.Errorf("couldn't get tablename to unchanged cols mapping: %w", err) @@ -599,8 +593,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) dstTableName: tableName, dstDatasetTable: dstDatasetTable, normalizedTableSchema: req.TableNameSchemaMapping[tableName], - syncBatchID: batchIDs.SyncBatchID, - normalizeBatchID: batchIDs.NormalizeBatchID, + syncBatchID: req.SyncBatchID, + normalizeBatchID: normBatchID, peerdbCols: &protos.PeerDBColumns{ SoftDeleteColName: req.SoftDeleteColName, SyncedAtColName: req.SyncedAtColName, @@ -625,7 +619,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) // update metadata to make the last normalized batch id to the recent last sync batch id. updateMetadataStmt := fmt.Sprintf( "UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", - MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName) + MirrorJobsTable, req.SyncBatchID, req.FlowJobName) query := c.client.Query(updateMetadataStmt) query.DefaultProjectID = c.projectID @@ -637,8 +631,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index c4c9a879dc..c6ae8f7d7f 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -38,11 +38,11 @@ const ( createRawTableBatchIDIndexSQL = "CREATE INDEX IF NOT EXISTS %s_batchid_idx ON %s.%s(_peerdb_batch_id)" createRawTableDstTableIndexSQL = "CREATE INDEX IF NOT EXISTS %s_dst_table_idx ON %s.%s(_peerdb_destination_table_name)" - getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1" - setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2" - getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1" - getLastSyncAndNormalizeBatchID_SQL = "SELECT sync_batch_id,normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1" - createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)" + getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1" + setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2" + getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1" + getLastNormalizeBatchID_SQL = "SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1" + createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)" upsertJobMetadataForSyncSQL = `INSERT INTO %s.%s AS j VALUES ($1,$2,$3,$4) ON CONFLICT(mirror_job_name) DO UPDATE SET lsn_offset=GREATEST(j.lsn_offset, EXCLUDED.lsn_offset), sync_batch_id=EXCLUDED.sync_batch_id` @@ -471,24 +471,21 @@ func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *PostgresConnector) GetLastSyncAndNormalizeBatchID(jobName string) (*model.SyncAndNormalizeBatchID, error) { - var syncResult, normalizeResult pgtype.Int8 +func (c *PostgresConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + var result pgtype.Int8 err := c.pool.QueryRow(c.ctx, fmt.Sprintf( - getLastSyncAndNormalizeBatchID_SQL, + getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier, - ), jobName).Scan(&syncResult, &normalizeResult) + ), jobName).Scan(&result) if err != nil { if err == pgx.ErrNoRows { c.logger.Info("No row found, returning 0") - return &model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - return nil, fmt.Errorf("error while reading result row: %w", err) + return 0, fmt.Errorf("error while reading result row: %w", err) } - return &model.SyncAndNormalizeBatchID{ - SyncBatchID: syncResult.Int64, - NormalizeBatchID: normalizeResult.Int64, - }, nil + return result.Int64, nil } func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 79c6fb29a0..231d5f78f2 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -431,28 +431,29 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) }, nil } - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } + // normalize has caught up with sync, chill until more records are loaded. - if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if normBatchID >= req.SyncBatchID { c.logger.Info(fmt.Sprintf("no records to normalize: syncBatchID %d, normalizeBatchID %d", - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID)) + req.SyncBatchID, normBatchID)) return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } destinationTableNames, err := c.getDistinctTableNamesInBatch( - req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, err } unchangedToastColsMap, err := c.getTableNametoUnchangedCols(req.FlowJobName, - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + req.SyncBatchID, normBatchID) if err != nil { return nil, err } @@ -491,7 +492,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } normalizeStatements := normalizeStmtGen.generateNormalizeStatements() for _, normalizeStatement := range normalizeStatements { - mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec( + mergeStatementsBatch.Queue(normalizeStatement, normBatchID, req.SyncBatchID, destinationTableName).Exec( func(ct pgconn.CommandTag) error { totalRowsAffected += int(ct.RowsAffected()) return nil @@ -508,7 +509,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) c.logger.Info(fmt.Sprintf("normalized %d records", totalRowsAffected)) // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID, normalizeRecordsTx) + err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID, normalizeRecordsTx) if err != nil { return nil, err } @@ -520,8 +521,8 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } @@ -711,7 +712,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. -func (c *PostgresConnector) ReplayTableSchemaDeltas(flowJobName string, +func (c *PostgresConnector) ReplayTableSchemaDeltas( + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { if len(schemaDeltas) == 0 { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 763de043dd..ac8e8badcf 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -70,15 +70,15 @@ const ( checkIfTableExistsSQL = `SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? and TABLE_NAME=?` - checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" - getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" - setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" - getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" - getLastSyncNormalizeBatchID_SQL = "SELECT SYNC_BATCH_ID, NORMALIZE_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" - dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" - deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" - dropSchemaIfExistsSQL = "DROP SCHEMA IF EXISTS %s" - checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" + checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" + getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" + setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" + getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" + getLastNormalizeBatchID_SQL = "SELECT NORMALIZE_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" + dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" + deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" + dropSchemaIfExistsSQL = "DROP SCHEMA IF EXISTS %s" + checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" ) type SnowflakeConnector struct { @@ -381,28 +381,24 @@ func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *SnowflakeConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastSyncNormalizeBatchID_SQL, c.metadataSchema, +func (c *SnowflakeConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) if err != nil { - return model.SyncAndNormalizeBatchID{}, - fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) + return 0, fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) } defer rows.Close() - var syncResult, normResult pgtype.Int8 + var normBatchID pgtype.Int8 if !rows.Next() { c.logger.Warn("No row found, returning 0") - return model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - err = rows.Scan(&syncResult, &normResult) + err = rows.Scan(&normBatchID) if err != nil { - return model.SyncAndNormalizeBatchID{}, fmt.Errorf("error while reading result row: %w", err) + return 0, fmt.Errorf("error while reading result row: %w", err) } - return model.SyncAndNormalizeBatchID{ - SyncBatchID: syncResult.Int64, - NormalizeBatchID: normResult.Int64, - }, nil + return normBatchID.Int64, nil } func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -633,16 +629,17 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( // NormalizeRecords normalizes raw table to destination table. func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, err } + // normalize has caught up with sync, chill until more records are loaded. - if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if normBatchID >= req.SyncBatchID { return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } @@ -658,14 +655,14 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } destinationTableNames, err := c.getDistinctTableNamesInBatch( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, err } - tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } @@ -681,8 +678,8 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest mergeGen := &mergeStmtGenerator{ rawTableName: getRawTableIdentifier(req.FlowJobName), dstTableName: tableName, - syncBatchID: batchIDs.SyncBatchID, - normalizeBatchID: batchIDs.NormalizeBatchID, + syncBatchID: req.SyncBatchID, + normalizeBatchID: normBatchID, normalizedTableSchema: req.TableNameSchemaMapping[tableName], unchangedToastColumns: tableNametoUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ @@ -728,15 +725,15 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID) + err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID) if err != nil { return nil, err } return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } diff --git a/flow/model/model.go b/flow/model/model.go index b008e624f1..d8e3a7c751 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -507,11 +507,6 @@ func (r *CDCRecordStream) GetRecords() <-chan Record { return r.records } -type SyncAndNormalizeBatchID struct { - SyncBatchID int64 - NormalizeBatchID int64 -} - type SyncRecordsRequest struct { SyncBatchID int64 Records *CDCRecordStream @@ -525,6 +520,7 @@ type SyncRecordsRequest struct { type NormalizeRecordsRequest struct { FlowJobName string + SyncBatchID int64 SoftDelete bool SoftDeleteColName string SyncedAtColName string @@ -546,6 +542,17 @@ type SyncResponse struct { RelationMessageMapping RelationMessageMapping } +type NormalizeSignal struct { + Done bool + SyncBatchID int64 + TableNameSchemaMapping map[string]*protos.TableSchema +} + +type NormalizeFlowResponse struct { + Results []NormalizeResponse + Errors []string +} + type NormalizeResponse struct { // Flag to depict if normalization is done Done bool diff --git a/flow/peerdbenv/config.go b/flow/peerdbenv/config.go index ca238899ac..65254de73a 100644 --- a/flow/peerdbenv/config.go +++ b/flow/peerdbenv/config.go @@ -74,3 +74,8 @@ func PeerDBCatalogDatabase() string { func PeerDBEnableWALHeartbeat() bool { return getEnvBool("PEERDB_ENABLE_WAL_HEARTBEAT", false) } + +// PEERDB_ENABLE_PARALLEL_SYNC_NORMALIZE +func PeerDBEnableParallelSyncNormalize() bool { + return getEnvBool("PEERDB_ENABLE_PARALLEL_SYNC_NORMALIZE", false) +} diff --git a/flow/shared/constants.go b/flow/shared/constants.go index 119514fb76..a75f9a0f52 100644 --- a/flow/shared/constants.go +++ b/flow/shared/constants.go @@ -14,6 +14,8 @@ const ( // Signals FlowSignalName = "peer-flow-signal" CDCDynamicPropertiesSignalName = "cdc-dynamic-properties" + NormalizeSyncSignalName = "normalize-sync" + NormalizeSyncDoneSignalName = "normalize-sync-done" // Queries CDCFlowStateQuery = "q-cdc-flow-status" diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index f2a5a9cddd..a0d73aa593 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -1,6 +1,7 @@ package peerflow import ( + "errors" "fmt" "log/slog" "strings" @@ -15,6 +16,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" ) @@ -26,7 +28,7 @@ type CDCFlowLimits struct { // Maximum number of rows in a sync flow batch. MaxBatchSize uint32 // Rows synced after which we can say a test is done. - ExitAfterRecords int + ExitAfterRecords int64 } type CDCFlowWorkflowState struct { @@ -35,7 +37,7 @@ type CDCFlowWorkflowState struct { // Accumulates status for sync flows spawned. SyncFlowStatuses []*model.SyncResponse // Accumulates status for sync flows spawned. - NormalizeFlowStatuses []*model.NormalizeResponse + NormalizeFlowStatuses []model.NormalizeResponse // Current signalled state of the peer flow. ActiveSignal shared.CDCFlowSignal // Errors encountered during child sync flow executions. @@ -137,7 +139,7 @@ type CDCFlowWorkflowResult = CDCFlowWorkflowState func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Context, cfg *protos.FlowConnectionConfigs, state *CDCFlowWorkflowState, - limits *CDCFlowLimits, mirrorNameSearch *map[string]interface{}, + limits *CDCFlowLimits, mirrorNameSearch map[string]interface{}, ) error { for _, flowConfigUpdate := range state.FlowConfigUpdates { if len(flowConfigUpdate.AdditionalTables) == 0 { @@ -177,7 +179,7 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Cont RetryPolicy: &temporal.RetryPolicy{ MaximumAttempts: 20, }, - SearchAttributes: *mirrorNameSearch, + SearchAttributes: mirrorNameSearch, WaitForCancellation: true, } childAdditionalTablesCDCFlowCtx := workflow.WithChildOptions(ctx, childAdditionalTablesCDCFlowOpts) @@ -393,19 +395,65 @@ func CDCFlowWorkflowWithConfig( }) currentSyncFlowNum := 0 - totalRecordsSynced := 0 + totalRecordsSynced := int64(0) + + normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) + if err != nil { + return state, err + } + + childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ + WorkflowID: normalizeFlowID, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 20, + }, + SearchAttributes: mirrorNameSearch, + WaitForCancellation: true, + } + normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) + childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( + normCtx, + NormalizeFlowWorkflow, + cfg, + normalizeFlowOptions, + ) + + var normWaitChan workflow.ReceiveChannel + if !peerdbenv.PeerDBEnableParallelSyncNormalize() { + normWaitChan = workflow.GetSignalChannel(ctx, shared.NormalizeSyncDoneSignalName) + } + + finishNormalize := func() { + childNormalizeFlowFuture.SignalChildWorkflow(ctx, shared.NormalizeSyncSignalName, model.NormalizeSignal{ + Done: true, + SyncBatchID: -1, + }) + var childNormalizeFlowRes *model.NormalizeFlowResponse + if err := childNormalizeFlowFuture.Get(ctx, &childNormalizeFlowRes); err != nil { + w.logger.Error("failed to execute normalize flow: ", err) + var panicErr *temporal.PanicError + if errors.As(err, &panicErr) { + w.logger.Error("PANIC", panicErr.Error(), panicErr.StackTrace()) + } + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, err.Error()) + } else { + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, childNormalizeFlowRes.Errors...) + state.NormalizeFlowStatuses = append(state.NormalizeFlowStatuses, childNormalizeFlowRes.Results...) + } + } var canceled bool signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) mainLoopSelector := workflow.NewSelector(ctx) + mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + canceled = true + }) mainLoopSelector.AddReceive(signalChan, func(c workflow.ReceiveChannel, _ bool) { var signalVal shared.CDCFlowSignal c.ReceiveAsync(&signalVal) state.ActiveSignal = shared.FlowSignalHandler(state.ActiveSignal, signalVal, w.logger) }) - mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { - canceled = true - }) for { for !canceled && mainLoopSelector.HasPending() { @@ -429,7 +477,7 @@ func CDCFlowWorkflowWithConfig( state.ActiveSignal = shared.FlowSignalHandler(state.ActiveSignal, signalVal, w.logger) // only process config updates when going from STATUS_PAUSED to STATUS_RUNNING if state.ActiveSignal == shared.NoopSignal { - err = w.processCDCFlowConfigUpdates(ctx, cfg, state, limits, &mirrorNameSearch) + err = w.processCDCFlowConfigUpdates(ctx, cfg, state, limits, mirrorNameSearch) if err != nil { return state, err } @@ -444,6 +492,7 @@ func CDCFlowWorkflowWithConfig( // check if the peer flow has been shutdown if state.ActiveSignal == shared.ShutdownSignal { + finishNormalize() w.logger.Info("peer flow has been shutdown") state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return state, nil @@ -470,6 +519,7 @@ func CDCFlowWorkflowWithConfig( syncFlowID, err := GetChildWorkflowID(ctx, "sync-flow", cfg.FlowJobName) if err != nil { + finishNormalize() return state, err } @@ -493,53 +543,66 @@ func CDCFlowWorkflowWithConfig( ) var syncDone bool - var childSyncFlowRes *model.SyncResponse + var normalizeSignalError error + normDone := normWaitChan == nil mainLoopSelector.AddFuture(childSyncFlowFuture, func(f workflow.Future) { syncDone = true + + var childSyncFlowRes *model.SyncResponse if err := f.Get(syncCtx, &childSyncFlowRes); err != nil { w.logger.Error("failed to execute sync flow: ", err) state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) } else if childSyncFlowRes != nil { state.SyncFlowStatuses = append(state.SyncFlowStatuses, childSyncFlowRes) state.RelationMessageMapping = childSyncFlowRes.RelationMessageMapping - totalRecordsSynced += int(childSyncFlowRes.NumRecordsSynced) + totalRecordsSynced += childSyncFlowRes.NumRecordsSynced w.logger.Info("Total records synced: ", totalRecordsSynced) } - var tableSchemaDeltas []*protos.TableSchemaDelta = nil if childSyncFlowRes != nil { - tableSchemaDeltas = childSyncFlowRes.TableSchemaDeltas - } - - // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. - if tableSchemaDeltas != nil { - modifiedSrcTables := make([]string, 0) - modifiedDstTables := make([]string, 0) - - for _, tableSchemaDelta := range tableSchemaDeltas { - modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) - modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) - } + tableSchemaDeltasCount := len(childSyncFlowRes.TableSchemaDeltas) + + var normalizeTableNameSchemaMapping map[string]*protos.TableSchema + // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. + if tableSchemaDeltasCount != 0 { + modifiedSrcTables := make([]string, 0, tableSchemaDeltasCount) + modifiedDstTables := make([]string, 0, tableSchemaDeltasCount) + for _, tableSchemaDelta := range childSyncFlowRes.TableSchemaDeltas { + modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) + modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) + } - getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 5 * time.Minute, - }) - getModifiedSchemaFuture := workflow.ExecuteActivity(getModifiedSchemaCtx, flowable.GetTableSchema, - &protos.GetTableSchemaBatchInput{ - PeerConnectionConfig: cfg.Source, - TableIdentifiers: modifiedSrcTables, - FlowName: cfg.FlowJobName, + getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, }) - - var getModifiedSchemaRes *protos.GetTableSchemaBatchOutput - if err := getModifiedSchemaFuture.Get(ctx, &getModifiedSchemaRes); err != nil { - w.logger.Error("failed to execute schema update at source: ", err) - state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) - } else { - for i := range modifiedSrcTables { - state.TableNameSchemaMapping[modifiedDstTables[i]] = getModifiedSchemaRes.TableNameSchemaMapping[modifiedSrcTables[i]] + getModifiedSchemaFuture := workflow.ExecuteActivity(getModifiedSchemaCtx, flowable.GetTableSchema, + &protos.GetTableSchemaBatchInput{ + PeerConnectionConfig: cfg.Source, + TableIdentifiers: modifiedSrcTables, + FlowName: cfg.FlowJobName, + }) + + var getModifiedSchemaRes *protos.GetTableSchemaBatchOutput + if err := getModifiedSchemaFuture.Get(ctx, &getModifiedSchemaRes); err != nil { + w.logger.Error("failed to execute schema update at source: ", err) + state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) + } else { + for i, srcTable := range modifiedSrcTables { + dstTable := modifiedDstTables[i] + state.TableNameSchemaMapping[dstTable] = getModifiedSchemaRes.TableNameSchemaMapping[srcTable] + } + normalizeTableNameSchemaMapping = state.TableNameSchemaMapping } } + + signalFuture := childNormalizeFlowFuture.SignalChildWorkflow(ctx, shared.NormalizeSyncSignalName, model.NormalizeSignal{ + Done: false, + SyncBatchID: childSyncFlowRes.CurrentSyncBatchID, + TableNameSchemaMapping: normalizeTableNameSchemaMapping, + }) + normalizeSignalError = signalFuture.Get(ctx, nil) + } else { + normDone = true } }) @@ -555,38 +618,15 @@ func CDCFlowWorkflowWithConfig( state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return state, nil } - - normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) - if err != nil { - return state, err + if normalizeSignalError != nil { + return state, normalizeSignalError } - - childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ - WorkflowID: normalizeFlowID, - ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, - RetryPolicy: &temporal.RetryPolicy{ - MaximumAttempts: 20, - }, - SearchAttributes: mirrorNameSearch, - WaitForCancellation: true, - } - normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) - childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( - normCtx, - NormalizeFlowWorkflow, - cfg, - normalizeFlowOptions, - ) - - var childNormalizeFlowRes *model.NormalizeResponse - if err := childNormalizeFlowFuture.Get(normCtx, &childNormalizeFlowRes); err != nil { - w.logger.Error("failed to execute normalize flow: ", err) - state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, err.Error()) - } else { - state.NormalizeFlowStatuses = append(state.NormalizeFlowStatuses, childNormalizeFlowRes) + if !normDone { + normWaitChan.Receive(ctx, nil) } } + finishNormalize() state.TruncateProgress(w.logger) return state, workflow.NewContinueAsNewError(ctx, CDCFlowWorkflowWithConfig, cfg, limits, state) } diff --git a/flow/workflows/normalize_flow.go b/flow/workflows/normalize_flow.go index 70f6463aef..a598c51792 100644 --- a/flow/workflows/normalize_flow.go +++ b/flow/workflows/normalize_flow.go @@ -4,66 +4,97 @@ import ( "fmt" "time" - "go.temporal.io/sdk/log" "go.temporal.io/sdk/workflow" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/peerdbenv" + "github.com/PeerDB-io/peer-flow/shared" ) -type NormalizeFlowState struct { - CDCFlowName string - Progress []string -} - -type NormalizeFlowExecution struct { - NormalizeFlowState - executionID string - logger log.Logger -} - -func NewNormalizeFlowExecution(ctx workflow.Context, state *NormalizeFlowState) *NormalizeFlowExecution { - return &NormalizeFlowExecution{ - NormalizeFlowState: *state, - executionID: workflow.GetInfo(ctx).WorkflowExecution.ID, - logger: workflow.GetLogger(ctx), - } -} - func NormalizeFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionConfigs, options *protos.NormalizeFlowOptions, -) (*model.NormalizeResponse, error) { - s := NewNormalizeFlowExecution(ctx, &NormalizeFlowState{ - CDCFlowName: config.FlowJobName, - Progress: []string{}, - }) - - return s.executeNormalizeFlow(ctx, config, options) -} - -func (s *NormalizeFlowExecution) executeNormalizeFlow( - ctx workflow.Context, - config *protos.FlowConnectionConfigs, - options *protos.NormalizeFlowOptions, -) (*model.NormalizeResponse, error) { - s.logger.Info("executing normalize flow - ", s.CDCFlowName) +) (*model.NormalizeFlowResponse, error) { + logger := workflow.GetLogger(ctx) + tableNameSchemaMapping := options.TableNameSchemaMapping normalizeFlowCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ StartToCloseTimeout: 7 * 24 * time.Hour, HeartbeatTimeout: time.Minute, }) - startNormalizeInput := &protos.StartNormalizeInput{ - FlowConnectionConfigs: config, - TableNameSchemaMapping: options.TableNameSchemaMapping, - } - fStartNormalize := workflow.ExecuteActivity(normalizeFlowCtx, flowable.StartNormalize, startNormalizeInput) + results := make([]model.NormalizeResponse, 0, 4) + errors := make([]string, 0) + syncChan := workflow.GetSignalChannel(ctx, shared.NormalizeSyncSignalName) + + var stopLoop, canceled bool + var lastSyncBatchID, syncBatchID int64 + lastSyncBatchID = -1 + syncBatchID = -1 + selector := workflow.NewNamedSelector(ctx, fmt.Sprintf("%s-normalize", config.FlowJobName)) + selector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + canceled = true + }) + selector.AddReceive(syncChan, func(c workflow.ReceiveChannel, _ bool) { + var s model.NormalizeSignal + c.ReceiveAsync(&s) + if s.Done { + stopLoop = true + } + if s.SyncBatchID > syncBatchID { + syncBatchID = s.SyncBatchID + } + if len(s.TableNameSchemaMapping) != 0 { + tableNameSchemaMapping = s.TableNameSchemaMapping + } + }) + for !stopLoop { + selector.Select(ctx) + for !canceled && selector.HasPending() { + selector.Select(ctx) + } + if canceled || (stopLoop && lastSyncBatchID == syncBatchID) { + if canceled { + logger.Info("normalize canceled - ", config.FlowJobName) + } else { + logger.Info("normalize finished - ", config.FlowJobName) + } + break + } + if lastSyncBatchID != syncBatchID { + lastSyncBatchID = syncBatchID + + logger.Info("executing normalize - ", config.FlowJobName) + startNormalizeInput := &protos.StartNormalizeInput{ + FlowConnectionConfigs: config, + TableNameSchemaMapping: tableNameSchemaMapping, + SyncBatchID: syncBatchID, + } + fStartNormalize := workflow.ExecuteActivity(normalizeFlowCtx, flowable.StartNormalize, startNormalizeInput) + + var normalizeResponse *model.NormalizeResponse + if err := fStartNormalize.Get(normalizeFlowCtx, &normalizeResponse); err != nil { + errors = append(errors, err.Error()) + } else if normalizeResponse != nil { + results = append(results, *normalizeResponse) + } + } - var normalizeResponse *model.NormalizeResponse - if err := fStartNormalize.Get(normalizeFlowCtx, &normalizeResponse); err != nil { - return nil, fmt.Errorf("failed to flow: %w", err) + if !peerdbenv.PeerDBEnableParallelSyncNormalize() { + parent := workflow.GetInfo(ctx).ParentWorkflowExecution + workflow.SignalExternalWorkflow( + ctx, + parent.ID, + parent.RunID, + shared.NormalizeSyncDoneSignalName, + struct{}{}, + ) + } } - return normalizeResponse, nil + return &model.NormalizeFlowResponse{ + Results: results, + Errors: errors, + }, nil } diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index e7e87a41d5..bb4f2323c6 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -362,9 +362,7 @@ func (q *QRepFlowExecution) handleTableRenameForResync(ctx workflow.Context, sta return nil } -func (q *QRepFlowExecution) receiveAndHandleSignalAsync(ctx workflow.Context) { - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) - +func (q *QRepFlowExecution) receiveAndHandleSignalAsync(signalChan workflow.ReceiveChannel) { var signalVal shared.CDCFlowSignal ok := signalChan.ReceiveAsync(&signalVal) if ok { @@ -509,11 +507,11 @@ func QRepFlowWorkflow( // here, we handle signals after the end of the flow because a new workflow does not inherit the signals // and the chance of missing a signal is much higher if the check is before the time consuming parts run - q.receiveAndHandleSignalAsync(ctx) + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) + q.receiveAndHandleSignalAsync(signalChan) if q.activeSignal == shared.PauseSignal { startTime := time.Now() state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for q.activeSignal == shared.PauseSignal { diff --git a/flow/workflows/xmin_flow.go b/flow/workflows/xmin_flow.go index c6885253df..84f6ba73cb 100644 --- a/flow/workflows/xmin_flow.go +++ b/flow/workflows/xmin_flow.go @@ -117,11 +117,11 @@ func XminFlowWorkflow( // here, we handle signals after the end of the flow because a new workflow does not inherit the signals // and the chance of missing a signal is much higher if the check is before the time consuming parts run - q.receiveAndHandleSignalAsync(ctx) + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) + q.receiveAndHandleSignalAsync(signalChan) if x.activeSignal == shared.PauseSignal { startTime := time.Now() state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for x.activeSignal == shared.PauseSignal { diff --git a/protos/flow.proto b/protos/flow.proto index 41ba8bf966..fe0f84ff86 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -125,6 +125,7 @@ message StartFlowInput { message StartNormalizeInput { FlowConnectionConfigs flow_connection_configs = 1; map table_name_schema_mapping = 2; + int64 SyncBatchID = 3; } message GetLastSyncedIDInput {