From 5e30571fd0690b4fb97b54155b7cf3539f604d54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Sun, 11 Feb 2024 05:52:02 +0000 Subject: [PATCH] workflows: also avoid storing workflow.Context in structs --- flow/concurrency/bound_selector.go | 14 ++++++-------- flow/workflows/cdc_flow.go | 2 -- flow/workflows/snapshot_flow.go | 12 ++++++------ 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/flow/concurrency/bound_selector.go b/flow/concurrency/bound_selector.go index 75d6a5c983..922a422740 100644 --- a/flow/concurrency/bound_selector.go +++ b/flow/concurrency/bound_selector.go @@ -7,29 +7,27 @@ import ( ) type BoundSelector struct { - ctx workflow.Context limit int futures []workflow.Future ferrors []error } -func NewBoundSelector(limit int, ctx workflow.Context) *BoundSelector { +func NewBoundSelector(limit int) *BoundSelector { return &BoundSelector{ - ctx: ctx, limit: limit, } } func (s *BoundSelector) SpawnChild(chCtx workflow.Context, w interface{}, args ...interface{}) { if len(s.futures) >= s.limit { - s.waitOne() + s.waitOne(chCtx) } future := workflow.ExecuteChildWorkflow(chCtx, w, args...) s.futures = append(s.futures, future) } -func (s *BoundSelector) waitOne() { +func (s *BoundSelector) waitOne(ctx workflow.Context) { if len(s.futures) == 0 { return } @@ -37,15 +35,15 @@ func (s *BoundSelector) waitOne() { f := s.futures[0] s.futures = s.futures[1:] - err := f.Get(s.ctx, nil) + err := f.Get(ctx, nil) if err != nil { s.ferrors = append(s.ferrors, err) } } -func (s *BoundSelector) Wait() error { +func (s *BoundSelector) Wait(ctx workflow.Context) error { for len(s.futures) > 0 { - s.waitOne() + s.waitOne(ctx) } if len(s.ferrors) > 0 { diff --git a/flow/workflows/cdc_flow.go b/flow/workflows/cdc_flow.go index 51272ffb25..afc17927ce 100644 --- a/flow/workflows/cdc_flow.go +++ b/flow/workflows/cdc_flow.go @@ -109,7 +109,6 @@ func (s *CDCFlowWorkflowState) TruncateProgress(logger log.Logger) { type CDCFlowWorkflowExecution struct { flowExecutionID string logger log.Logger - ctx workflow.Context } // NewCDCFlowWorkflowExecution creates a new instance of PeerFlowWorkflowExecution. @@ -117,7 +116,6 @@ func NewCDCFlowWorkflowExecution(ctx workflow.Context) *CDCFlowWorkflowExecution return &CDCFlowWorkflowExecution{ flowExecutionID: workflow.GetInfo(ctx).WorkflowExecution.ID, logger: workflow.GetLogger(ctx), - ctx: ctx, } } diff --git a/flow/workflows/snapshot_flow.go b/flow/workflows/snapshot_flow.go index 39eafaf87c..f7de2f8850 100644 --- a/flow/workflows/snapshot_flow.go +++ b/flow/workflows/snapshot_flow.go @@ -85,8 +85,8 @@ func (s *SnapshotFlowExecution) closeSlotKeepAlive( } func (s *SnapshotFlowExecution) cloneTable( + ctx workflow.Context, boundSelector *concurrency.BoundSelector, - childCtx workflow.Context, snapshotName string, mapping *protos.TableMapping, ) error { @@ -97,7 +97,7 @@ func (s *SnapshotFlowExecution) cloneTable( srcName := mapping.SourceTableIdentifier dstName := mapping.DestinationTableIdentifier - childWorkflowIDSideEffect := workflow.SideEffect(childCtx, func(ctx workflow.Context) interface{} { + childWorkflowIDSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { childWorkflowID := fmt.Sprintf("clone_%s_%s_%s", flowName, dstName, uuid.New().String()) reg := regexp.MustCompile("[^a-zA-Z0-9]+") return reg.ReplaceAllString(childWorkflowID, "_") @@ -118,7 +118,7 @@ func (s *SnapshotFlowExecution) cloneTable( return queueErr } - childCtx = workflow.WithChildOptions(childCtx, workflow.ChildWorkflowOptions{ + childCtx := workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{ WorkflowID: childWorkflowID, WorkflowTaskTimeout: 5 * time.Minute, TaskQueue: taskQueue, @@ -200,7 +200,7 @@ func (s *SnapshotFlowExecution) cloneTables( s.logger.Info(fmt.Sprintf("cloning tables for slot name %s and snapshotName %s", slotInfo.SlotName, slotInfo.SnapshotName)) - boundSelector := concurrency.NewBoundSelector(maxParallelClones, ctx) + boundSelector := concurrency.NewBoundSelector(maxParallelClones) for _, v := range s.config.TableMappings { source := v.SourceTableIdentifier @@ -211,14 +211,14 @@ func (s *SnapshotFlowExecution) cloneTables( source, destination), slog.String("snapshotName", snapshotName), ) - err := s.cloneTable(boundSelector, ctx, snapshotName, v) + err := s.cloneTable(ctx, boundSelector, snapshotName, v) if err != nil { s.logger.Error("failed to start clone child workflow: ", err) continue } } - if err := boundSelector.Wait(); err != nil { + if err := boundSelector.Wait(ctx); err != nil { s.logger.Error("failed to clone some tables", "error", err) return err }