diff --git a/flow/concurrency/bound_selector.go b/flow/concurrency/bound_selector.go index 29735c09d6..ed40f0775c 100644 --- a/flow/concurrency/bound_selector.go +++ b/flow/concurrency/bound_selector.go @@ -7,44 +7,45 @@ import ( ) type BoundSelector struct { - ctx workflow.Context - limit int - futures []workflow.Future - ferrors []error + ctx workflow.Context + limit int + count int + selector workflow.Selector + ferrors []error } func NewBoundSelector(limit int, total int, ctx workflow.Context) *BoundSelector { return &BoundSelector{ - ctx: ctx, - limit: limit, + ctx: ctx, + limit: limit, + selector: workflow.NewSelector(ctx), } } func (s *BoundSelector) SpawnChild(chCtx workflow.Context, w interface{}, args ...interface{}) { - if len(s.futures) >= s.limit { + if s.count >= s.limit { s.waitOne() } future := workflow.ExecuteChildWorkflow(chCtx, w, args...) - s.futures = append(s.futures, future) + s.selector.AddFuture(future, func(f workflow.Future) { + err := f.Get(s.ctx, nil) + if err != nil { + s.ferrors = append(s.ferrors, err) + } + }) + s.count += 1 } func (s *BoundSelector) waitOne() { - if len(s.futures) == 0 { - return - } - - f := s.futures[0] - s.futures = s.futures[1:] - - err := f.Get(s.ctx, nil) - if err != nil { - s.ferrors = append(s.ferrors, err) + if s.count > 0 { + s.selector.Select(s.ctx) + s.count -= 1 } } func (s *BoundSelector) Wait() error { - for len(s.futures) > 0 { + for s.count > 0 { s.waitOne() }