diff --git a/flow/workflows/drop_flow.go b/flow/workflows/drop_flow.go index e9092b7a2b..decf1b945c 100644 --- a/flow/workflows/drop_flow.go +++ b/flow/workflows/drop_flow.go @@ -20,19 +20,20 @@ func DropFlowWorkflow(ctx workflow.Context, req *protos.ShutdownRequest) error { }) ctx = workflow.WithValue(ctx, shared.FlowNameKey, req.FlowJobName) - dropSourceFuture := workflow.ExecuteActivity(ctx, flowable.DropFlowSource, req) - dropDestinationFuture := workflow.ExecuteActivity(ctx, flowable.DropFlowDestination, req) var sourceError, destinationError error - var sourceOk, destinationOk bool + var sourceOk, destinationOk, canceled bool selector := workflow.NewNamedSelector(ctx, fmt.Sprintf("%s-drop", req.FlowJobName)) + selector.AddReceive(ctx.Done(), func(_ workflow.ReceiveChannel, _ bool) { + canceled = true + }) var dropSource, dropDestination func(f workflow.Future) dropSource = func(f workflow.Future) { sourceError = f.Get(ctx, nil) sourceOk = sourceError == nil if !sourceOk { - dropSourceFuture = workflow.ExecuteActivity(ctx, flowable.DropFlowSource, req) + dropSourceFuture := workflow.ExecuteActivity(ctx, flowable.DropFlowSource, req) selector.AddFuture(dropSourceFuture, dropSource) _ = workflow.Sleep(ctx, time.Second) } @@ -41,16 +42,20 @@ func DropFlowWorkflow(ctx workflow.Context, req *protos.ShutdownRequest) error { destinationError = f.Get(ctx, nil) destinationOk = destinationError == nil if !destinationOk { - dropDestinationFuture = workflow.ExecuteActivity(ctx, flowable.DropFlowDestination, req) + dropDestinationFuture := workflow.ExecuteActivity(ctx, flowable.DropFlowDestination, req) selector.AddFuture(dropDestinationFuture, dropDestination) _ = workflow.Sleep(ctx, time.Second) } } + dropSourceFuture := workflow.ExecuteActivity(ctx, flowable.DropFlowSource, req) + selector.AddFuture(dropSourceFuture, dropSource) + dropDestinationFuture := workflow.ExecuteActivity(ctx, flowable.DropFlowDestination, req) + selector.AddFuture(dropDestinationFuture, dropDestination) for { selector.Select(ctx) - if ctx.Err() != nil { - return errors.Join(sourceError, destinationError) + if canceled { + return errors.Join(ctx.Err(), sourceError, destinationError) } else if sourceOk && destinationOk { return nil }