Skip to content

Commit

Permalink
cdc_flow: remove custom retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Mar 5, 2024
1 parent 62c8856 commit 6f26120
Showing 1 changed file with 60 additions and 84 deletions.
144 changes: 60 additions & 84 deletions flow/workflows/cdc_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,6 @@ func NewCDCFlowWorkflowState(cfg *protos.FlowConnectionConfigs) *CDCFlowWorkflow
}
}

// CDCFlowWorkflowExecution represents the state for execution of a peer flow.
type CDCFlowWorkflowExecution struct {
flowExecutionID string
logger log.Logger
syncFlowFuture workflow.ChildWorkflowFuture
normFlowFuture workflow.ChildWorkflowFuture
}

// NewCDCFlowWorkflowExecution creates a new instance of PeerFlowWorkflowExecution.
func NewCDCFlowWorkflowExecution(ctx workflow.Context, flowName string) *CDCFlowWorkflowExecution {
return &CDCFlowWorkflowExecution{
flowExecutionID: workflow.GetInfo(ctx).WorkflowExecution.ID,
logger: log.With(workflow.GetLogger(ctx), slog.String(string(shared.FlowNameKey), flowName)),
}
}

func GetSideEffect[T any](ctx workflow.Context, f func(workflow.Context) T) T {
sideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} {
return f(ctx)
Expand Down Expand Up @@ -103,24 +87,26 @@ const (
maxSyncsPerCdcFlow = 32
)

func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdate(ctx workflow.Context,
func processCDCFlowConfigUpdate(
ctx workflow.Context,
logger log.Logger,
cfg *protos.FlowConnectionConfigs, state *CDCFlowWorkflowState,
mirrorNameSearch map[string]interface{},
) error {
flowConfigUpdate := state.FlowConfigUpdate

if flowConfigUpdate != nil {
w.logger.Info("processing CDCFlowConfigUpdate", slog.Any("updatedState", flowConfigUpdate))
logger.Info("processing CDCFlowConfigUpdate", slog.Any("updatedState", flowConfigUpdate))
if len(flowConfigUpdate.AdditionalTables) == 0 {
return nil
}
if shared.AdditionalTablesHasOverlap(state.SyncFlowOptions.TableMappings, flowConfigUpdate.AdditionalTables) {
w.logger.Warn("duplicate source/destination tables found in additionalTables")
logger.Warn("duplicate source/destination tables found in additionalTables")
return nil
}
state.CurrentFlowStatus = protos.FlowStatus_STATUS_SNAPSHOT

w.logger.Info("altering publication for additional tables")
logger.Info("altering publication for additional tables")
alterPublicationAddAdditionalTablesCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: 5 * time.Minute,
})
Expand All @@ -129,11 +115,11 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdate(ctx workflow.Conte
flowable.AddTablesToPublication,
cfg, flowConfigUpdate.AdditionalTables)
if err := alterPublicationAddAdditionalTablesFuture.Get(ctx, nil); err != nil {
w.logger.Error("failed to alter publication for additional tables: ", err)
logger.Error("failed to alter publication for additional tables: ", err)
return err
}

w.logger.Info("additional tables added to publication")
logger.Info("additional tables added to publication")
additionalTablesUUID := GetUUID(ctx)
childAdditionalTablesCDCFlowID := GetChildWorkflowID("additional-cdc-flow", cfg.FlowJobName, additionalTablesUUID)
additionalTablesCfg := proto.Clone(cfg).(*protos.FlowConnectionConfigs)
Expand Down Expand Up @@ -167,13 +153,14 @@ func (w *CDCFlowWorkflowExecution) processCDCFlowConfigUpdate(ctx workflow.Conte
maps.Copy(state.SyncFlowOptions.TableNameSchemaMapping, res.SyncFlowOptions.TableNameSchemaMapping)

state.SyncFlowOptions.TableMappings = append(state.SyncFlowOptions.TableMappings, flowConfigUpdate.AdditionalTables...)
w.logger.Info("additional tables added to sync flow")
logger.Info("additional tables added to sync flow")
}
return nil
}

func (w *CDCFlowWorkflowExecution) addCdcPropertiesSignalListener(
func addCdcPropertiesSignalListener(
ctx workflow.Context,
logger log.Logger,
selector workflow.Selector,
state *CDCFlowWorkflowState,
) {
Expand All @@ -189,21 +176,13 @@ func (w *CDCFlowWorkflowExecution) addCdcPropertiesSignalListener(
// do this irrespective of additional tables being present, for auto unpausing
state.FlowConfigUpdate = cdcConfigUpdate

w.logger.Info("CDC Signal received. Parameters on signal reception:",
logger.Info("CDC Signal received. Parameters on signal reception:",
slog.Int("BatchSize", int(state.SyncFlowOptions.BatchSize)),
slog.Int("IdleTimeout", int(state.SyncFlowOptions.IdleTimeoutSeconds)),
slog.Any("AdditionalTables", cdcConfigUpdate.AdditionalTables))
})
}

func (w *CDCFlowWorkflowExecution) startSyncFlow(ctx workflow.Context, config *protos.FlowConnectionConfigs, options *protos.SyncFlowOptions) {
w.syncFlowFuture = workflow.ExecuteChildWorkflow(ctx, SyncFlowWorkflow, config, options)
}

func (w *CDCFlowWorkflowExecution) startNormFlow(ctx workflow.Context, config *protos.FlowConnectionConfigs) {
w.normFlowFuture = workflow.ExecuteChildWorkflow(ctx, NormalizeFlowWorkflow, config, nil)
}

func CDCFlowWorkflow(
ctx workflow.Context,
cfg *protos.FlowConnectionConfigs,
Expand All @@ -217,7 +196,7 @@ func CDCFlowWorkflow(
state = NewCDCFlowWorkflowState(cfg)
}

w := NewCDCFlowWorkflowExecution(ctx, cfg.FlowJobName)
logger := log.With(workflow.GetLogger(ctx), slog.String(string(shared.FlowNameKey), cfg.FlowJobName))
flowSignalChan := model.FlowSignal.GetSignalChannel(ctx)

err := workflow.SetQueryHandler(ctx, shared.CDCFlowStateQuery, func() (CDCFlowWorkflowState, error) {
Expand Down Expand Up @@ -248,36 +227,36 @@ func CDCFlowWorkflow(
selector := workflow.NewNamedSelector(ctx, "PauseLoop")
selector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) {})
flowSignalChan.AddToSelector(selector, func(val model.CDCFlowSignal, _ bool) {
state.ActiveSignal = model.FlowSignalHandler(state.ActiveSignal, val, w.logger)
state.ActiveSignal = model.FlowSignalHandler(state.ActiveSignal, val, logger)
})
w.addCdcPropertiesSignalListener(ctx, selector, state)
addCdcPropertiesSignalListener(ctx, logger, selector, state)

startTime := workflow.Now(ctx)
state.CurrentFlowStatus = protos.FlowStatus_STATUS_PAUSED

for state.ActiveSignal == model.PauseSignal {
// only place we block on receive, so signal processing is immediate
for state.ActiveSignal == model.PauseSignal && state.FlowConfigUpdate == nil && ctx.Err() == nil {
w.logger.Info("mirror has been paused", slog.Any("duration", time.Since(startTime)))
logger.Info("mirror has been paused", slog.Any("duration", time.Since(startTime)))
selector.Select(ctx)
}
if err := ctx.Err(); err != nil {
return state, err
}

if state.FlowConfigUpdate != nil {
err = w.processCDCFlowConfigUpdate(ctx, cfg, state, mirrorNameSearch)
err = processCDCFlowConfigUpdate(ctx, logger, cfg, state, mirrorNameSearch)
if err != nil {
return state, err
}
w.logger.Info("wiping flow state after state update processing")
logger.Info("wiping flow state after state update processing")
// finished processing, wipe it
state.FlowConfigUpdate = nil
state.ActiveSignal = model.NoopSignal
}
}

w.logger.Info("mirror has been resumed after ", time.Since(startTime))
logger.Info("mirror has been resumed after ", time.Since(startTime))
state.CurrentFlowStatus = protos.FlowStatus_STATUS_RUNNING
}

Expand Down Expand Up @@ -346,7 +325,7 @@ func CDCFlowWorkflow(
state.SyncFlowOptions.TableNameSchemaMapping,
)
if err := snapshotFlowFuture.Get(snapshotFlowCtx, nil); err != nil {
w.logger.Error("snapshot flow failed", slog.Any("error", err))
logger.Error("snapshot flow failed", slog.Any("error", err))
return state, fmt.Errorf("failed to execute snapshot workflow: %w", err)
}

Expand Down Expand Up @@ -385,7 +364,7 @@ func CDCFlowWorkflow(
}

state.CurrentFlowStatus = protos.FlowStatus_STATUS_RUNNING
w.logger.Info("executed setup flow and snapshot flow")
logger.Info("executed setup flow and snapshot flow")

// if initial_copy_only is opted for, we end the flow here.
if cfg.InitialSnapshotOnly {
Expand Down Expand Up @@ -424,65 +403,54 @@ func CDCFlowWorkflow(
handleError := func(name string, err error) {
var panicErr *temporal.PanicError
if errors.As(err, &panicErr) {
w.logger.Error(
logger.Error(
"panic in flow",
slog.String("name", name),
slog.Any("error", panicErr.Error()),
slog.String("stack", panicErr.StackTrace()),
)
} else {
w.logger.Error("error in flow", slog.String("name", name), slog.Any("error", err))
logger.Error("error in flow", slog.String("name", name), slog.Any("error", err))
}
}

syncFlowFuture := workflow.ExecuteChildWorkflow(syncCtx, SyncFlowWorkflow, cfg, state.SyncFlowOptions)
normFlowFuture := workflow.ExecuteChildWorkflow(normCtx, NormalizeFlowWorkflow, cfg, nil)

mainLoopSelector := workflow.NewNamedSelector(ctx, "MainLoop")
mainLoopSelector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) {})

var handleNormFlow, handleSyncFlow func(workflow.Future)
handleSyncFlow = func(f workflow.Future) {
mainLoopSelector.AddFuture(syncFlowFuture, func(f workflow.Future) {
err := f.Get(ctx, nil)
if err != nil {
handleError("sync", err)
}

if restart {
w.logger.Info("sync finished, finishing normalize")
w.syncFlowFuture = nil
_ = model.NormalizeSignal.SignalChildWorkflow(ctx, w.normFlowFuture, model.NormalizePayload{
Done: true,
SyncBatchID: -1,
}).Get(ctx, nil)
} else {
w.logger.Warn("sync flow ended, restarting", slog.Any("error", err))
w.startSyncFlow(syncCtx, cfg, state.SyncFlowOptions)
mainLoopSelector.AddFuture(w.syncFlowFuture, handleSyncFlow)
logger.Info("sync finished, finishing normalize")
syncFlowFuture = nil
restart = true
err = model.NormalizeSignal.SignalChildWorkflow(ctx, normFlowFuture, model.NormalizePayload{
Done: true,
SyncBatchID: -1,
}).Get(ctx, nil)
if err != nil {
logger.Warn("failed to signal normalize done, finishing", slog.Any("error", err))
finished = true
}
}
handleNormFlow = func(f workflow.Future) {
})
mainLoopSelector.AddFuture(normFlowFuture, func(f workflow.Future) {
err := f.Get(ctx, nil)
if err != nil {
handleError("normalize", err)
}

if restart {
w.logger.Info("normalize finished, finishing")
w.normFlowFuture = nil
finished = true
} else {
w.logger.Warn("normalize flow ended, restarting", slog.Any("error", err))
w.startNormFlow(normCtx, cfg)
mainLoopSelector.AddFuture(w.normFlowFuture, handleNormFlow)
}
}

w.startSyncFlow(syncCtx, cfg, state.SyncFlowOptions)
mainLoopSelector.AddFuture(w.syncFlowFuture, handleSyncFlow)

w.startNormFlow(normCtx, cfg)
mainLoopSelector.AddFuture(w.normFlowFuture, handleNormFlow)
logger.Info("normalize finished, finishing")
normFlowFuture = nil
restart = true
finished = true
})

flowSignalChan.AddToSelector(mainLoopSelector, func(val model.CDCFlowSignal, _ bool) {
state.ActiveSignal = model.FlowSignalHandler(state.ActiveSignal, val, w.logger)
state.ActiveSignal = model.FlowSignalHandler(state.ActiveSignal, val, logger)
})

syncResultChan := model.SyncResultSignal.GetSignalChannel(ctx)
Expand All @@ -499,7 +467,9 @@ func CDCFlowWorkflow(

normChan := model.NormalizeSignal.GetSignalChannel(ctx)
normChan.AddToSelector(mainLoopSelector, func(payload model.NormalizePayload, _ bool) {
_ = model.NormalizeSignal.SignalChildWorkflow(ctx, w.normFlowFuture, payload).Get(ctx, nil)
if normFlowFuture != nil {
_ = model.NormalizeSignal.SignalChildWorkflow(ctx, normFlowFuture, payload).Get(ctx, nil)
}
maps.Copy(state.SyncFlowOptions.TableNameSchemaMapping, payload.TableNameSchemaMapping)
})

Expand All @@ -509,13 +479,13 @@ func CDCFlowWorkflow(
if !parallel {
normDoneChan := model.NormalizeDoneSignal.GetSignalChannel(ctx)
normDoneChan.AddToSelector(mainLoopSelector, func(x struct{}, _ bool) {
if w.syncFlowFuture != nil {
_ = model.NormalizeDoneSignal.SignalChildWorkflow(ctx, w.syncFlowFuture, x).Get(ctx, nil)
if syncFlowFuture != nil {
_ = model.NormalizeDoneSignal.SignalChildWorkflow(ctx, syncFlowFuture, x).Get(ctx, nil)
}
})
}

w.addCdcPropertiesSignalListener(ctx, mainLoopSelector, state)
addCdcPropertiesSignalListener(ctx, logger, mainLoopSelector, state)

state.CurrentFlowStatus = protos.FlowStatus_STATUS_RUNNING
for {
Expand All @@ -524,13 +494,19 @@ func CDCFlowWorkflow(
mainLoopSelector.Select(ctx)
}
if err := ctx.Err(); err != nil {
w.logger.Info("mirror canceled", slog.Any("error", err))
logger.Info("mirror canceled", slog.Any("error", err))
return state, err
}

if state.ActiveSignal == model.PauseSignal || syncCount >= maxSyncsPerCdcFlow {
restart = true
_ = model.SyncStopSignal.SignalChildWorkflow(ctx, w.syncFlowFuture, struct{}{}).Get(ctx, nil)
if syncFlowFuture != nil {
err := model.SyncStopSignal.SignalChildWorkflow(ctx, syncFlowFuture, struct{}{}).Get(ctx, nil)
if err != nil {
logger.Warn("failed to send sync-stop, finishing", slog.Any("error", err))
finished = true
}
}
}

if restart {
Expand All @@ -543,7 +519,7 @@ func CDCFlowWorkflow(
}

if err := ctx.Err(); err != nil {
w.logger.Info("mirror canceled", slog.Any("error", err))
logger.Info("mirror canceled", slog.Any("error", err))
return nil, err
}

Expand Down

0 comments on commit 6f26120

Please sign in to comment.