diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 34ce7eb70e..8e0916fda5 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -988,3 +988,24 @@ func (a *FlowableActivity) ReplicateXminPartition(ctx context.Context, return currentSnapshotXmin, nil } + +func (a *FlowableActivity) AddTablesToPublication(ctx context.Context, cfg *protos.FlowConnectionConfigs, + additionalTableMappings []*protos.TableMapping, +) error { + ctx = context.WithValue(ctx, shared.FlowNameKey, cfg.FlowJobName) + srcConn, err := connectors.GetCDCPullConnector(ctx, cfg.Source) + if err != nil { + return fmt.Errorf("failed to get source connector: %w", err) + } + defer connectors.CloseConnector(srcConn) + + err = srcConn.AddTablesToPublication(&protos.AddTablesToPublicationInput{ + FlowJobName: cfg.FlowJobName, + PublicationName: cfg.PublicationName, + AdditionalTables: additionalTableMappings, + }) + if err != nil { + a.Alerter.LogFlowError(ctx, cfg.FlowJobName, err) + } + return err +} diff --git a/flow/cmd/handler.go b/flow/cmd/handler.go index b655607989..a4a43b4891 100644 --- a/flow/cmd/handler.go +++ b/flow/cmd/handler.go @@ -329,7 +329,7 @@ func (h *FlowRequestHandler) ShutdownFlow( ctx, req.WorkflowId, "", - shared.CDCFlowSignalName, + shared.FlowSignalName, shared.ShutdownSignal, ) if err != nil { @@ -444,8 +444,13 @@ func (h *FlowRequestHandler) FlowStateChange( if err != nil { return nil, err } + isCDCFlow, err := h.isCDCFlow(ctx, req.FlowJobName) + if err != nil { + return nil, err + } + if req.RequestedFlowState == protos.FlowStatus_STATUS_PAUSED && - *currState == protos.FlowStatus_STATUS_RUNNING { + currState == protos.FlowStatus_STATUS_RUNNING { err = h.updateWorkflowStatus(ctx, workflowID, protos.FlowStatus_STATUS_PAUSING) if err != nil { return nil, err @@ -454,20 +459,27 @@ func (h *FlowRequestHandler) FlowStateChange( ctx, workflowID, "", - shared.CDCFlowSignalName, + shared.FlowSignalName, shared.PauseSignal, ) } else if req.RequestedFlowState == protos.FlowStatus_STATUS_RUNNING && - *currState == protos.FlowStatus_STATUS_PAUSED { + currState == protos.FlowStatus_STATUS_PAUSED { + if isCDCFlow && req.FlowConfigUpdate.GetCdcFlowConfigUpdate() != nil { + err = h.attemptCDCFlowConfigUpdate(ctx, workflowID, + req.FlowConfigUpdate.GetCdcFlowConfigUpdate()) + if err != nil { + return nil, err + } + } err = h.temporalClient.SignalWorkflow( ctx, workflowID, "", - shared.CDCFlowSignalName, + shared.FlowSignalName, shared.NoopSignal, ) } else if req.RequestedFlowState == protos.FlowStatus_STATUS_TERMINATED && - (*currState == protos.FlowStatus_STATUS_RUNNING || *currState == protos.FlowStatus_STATUS_PAUSED) { + (currState == protos.FlowStatus_STATUS_RUNNING || currState == protos.FlowStatus_STATUS_PAUSED) { err = h.updateWorkflowStatus(ctx, workflowID, protos.FlowStatus_STATUS_TERMINATING) if err != nil { return nil, err @@ -484,7 +496,7 @@ func (h *FlowRequestHandler) FlowStateChange( req.RequestedFlowState, currState) } if err != nil { - return nil, fmt.Errorf("unable to signal CDCFlow workflow: %w", err) + return nil, fmt.Errorf("unable to signal workflow: %w", err) } return &protos.FlowStateChangeResponse{ diff --git a/flow/cmd/mirror_status.go b/flow/cmd/mirror_status.go index 9c904ab716..a42360112b 100644 --- a/flow/cmd/mirror_status.go +++ b/flow/cmd/mirror_status.go @@ -51,7 +51,7 @@ func (h *FlowRequestHandler) MirrorStatus( Status: &protos.MirrorStatusResponse_CdcStatus{ CdcStatus: cdcStatus, }, - CurrentFlowState: *currState, + CurrentFlowState: currState, }, nil } else { qrepStatus, err := h.QRepFlowStatus(ctx, req) @@ -66,7 +66,7 @@ func (h *FlowRequestHandler) MirrorStatus( Status: &protos.MirrorStatusResponse_QrepStatus{ QrepStatus: qrepStatus, }, - CurrentFlowState: *currState, + CurrentFlowState: currState, }, nil } } @@ -334,17 +334,41 @@ func (h *FlowRequestHandler) isCDCFlow(ctx context.Context, flowJobName string) return false, nil } -func (h *FlowRequestHandler) getWorkflowStatus(ctx context.Context, workflowID string) (*protos.FlowStatus, error) { +func (h *FlowRequestHandler) getCloneTableFlowNames(ctx context.Context, flowJobName string) ([]string, error) { + q := "SELECT flow_name FROM peerdb_stats.qrep_runs WHERE flow_name ILIKE $1" + rows, err := h.pool.Query(ctx, q, "clone_"+flowJobName+"_%") + if err != nil { + return nil, fmt.Errorf("unable to getCloneTableFlowNames: %w", err) + } + defer rows.Close() + + flowNames := []string{} + for rows.Next() { + var name pgtype.Text + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("unable to scan flow row: %w", err) + } + if name.Valid { + flowNames = append(flowNames, name.String) + } + } + + return flowNames, nil +} + +func (h *FlowRequestHandler) getWorkflowStatus(ctx context.Context, workflowID string) (protos.FlowStatus, error) { res, err := h.temporalClient.QueryWorkflow(ctx, workflowID, "", shared.FlowStatusQuery) if err != nil { slog.Error(fmt.Sprintf("failed to get state in workflow with ID %s: %s", workflowID, err.Error())) - return nil, fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err) + return protos.FlowStatus_STATUS_UNKNOWN, + fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err) } - var state *protos.FlowStatus + var state protos.FlowStatus err = res.Get(&state) if err != nil { slog.Error(fmt.Sprintf("failed to get state in workflow with ID %s: %s", workflowID, err.Error())) - return nil, fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err) + return protos.FlowStatus_STATUS_UNKNOWN, + fmt.Errorf("failed to get state in workflow with ID %s: %w", workflowID, err) } return state, nil } @@ -361,3 +385,14 @@ func (h *FlowRequestHandler) updateWorkflowStatus( } return nil } + +func (h *FlowRequestHandler) attemptCDCFlowConfigUpdate(ctx context.Context, + workflowID string, cdcFlowConfigUpdate *protos.CDCFlowConfigUpdate, +) error { + _, err := h.temporalClient.UpdateWorkflow(ctx, workflowID, "", + shared.CDCFlowConfigUpdate, cdcFlowConfigUpdate) + if err != nil { + return fmt.Errorf("failed to update config in CDC workflow with ID %s: %w", workflowID, err) + } + return nil +} diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 44a7f4a250..df82959e9a 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -48,6 +48,9 @@ type CDCPullConnector interface { // GetOpenConnectionsForUser returns the number of open connections for the user configured in the peer. GetOpenConnectionsForUser() (*protos.GetOpenConnectionsForUserResult, error) + + // AddTablesToPublication adds additional tables added to a mirror to the publication also + AddTablesToPublication(req *protos.AddTablesToPublicationInput) error } type CDCSyncConnector interface { diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 4fd1b3dd79..fd9b1a6201 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -556,7 +556,7 @@ func (p *PostgresCDCSource) processMessage(batch *model.CDCRecordStream, xld pgl // TODO (kaushik): consider persistent state for a mirror job // to be stored somewhere in temporal state. We might need to persist // the state of the relation message somewhere - p.logger.Debug(fmt.Sprintf("RelationMessage => RelationID: %d, Namespace: %s, RelationName: %s, Columns: %v", + p.logger.Info(fmt.Sprintf("RelationMessage => RelationID: %d, Namespace: %s, RelationName: %s, Columns: %v", msg.RelationID, msg.Namespace, msg.RelationName, msg.Columns)) if p.relationMessageMapping[msg.RelationID] == nil { p.relationMessageMapping[msg.RelationID] = convertRelationMessageToProto(msg) diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 0a99bce668..794fb45330 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -606,3 +606,7 @@ func (c *PostgresConnector) getCurrentLSN() (pglogrepl.LSN, error) { } return pglogrepl.ParseLSN(result.String) } + +func (c *PostgresConnector) getDefaultPublicationName(jobName string) string { + return fmt.Sprintf("peerflow_pub_%s", jobName) +} diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index fd2aac3735..5236fb2add 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "regexp" + "strings" "time" "github.com/PeerDB-io/peer-flow/connectors/utils" @@ -206,8 +207,7 @@ func (c *PostgresConnector) PullRecords(catalogPool *pgxpool.Pool, req *model.Pu slotName = req.OverrideReplicationSlotName } - // Publication name would be the job name prefixed with "peerflow_pub_" - publicationName := fmt.Sprintf("peerflow_pub_%s", req.FlowJobName) + publicationName := c.getDefaultPublicationName(req.FlowJobName) if req.OverridePublicationName != "" { publicationName = req.OverridePublicationName } @@ -820,8 +820,7 @@ func (c *PostgresConnector) SetupReplication(signal SlotSignal, req *protos.Setu slotName = req.ExistingReplicationSlotName } - // Publication name would be the job name prefixed with "peerflow_pub_" - publicationName := fmt.Sprintf("peerflow_pub_%s", req.FlowJobName) + publicationName := c.getDefaultPublicationName(req.FlowJobName) if req.ExistingPublicationName != "" { publicationName = req.ExistingPublicationName } @@ -853,8 +852,7 @@ func (c *PostgresConnector) PullFlowCleanup(jobName string) error { // Slotname would be the job name prefixed with "peerflow_slot_" slotName := fmt.Sprintf("peerflow_slot_%s", jobName) - // Publication name would be the job name prefixed with "peerflow_pub_" - publicationName := fmt.Sprintf("peerflow_pub_%s", jobName) + publicationName := c.getDefaultPublicationName(jobName) pullFlowCleanupTx, err := c.pool.Begin(c.ctx) if err != nil { @@ -932,3 +930,23 @@ func (c *PostgresConnector) GetOpenConnectionsForUser() (*protos.GetOpenConnecti CurrentOpenConnections: result.Int64, }, nil } + +func (c *PostgresConnector) AddTablesToPublication(req *protos.AddTablesToPublicationInput) error { + // don't modify custom publications + if req == nil || req.PublicationName != "" || len(req.AdditionalTables) == 0 { + return nil + } + + additionalSrcTables := make([]string, 0, len(req.AdditionalTables)) + for _, additionalTableMapping := range req.AdditionalTables { + additionalSrcTables = append(additionalSrcTables, additionalTableMapping.SourceTableIdentifier) + } + additionalSrcTablesString := strings.Join(additionalSrcTables, ",") + + _, err := c.pool.Exec(c.ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s", + c.getDefaultPublicationName(req.FlowJobName), additionalSrcTablesString)) + if err != nil { + return fmt.Errorf("failed to alter publication: %w", err) + } + return nil +} diff --git a/flow/connectors/utils/array.go b/flow/connectors/utils/array.go index 3db4d53bb3..2a0d983fca 100644 --- a/flow/connectors/utils/array.go +++ b/flow/connectors/utils/array.go @@ -1,13 +1,13 @@ package utils -func ArrayMinus(first []string, second []string) []string { - lookup := make(map[string]struct{}, len(second)) +func ArrayMinus[T comparable](first, second []T) []T { + lookup := make(map[T]struct{}, len(second)) // Add elements from arrayB to the lookup map for _, element := range second { lookup[element] = struct{}{} } // Iterate over arrayA and check if the element is present in the lookup map - var result []string + var result []T for _, element := range first { _, exists := lookup[element] if !exists { @@ -29,3 +29,19 @@ func ArrayChunks[T any](slice []T, size int) [][]T { return partitions } + +func ArraysHaveOverlap[T comparable](first, second []T) bool { + lookup := make(map[T]struct{}) + + for _, element := range first { + lookup[element] = struct{}{} + } + + for _, element := range second { + if _, exists := lookup[element]; exists { + return true + } + } + + return false +} diff --git a/flow/e2e/test_utils.go b/flow/e2e/test_utils.go index 641fa594e5..b0629179f9 100644 --- a/flow/e2e/test_utils.go +++ b/flow/e2e/test_utils.go @@ -135,38 +135,30 @@ func EnvWaitForEqualTables( table string, cols string, ) { - suite.T().Helper() - EnvWaitForEqualTablesWithNames(env, suite, reason, table, table, cols) -} - -func EnvWaitForEqualTablesWithNames( - env *testsuite.TestWorkflowEnvironment, - suite e2eshared.RowSource, - reason string, - srcTable string, - dstTable string, - cols string, -) { - t := suite.T() - t.Helper() - - EnvWaitFor(t, env, 3*time.Minute, reason, func() bool { - t.Helper() - - suffix := suite.Suffix() - pool := suite.Pool() - pgRows, err := GetPgRows(pool, suffix, srcTable, cols) - if err != nil { - return false - } + // wait for PeerFlowStatusQuery to finish setup + // sleep for 5 second to allow the workflow to start + time.Sleep(5 * time.Second) + for { + response, err := env.QueryWorkflow( + shared.CDCFlowStateQuery, + connectionGen.FlowJobName, + ) + if err == nil { + var state peerflow.CDCFlowWorkflowState + err = response.Get(&state) + if err != nil { + slog.Error(err.Error()) + } - rows, err := suite.GetRows(dstTable, cols) - if err != nil { - return false + if *state.CurrentFlowState == protos.FlowStatus_STATUS_RUNNING { + break + } + } else { + // log the error for informational purposes + slog.Error(err.Error()) } - - return e2eshared.CheckEqualRecordBatches(t, pgRows, rows) - }) + time.Sleep(1 * time.Second) + } } func SetupCDCFlowStatusQuery(t *testing.T, env *testsuite.TestWorkflowEnvironment, diff --git a/flow/shared/constants.go b/flow/shared/constants.go index 07cfea307a..f549123a98 100644 --- a/flow/shared/constants.go +++ b/flow/shared/constants.go @@ -12,7 +12,7 @@ const ( snapshotFlowTaskQueue = "snapshot-flow-task-queue" // Signals - CDCFlowSignalName = "peer-flow-signal" + FlowSignalName = "peer-flow-signal" CDCDynamicPropertiesSignalName = "cdc-dynamic-properties" // Queries @@ -21,7 +21,8 @@ const ( FlowStatusQuery = "q-flow-status" // Updates - FlowStatusUpdate = "u-flow-status" + FlowStatusUpdate = "u-flow-status" + CDCFlowConfigUpdate = "u-cdc-flow-config-update" ) const MirrorNameSearchAttribute = "MirrorName" diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index f6a5bda189..2aea56edec 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" "github.com/PeerDB-io/peer-flow/shared" @@ -14,6 +15,7 @@ import ( "go.temporal.io/sdk/log" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" + "google.golang.org/protobuf/proto" ) const ( @@ -48,7 +50,9 @@ type CDCFlowWorkflowState struct { // Needed to support schema changes. RelationMessageMapping model.RelationMessageMapping // current workflow state - CurrentFlowState protos.FlowStatus + CurrentFlowStatus protos.FlowStatus + // flow config update request, set to nil after processed + FlowConfigUpdates []*protos.CDCFlowConfigUpdate } type SignalProps struct { @@ -72,7 +76,8 @@ func NewCDCFlowWorkflowState() *CDCFlowWorkflowState { RelationName: "protobuf_workaround", }, }, - CurrentFlowState: protos.FlowStatus_STATUS_SETUP, + CurrentFlowStatus: protos.FlowStatus_STATUS_SETUP, + FlowConfigUpdates: nil, } } @@ -136,7 +141,7 @@ func GetChildWorkflowID( type CDCFlowWorkflowResult = CDCFlowWorkflowState func (w *CDCFlowWorkflowExecution) receiveAndHandleSignalAsync(ctx workflow.Context, state *CDCFlowWorkflowState) { - signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal ok := signalChan.ReceiveAsync(&signalVal) @@ -145,26 +150,171 @@ func (w *CDCFlowWorkflowExecution) receiveAndHandleSignalAsync(ctx workflow.Cont } } +func additionalTablesHasOverlap(currentTableMappings []*protos.TableMapping, + additionalTableMappings []*protos.TableMapping, +) bool { + currentSrcTables := make([]string, 0, len(currentTableMappings)) + currentDstTables := make([]string, 0, len(currentTableMappings)) + additionalSrcTables := make([]string, 0, len(additionalTableMappings)) + additionalDstTables := make([]string, 0, len(additionalTableMappings)) + + for _, currentTableMapping := range currentTableMappings { + currentSrcTables = append(currentSrcTables, currentTableMapping.SourceTableIdentifier) + currentDstTables = append(currentDstTables, currentTableMapping.DestinationTableIdentifier) + } + for _, additionalTableMapping := range additionalTableMappings { + currentSrcTables = append(currentSrcTables, additionalTableMapping.SourceTableIdentifier) + currentDstTables = append(currentDstTables, additionalTableMapping.DestinationTableIdentifier) + } + + return utils.ArraysHaveOverlap[string](currentSrcTables, additionalSrcTables) || + utils.ArraysHaveOverlap[string](currentDstTables, additionalDstTables) +} + +func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdates(ctx workflow.Context, + cfg *protos.FlowConnectionConfigs, state *CDCFlowWorkflowState, + mirrorNameSearch *map[string]interface{}, +) error { + for _, flowConfigUpdate := range state.FlowConfigUpdates { + if len(flowConfigUpdate.AdditionalTables) == 0 { + continue + } + if additionalTablesHasOverlap(cfg.TableMappings, flowConfigUpdate.AdditionalTables) { + return fmt.Errorf("duplicate source/destination tables found in additionalTables") + } + + alterPublicationAddAdditionalTablesCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, + }) + alterPublicationAddAdditionalTablesFuture := workflow.ExecuteActivity( + alterPublicationAddAdditionalTablesCtx, + flowable.AddTablesToPublication, + cfg, flowConfigUpdate.AdditionalTables) + if err := alterPublicationAddAdditionalTablesFuture.Get(ctx, nil); err != nil { + w.logger.Error("failed to alter publication for additional tables: ", err) + return err + } + + additionalTablesSetupFlowID, err := GetChildWorkflowID(ctx, + "additional-tables-setup-flow", cfg.FlowJobName) + if err != nil { + return err + } + additionalTablesSetupFlowOpts := workflow.ChildWorkflowOptions{ + WorkflowID: additionalTablesSetupFlowID, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 2, + }, + SearchAttributes: *mirrorNameSearch, + } + + additionalTablesWorkflowCfg := proto.Clone(cfg).(*protos.FlowConnectionConfigs) + additionalTablesWorkflowCfg.DoInitialCopy = true + additionalTablesWorkflowCfg.TableMappings = flowConfigUpdate.AdditionalTables + additionalTablesWorkflowCfg.FlowJobName = fmt.Sprintf("%s_additional_tables_%s", cfg.FlowJobName, + strings.ToLower(shared.RandomString(8))) + + additionalTablesSetupCtx := workflow.WithChildOptions(ctx, + additionalTablesSetupFlowOpts) + additionalTablesSetupFlowFuture := workflow.ExecuteChildWorkflow( + additionalTablesSetupCtx, + SetupFlowWorkflow, + additionalTablesWorkflowCfg, + ) + if err := additionalTablesSetupFlowFuture.Get(additionalTablesSetupCtx, + &additionalTablesWorkflowCfg); err != nil { + w.logger.Error("failed to execute SetupFlow for additional tables: ", err) + return fmt.Errorf("failed to execute SetupFlow for additional tables: %w", err) + } + + // next part of the setup is to snapshot-initial-copy and setup replication slots. + additionalTablesSnapshotFlowID, err := GetChildWorkflowID(ctx, + "additional-tables-snapshot-flow", cfg.FlowJobName) + if err != nil { + return err + } + + taskQueue, err := shared.GetPeerFlowTaskQueueName(shared.SnapshotFlowTaskQueueID) + if err != nil { + return err + } + + additionalTablesSnapshotFlowOpts := workflow.ChildWorkflowOptions{ + WorkflowID: additionalTablesSnapshotFlowID, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 20, + }, + TaskQueue: taskQueue, + SearchAttributes: *mirrorNameSearch, + } + additionalTablesSnapshotFlowCtx := workflow.WithChildOptions(ctx, additionalTablesSnapshotFlowOpts) + additionalTablesSnapshotFlowFuture := workflow.ExecuteChildWorkflow(additionalTablesSnapshotFlowCtx, + SnapshotFlowWorkflow, additionalTablesWorkflowCfg) + if err := additionalTablesSnapshotFlowFuture.Get(additionalTablesSnapshotFlowCtx, nil); err != nil { + return fmt.Errorf("failed to execute child workflow: %w", err) + } + + additionalTablesDropFlowID, err := GetChildWorkflowID(ctx, + "additional-tables-drop-flow", cfg.FlowJobName) + if err != nil { + return err + } + additionalTablesDropFlowOpts := workflow.ChildWorkflowOptions{ + WorkflowID: additionalTablesDropFlowID, + ParentClosePolicy: enums.PARENT_CLOSE_POLICY_REQUEST_CANCEL, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 1, + }, + SearchAttributes: *mirrorNameSearch, + } + additionalTablesDropCtx := workflow.WithChildOptions(ctx, + additionalTablesDropFlowOpts) + additionalTablesDropFlowFuture := workflow.ExecuteChildWorkflow( + additionalTablesDropCtx, + DropFlowWorkflow, + &protos.ShutdownRequest{ + WorkflowId: additionalTablesSetupFlowID, + FlowJobName: additionalTablesWorkflowCfg.FlowJobName, + SourcePeer: cfg.Source, + DestinationPeer: cfg.Destination, + RemoveFlowEntry: false, + }, + ) + if err := additionalTablesDropFlowFuture.Get(additionalTablesDropCtx, nil); err != nil { + w.logger.Error("failed to execute DropFlow for additional tables: ", err) + return fmt.Errorf("failed to execute DropFlow for additional tables: %w", err) + } + + for tableID, tableName := range additionalTablesWorkflowCfg.SrcTableIdNameMapping { + cfg.SrcTableIdNameMapping[tableID] = tableName + } + for tableName, tableSchema := range additionalTablesWorkflowCfg.TableNameSchemaMapping { + cfg.TableNameSchemaMapping[tableName] = tableSchema + } + cfg.TableMappings = append(cfg.TableMappings, flowConfigUpdate.AdditionalTables...) + // finished processing, wipe it + state.FlowConfigUpdates = nil + } + return nil +} + func CDCFlowWorkflowWithConfig( ctx workflow.Context, cfg *protos.FlowConnectionConfigs, limits *CDCFlowLimits, state *CDCFlowWorkflowState, ) (*CDCFlowWorkflowResult, error) { - if state == nil { - state = NewCDCFlowWorkflowState() - } - if cfg == nil { return nil, fmt.Errorf("invalid connection configs") } + if state == nil { + state = NewCDCFlowWorkflowState() + } w := NewCDCFlowWorkflowExecution(ctx) - if limits.TotalSyncFlows == 0 { - limits.TotalSyncFlows = maxSyncFlowsPerCDCFlow - } - err := workflow.SetQueryHandler(ctx, shared.CDCFlowStateQuery, func() (CDCFlowWorkflowState, error) { return *state, nil }) @@ -172,19 +322,30 @@ func CDCFlowWorkflowWithConfig( return state, fmt.Errorf("failed to set `%s` query handler: %w", shared.CDCFlowStateQuery, err) } err = workflow.SetQueryHandler(ctx, shared.FlowStatusQuery, func() (protos.FlowStatus, error) { - return state.CurrentFlowState, nil + return state.CurrentFlowStatus, nil }) if err != nil { return state, fmt.Errorf("failed to set `%s` query handler: %w", shared.FlowStatusQuery, err) } err = workflow.SetUpdateHandler(ctx, shared.FlowStatusUpdate, func(status protos.FlowStatus) error { - state.CurrentFlowState = status + state.CurrentFlowStatus = status return nil }) if err != nil { return state, fmt.Errorf("failed to set `%s` update handler: %w", shared.FlowStatusUpdate, err) } - + err = workflow.SetUpdateHandler(ctx, shared.CDCFlowConfigUpdate, + func(cdcFlowConfigUpdate *protos.CDCFlowConfigUpdate) error { + if state.CurrentFlowStatus == protos.FlowStatus_STATUS_PAUSED { + state.FlowConfigUpdates = append(state.FlowConfigUpdates, cdcFlowConfigUpdate) + return nil + } + return fmt.Errorf(`flow config updates can only be sent when workflow is paused, + current status: %v`, state.CurrentFlowStatus) + }) + if err != nil { + return state, fmt.Errorf("failed to set `%s` update handler: %w", shared.CDCFlowConfigUpdate, err) + } mirrorNameSearch := map[string]interface{}{ shared.MirrorNameSearchAttribute: cfg.FlowJobName, } @@ -193,7 +354,7 @@ func CDCFlowWorkflowWithConfig( // because Resync modifies TableMappings before Setup and also before Snapshot // for safety, rely on the idempotency of SetupFlow instead // also, no signals are being handled until the loop starts, so no PAUSE/DROP will take here. - if state.CurrentFlowState != protos.FlowStatus_STATUS_RUNNING { + if state.CurrentFlowStatus != protos.FlowStatus_STATUS_RUNNING { // if resync is true, alter the table name schema mapping to temporarily add // a suffix to the table names. if cfg.Resync { @@ -224,7 +385,7 @@ func CDCFlowWorkflowWithConfig( if err := setupFlowFuture.Get(setupFlowCtx, &cfg); err != nil { return state, fmt.Errorf("failed to execute child workflow: %w", err) } - state.CurrentFlowState = protos.FlowStatus_STATUS_SNAPSHOT + state.CurrentFlowStatus = protos.FlowStatus_STATUS_SNAPSHOT // next part of the setup is to snapshot-initial-copy and setup replication slots. snapshotFlowID, err := GetChildWorkflowID(ctx, "snapshot-flow", cfg.FlowJobName) @@ -287,7 +448,7 @@ func CDCFlowWorkflowWithConfig( } } - state.CurrentFlowState = protos.FlowStatus_STATUS_RUNNING + state.CurrentFlowStatus = protos.FlowStatus_STATUS_RUNNING state.Progress = append(state.Progress, "executed setup flow and snapshot flow") // if initial_copy_only is opted for, we end the flow here. @@ -296,6 +457,10 @@ func CDCFlowWorkflowWithConfig( } } + if limits.TotalSyncFlows == 0 { + limits.TotalSyncFlows = maxSyncFlowsPerCDCFlow + } + syncFlowOptions := &protos.SyncFlowOptions{ BatchSize: int32(limits.MaxBatchSize), IdleTimeoutSeconds: 0, @@ -339,8 +504,8 @@ func CDCFlowWorkflowWithConfig( if state.ActiveSignal == shared.PauseSignal { startTime := time.Now() - state.CurrentFlowState = protos.FlowStatus_STATUS_PAUSED - signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) + state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for state.ActiveSignal == shared.PauseSignal { @@ -349,6 +514,13 @@ func CDCFlowWorkflowWithConfig( ok, _ := signalChan.ReceiveWithTimeout(ctx, 1*time.Minute, &signalVal) if ok { state.ActiveSignal = shared.FlowSignalHandler(state.ActiveSignal, signalVal, w.logger) + // only process config updates when going from STATUS_PAUSED to STATUS_RUNNING + if state.ActiveSignal == shared.NoopSignal { + err = w.processCDCFlowConfigUpdates(ctx, cfg, state, &mirrorNameSearch) + if err != nil { + return state, err + } + } } else if err := ctx.Err(); err != nil { return nil, err } @@ -360,11 +532,11 @@ func CDCFlowWorkflowWithConfig( // check if the peer flow has been shutdown if state.ActiveSignal == shared.ShutdownSignal { w.logger.Info("peer flow has been shutdown") - state.CurrentFlowState = protos.FlowStatus_STATUS_TERMINATED + state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return state, nil } - state.CurrentFlowState = protos.FlowStatus_STATUS_RUNNING + state.CurrentFlowStatus = protos.FlowStatus_STATUS_RUNNING // check if total sync flows have been completed // since this happens immediately after we check for signals, the case of a signal being missed diff --git a/flow/workflows/qrep_flow.go b/flow/workflows/qrep_flow.go index b44e0df207..2df08187e3 100644 --- a/flow/workflows/qrep_flow.go +++ b/flow/workflows/qrep_flow.go @@ -42,6 +42,7 @@ func NewQRepFlowState() *protos.QRepFlowState { }, NumPartitionsProcessed: 0, NeedsResync: true, + CurrentFlowStatus: protos.FlowStatus_STATUS_RUNNING, } } @@ -367,7 +368,7 @@ func (q *QRepFlowExecution) handleTableRenameForResync(ctx workflow.Context, sta } func (q *QRepFlowExecution) receiveAndHandleSignalAsync(ctx workflow.Context) { - signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal ok := signalChan.ReceiveAsync(&signalVal) @@ -426,7 +427,24 @@ func QRepFlowWorkflow( err := setWorkflowQueries(ctx, state) if err != nil { - return err + return fmt.Errorf("failed to set `%s` query handler: %w", shared.QRepFlowStateQuery, err) + } + + // Support a Query for the current status of the arep flow. + err = workflow.SetQueryHandler(ctx, shared.FlowStatusQuery, func() (*protos.FlowStatus, error) { + return &state.CurrentFlowStatus, nil + }) + if err != nil { + return fmt.Errorf("failed to set `%s` query handler: %w", shared.FlowStatusQuery, err) + } + + // Support an Update for the current status of the qrep flow. + err = workflow.SetUpdateHandler(ctx, shared.FlowStatusUpdate, func(status *protos.FlowStatus) error { + state.CurrentFlowStatus = *status + return nil + }) + if err != nil { + return fmt.Errorf("failed to register query handler: %w", err) } // get qrep run uuid via side-effect @@ -507,7 +525,8 @@ func QRepFlowWorkflow( q.receiveAndHandleSignalAsync(ctx) if q.activeSignal == shared.PauseSignal { startTime := time.Now() - signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) + state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for q.activeSignal == shared.PauseSignal { @@ -521,6 +540,7 @@ func QRepFlowWorkflow( } if q.activeSignal == shared.ShutdownSignal { q.logger.Info("terminating workflow - ", config.FlowJobName) + state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return nil } diff --git a/flow/workflows/xmin_flow.go b/flow/workflows/xmin_flow.go index 1394d17353..387ab2e0a7 100644 --- a/flow/workflows/xmin_flow.go +++ b/flow/workflows/xmin_flow.go @@ -119,7 +119,8 @@ func XminFlowWorkflow( q.receiveAndHandleSignalAsync(ctx) if x.activeSignal == shared.PauseSignal { startTime := time.Now() - signalChan := workflow.GetSignalChannel(ctx, shared.CDCFlowSignalName) + state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED + signalChan := workflow.GetSignalChannel(ctx, shared.FlowSignalName) var signalVal shared.CDCFlowSignal for x.activeSignal == shared.PauseSignal { @@ -131,8 +132,9 @@ func XminFlowWorkflow( } } } - if x.activeSignal == shared.ShutdownSignal { - x.logger.Info("terminating workflow - ", config.FlowJobName) + if q.activeSignal == shared.ShutdownSignal { + q.logger.Info("terminating workflow - ", config.FlowJobName) + state.CurrentFlowStatus = protos.FlowStatus_STATUS_TERMINATED return nil } diff --git a/nexus/flow-rs/src/grpc.rs b/nexus/flow-rs/src/grpc.rs index b8adf9ea84..4fc40b2003 100644 --- a/nexus/flow-rs/src/grpc.rs +++ b/nexus/flow-rs/src/grpc.rs @@ -107,13 +107,14 @@ impl FlowGrpcClient { flow_job_name: &str, workflow_details: WorkflowDetails, state: pt::peerdb_flow::FlowStatus, + flow_config_update: Option, ) -> anyhow::Result<()> { let state_change_req = pt::peerdb_route::FlowStateChangeRequest { flow_job_name: flow_job_name.to_owned(), requested_flow_state: state.into(), source_peer: Some(workflow_details.source_peer), destination_peer: Some(workflow_details.destination_peer), - flow_state_update: None, + flow_config_update }; let response = self.client.flow_state_change(state_change_req).await?; let state_change_response = response.into_inner(); diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index 4a1eebe7e1..8cb6e48db3 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -272,6 +272,7 @@ impl NexusBackend { flow_job_name, workflow_details, pt::peerdb_flow::FlowStatus::StatusTerminated, + None, ) .await .map_err(|err| { @@ -687,6 +688,7 @@ impl NexusBackend { flow_job_name, workflow_details, pt::peerdb_flow::FlowStatus::StatusPaused, + None, ) .await .map_err(|err| { @@ -748,6 +750,18 @@ impl NexusBackend { flow_job_name, workflow_details, pt::peerdb_flow::FlowStatus::StatusRunning, + Some(pt::peerdb_flow::FlowConfigUpdate { + update: Some(pt::peerdb_flow::flow_config_update::Update::CdcFlowConfigUpdate( + pt::peerdb_flow::CdcFlowConfigUpdate { + additional_tables: vec![pt::peerdb_flow::TableMapping { + source_table_identifier: "public.oss2".to_string(), + destination_table_identifier: "public.oss2dst" + .to_string(), + partition_key: "".to_string(), + exclude: vec![], + }], + })), + }), ) .await .map_err(|err| { @@ -755,9 +769,9 @@ impl NexusBackend { format!("unable to resume flow job: {:?}", err).into(), ) })?; - let drop_mirror_success = format!("RESUME MIRROR {}", flow_job_name); + let resume_mirror_success = format!("RESUME MIRROR {}", flow_job_name); Ok(vec![Response::Execution(Tag::new( - &drop_mirror_success, + &resume_mirror_success ))]) } else if *if_exists { let no_mirror_success = "NO SUCH MIRROR"; diff --git a/protos/flow.proto b/protos/flow.proto index 24c9d9f87c..02e6b8f460 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -60,7 +60,6 @@ message FlowConnectionConfigs { string snapshot_staging_path = 17; string cdc_staging_path = 18; - // currently only works for snowflake bool soft_delete = 19; string replication_slot_name = 20; @@ -377,7 +376,7 @@ message QRepFlowState { uint64 num_partitions_processed = 2; bool needs_resync = 3; bool disable_wait_for_new_rows = 4; - FlowStatus current_flow_state = 5; + FlowStatus current_flow_status = 5; } message PeerDBColumns { @@ -415,15 +414,22 @@ enum FlowStatus { STATUS_TERMINATED = 7; } -message CDCFlowStateUpdate { +message CDCFlowConfigUpdate { + repeated TableMapping additional_tables = 1; } -message QRepFlowStateUpdate { +message QRepFlowConfigUpdate { } -message FlowStateUpdate { +message FlowConfigUpdate { oneof update { - CDCFlowStateUpdate cdc_flow_state_update = 1; - QRepFlowStateUpdate qrep_flow_state_update = 2; + CDCFlowConfigUpdate cdc_flow_config_update = 1; + QRepFlowConfigUpdate qrep_flow_config_update = 2; } } + +message AddTablesToPublicationInput{ + string flow_job_name = 1; + string publication_name = 2; + repeated TableMapping additional_tables = 3; +} diff --git a/protos/route.proto b/protos/route.proto index 577be49a4f..6dcfbe14ad 100644 --- a/protos/route.proto +++ b/protos/route.proto @@ -203,7 +203,7 @@ message FlowStateChangeRequest { peerdb_peers.Peer source_peer = 3; peerdb_peers.Peer destination_peer = 4; // only can be sent in certain situations - optional peerdb_flow.FlowStateUpdate flow_state_update = 5; + optional peerdb_flow.FlowConfigUpdate flow_config_update = 5; } message FlowStateChangeResponse {