Skip to content

Commit

Permalink
workflows: also avoid storing workflow.Context in structs
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex committed Feb 11, 2024
1 parent 986e0ca commit 5e30571
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
14 changes: 6 additions & 8 deletions flow/concurrency/bound_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,43 @@ import (
)

type BoundSelector struct {
ctx workflow.Context
limit int
futures []workflow.Future
ferrors []error
}

func NewBoundSelector(limit int, ctx workflow.Context) *BoundSelector {
func NewBoundSelector(limit int) *BoundSelector {
return &BoundSelector{
ctx: ctx,
limit: limit,
}
}

func (s *BoundSelector) SpawnChild(chCtx workflow.Context, w interface{}, args ...interface{}) {
if len(s.futures) >= s.limit {
s.waitOne()
s.waitOne(chCtx)
}

future := workflow.ExecuteChildWorkflow(chCtx, w, args...)
s.futures = append(s.futures, future)
}

func (s *BoundSelector) waitOne() {
func (s *BoundSelector) waitOne(ctx workflow.Context) {
if len(s.futures) == 0 {
return
}

f := s.futures[0]
s.futures = s.futures[1:]

err := f.Get(s.ctx, nil)
err := f.Get(ctx, nil)
if err != nil {
s.ferrors = append(s.ferrors, err)
}
}

func (s *BoundSelector) Wait() error {
func (s *BoundSelector) Wait(ctx workflow.Context) error {
for len(s.futures) > 0 {
s.waitOne()
s.waitOne(ctx)
}

if len(s.ferrors) > 0 {
Expand Down
2 changes: 0 additions & 2 deletions flow/workflows/cdc_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,13 @@ func (s *CDCFlowWorkflowState) TruncateProgress(logger log.Logger) {
type CDCFlowWorkflowExecution struct {
flowExecutionID string
logger log.Logger
ctx workflow.Context
}

// NewCDCFlowWorkflowExecution creates a new instance of PeerFlowWorkflowExecution.
func NewCDCFlowWorkflowExecution(ctx workflow.Context) *CDCFlowWorkflowExecution {
return &CDCFlowWorkflowExecution{
flowExecutionID: workflow.GetInfo(ctx).WorkflowExecution.ID,
logger: workflow.GetLogger(ctx),
ctx: ctx,
}
}

Expand Down
12 changes: 6 additions & 6 deletions flow/workflows/snapshot_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func (s *SnapshotFlowExecution) closeSlotKeepAlive(
}

func (s *SnapshotFlowExecution) cloneTable(
ctx workflow.Context,
boundSelector *concurrency.BoundSelector,
childCtx workflow.Context,
snapshotName string,
mapping *protos.TableMapping,
) error {
Expand All @@ -97,7 +97,7 @@ func (s *SnapshotFlowExecution) cloneTable(

srcName := mapping.SourceTableIdentifier
dstName := mapping.DestinationTableIdentifier
childWorkflowIDSideEffect := workflow.SideEffect(childCtx, func(ctx workflow.Context) interface{} {
childWorkflowIDSideEffect := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} {
childWorkflowID := fmt.Sprintf("clone_%s_%s_%s", flowName, dstName, uuid.New().String())
reg := regexp.MustCompile("[^a-zA-Z0-9]+")
return reg.ReplaceAllString(childWorkflowID, "_")
Expand All @@ -118,7 +118,7 @@ func (s *SnapshotFlowExecution) cloneTable(
return queueErr
}

childCtx = workflow.WithChildOptions(childCtx, workflow.ChildWorkflowOptions{
childCtx := workflow.WithChildOptions(ctx, workflow.ChildWorkflowOptions{
WorkflowID: childWorkflowID,
WorkflowTaskTimeout: 5 * time.Minute,
TaskQueue: taskQueue,
Expand Down Expand Up @@ -200,7 +200,7 @@ func (s *SnapshotFlowExecution) cloneTables(
s.logger.Info(fmt.Sprintf("cloning tables for slot name %s and snapshotName %s",
slotInfo.SlotName, slotInfo.SnapshotName))

boundSelector := concurrency.NewBoundSelector(maxParallelClones, ctx)
boundSelector := concurrency.NewBoundSelector(maxParallelClones)

for _, v := range s.config.TableMappings {
source := v.SourceTableIdentifier
Expand All @@ -211,14 +211,14 @@ func (s *SnapshotFlowExecution) cloneTables(
source, destination),
slog.String("snapshotName", snapshotName),
)
err := s.cloneTable(boundSelector, ctx, snapshotName, v)
err := s.cloneTable(ctx, boundSelector, snapshotName, v)
if err != nil {
s.logger.Error("failed to start clone child workflow: ", err)
continue
}
}

if err := boundSelector.Wait(); err != nil {
if err := boundSelector.Wait(ctx); err != nil {
s.logger.Error("failed to clone some tables", "error", err)
return err
}
Expand Down

0 comments on commit 5e30571

Please sign in to comment.