diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 2eaf9f1d3f..c5680ebe6c 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -198,7 +198,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } defer connectors.CloseConnector(dstConn) - activity.RecordHeartbeat(ctx, "initialized table schema") slog.InfoContext(ctx, "pulling records...") tblNameMapping := make(map[string]model.NameAndExclude) for _, v := range input.FlowConnectionConfigs.TableMappings { @@ -268,6 +267,7 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } return &model.SyncResponse{ + CurrentSyncBatchID: -1, TableSchemaDeltas: recordBatch.SchemaDeltas, RelationMessageMapping: input.RelationMessageMapping, }, nil @@ -379,13 +379,8 @@ func (a *FlowableActivity) StartNormalize( } defer connectors.CloseConnector(dstConn) - lastSyncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get last sync batch ID: %w", err) - } - err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName, - lastSyncBatchID) + input.SyncBatchID) return nil, err } else if err != nil { return nil, err @@ -399,6 +394,7 @@ func (a *FlowableActivity) StartNormalize( res, err := dstConn.NormalizeRecords(&model.NormalizeRecordsRequest{ FlowJobName: input.FlowConnectionConfigs.FlowJobName, + SyncBatchID: input.SyncBatchID, SoftDelete: input.FlowConnectionConfigs.SoftDelete, SoftDeleteColName: input.FlowConnectionConfigs.SoftDeleteColName, SyncedAtColName: input.FlowConnectionConfigs.SyncedAtColName, @@ -423,10 +419,8 @@ func (a *FlowableActivity) StartNormalize( } // log the number of batches normalized - if res != nil { - slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d", - res.StartBatchID, res.EndBatchID)) - } + slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d", + res.StartBatchID, res.EndBatchID)) return res, nil } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 9f501d9fcc..3443993512 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -359,6 +359,7 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { query := fmt.Sprintf("SELECT sync_batch_id FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) q := c.client.Query(query) + q.DisableQueryCache = true q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID it, err := q.Read(c.ctx) @@ -382,37 +383,30 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { } } -func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s WHERE mirror_job_name = '%s'", +func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + query := fmt.Sprintf("SELECT normalize_batch_id FROM %s WHERE mirror_job_name = '%s'", MirrorJobsTable, jobName) q := c.client.Query(query) + q.DisableQueryCache = true q.DefaultProjectID = c.projectID q.DefaultDatasetID = c.datasetID it, err := q.Read(c.ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) - return model.SyncAndNormalizeBatchID{}, err + return 0, err } var row []bigquery.Value err = it.Next(&row) if err != nil { c.logger.Info("no row found for job") - return model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - syncBatchID := int64(0) - normBatchID := int64(0) if row[0] != nil { - syncBatchID = row[0].(int64) - } - if row[1] != nil { - normBatchID = row[1].(int64) + return row[0].(int64), nil } - return model.SyncAndNormalizeBatchID{ - SyncBatchID: syncBatchID, - NormalizeBatchID: normBatchID, - }, nil + return 0, nil } func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -546,7 +540,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro( func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } @@ -557,18 +551,18 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } // if job is not yet found in the peerdb_mirror_jobs_table // OR sync is lagging end normalize - if !hasJob || batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if !hasJob || normBatchID >= req.SyncBatchID { c.logger.Info("waiting for sync to catch up, so finishing") return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } distinctTableNames, err := c.getDistinctTableNamesInBatch( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, fmt.Errorf("couldn't get distinct table names to normalize: %w", err) @@ -576,8 +570,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, fmt.Errorf("couldn't get tablename to unchanged cols mapping: %w", err) @@ -599,8 +593,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) dstTableName: tableName, dstDatasetTable: dstDatasetTable, normalizedTableSchema: req.TableNameSchemaMapping[tableName], - syncBatchID: batchIDs.SyncBatchID, - normalizeBatchID: batchIDs.NormalizeBatchID, + syncBatchID: req.SyncBatchID, + normalizeBatchID: normBatchID, peerdbCols: &protos.PeerDBColumns{ SoftDeleteColName: req.SoftDeleteColName, SyncedAtColName: req.SyncedAtColName, @@ -625,7 +619,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) // update metadata to make the last normalized batch id to the recent last sync batch id. updateMetadataStmt := fmt.Sprintf( "UPDATE %s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", - MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName) + MirrorJobsTable, req.SyncBatchID, req.FlowJobName) query := c.client.Query(updateMetadataStmt) query.DefaultProjectID = c.projectID @@ -637,8 +631,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index c4c9a879dc..c6ae8f7d7f 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -38,11 +38,11 @@ const ( createRawTableBatchIDIndexSQL = "CREATE INDEX IF NOT EXISTS %s_batchid_idx ON %s.%s(_peerdb_batch_id)" createRawTableDstTableIndexSQL = "CREATE INDEX IF NOT EXISTS %s_dst_table_idx ON %s.%s(_peerdb_destination_table_name)" - getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1" - setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2" - getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1" - getLastSyncAndNormalizeBatchID_SQL = "SELECT sync_batch_id,normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1" - createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)" + getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1" + setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2" + getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1" + getLastNormalizeBatchID_SQL = "SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1" + createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)" upsertJobMetadataForSyncSQL = `INSERT INTO %s.%s AS j VALUES ($1,$2,$3,$4) ON CONFLICT(mirror_job_name) DO UPDATE SET lsn_offset=GREATEST(j.lsn_offset, EXCLUDED.lsn_offset), sync_batch_id=EXCLUDED.sync_batch_id` @@ -471,24 +471,21 @@ func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *PostgresConnector) GetLastSyncAndNormalizeBatchID(jobName string) (*model.SyncAndNormalizeBatchID, error) { - var syncResult, normalizeResult pgtype.Int8 +func (c *PostgresConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + var result pgtype.Int8 err := c.pool.QueryRow(c.ctx, fmt.Sprintf( - getLastSyncAndNormalizeBatchID_SQL, + getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier, - ), jobName).Scan(&syncResult, &normalizeResult) + ), jobName).Scan(&result) if err != nil { if err == pgx.ErrNoRows { c.logger.Info("No row found, returning 0") - return &model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - return nil, fmt.Errorf("error while reading result row: %w", err) + return 0, fmt.Errorf("error while reading result row: %w", err) } - return &model.SyncAndNormalizeBatchID{ - SyncBatchID: syncResult.Int64, - NormalizeBatchID: normalizeResult.Int64, - }, nil + return result.Int64, nil } func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 79c6fb29a0..231d5f78f2 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -431,28 +431,29 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) }, nil } - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } + // normalize has caught up with sync, chill until more records are loaded. - if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if normBatchID >= req.SyncBatchID { c.logger.Info(fmt.Sprintf("no records to normalize: syncBatchID %d, normalizeBatchID %d", - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID)) + req.SyncBatchID, normBatchID)) return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } destinationTableNames, err := c.getDistinctTableNamesInBatch( - req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, err } unchangedToastColsMap, err := c.getTableNametoUnchangedCols(req.FlowJobName, - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + req.SyncBatchID, normBatchID) if err != nil { return nil, err } @@ -491,7 +492,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } normalizeStatements := normalizeStmtGen.generateNormalizeStatements() for _, normalizeStatement := range normalizeStatements { - mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec( + mergeStatementsBatch.Queue(normalizeStatement, normBatchID, req.SyncBatchID, destinationTableName).Exec( func(ct pgconn.CommandTag) error { totalRowsAffected += int(ct.RowsAffected()) return nil @@ -508,7 +509,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) c.logger.Info(fmt.Sprintf("normalized %d records", totalRowsAffected)) // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID, normalizeRecordsTx) + err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID, normalizeRecordsTx) if err != nil { return nil, err } @@ -520,8 +521,8 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } @@ -711,7 +712,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. -func (c *PostgresConnector) ReplayTableSchemaDeltas(flowJobName string, +func (c *PostgresConnector) ReplayTableSchemaDeltas( + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { if len(schemaDeltas) == 0 { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 763de043dd..ac8e8badcf 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -70,15 +70,15 @@ const ( checkIfTableExistsSQL = `SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? and TABLE_NAME=?` - checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" - getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" - setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" - getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" - getLastSyncNormalizeBatchID_SQL = "SELECT SYNC_BATCH_ID, NORMALIZE_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" - dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" - deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" - dropSchemaIfExistsSQL = "DROP SCHEMA IF EXISTS %s" - checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" + checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" + getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" + setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" + getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" + getLastNormalizeBatchID_SQL = "SELECT NORMALIZE_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" + dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" + deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" + dropSchemaIfExistsSQL = "DROP SCHEMA IF EXISTS %s" + checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" ) type SnowflakeConnector struct { @@ -381,28 +381,24 @@ func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *SnowflakeConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastSyncNormalizeBatchID_SQL, c.metadataSchema, +func (c *SnowflakeConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) if err != nil { - return model.SyncAndNormalizeBatchID{}, - fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) + return 0, fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) } defer rows.Close() - var syncResult, normResult pgtype.Int8 + var normBatchID pgtype.Int8 if !rows.Next() { c.logger.Warn("No row found, returning 0") - return model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - err = rows.Scan(&syncResult, &normResult) + err = rows.Scan(&normBatchID) if err != nil { - return model.SyncAndNormalizeBatchID{}, fmt.Errorf("error while reading result row: %w", err) + return 0, fmt.Errorf("error while reading result row: %w", err) } - return model.SyncAndNormalizeBatchID{ - SyncBatchID: syncResult.Int64, - NormalizeBatchID: normResult.Int64, - }, nil + return normBatchID.Int64, nil } func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -633,16 +629,17 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( // NormalizeRecords normalizes raw table to destination table. func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, err } + // normalize has caught up with sync, chill until more records are loaded. - if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if normBatchID >= req.SyncBatchID { return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } @@ -658,14 +655,14 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } destinationTableNames, err := c.getDistinctTableNamesInBatch( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, err } - tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } @@ -681,8 +678,8 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest mergeGen := &mergeStmtGenerator{ rawTableName: getRawTableIdentifier(req.FlowJobName), dstTableName: tableName, - syncBatchID: batchIDs.SyncBatchID, - normalizeBatchID: batchIDs.NormalizeBatchID, + syncBatchID: req.SyncBatchID, + normalizeBatchID: normBatchID, normalizedTableSchema: req.TableNameSchemaMapping[tableName], unchangedToastColumns: tableNametoUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ @@ -728,15 +725,15 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID) + err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID) if err != nil { return nil, err } return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } diff --git a/flow/model/model.go b/flow/model/model.go index b008e624f1..d8e3a7c751 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -507,11 +507,6 @@ func (r *CDCRecordStream) GetRecords() <-chan Record { return r.records } -type SyncAndNormalizeBatchID struct { - SyncBatchID int64 - NormalizeBatchID int64 -} - type SyncRecordsRequest struct { SyncBatchID int64 Records *CDCRecordStream @@ -525,6 +520,7 @@ type SyncRecordsRequest struct { type NormalizeRecordsRequest struct { FlowJobName string + SyncBatchID int64 SoftDelete bool SoftDeleteColName string SyncedAtColName string @@ -546,6 +542,17 @@ type SyncResponse struct { RelationMessageMapping RelationMessageMapping } +type NormalizeSignal struct { + Done bool + SyncBatchID int64 + TableNameSchemaMapping map[string]*protos.TableSchema +} + +type NormalizeFlowResponse struct { + Results []NormalizeResponse + Errors []string +} + type NormalizeResponse struct { // Flag to depict if normalization is done Done bool diff --git a/flow/peerdbenv/config.go b/flow/peerdbenv/config.go index ca238899ac..65254de73a 100644 --- a/flow/peerdbenv/config.go +++ b/flow/peerdbenv/config.go @@ -74,3 +74,8 @@ func PeerDBCatalogDatabase() string { func PeerDBEnableWALHeartbeat() bool { return getEnvBool("PEERDB_ENABLE_WAL_HEARTBEAT", false) } + +// PEERDB_ENABLE_PARALLEL_SYNC_NORMALIZE +func PeerDBEnableParallelSyncNormalize() bool { + return getEnvBool("PEERDB_ENABLE_PARALLEL_SYNC_NORMALIZE", false) +} diff --git a/flow/shared/constants.go b/flow/shared/constants.go index 119514fb76..a75f9a0f52 100644 --- a/flow/shared/constants.go +++ b/flow/shared/constants.go @@ -14,6 +14,8 @@ const ( // Signals FlowSignalName = "peer-flow-signal" CDCDynamicPropertiesSignalName = "cdc-dynamic-properties" + NormalizeSyncSignalName = "normalize-sync" + NormalizeSyncDoneSignalName = "normalize-sync-done" // Queries CDCFlowStateQuery = "q-cdc-flow-status" diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index f2a5a9cddd..a0d73aa593 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -1,6 +1,7 @@ package peerflow import ( + "errors" "fmt" "log/slog" "strings" @@ -15,6 +16,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/peerdbenv" "github.com/PeerDB-io/peer-flow/shared" ) @@ -26,7 +28,7 @@ type CDCFlowLimits struct { // Maximum number of rows in a sync flow batch. MaxBatchSize uint32 // Rows synced after which we can say a test is done. - ExitAfterRecords int + ExitAfterRecords int64 } type CDCFlowWorkflowState struct { @@ -35,7 +37,7 @@ type CDCFlowWorkflowState struct { // Accumulates status for sync flows spawned. SyncFlowStatuses []*model.SyncResponse // Accumulates status for sync flows spawned. - NormalizeFlowStatuses []*model.NormalizeResponse + NormalizeFlowStatuses []model.NormalizeResponse // Current signalled state of the peer flow. ActiveSignal shared.CDCFlowSignal // Errors encountered during child sync flow executions. @@ -137,7 +139,7 @@ type CDCFlowWorkflowResult = CDCFlowWorkflowState func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Context, cfg *protos.FlowConnectionConfigs, state *CDCFlowWorkflowState, - limits *CDCFlowLimits, mirrorNameSearch *map[string]interface{}, + limits *CDCFlowLimits, mirrorNameSearch map[string]interface{}, ) error { for _, flowConfigUpdate := range state.FlowConfigUpdates { if len(flowConfigUpdate.AdditionalTables) == 0 { @@ -177,7 +179,7 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Cont RetryPolicy: &temporal.RetryPolicy{ MaximumAttempts: 20, }, - SearchAttributes: *mirrorNameSearch, + SearchAttributes: mirrorNameSearch, WaitForCancellation: true, } childAdditionalTablesCDCFlowCtx := workflow.WithChildOptions(ctx, childAdditionalTablesCDCFlowOpts) @@ -393,19 +395,65 @@ func CDCFlowWorkflowWithConfig( }) currentSyncFlowNum := 0 - totalRecordsSynced := 0 + totalRecordsSynced := int64(0) + + normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) + if err != nil { + return state, err + } + + childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ + WorkflowID: normalizeFlowID, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 20, + }, + SearchAttributes: mirrorNameSearch, + WaitForCancellation: true, + } + normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) + childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( + normCtx, + NormalizeFlowWorkflow, + cfg, + normalizeFlowOptions, + ) + + var normWaitChan workflow.ReceiveChannel + if !peerdbenv.PeerDBEnableParallelSyncNormalize() { + normWaitChan = workflow.GetSignalChannel(ctx, shared.NormalizeSyncDoneSignalName) + } + + finishNormalize := func() { + childNormalizeFlowFuture.SignalChildWorkflow(ctx, shared.NormalizeSyncSignalName, model.NormalizeSignal{ + Done: true, + SyncBatchID: -1, + }) + var childNormalizeFlowRes *model.NormalizeFlowResponse + if err := childNormalizeFlowFuture.Get(ctx, &childNormalizeFlowRes); err != nil { + w.logger.Error("failed to execute normalize flow: ", err) + var panicErr *temporal.PanicError + if errors.As(err, &panicErr) { + w.logger.Error("PANIC", panicErr.Error(), panicErr.StackTrace()) + } + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, err.Error()) + } else { + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, childNormalizeFlowRes.Errors...) + state.NormalizeFlowStatuses = append(state.NormalizeFlowStatuses, childNormalizeFlowRes.Results...) + } + } var canceled bool signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) mainLoopSelector := workflow.NewSelector(ctx) + mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + canceled = true + }) mainLoopSelector.AddReceive(signalChan, func(c workflow.ReceiveChannel, _ bool) { var signalVal shared.CDCFlowSignal c.ReceiveAsync(&signalVal) state.ActiveSignal = shared.FlowSignalHandler(state.ActiveSignal, signalVal, w.logger) }) - mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { - canceled = true - }) for { for !canceled && mainLoopSelector.HasPending() { @@ -429,7 +477,7 @@ func CDCFlowWorkflowWithConfig( state.ActiveSignal = shared.FlowSignalHandler(state.ActiveSignal, signalVal, w.logger) // only process config updates when going from STATUS_PAUSED to STATUS_RUNNING if state.ActiveSignal == shared.NoopSignal { - err = w.processCDCFlowConfigUpdates(ctx, cfg, state, limits, &mirrorNameSearch) + err = w.processCDCFlowConfigUpdates(ctx, cfg, state, limits, mirrorNameSearch) if err != nil { return state, err } @@ -444,6 +492,7 @@ func CDCFlowWorkflowWithConfig( // check if the peer flow has been shutdown if state.ActiveSignal == shared.ShutdownSignal { + finishNormalize() w.logger.Info("peer flow has been shutdown") state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return state, nil @@ -470,6 +519,7 @@ func CDCFlowWorkflowWithConfig( syncFlowID, err := GetChildWorkflowID(ctx, "sync-flow", cfg.FlowJobName) if err != nil { + finishNormalize() return state, err } @@ -493,53 +543,66 @@ func CDCFlowWorkflowWithConfig( ) var syncDone bool - var childSyncFlowRes *model.SyncResponse + var normalizeSignalError error + normDone := normWaitChan == nil mainLoopSelector.AddFuture(childSyncFlowFuture, func(f workflow.Future) { syncDone = true + + var childSyncFlowRes *model.SyncResponse if err := f.Get(syncCtx, &childSyncFlowRes); err != nil { w.logger.Error("failed to execute sync flow: ", err) state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) } else if childSyncFlowRes != nil { state.SyncFlowStatuses = append(state.SyncFlowStatuses, childSyncFlowRes) state.RelationMessageMapping = childSyncFlowRes.RelationMessageMapping - totalRecordsSynced += int(childSyncFlowRes.NumRecordsSynced) + totalRecordsSynced += childSyncFlowRes.NumRecordsSynced w.logger.Info("Total records synced: ", totalRecordsSynced) } - var tableSchemaDeltas []*protos.TableSchemaDelta = nil if childSyncFlowRes != nil { - tableSchemaDeltas = childSyncFlowRes.TableSchemaDeltas - } - - // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. - if tableSchemaDeltas != nil { - modifiedSrcTables := make([]string, 0) - modifiedDstTables := make([]string, 0) - - for _, tableSchemaDelta := range tableSchemaDeltas { - modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) - modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) - } + tableSchemaDeltasCount := len(childSyncFlowRes.TableSchemaDeltas) + + var normalizeTableNameSchemaMapping map[string]*protos.TableSchema + // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. + if tableSchemaDeltasCount != 0 { + modifiedSrcTables := make([]string, 0, tableSchemaDeltasCount) + modifiedDstTables := make([]string, 0, tableSchemaDeltasCount) + for _, tableSchemaDelta := range childSyncFlowRes.TableSchemaDeltas { + modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) + modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) + } - getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 5 * time.Minute, - }) - getModifiedSchemaFuture := workflow.ExecuteActivity(getModifiedSchemaCtx, flowable.GetTableSchema, - &protos.GetTableSchemaBatchInput{ - PeerConnectionConfig: cfg.Source, - TableIdentifiers: modifiedSrcTables, - FlowName: cfg.FlowJobName, + getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, }) - - var getModifiedSchemaRes *protos.GetTableSchemaBatchOutput - if err := getModifiedSchemaFuture.Get(ctx, &getModifiedSchemaRes); err != nil { - w.logger.Error("failed to execute schema update at source: ", err) - state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) - } else { - for i := range modifiedSrcTables { - state.TableNameSchemaMapping[modifiedDstTables[i]] = getModifiedSchemaRes.TableNameSchemaMapping[modifiedSrcTables[i]] + getModifiedSchemaFuture := workflow.ExecuteActivity(getModifiedSchemaCtx, flowable.GetTableSchema, + &protos.GetTableSchemaBatchInput{ + PeerConnectionConfig: cfg.Source, + TableIdentifiers: modifiedSrcTables, + FlowName: cfg.FlowJobName, + }) + + var getModifiedSchemaRes *protos.GetTableSchemaBatchOutput + if err := getModifiedSchemaFuture.Get(ctx, &getModifiedSchemaRes); err != nil { + w.logger.Error("failed to execute schema update at source: ", err) + state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) + } else { + for i, srcTable := range modifiedSrcTables { + dstTable := modifiedDstTables[i] + state.TableNameSchemaMapping[dstTable] = getModifiedSchemaRes.TableNameSchemaMapping[srcTable] + } + normalizeTableNameSchemaMapping = state.TableNameSchemaMapping } } + + signalFuture := childNormalizeFlowFuture.SignalChildWorkflow(ctx, shared.NormalizeSyncSignalName, model.NormalizeSignal{ + Done: false, + SyncBatchID: childSyncFlowRes.CurrentSyncBatchID, + TableNameSchemaMapping: normalizeTableNameSchemaMapping, + }) + normalizeSignalError = signalFuture.Get(ctx, nil) + } else { + normDone = true } }) @@ -555,38 +618,15 @@ func CDCFlowWorkflowWithConfig( state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return state, nil } - - normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) - if err != nil { - return state, err + if normalizeSignalError != nil { + return state, normalizeSignalError } - - childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ - WorkflowID: normalizeFlowID, - ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, - RetryPolicy: &temporal.RetryPolicy{ - MaximumAttempts: 20, - }, - SearchAttributes: mirrorNameSearch, - WaitForCancellation: true, - } - normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) - childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( - normCtx, - NormalizeFlowWorkflow, - cfg, - normalizeFlowOptions, - ) - - var childNormalizeFlowRes *model.NormalizeResponse - if err := childNormalizeFlowFuture.Get(normCtx, &childNormalizeFlowRes); err != nil { - w.logger.Error("failed to execute normalize flow: ", err) - state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, err.Error()) - } else { - state.NormalizeFlowStatuses = append(state.NormalizeFlowStatuses, childNormalizeFlowRes) + if !normDone { + normWaitChan.Receive(ctx, nil) } } + finishNormalize() state.TruncateProgress(w.logger) return state, workflow.NewContinueAsNewError(ctx, CDCFlowWorkflowWithConfig, cfg, limits, state) } diff --git a/flow/workflows/normalize_flow.go b/flow/workflows/normalize_flow.go index 70f6463aef..a598c51792 100644 --- a/flow/workflows/normalize_flow.go +++ b/flow/workflows/normalize_flow.go @@ -4,66 +4,97 @@ import ( "fmt" "time" - "go.temporal.io/sdk/log" "go.temporal.io/sdk/workflow" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/peerdbenv" + "github.com/PeerDB-io/peer-flow/shared" ) -type NormalizeFlowState struct { - CDCFlowName string - Progress []string -} - -type NormalizeFlowExecution struct { - NormalizeFlowState - executionID string - logger log.Logger -} - -func NewNormalizeFlowExecution(ctx workflow.Context, state *NormalizeFlowState) *NormalizeFlowExecution { - return &NormalizeFlowExecution{ - NormalizeFlowState: *state, - executionID: workflow.GetInfo(ctx).WorkflowExecution.ID, - logger: workflow.GetLogger(ctx), - } -} - func NormalizeFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionConfigs, options *protos.NormalizeFlowOptions, -) (*model.NormalizeResponse, error) { - s := NewNormalizeFlowExecution(ctx, &NormalizeFlowState{ - CDCFlowName: config.FlowJobName, - Progress: []string{}, - }) - - return s.executeNormalizeFlow(ctx, config, options) -} - -func (s *NormalizeFlowExecution) executeNormalizeFlow( - ctx workflow.Context, - config *protos.FlowConnectionConfigs, - options *protos.NormalizeFlowOptions, -) (*model.NormalizeResponse, error) { - s.logger.Info("executing normalize flow - ", s.CDCFlowName) +) (*model.NormalizeFlowResponse, error) { + logger := workflow.GetLogger(ctx) + tableNameSchemaMapping := options.TableNameSchemaMapping normalizeFlowCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ StartToCloseTimeout: 7 * 24 * time.Hour, HeartbeatTimeout: time.Minute, }) - startNormalizeInput := &protos.StartNormalizeInput{ - FlowConnectionConfigs: config, - TableNameSchemaMapping: options.TableNameSchemaMapping, - } - fStartNormalize := workflow.ExecuteActivity(normalizeFlowCtx, flowable.StartNormalize, startNormalizeInput) + results := make([]model.NormalizeResponse, 0, 4) + errors := make([]string, 0) + syncChan := workflow.GetSignalChannel(ctx, shared.NormalizeSyncSignalName) + + var stopLoop, canceled bool + var lastSyncBatchID, syncBatchID int64 + lastSyncBatchID = -1 + syncBatchID = -1 + selector := workflow.NewNamedSelector(ctx, fmt.Sprintf("%s-normalize", config.FlowJobName)) + selector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + canceled = true + }) + selector.AddReceive(syncChan, func(c workflow.ReceiveChannel, _ bool) { + var s model.NormalizeSignal + c.ReceiveAsync(&s) + if s.Done { + stopLoop = true + } + if s.SyncBatchID > syncBatchID { + syncBatchID = s.SyncBatchID + } + if len(s.TableNameSchemaMapping) != 0 { + tableNameSchemaMapping = s.TableNameSchemaMapping + } + }) + for !stopLoop { + selector.Select(ctx) + for !canceled && selector.HasPending() { + selector.Select(ctx) + } + if canceled || (stopLoop && lastSyncBatchID == syncBatchID) { + if canceled { + logger.Info("normalize canceled - ", config.FlowJobName) + } else { + logger.Info("normalize finished - ", config.FlowJobName) + } + break + } + if lastSyncBatchID != syncBatchID { + lastSyncBatchID = syncBatchID + + logger.Info("executing normalize - ", config.FlowJobName) + startNormalizeInput := &protos.StartNormalizeInput{ + FlowConnectionConfigs: config, + TableNameSchemaMapping: tableNameSchemaMapping, + SyncBatchID: syncBatchID, + } + fStartNormalize := workflow.ExecuteActivity(normalizeFlowCtx, flowable.StartNormalize, startNormalizeInput) + + var normalizeResponse *model.NormalizeResponse + if err := fStartNormalize.Get(normalizeFlowCtx, &normalizeResponse); err != nil { + errors = append(errors, err.Error()) + } else if normalizeResponse != nil { + results = append(results, *normalizeResponse) + } + } - var normalizeResponse *model.NormalizeResponse - if err := fStartNormalize.Get(normalizeFlowCtx, &normalizeResponse); err != nil { - return nil, fmt.Errorf("failed to flow: %w", err) + if !peerdbenv.PeerDBEnableParallelSyncNormalize() { + parent := workflow.GetInfo(ctx).ParentWorkflowExecution + workflow.SignalExternalWorkflow( + ctx, + parent.ID, + parent.RunID, + shared.NormalizeSyncDoneSignalName, + struct{}{}, + ) + } } - return normalizeResponse, nil + return &model.NormalizeFlowResponse{ + Results: results, + Errors: errors, + }, nil } diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index e7e87a41d5..bb4f2323c6 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -362,9 +362,7 @@ func (q *QRepFlowExecution) handleTableRenameForResync(ctx workflow.Context, sta return nil } -func (q *QRepFlowExecution) receiveAndHandleSignalAsync(ctx workflow.Context) { - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) - +func (q *QRepFlowExecution) receiveAndHandleSignalAsync(signalChan workflow.ReceiveChannel) { var signalVal shared.CDCFlowSignal ok := signalChan.ReceiveAsync(&signalVal) if ok { @@ -509,11 +507,11 @@ func QRepFlowWorkflow( // here, we handle signals after the end of the flow because a new workflow does not inherit the signals // and the chance of missing a signal is much higher if the check is before the time consuming parts run - q.receiveAndHandleSignalAsync(ctx) + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) + q.receiveAndHandleSignalAsync(signalChan) if q.activeSignal == shared.PauseSignal { startTime := time.Now() state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for q.activeSignal == shared.PauseSignal { diff --git a/flow/workflows/xmin_flow.go b/flow/workflows/xmin_flow.go index c6885253df..84f6ba73cb 100644 --- a/flow/workflows/xmin_flow.go +++ b/flow/workflows/xmin_flow.go @@ -117,11 +117,11 @@ func XminFlowWorkflow( // here, we handle signals after the end of the flow because a new workflow does not inherit the signals // and the chance of missing a signal is much higher if the check is before the time consuming parts run - q.receiveAndHandleSignalAsync(ctx) + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) + q.receiveAndHandleSignalAsync(signalChan) if x.activeSignal == shared.PauseSignal { startTime := time.Now() state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED - signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for x.activeSignal == shared.PauseSignal { diff --git a/protos/flow.proto b/protos/flow.proto index 41ba8bf966..fe0f84ff86 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -125,6 +125,7 @@ message StartFlowInput { message StartNormalizeInput { FlowConnectionConfigs flow_connection_configs = 1; map table_name_schema_mapping = 2; + int64 SyncBatchID = 3; } message GetLastSyncedIDInput {