From 7c18ab766e64e811f8f66b8623406127f69c847f Mon Sep 17 00:00:00 2001 From: Matt Toohey Date: Tue, 7 Jan 2025 13:28:41 +1100 Subject: [PATCH] fix: race condition while collecting runtime events --- backend/provisioner/inmem_provisioner.go | 60 +++++++++++++++++------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/backend/provisioner/inmem_provisioner.go b/backend/provisioner/inmem_provisioner.go index 96275de655..8592575c84 100644 --- a/backend/provisioner/inmem_provisioner.go +++ b/backend/provisioner/inmem_provisioner.go @@ -6,6 +6,7 @@ import ( "connectrpc.com/connect" "github.com/alecthomas/atomic" + "github.com/alecthomas/types/optional" "github.com/google/uuid" "github.com/puzpuzpuz/xsync/v3" @@ -75,6 +76,11 @@ func (d *InMemProvisioner) Ping(context.Context, *connect.Request[ftlv1.PingRequ return &connect.Response[ftlv1.PingResponse]{}, nil } +type stepCompletedEvent struct { + step *inMemProvisioningStep + event optional.Option[*RuntimeEvent] +} + func (d *InMemProvisioner) Provision(ctx context.Context, req *connect.Request[provisioner.ProvisionRequest]) (*connect.Response[provisioner.ProvisionResponse], error) { logger := log.FromContext(ctx) @@ -95,36 +101,58 @@ func (d *InMemProvisioner) Provision(ctx context.Context, req *connect.Request[p desiredNodes := schema.GetProvisioned(desiredModule) task := &inMemProvisioningTask{} + // use chans to safely collect all events before completing each task + completions := make(chan stepCompletedEvent, 16) + for id, desired := range desiredNodes { previous, ok := previousNodes[id] for _, resource := range desired.GetProvisioned() { if !ok || !resource.IsEqual(previous.GetProvisioned().Get(resource.Kind)) { if slices.Contains(kinds, resource.Kind) { - if handler, ok := d.handlers[resource.Kind]; ok { - step := &inMemProvisioningStep{Done: atomic.New(false)} - task.steps = append(task.steps, step) - go func() { - defer step.Done.Store(true) - event, err := handler(ctx, desiredModule.Name, desired) - if err != nil { - step.Err = err - logger.Errorf(err, "failed to provision resource %s:%s", resource.Kind, desired.ResourceID()) - return - } - if event != nil { - task.events = append(task.events, event) - } - }() - } else { + handler, ok := d.handlers[resource.Kind] + if !ok { err := fmt.Errorf("unsupported resource type: %s", resource.Kind) return nil, connect.NewError(connect.CodeInvalidArgument, err) } + step := &inMemProvisioningStep{Done: atomic.New(false)} + task.steps = append(task.steps, step) + go func() { + event, err := handler(ctx, desiredModule.Name, desired) + if err != nil { + step.Err = err + logger.Errorf(err, "failed to provision resource %s:%s", resource.Kind, desired.ResourceID()) + completions <- stepCompletedEvent{step: step} + return + } + completions <- stepCompletedEvent{ + step: step, + event: optional.From(event, event != nil), + } + }() } } } } + go func() { + for { + done, err := task.Done() + if done || err != nil { + return + } + select { + case <-ctx.Done(): + return + case c := <-completions: + if e, ok := c.event.Get(); ok { + task.events = append(task.events, e) + } + c.step.Done.Store(true) + } + } + }() + token := uuid.New().String() logger.Debugf("started a task with token %s", token) d.running.Store(token, task)