diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index e8cdebbcc3..20917d7b08 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -197,7 +197,6 @@ func (a *FlowableActivity) StartFlow(ctx context.Context, } defer connectors.CloseConnector(dstConn) - activity.RecordHeartbeat(ctx, "initialized table schema") slog.InfoContext(ctx, "pulling records...") tblNameMapping := make(map[string]model.NameAndExclude) for _, v := range input.FlowConnectionConfigs.TableMappings { @@ -379,26 +378,22 @@ func (a *FlowableActivity) StartNormalize( } defer connectors.CloseConnector(dstConn) - lastSyncBatchID, err := dstConn.GetLastSyncBatchID(input.FlowConnectionConfigs.FlowJobName) - if err != nil { - return nil, fmt.Errorf("failed to get last sync batch ID: %v", err) - } - err = monitoring.UpdateEndTimeForCDCBatch(ctx, a.CatalogPool, input.FlowConnectionConfigs.FlowJobName, - lastSyncBatchID) + input.SyncBatchID) return nil, err } else if err != nil { return nil, err } defer connectors.CloseConnector(dstConn) - shutdown := utils.HeartbeatRoutine(ctx, 2*time.Minute, func() string { + shutdown := utils.HeartbeatRoutine(ctx, 15*time.Second, func() string { return fmt.Sprintf("normalizing records from batch for job - %s", input.FlowConnectionConfigs.FlowJobName) }) defer shutdown() res, err := dstConn.NormalizeRecords(&model.NormalizeRecordsRequest{ FlowJobName: input.FlowConnectionConfigs.FlowJobName, + SyncBatchID: input.SyncBatchID, SoftDelete: input.FlowConnectionConfigs.SoftDelete, SoftDeleteColName: input.FlowConnectionConfigs.SoftDeleteColName, SyncedAtColName: input.FlowConnectionConfigs.SyncedAtColName, @@ -423,10 +418,8 @@ func (a *FlowableActivity) StartNormalize( } // log the number of batches normalized - if res != nil { - slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d\n", - res.StartBatchID, res.EndBatchID)) - } + slog.InfoContext(ctx, fmt.Sprintf("normalized records from batch %d to batch %d\n", + res.StartBatchID, res.EndBatchID)) return res, nil } diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 7177e4bb42..15ee6861e6 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -340,6 +340,7 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { query := fmt.Sprintf("SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) q := c.client.Query(query) + q.DisableQueryCache = true it, err := q.Read(c.ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) @@ -361,35 +362,28 @@ func (c *BigQueryConnector) GetLastSyncBatchID(jobName string) (int64, error) { } } -func (c *BigQueryConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - query := fmt.Sprintf("SELECT sync_batch_id, normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", +func (c *BigQueryConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + query := fmt.Sprintf("SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name = '%s'", c.datasetID, MirrorJobsTable, jobName) q := c.client.Query(query) + q.DisableQueryCache = true it, err := q.Read(c.ctx) if err != nil { err = fmt.Errorf("failed to run query %s on BigQuery:\n %w", query, err) - return model.SyncAndNormalizeBatchID{}, err + return 0, err } var row []bigquery.Value err = it.Next(&row) if err != nil { c.logger.Info("no row found for job") - return model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - syncBatchID := int64(0) - normBatchID := int64(0) if row[0] != nil { - syncBatchID = row[0].(int64) - } - if row[1] != nil { - normBatchID = row[1].(int64) + return row[0].(int64), nil } - return model.SyncAndNormalizeBatchID{ - SyncBatchID: syncBatchID, - NormalizeBatchID: normBatchID, - }, nil + return 0, nil } func (c *BigQueryConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -527,7 +521,7 @@ func (c *BigQueryConnector) syncRecordsViaAvro( func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { rawTableName := c.getRawTableName(req.FlowJobName) - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } @@ -538,18 +532,18 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } // if job is not yet found in the peerdb_mirror_jobs_table // OR sync is lagging end normalize - if !hasJob || batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if !hasJob || normBatchID >= req.SyncBatchID { c.logger.Info("waiting for sync to catch up, so finishing") return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } distinctTableNames, err := c.getDistinctTableNamesInBatch( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, fmt.Errorf("couldn't get distinct table names to normalize: %w", err) @@ -557,8 +551,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, fmt.Errorf("couldn't get tablename to unchanged cols mapping: %w", err) @@ -579,8 +573,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) dstTableName: tableName, dstDatasetTable: dstDatasetTable, normalizedTableSchema: req.TableNameSchemaMapping[tableName], - syncBatchID: batchIDs.SyncBatchID, - normalizeBatchID: batchIDs.NormalizeBatchID, + syncBatchID: req.SyncBatchID, + normalizeBatchID: normBatchID, peerdbCols: &protos.PeerDBColumns{ SoftDeleteColName: req.SoftDeleteColName, SyncedAtColName: req.SyncedAtColName, @@ -603,7 +597,7 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) // update metadata to make the last normalized batch id to the recent last sync batch id. updateMetadataStmt := fmt.Sprintf( "UPDATE %s.%s SET normalize_batch_id=%d WHERE mirror_job_name='%s';", - c.datasetID, MirrorJobsTable, batchIDs.SyncBatchID, req.FlowJobName) + c.datasetID, MirrorJobsTable, req.SyncBatchID, req.FlowJobName) _, err = c.client.Query(updateMetadataStmt).Read(c.ctx) if err != nil { @@ -612,8 +606,8 @@ func (c *BigQueryConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 8c8113911b..57504c96fa 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -29,11 +29,11 @@ const ( createRawTableBatchIDIndexSQL = "CREATE INDEX IF NOT EXISTS %s_batchid_idx ON %s.%s(_peerdb_batch_id)" createRawTableDstTableIndexSQL = "CREATE INDEX IF NOT EXISTS %s_dst_table_idx ON %s.%s(_peerdb_destination_table_name)" - getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1" - setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2" - getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1" - getLastSyncAndNormalizeBatchID_SQL = "SELECT sync_batch_id,normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1" - createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)" + getLastOffsetSQL = "SELECT lsn_offset FROM %s.%s WHERE mirror_job_name=$1" + setLastOffsetSQL = "UPDATE %s.%s SET lsn_offset=GREATEST(lsn_offset, $1) WHERE mirror_job_name=$2" + getLastSyncBatchID_SQL = "SELECT sync_batch_id FROM %s.%s WHERE mirror_job_name=$1" + getLastNormalizeBatchID_SQL = "SELECT normalize_batch_id FROM %s.%s WHERE mirror_job_name=$1" + createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)" insertJobMetadataSQL = "INSERT INTO %s.%s VALUES ($1,$2,$3,$4)" checkIfJobMetadataExistsSQL = "SELECT COUNT(1)::TEXT::BOOL FROM %s.%s WHERE mirror_job_name=$1" @@ -441,24 +441,21 @@ func (c *PostgresConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *PostgresConnector) GetLastSyncAndNormalizeBatchID(jobName string) (*model.SyncAndNormalizeBatchID, error) { - var syncResult, normalizeResult pgtype.Int8 +func (c *PostgresConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + var result pgtype.Int8 err := c.pool.QueryRow(c.ctx, fmt.Sprintf( - getLastSyncAndNormalizeBatchID_SQL, + getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier, - ), jobName).Scan(&syncResult, &normalizeResult) + ), jobName).Scan(&result) if err != nil { if err == pgx.ErrNoRows { c.logger.Info("No row found, returning 0") - return &model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - return nil, fmt.Errorf("error while reading result row: %w", err) + return 0, fmt.Errorf("error while reading result row: %w", err) } - return &model.SyncAndNormalizeBatchID{ - SyncBatchID: syncResult.Int64, - NormalizeBatchID: normalizeResult.Int64, - }, nil + return result.Int64, nil } func (c *PostgresConnector) jobMetadataExists(jobName string) (bool, error) { diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index d0a7f4db52..eb0071e218 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -364,6 +364,8 @@ func (c *PostgresConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S return &model.SyncResponse{ LastSyncedCheckPointID: 0, NumRecordsSynced: 0, + TableSchemaDeltas: tableSchemaDeltas, + RelationMessageMapping: <-req.Records.RelationMessageMapping, }, nil } @@ -436,28 +438,29 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) }, nil } - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, fmt.Errorf("failed to get batch for the current mirror: %v", err) } + // normalize has caught up with sync, chill until more records are loaded. - if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if normBatchID >= req.SyncBatchID { c.logger.Info(fmt.Sprintf("no records to normalize: syncBatchID %d, normalizeBatchID %d", - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID)) + req.SyncBatchID, normBatchID)) return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } destinationTableNames, err := c.getDistinctTableNamesInBatch( - req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, err } unchangedToastColsMap, err := c.getTableNametoUnchangedCols(req.FlowJobName, - batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + req.SyncBatchID, normBatchID) if err != nil { return nil, err } @@ -496,7 +499,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) } normalizeStatements := normalizeStmtGen.generateNormalizeStatements() for _, normalizeStatement := range normalizeStatements { - mergeStatementsBatch.Queue(normalizeStatement, batchIDs.NormalizeBatchID, batchIDs.SyncBatchID, destinationTableName).Exec( + mergeStatementsBatch.Queue(normalizeStatement, normBatchID, req.SyncBatchID, destinationTableName).Exec( func(ct pgconn.CommandTag) error { totalRowsAffected += int(ct.RowsAffected()) return nil @@ -513,7 +516,7 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) c.logger.Info(fmt.Sprintf("normalized %d records", totalRowsAffected)) // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID, normalizeRecordsTx) + err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID, normalizeRecordsTx) if err != nil { return nil, err } @@ -525,8 +528,8 @@ func (c *PostgresConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } @@ -720,7 +723,8 @@ func (c *PostgresConnector) SetupNormalizedTables(req *protos.SetupNormalizedTab // ReplayTableSchemaDelta changes a destination table to match the schema at source // This could involve adding or dropping multiple columns. -func (c *PostgresConnector) ReplayTableSchemaDeltas(flowJobName string, +func (c *PostgresConnector) ReplayTableSchemaDeltas( + flowJobName string, schemaDeltas []*protos.TableSchemaDelta, ) error { if len(schemaDeltas) == 0 { diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 10bb93b22c..e4d62a5f01 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -69,14 +69,14 @@ const ( checkIfTableExistsSQL = `SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=? and TABLE_NAME=?` - checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" - getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" - setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" - getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" - getLastSyncNormalizeBatchID_SQL = "SELECT SYNC_BATCH_ID, NORMALIZE_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" - dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" - deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" - checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" + checkIfJobMetadataExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM %s.%s WHERE MIRROR_JOB_NAME=?" + getLastOffsetSQL = "SELECT OFFSET FROM %s.%s WHERE MIRROR_JOB_NAME=?" + setLastOffsetSQL = "UPDATE %s.%s SET OFFSET=GREATEST(OFFSET, ?) WHERE MIRROR_JOB_NAME=?" + getLastSyncBatchID_SQL = "SELECT SYNC_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" + getLastNormalizeBatchID_SQL = "SELECT NORMALIZE_BATCH_ID FROM %s.%s WHERE MIRROR_JOB_NAME=?" + dropTableIfExistsSQL = "DROP TABLE IF EXISTS %s.%s" + deleteJobMetadataSQL = "DELETE FROM %s.%s WHERE MIRROR_JOB_NAME=?" + checkSchemaExistsSQL = "SELECT TO_BOOLEAN(COUNT(1)) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME=?" ) type SnowflakeConnector struct { @@ -316,28 +316,24 @@ func (c *SnowflakeConnector) GetLastSyncBatchID(jobName string) (int64, error) { return result.Int64, nil } -func (c *SnowflakeConnector) GetLastSyncAndNormalizeBatchID(jobName string) (model.SyncAndNormalizeBatchID, error) { - rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastSyncNormalizeBatchID_SQL, c.metadataSchema, +func (c *SnowflakeConnector) GetLastNormalizeBatchID(jobName string) (int64, error) { + rows, err := c.database.QueryContext(c.ctx, fmt.Sprintf(getLastNormalizeBatchID_SQL, c.metadataSchema, mirrorJobsTableIdentifier), jobName) if err != nil { - return model.SyncAndNormalizeBatchID{}, - fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) + return 0, fmt.Errorf("error querying Snowflake peer for last normalizeBatchId: %w", err) } defer rows.Close() - var syncResult, normResult pgtype.Int8 + var normBatchID pgtype.Int8 if !rows.Next() { c.logger.Warn("No row found, returning 0") - return model.SyncAndNormalizeBatchID{}, nil + return 0, nil } - err = rows.Scan(&syncResult, &normResult) + err = rows.Scan(&normBatchID) if err != nil { - return model.SyncAndNormalizeBatchID{}, fmt.Errorf("error while reading result row: %w", err) + return 0, fmt.Errorf("error while reading result row: %w", err) } - return model.SyncAndNormalizeBatchID{ - SyncBatchID: syncResult.Int64, - NormalizeBatchID: normResult.Int64, - }, nil + return normBatchID.Int64, nil } func (c *SnowflakeConnector) getDistinctTableNamesInBatch(flowJobName string, syncBatchID int64, @@ -575,16 +571,17 @@ func (c *SnowflakeConnector) syncRecordsViaAvro( // NormalizeRecords normalizes raw table to destination table. func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest) (*model.NormalizeResponse, error) { - batchIDs, err := c.GetLastSyncAndNormalizeBatchID(req.FlowJobName) + normBatchID, err := c.GetLastNormalizeBatchID(req.FlowJobName) if err != nil { return nil, err } + // normalize has caught up with sync, chill until more records are loaded. - if batchIDs.NormalizeBatchID >= batchIDs.SyncBatchID { + if normBatchID >= req.SyncBatchID { return &model.NormalizeResponse{ Done: false, - StartBatchID: batchIDs.NormalizeBatchID, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID, + EndBatchID: req.SyncBatchID, }, nil } @@ -600,14 +597,14 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } destinationTableNames, err := c.getDistinctTableNamesInBatch( req.FlowJobName, - batchIDs.SyncBatchID, - batchIDs.NormalizeBatchID, + req.SyncBatchID, + normBatchID, ) if err != nil { return nil, err } - tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, batchIDs.SyncBatchID, batchIDs.NormalizeBatchID) + tableNametoUnchangedToastCols, err := c.getTableNametoUnchangedCols(req.FlowJobName, req.SyncBatchID, normBatchID) if err != nil { return nil, fmt.Errorf("couldn't tablename to unchanged cols mapping: %w", err) } @@ -623,8 +620,8 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest mergeGen := &mergeStmtGenerator{ rawTableName: getRawTableIdentifier(req.FlowJobName), dstTableName: tableName, - syncBatchID: batchIDs.SyncBatchID, - normalizeBatchID: batchIDs.NormalizeBatchID, + syncBatchID: req.SyncBatchID, + normalizeBatchID: normBatchID, normalizedTableSchema: req.TableNameSchemaMapping[tableName], unchangedToastColumns: tableNametoUnchangedToastCols[tableName], peerdbCols: &protos.PeerDBColumns{ @@ -670,15 +667,15 @@ func (c *SnowflakeConnector) NormalizeRecords(req *model.NormalizeRecordsRequest } // updating metadata with new normalizeBatchID - err = c.updateNormalizeMetadata(req.FlowJobName, batchIDs.SyncBatchID) + err = c.updateNormalizeMetadata(req.FlowJobName, req.SyncBatchID) if err != nil { return nil, err } return &model.NormalizeResponse{ Done: true, - StartBatchID: batchIDs.NormalizeBatchID + 1, - EndBatchID: batchIDs.SyncBatchID, + StartBatchID: normBatchID + 1, + EndBatchID: req.SyncBatchID, }, nil } diff --git a/flow/model/model.go b/flow/model/model.go index b3bd44b0ed..776157559d 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -523,11 +523,6 @@ func (r *CDCRecordStream) GetRecords() <-chan Record { return r.records } -type SyncAndNormalizeBatchID struct { - SyncBatchID int64 - NormalizeBatchID int64 -} - type SyncRecordsRequest struct { Records *CDCRecordStream // FlowJobName is the name of the flow job. @@ -542,6 +537,7 @@ type SyncRecordsRequest struct { type NormalizeRecordsRequest struct { FlowJobName string + SyncBatchID int64 SoftDelete bool SoftDeleteColName string SyncedAtColName string @@ -563,6 +559,17 @@ type SyncResponse struct { RelationMessageMapping RelationMessageMapping } +type NormalizeSignal struct { + Done bool + SyncBatchID int64 + TableNameSchemaMapping map[string]*protos.TableSchema +} + +type NormalizeFlowResponse struct { + Results []NormalizeResponse + Errors []string +} + type NormalizeResponse struct { // Flag to depict if normalization is done Done bool diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index 70ac1723ba..9f94e53a7b 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -1,6 +1,7 @@ package peerflow import ( + "errors" "fmt" "log/slog" "strings" @@ -22,13 +23,13 @@ const ( type CDCFlowLimits struct { // Number of sync flows to execute in total. - // If 0, the number of sync flows will be continuously executed until the peer flow is cancelled. + // If 0, the number of sync flows will be continuously executed until the peer flow is canceled. // This is typically non-zero for testing purposes. TotalSyncFlows int // Maximum number of rows in a sync flow batch. MaxBatchSize uint32 // Rows synced after which we can say a test is done. - ExitAfterRecords int + ExitAfterRecords int64 } type CDCFlowWorkflowState struct { @@ -37,7 +38,7 @@ type CDCFlowWorkflowState struct { // Accumulates status for sync flows spawned. SyncFlowStatuses []*model.SyncResponse // Accumulates status for sync flows spawned. - NormalizeFlowStatuses []*model.NormalizeResponse + NormalizeFlowStatuses []model.NormalizeResponse // Current signalled state of the peer flow. ActiveSignal shared.CDCFlowSignal // Errors encountered during child sync flow executions. @@ -337,7 +338,45 @@ func CDCFlowWorkflowWithConfig( }) currentSyncFlowNum := 0 - totalRecordsSynced := 0 + totalRecordsSynced := int64(0) + + normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) + if err != nil { + return state, err + } + + childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ + WorkflowID: normalizeFlowID, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 20, + }, + SearchAttributes: mirrorNameSearch, + WaitForCancellation: true, + } + normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) + childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( + normCtx, + NormalizeFlowWorkflow, + cfg, + normalizeFlowOptions, + ) + + finishNormalize := func() { + childNormalizeFlowFuture.SignalChildWorkflow(ctx, "Sync", model.NormalizeSignal{Done: true}) + var childNormalizeFlowRes *model.NormalizeFlowResponse + if err := childNormalizeFlowFuture.Get(ctx, &childNormalizeFlowRes); err != nil { + w.logger.Error("failed to execute normalize flow: ", err) + var panicErr *temporal.PanicError + if errors.As(err, &panicErr) { + w.logger.Error("PANIC", panicErr.Error(), panicErr.StackTrace()) + } + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, err.Error()) + } else { + state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, childNormalizeFlowRes.Errors...) + state.NormalizeFlowStatuses = append(state.NormalizeFlowStatuses, childNormalizeFlowRes.Results...) + } + } for { // check and act on signals before a fresh flow starts. @@ -369,6 +408,7 @@ func CDCFlowWorkflowWithConfig( // check if the peer flow has been shutdown if state.ActiveSignal == shared.ShutdownSignal { + finishNormalize() w.logger.Info("peer flow has been shutdown") state.CurrentFlowState = protos.FlowStatus_STATUS_TERMINATED return state, nil @@ -394,6 +434,7 @@ func CDCFlowWorkflowWithConfig( syncFlowID, err := GetChildWorkflowID(ctx, "sync-flow", cfg.FlowJobName) if err != nil { + finishNormalize() return state, err } @@ -424,80 +465,59 @@ func CDCFlowWorkflowWithConfig( state.SyncFlowStatuses = append(state.SyncFlowStatuses, childSyncFlowRes) if childSyncFlowRes != nil { state.RelationMessageMapping = childSyncFlowRes.RelationMessageMapping - totalRecordsSynced += int(childSyncFlowRes.NumRecordsSynced) + totalRecordsSynced += childSyncFlowRes.NumRecordsSynced } } w.logger.Info("Total records synced: ", totalRecordsSynced) - var tableSchemaDeltas []*protos.TableSchemaDelta = nil if childSyncFlowRes != nil { - tableSchemaDeltas = childSyncFlowRes.TableSchemaDeltas - } + tableSchemaDeltasCount := len(childSyncFlowRes.TableSchemaDeltas) - // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. - if tableSchemaDeltas != nil { - modifiedSrcTables := make([]string, 0) - modifiedDstTables := make([]string, 0) + var normalizeTableNameSchemaMapping map[string]*protos.TableSchema + // slightly hacky: table schema mapping is cached, so we need to manually update it if schema changes. + if tableSchemaDeltasCount != 0 { + modifiedSrcTables := make([]string, 0, tableSchemaDeltasCount) + modifiedDstTables := make([]string, 0, tableSchemaDeltasCount) - for _, tableSchemaDelta := range tableSchemaDeltas { - modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) - modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) - } + for _, tableSchemaDelta := range childSyncFlowRes.TableSchemaDeltas { + modifiedSrcTables = append(modifiedSrcTables, tableSchemaDelta.SrcTableName) + modifiedDstTables = append(modifiedDstTables, tableSchemaDelta.DstTableName) + } - getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 5 * time.Minute, - }) - getModifiedSchemaFuture := workflow.ExecuteActivity(getModifiedSchemaCtx, flowable.GetTableSchema, - &protos.GetTableSchemaBatchInput{ - PeerConnectionConfig: cfg.Source, - TableIdentifiers: modifiedSrcTables, - FlowName: cfg.FlowJobName, + getModifiedSchemaCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, }) - - var getModifiedSchemaRes *protos.GetTableSchemaBatchOutput - if err := getModifiedSchemaFuture.Get(ctx, &getModifiedSchemaRes); err != nil { - w.logger.Error("failed to execute schema update at source: ", err) - state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) - } else { - for i := range modifiedSrcTables { - state.TableNameSchemaMapping[modifiedDstTables[i]] = getModifiedSchemaRes.TableNameSchemaMapping[modifiedSrcTables[i]] + getModifiedSchemaFuture := workflow.ExecuteActivity(getModifiedSchemaCtx, flowable.GetTableSchema, + &protos.GetTableSchemaBatchInput{ + PeerConnectionConfig: cfg.Source, + TableIdentifiers: modifiedSrcTables, + FlowName: cfg.FlowJobName, + }) + + var getModifiedSchemaRes *protos.GetTableSchemaBatchOutput + if err := getModifiedSchemaFuture.Get(ctx, &getModifiedSchemaRes); err != nil { + w.logger.Error("failed to execute schema update at source: ", err) + state.SyncFlowErrors = append(state.SyncFlowErrors, err.Error()) + } else { + for i := range modifiedSrcTables { + state.TableNameSchemaMapping[modifiedDstTables[i]] = getModifiedSchemaRes.TableNameSchemaMapping[modifiedSrcTables[i]] + } + normalizeTableNameSchemaMapping = state.TableNameSchemaMapping } } - } - normalizeFlowID, err := GetChildWorkflowID(ctx, "normalize-flow", cfg.FlowJobName) - if err != nil { - return state, err - } - - childNormalizeFlowOpts := workflow.ChildWorkflowOptions{ - WorkflowID: normalizeFlowID, - ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, - RetryPolicy: &temporal.RetryPolicy{ - MaximumAttempts: 20, - }, - SearchAttributes: mirrorNameSearch, - WaitForCancellation: true, + childNormalizeFlowFuture.SignalChildWorkflow(ctx, "Sync", model.NormalizeSignal{ + Done: false, + SyncBatchID: childSyncFlowRes.CurrentSyncBatchID, + TableNameSchemaMapping: normalizeTableNameSchemaMapping, + }) } - normCtx := workflow.WithChildOptions(ctx, childNormalizeFlowOpts) - childNormalizeFlowFuture := workflow.ExecuteChildWorkflow( - normCtx, - NormalizeFlowWorkflow, - cfg, - normalizeFlowOptions, - ) - var childNormalizeFlowRes *model.NormalizeResponse - if err := childNormalizeFlowFuture.Get(normCtx, &childNormalizeFlowRes); err != nil { - w.logger.Error("failed to execute normalize flow: ", err) - state.NormalizeFlowErrors = append(state.NormalizeFlowErrors, err.Error()) - } else { - state.NormalizeFlowStatuses = append(state.NormalizeFlowStatuses, childNormalizeFlowRes) - } cdcPropertiesSelector.Select(ctx) } + finishNormalize() state.TruncateProgress(w.logger) return nil, workflow.NewContinueAsNewError(ctx, CDCFlowWorkflowWithConfig, cfg, limits, state) } diff --git a/flow/workflows/normalize_flow.go b/flow/workflows/normalize_flow.go index ebf23051f7..599bd7ea61 100644 --- a/flow/workflows/normalize_flow.go +++ b/flow/workflows/normalize_flow.go @@ -6,63 +6,75 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" - "go.temporal.io/sdk/log" "go.temporal.io/sdk/workflow" ) -type NormalizeFlowState struct { - CDCFlowName string - Progress []string -} - -type NormalizeFlowExecution struct { - NormalizeFlowState - executionID string - logger log.Logger -} - -func NewNormalizeFlowExecution(ctx workflow.Context, state *NormalizeFlowState) *NormalizeFlowExecution { - return &NormalizeFlowExecution{ - NormalizeFlowState: *state, - executionID: workflow.GetInfo(ctx).WorkflowExecution.ID, - logger: workflow.GetLogger(ctx), - } -} - func NormalizeFlowWorkflow(ctx workflow.Context, config *protos.FlowConnectionConfigs, options *protos.NormalizeFlowOptions, -) (*model.NormalizeResponse, error) { - s := NewNormalizeFlowExecution(ctx, &NormalizeFlowState{ - CDCFlowName: config.FlowJobName, - Progress: []string{}, - }) - - return s.executeNormalizeFlow(ctx, config, options) -} - -func (s *NormalizeFlowExecution) executeNormalizeFlow( - ctx workflow.Context, - config *protos.FlowConnectionConfigs, - options *protos.NormalizeFlowOptions, -) (*model.NormalizeResponse, error) { - s.logger.Info("executing normalize flow - ", s.CDCFlowName) +) (*model.NormalizeFlowResponse, error) { + logger := workflow.GetLogger(ctx) + tableNameSchemaMapping := options.TableNameSchemaMapping normalizeFlowCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ StartToCloseTimeout: 7 * 24 * time.Hour, - HeartbeatTimeout: 5 * time.Minute, + HeartbeatTimeout: time.Minute, }) - startNormalizeInput := &protos.StartNormalizeInput{ - FlowConnectionConfigs: config, - TableNameSchemaMapping: options.TableNameSchemaMapping, - } - fStartNormalize := workflow.ExecuteActivity(normalizeFlowCtx, flowable.StartNormalize, startNormalizeInput) + results := make([]model.NormalizeResponse, 0, 4) + errors := make([]string, 0) + syncChan := workflow.GetSignalChannel(ctx, "Sync") + + var stopLoop, canceled bool + var lastSyncBatchID, syncBatchID int64 + selector := workflow.NewNamedSelector(ctx, fmt.Sprintf("%s-normalize", config.FlowJobName)) + selector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + canceled = true + }) + selector.AddReceive(syncChan, func(c workflow.ReceiveChannel, _ bool) { + var s model.NormalizeSignal + c.ReceiveAsync(&s) + if s.Done { + stopLoop = true + } + if s.SyncBatchID != 0 { + syncBatchID = s.SyncBatchID + } + if len(s.TableNameSchemaMapping) != 0 { + tableNameSchemaMapping = s.TableNameSchemaMapping + } + }) + for !stopLoop { + selector.Select(ctx) + for !canceled && selector.HasPending() { + selector.Select(ctx) + } + if canceled || (stopLoop && lastSyncBatchID == syncBatchID) { + break + } + if lastSyncBatchID == syncBatchID { + continue + } + lastSyncBatchID = syncBatchID + + logger.Info("executing normalize - ", config.FlowJobName) + startNormalizeInput := &protos.StartNormalizeInput{ + FlowConnectionConfigs: config, + TableNameSchemaMapping: tableNameSchemaMapping, + SyncBatchID: syncBatchID, + } + fStartNormalize := workflow.ExecuteActivity(normalizeFlowCtx, flowable.StartNormalize, startNormalizeInput) - var normalizeResponse *model.NormalizeResponse - if err := fStartNormalize.Get(normalizeFlowCtx, &normalizeResponse); err != nil { - return nil, fmt.Errorf("failed to flow: %w", err) + var normalizeResponse *model.NormalizeResponse + if err := fStartNormalize.Get(normalizeFlowCtx, &normalizeResponse); err != nil { + errors = append(errors, err.Error()) + } else if normalizeResponse != nil { + results = append(results, *normalizeResponse) + } } - return normalizeResponse, nil + return &model.NormalizeFlowResponse{ + Results: results, + Errors: errors, + }, nil } diff --git a/protos/flow.proto b/protos/flow.proto index 60d15d9d91..3bba098d28 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -125,6 +125,7 @@ message StartFlowInput { message StartNormalizeInput { FlowConnectionConfigs flow_connection_configs = 1; map table_name_schema_mapping = 2; + int64 SyncBatchID = 3; } message GetLastSyncedIDInput {