From 9c57be520f26253d3f5328927874fd10d187b7c2 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 27 Jun 2024 10:33:51 -0400 Subject: [PATCH 01/12] Fix many syncronization issues There is a problem with closing --- atp/client.go | 80 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 15 deletions(-) diff --git a/atp/client.go b/atp/client.go index 1272c02..54aa535 100644 --- a/atp/client.go +++ b/atp/client.go @@ -8,6 +8,7 @@ import ( "io" "strings" "sync" + "time" ) var supportedServerVersions = []int64{1, 3} @@ -175,7 +176,9 @@ func (c *client) Execute( workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg} // Handle signals to the step if receivedSignals != nil { + c.wg.Add(1) go func() { + defer c.wg.Done() c.executeWriteLoop(stepData.RunID, receivedSignals) }() } @@ -215,6 +218,7 @@ func (c *client) handleStepComplete(runID string, receivedSignals chan schema.In func (c *client) Close() error { c.mutex.Lock() if c.done { + c.mutex.Unlock() return nil } c.done = true @@ -235,14 +239,41 @@ func (c *client) Close() error { clientDoneMessage{}, }) if err != nil { - return fmt.Errorf("client with steps '%s' failed to write client done message with error: %w", - c.getRunningStepIDs(), err) + // add a timeout to the wait to prevent it from causing a deadlock. + // 5 seconds is arbitrary, but gives it enough time to exit. + waitedGracefully := waitWithTimeout(time.Second*5, &c.wg) + if waitedGracefully { + return fmt.Errorf("client with step '%s' failed to write client done message with error: %w", + c.getRunningStepIDs(), err) + } else { + panic(fmt.Errorf("potential deadlock after client with step '%s' failed to write client done message with error: %w", + c.getRunningStepIDs(), err)) + } } } c.wg.Wait() return nil } +// Waits for the WaitGroup to finish, but with a timeout to +// prevent a deadlock. +// Returns true if the WaitGroup finished, and false if +// it reached the end of the timeout. +func waitWithTimeout(duration time.Duration, wg *sync.WaitGroup) bool { + // Run a goroutine to do the waiting + doneChannel := make(chan bool, 1) + go func() { + defer close(doneChannel) + wg.Wait() + }() + select { + case <-doneChannel: + return true + case <-time.After(duration): + return false + } +} + func (c *client) getRunningStepIDs() string { if len(c.runningSteps) == 0 { return "No running steps" @@ -261,6 +292,13 @@ func (c *client) executeWriteLoop( ) { // Add the channel to the client so that it can be kept track of c.mutex.Lock() + if c.done { + c.mutex.Unlock() + // You need to abort to allow proper closure, since the channel would otherwise + // be left open. + c.logger.Warningf("aborting write loop for run ID %q due to done client", runID) + return + } c.runningSignalReceiveLoops[runID] = receivedSignals c.mutex.Unlock() defer func() { @@ -295,28 +333,23 @@ func (c *client) executeWriteLoop( } } -// sendExecutionResult sends the results to the channel, and closes then removes the channels for the +// sendExecutionResult sends the results to the result channel, and closes then removes the channels for the // step results and the signals. +// The caller should have the mutex locked while calling this function. func (c *client) sendExecutionResult(runID string, result ExecutionResult) { - c.logger.Debugf("Providing input for run ID '%s'", runID) - c.mutex.Lock() + c.logger.Debugf("Sending results for run ID '%s'", runID) resultChannel, found := c.runningStepResultChannels[runID] - c.mutex.Unlock() if found { // Send the result + // TODO: Consider replacing the channel with a condition variable. resultChannel <- result - // Close the channel and remove it to detect incorrectly duplicate results. + // Close the channel and remove it to detect incorrectly duplicated results. close(resultChannel) - c.mutex.Lock() - delete(c.runningStepResultChannels, runID) - c.mutex.Unlock() } else { c.logger.Errorf("Step result channel not found for run ID '%s'. This is either a bug in the ATP "+ "client, or the plugin erroneously sent a second result.", runID) } // Now close the signal channel, since it's invalid to send a signal after the step is complete. - c.mutex.Lock() - defer c.mutex.Unlock() signalChannel, found := c.runningStepEmittedSignalChannels[runID] if !found { c.logger.Debugf("Could not find signal output channel for run ID '%s'", runID) @@ -328,9 +361,11 @@ func (c *client) sendExecutionResult(runID string, result ExecutionResult) { func (c *client) sendErrorToAll(err error) { result := NewErrorExecutionResult(err) + c.mutex.Lock() for runID := range c.runningStepResultChannels { c.sendExecutionResult(runID, result) } + c.mutex.Unlock() } //nolint:funlen @@ -356,17 +391,23 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { var doneMessage WorkDoneMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID) + c.mutex.Lock() c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( fmt.Errorf("failed to decode work done message (%w)", err))) + c.mutex.Unlock() } + c.mutex.Lock() c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) + c.mutex.Unlock() case MessageTypeSignal: var signalMessage SignalMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", runtimeMessage.RunID, err) } + c.mutex.Lock() signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] + c.mutex.Unlock() if !found { c.logger.Warningf( "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ @@ -401,6 +442,8 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { runtimeMessage.MessageID) } c.mutex.Lock() + // TODO: This is likely the cause of the deadlock. The recent reordering of the removal from the + // channel is resulting in the value being left in here. if len(c.runningStepResultChannels) == 0 { c.mutex.Unlock() return // Done @@ -441,6 +484,7 @@ func (c *client) prepareResultChannels( stepData schema.Input, emittedSignals chan<- schema.Input, ) error { + c.logger.Debugf("Preparing result channels for step with run ID %q", stepData.RunID) c.mutex.Lock() defer c.mutex.Unlock() _, existing := c.runningStepResultChannels[stepData.RunID] @@ -448,7 +492,7 @@ func (c *client) prepareResultChannels( return fmt.Errorf("duplicate run ID given '%s'", stepData.RunID) } // Set up the signal and step results channels - resultChannel := make(chan ExecutionResult) + resultChannel := make(chan ExecutionResult, 5) // Arbitrary buffer size to prevent deadlocks. c.runningStepResultChannels[stepData.RunID] = resultChannel if emittedSignals != nil { c.runningStepEmittedSignalChannels[stepData.RunID] = emittedSignals @@ -472,10 +516,11 @@ func (c *client) getResultV2( c.mutex.Lock() resultChannel, found := c.runningStepResultChannels[stepData.RunID] c.mutex.Unlock() + c.logger.Debugf("Got result channel for run ID %q", stepData.RunID) if !found { return NewErrorExecutionResult( - fmt.Errorf("could not find result channel for step with run ID '%s'", - stepData.RunID), + fmt.Errorf("could not find result channel for step with run ID '%s'. Existing channels: %v", + stepData.RunID, c.runningStepResultChannels), ) } // Wait for the result @@ -486,6 +531,11 @@ func (c *client) getResultV2( stepData.RunID), ) } + // This needs to be done after receiving the value, or else the sender will + // not be able to get the channel. + c.mutex.Lock() + defer c.mutex.Unlock() + delete(c.runningStepResultChannels, stepData.RunID) return result } From f40f688b37fd22d67937b5199e7f48b85a471e6b Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 27 Jun 2024 13:58:23 -0400 Subject: [PATCH 02/12] Fix deadlock and switch to Cond variables --- atp/client.go | 63 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/atp/client.go b/atp/client.go index 54aa535..d80a605 100644 --- a/atp/client.go +++ b/atp/client.go @@ -73,7 +73,7 @@ func NewClientWithLogger( make(chan bool, 5), // Buffer to prevent deadlocks make([]schema.Input, 0), make(map[string]chan schema.Input), - make(map[string]chan ExecutionResult), + make(map[string]*executionEntry), make(map[string]chan<- schema.Input), sync.Mutex{}, false, @@ -90,6 +90,11 @@ func (c *client) Encoder() *cbor.Encoder { return c.encoder } +type executionEntry struct { + result *ExecutionResult + condition sync.Cond +} + type client struct { atpVersion int64 channel ClientChannel @@ -99,9 +104,9 @@ type client struct { encoder *cbor.Encoder doneChannel chan bool runningSteps []schema.Input - runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps - runningStepResultChannels map[string]chan ExecutionResult // Run ID to channel of results - runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps + runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps + runningStepResultEntries map[string]*executionEntry // Run ID to results + runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps mutex sync.Mutex readLoopRunning bool done bool @@ -338,13 +343,11 @@ func (c *client) executeWriteLoop( // The caller should have the mutex locked while calling this function. func (c *client) sendExecutionResult(runID string, result ExecutionResult) { c.logger.Debugf("Sending results for run ID '%s'", runID) - resultChannel, found := c.runningStepResultChannels[runID] + resultEntry, found := c.runningStepResultEntries[runID] if found { // Send the result - // TODO: Consider replacing the channel with a condition variable. - resultChannel <- result - // Close the channel and remove it to detect incorrectly duplicated results. - close(resultChannel) + resultEntry.result = &result + resultEntry.condition.Signal() } else { c.logger.Errorf("Step result channel not found for run ID '%s'. This is either a bug in the ATP "+ "client, or the plugin erroneously sent a second result.", runID) @@ -352,7 +355,6 @@ func (c *client) sendExecutionResult(runID string, result ExecutionResult) { // Now close the signal channel, since it's invalid to send a signal after the step is complete. signalChannel, found := c.runningStepEmittedSignalChannels[runID] if !found { - c.logger.Debugf("Could not find signal output channel for run ID '%s'", runID) return } close(signalChannel) @@ -362,7 +364,7 @@ func (c *client) sendExecutionResult(runID string, result ExecutionResult) { func (c *client) sendErrorToAll(err error) { result := NewErrorExecutionResult(err) c.mutex.Lock() - for runID := range c.runningStepResultChannels { + for runID := range c.runningStepResultEntries { c.sendExecutionResult(runID, result) } c.mutex.Unlock() @@ -442,9 +444,15 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { runtimeMessage.MessageID) } c.mutex.Lock() - // TODO: This is likely the cause of the deadlock. The recent reordering of the removal from the - // channel is resulting in the value being left in here. - if len(c.runningStepResultChannels) == 0 { + remainingSteps := 0 + for _, resultEntry := range c.runningStepResultEntries { + // The result is the reliable way to determine if it's done. There is a fraction of + // time when the entry is still in the map, but it's done. + if resultEntry.result == nil { + remainingSteps++ + } + } + if remainingSteps == 0 { c.mutex.Unlock() return // Done } @@ -487,13 +495,16 @@ func (c *client) prepareResultChannels( c.logger.Debugf("Preparing result channels for step with run ID %q", stepData.RunID) c.mutex.Lock() defer c.mutex.Unlock() - _, existing := c.runningStepResultChannels[stepData.RunID] + _, existing := c.runningStepResultEntries[stepData.RunID] if existing { return fmt.Errorf("duplicate run ID given '%s'", stepData.RunID) } // Set up the signal and step results channels - resultChannel := make(chan ExecutionResult, 5) // Arbitrary buffer size to prevent deadlocks. - c.runningStepResultChannels[stepData.RunID] = resultChannel + resultEntry := executionEntry{ + result: nil, + condition: sync.Cond{L: &c.mutex}, + } + c.runningStepResultEntries[stepData.RunID] = &resultEntry if emittedSignals != nil { c.runningStepEmittedSignalChannels[stepData.RunID] = emittedSignals } @@ -514,18 +525,19 @@ func (c *client) getResultV2( stepData schema.Input, ) ExecutionResult { c.mutex.Lock() - resultChannel, found := c.runningStepResultChannels[stepData.RunID] - c.mutex.Unlock() + resultChannel, found := c.runningStepResultEntries[stepData.RunID] c.logger.Debugf("Got result channel for run ID %q", stepData.RunID) if !found { return NewErrorExecutionResult( fmt.Errorf("could not find result channel for step with run ID '%s'. Existing channels: %v", - stepData.RunID, c.runningStepResultChannels), + stepData.RunID, c.runningStepResultEntries), ) } - // Wait for the result - result, received := <-resultChannel - if !received { + if resultChannel.result == nil { + // Wait for the result + resultChannel.condition.Wait() + } + if resultChannel.result == nil { return NewErrorExecutionResult( fmt.Errorf("did not receive result from results channel in ATP client for step with run ID '%s'", stepData.RunID), @@ -533,10 +545,9 @@ func (c *client) getResultV2( } // This needs to be done after receiving the value, or else the sender will // not be able to get the channel. - c.mutex.Lock() defer c.mutex.Unlock() - delete(c.runningStepResultChannels, stepData.RunID) - return result + delete(c.runningStepResultEntries, stepData.RunID) + return *resultChannel.result } func (c *client) processWorkDone( From beed03cfaef511d87f6cd5db799a5917e2fb148c Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 27 Jun 2024 14:15:55 -0400 Subject: [PATCH 03/12] Update terminology --- atp/client.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/atp/client.go b/atp/client.go index d80a605..0ff7057 100644 --- a/atp/client.go +++ b/atp/client.go @@ -338,8 +338,8 @@ func (c *client) executeWriteLoop( } } -// sendExecutionResult sends the results to the result channel, and closes then removes the channels for the -// step results and the signals. +// sendExecutionResult finalizes the result entry for processing by the client's caller, and +// closes then removes the channels for the signals. // The caller should have the mutex locked while calling this function. func (c *client) sendExecutionResult(runID string, result ExecutionResult) { c.logger.Debugf("Sending results for run ID '%s'", runID) @@ -349,7 +349,7 @@ func (c *client) sendExecutionResult(runID string, result ExecutionResult) { resultEntry.result = &result resultEntry.condition.Signal() } else { - c.logger.Errorf("Step result channel not found for run ID '%s'. This is either a bug in the ATP "+ + c.logger.Errorf("Step result entry not found for run ID '%s'. This is either a bug in the ATP "+ "client, or the plugin erroneously sent a second result.", runID) } // Now close the signal channel, since it's invalid to send a signal after the step is complete. @@ -520,34 +520,34 @@ func (c *client) prepareResultChannels( return nil } -// getResultV2 works with the channels that communicate with the RuntimeMessage loop. +// getResultV2 communicates with the RuntimeMessage loop to get the . func (c *client) getResultV2( stepData schema.Input, ) ExecutionResult { c.mutex.Lock() - resultChannel, found := c.runningStepResultEntries[stepData.RunID] - c.logger.Debugf("Got result channel for run ID %q", stepData.RunID) + resultEntry, found := c.runningStepResultEntries[stepData.RunID] + c.logger.Debugf("Got result entry for run ID %q", stepData.RunID) if !found { return NewErrorExecutionResult( - fmt.Errorf("could not find result channel for step with run ID '%s'. Existing channels: %v", + fmt.Errorf("could not find result entry for step with run ID '%s'. Existing entries: %v", stepData.RunID, c.runningStepResultEntries), ) } - if resultChannel.result == nil { + if resultEntry.result == nil { // Wait for the result - resultChannel.condition.Wait() + resultEntry.condition.Wait() } - if resultChannel.result == nil { + if resultEntry.result == nil { return NewErrorExecutionResult( - fmt.Errorf("did not receive result from results channel in ATP client for step with run ID '%s'", + fmt.Errorf("did not receive result from results entry in ATP client for step with run ID '%s'", stepData.RunID), ) } - // This needs to be done after receiving the value, or else the sender will - // not be able to get the channel. defer c.mutex.Unlock() + // This needs to be done after receiving the value, or else the sender will + // not be able to get the entry. delete(c.runningStepResultEntries, stepData.RunID) - return *resultChannel.result + return *resultEntry.result } func (c *client) processWorkDone( From f19dcfa8d216000466039c9589d7a1d3f81f7b32 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Fri, 28 Jun 2024 19:35:40 -0400 Subject: [PATCH 04/12] Addressed review comments and linter error --- atp/client.go | 157 +++++++++++++++++++++++++++++--------------------- 1 file changed, 90 insertions(+), 67 deletions(-) diff --git a/atp/client.go b/atp/client.go index 0ff7057..24d7a5f 100644 --- a/atp/client.go +++ b/atp/client.go @@ -260,8 +260,8 @@ func (c *client) Close() error { return nil } -// Waits for the WaitGroup to finish, but with a timeout to -// prevent a deadlock. +// waitWithTimeout waits for the provided wait group, aborting the wait if +// the provided timeout expires. // Returns true if the WaitGroup finished, and false if // it reached the end of the timeout. func waitWithTimeout(duration time.Duration, wg *sync.WaitGroup) bool { @@ -295,15 +295,16 @@ func (c *client) executeWriteLoop( runID string, receivedSignals chan schema.Input, ) { - // Add the channel to the client so that it can be kept track of c.mutex.Lock() if c.done { c.mutex.Unlock() - // You need to abort to allow proper closure, since the channel would otherwise - // be left open. - c.logger.Warningf("aborting write loop for run ID %q due to done client", runID) + // It is important to abort now since Close() was called. This is to prevent the channel + // from being added to the channel, since Close() uses that map to determine the exit + // condition. Adding to the map would cause it to never exit. + c.logger.Warningf("write called loop for run ID %q on done client; aborting", runID) return } + // Add the channel to the client so that it can be kept track of c.runningSignalReceiveLoops[runID] = receivedSignals c.mutex.Unlock() defer func() { @@ -340,7 +341,7 @@ func (c *client) executeWriteLoop( // sendExecutionResult finalizes the result entry for processing by the client's caller, and // closes then removes the channels for the signals. -// The caller should have the mutex locked while calling this function. +// The caller must have the mutex locked while calling this function. func (c *client) sendExecutionResult(runID string, result ExecutionResult) { c.logger.Debugf("Sending results for run ID '%s'", runID) resultEntry, found := c.runningStepResultEntries[runID] @@ -370,6 +371,79 @@ func (c *client) sendErrorToAll(err error) { c.mutex.Unlock() } +func (c *client) handleWorkDoneMessage(runtimeMessage DecodedRuntimeMessage) { + var doneMessage WorkDoneMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { + c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID) + c.mutex.Lock() + c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( + fmt.Errorf("failed to decode work done message (%w)", err))) + c.mutex.Unlock() + return + } + c.mutex.Lock() + c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) + c.mutex.Unlock() +} + +func (c *client) handleSignalMessage(runtimeMessage DecodedRuntimeMessage) { + var signalMessage SignalMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { + c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", + runtimeMessage.RunID, err) + } + c.mutex.Lock() + signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] + c.mutex.Unlock() + if !found { + c.logger.Warningf( + "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ + "(emittedSignals is nil).", + runtimeMessage.RunID, signalMessage.SignalID) + } else { + c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, + signalMessage.SignalID) + signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) + } +} + +// Returns true if fatal, requiring aborting the read loop. +func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { + var errMessage ErrorMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { + c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", + runtimeMessage.RunID, err) + } + c.logger.Errorf("Step with run ID '%s' sent error message: %v", runtimeMessage.RunID, errMessage) + resultMsg := fmt.Errorf("step '%s' sent error message: %s", runtimeMessage.RunID, + errMessage.ToString(runtimeMessage.RunID)) + if errMessage.ServerFatal { + c.sendErrorToAll(resultMsg) + return true // It's server fatal, so this is the last message from the server. + } else if errMessage.StepFatal { + if runtimeMessage.RunID == "" { + c.sendErrorToAll(fmt.Errorf("step fatal error missing run id (%w)", resultMsg)) + } else { + c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(resultMsg)) + } + } + return false +} + +func (c *client) hasEntriesRemaining() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + remainingSteps := 0 + for _, resultEntry := range c.runningStepResultEntries { + // The result is the reliable way to determine if it's done. There is a fraction of + // time when the entry is still in the map, but it is done. + if resultEntry.result == nil { + remainingSteps++ + } + } + return remainingSteps != 0 +} + //nolint:funlen func (c *client) executeReadLoop(cborReader *cbor.Decoder) { defer func() { @@ -390,73 +464,22 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { } switch runtimeMessage.MessageID { case MessageTypeWorkDone: - var doneMessage WorkDoneMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { - c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID) - c.mutex.Lock() - c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( - fmt.Errorf("failed to decode work done message (%w)", err))) - c.mutex.Unlock() - } - c.mutex.Lock() - c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) - c.mutex.Unlock() + c.handleWorkDoneMessage(runtimeMessage) case MessageTypeSignal: - var signalMessage SignalMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { - c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", - runtimeMessage.RunID, err) - } - c.mutex.Lock() - signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] - c.mutex.Unlock() - if !found { - c.logger.Warningf( - "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ - "(emittedSignals is nil).", - runtimeMessage.RunID, signalMessage.SignalID) - } else { - c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, - signalMessage.SignalID) - signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) - } + c.handleSignalMessage(runtimeMessage) case MessageTypeError: - var errMessage ErrorMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { - c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", - runtimeMessage.RunID, err) - } - c.logger.Errorf("Step with run ID '%s' sent error message: %v", runtimeMessage.RunID, errMessage) - resultMsg := fmt.Errorf("step '%s' sent error message: %s", runtimeMessage.RunID, - errMessage.ToString(runtimeMessage.RunID)) - if errMessage.ServerFatal { - c.sendErrorToAll(resultMsg) - return // It's server fatal, so this is the last message from the server. - } else if errMessage.StepFatal { - if runtimeMessage.RunID == "" { - c.sendErrorToAll(fmt.Errorf("step fatal error missing run id (%w)", resultMsg)) - } else { - c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(resultMsg)) - } + fatal := c.handleErrorMessage(runtimeMessage) + if fatal { + return } default: c.logger.Warningf("Step with run ID '%s' sent unknown message type: %s", runtimeMessage.RunID, runtimeMessage.MessageID) } - c.mutex.Lock() - remainingSteps := 0 - for _, resultEntry := range c.runningStepResultEntries { - // The result is the reliable way to determine if it's done. There is a fraction of - // time when the entry is still in the map, but it's done. - if resultEntry.result == nil { - remainingSteps++ - } - } - if remainingSteps == 0 { - c.mutex.Unlock() - return // Done + // The non-error exit condition is having no more entries remaining. + if !c.hasEntriesRemaining() { + return } - c.mutex.Unlock() } } @@ -520,7 +543,7 @@ func (c *client) prepareResultChannels( return nil } -// getResultV2 communicates with the RuntimeMessage loop to get the . +// getResultV2 communicates with the RuntimeMessage loop to get the ExecutionResult. func (c *client) getResultV2( stepData schema.Input, ) ExecutionResult { From 855a0f5690679a36118683c0f6a89905c5b59868 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Mon, 1 Jul 2024 15:56:43 -0400 Subject: [PATCH 05/12] Addressed review comments --- atp/client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/atp/client.go b/atp/client.go index 24d7a5f..769fe86 100644 --- a/atp/client.go +++ b/atp/client.go @@ -548,8 +548,8 @@ func (c *client) getResultV2( stepData schema.Input, ) ExecutionResult { c.mutex.Lock() + defer c.mutex.Unlock() resultEntry, found := c.runningStepResultEntries[stepData.RunID] - c.logger.Debugf("Got result entry for run ID %q", stepData.RunID) if !found { return NewErrorExecutionResult( fmt.Errorf("could not find result entry for step with run ID '%s'. Existing entries: %v", @@ -566,9 +566,9 @@ func (c *client) getResultV2( stepData.RunID), ) } - defer c.mutex.Unlock() - // This needs to be done after receiving the value, or else the sender will - // not be able to get the entry. + // Deletion of the entry needs to be done in this function after waiting for + // the value to ensure the value's lifetime is long enough in the map. + // It cannot be removed on the sender's side, since that would cause a race. delete(c.runningStepResultEntries, stepData.RunID) return *resultEntry.result } From 22032873b44e91ec7dd6f17e1933fc94b063724e Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Mon, 1 Jul 2024 15:58:05 -0400 Subject: [PATCH 06/12] Remove nolint directive --- atp/client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/atp/client.go b/atp/client.go index 769fe86..e123432 100644 --- a/atp/client.go +++ b/atp/client.go @@ -444,7 +444,6 @@ func (c *client) hasEntriesRemaining() bool { return remainingSteps != 0 } -//nolint:funlen func (c *client) executeReadLoop(cborReader *cbor.Decoder) { defer func() { c.mutex.Lock() From dff8200c2d26011356a527125a9ef43baeeadfed Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Wed, 3 Jul 2024 16:30:07 -0400 Subject: [PATCH 07/12] Addressed review comments --- atp/client.go | 65 ++++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/atp/client.go b/atp/client.go index e123432..7560095 100644 --- a/atp/client.go +++ b/atp/client.go @@ -298,10 +298,13 @@ func (c *client) executeWriteLoop( c.mutex.Lock() if c.done { c.mutex.Unlock() - // It is important to abort now since Close() was called. This is to prevent the channel - // from being added to the channel, since Close() uses that map to determine the exit - // condition. Adding to the map would cause it to never exit. - c.logger.Warningf("write called loop for run ID %q on done client; aborting", runID) + // Close() was called, so exit now to prevent the channel from being added to + // the map, since Close() uses that map to determine the exit condition. + // Adding to the map would cause it to never exit. + c.logger.Warningf( + "write called loop for run ID %q on done client; skipping receive loop", + runID, + ) return } // Add the channel to the client so that it can be kept track of @@ -373,16 +376,15 @@ func (c *client) sendErrorToAll(err error) { func (c *client) handleWorkDoneMessage(runtimeMessage DecodedRuntimeMessage) { var doneMessage WorkDoneMessage + var result ExecutionResult if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID) - c.mutex.Lock() - c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( - fmt.Errorf("failed to decode work done message (%w)", err))) - c.mutex.Unlock() - return + result = NewErrorExecutionResult(fmt.Errorf("failed to decode work done message (%w)", err)) + } else { + result = c.processWorkDone(runtimeMessage.RunID, doneMessage) } c.mutex.Lock() - c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) + c.sendExecutionResult(runtimeMessage.RunID, result) c.mutex.Unlock() } @@ -391,6 +393,7 @@ func (c *client) handleSignalMessage(runtimeMessage DecodedRuntimeMessage) { if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", runtimeMessage.RunID, err) + return } c.mutex.Lock() signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] @@ -400,23 +403,23 @@ func (c *client) handleSignalMessage(runtimeMessage DecodedRuntimeMessage) { "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ "(emittedSignals is nil).", runtimeMessage.RunID, signalMessage.SignalID) - } else { - c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, - signalMessage.SignalID) - signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) + return } + c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, + signalMessage.SignalID) + signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) } -// Returns true if fatal, requiring aborting the read loop. +// Returns true if the error is fatal func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { var errMessage ErrorMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", runtimeMessage.RunID, err) } - c.logger.Errorf("Step with run ID '%s' sent error message: %v", runtimeMessage.RunID, errMessage) - resultMsg := fmt.Errorf("step '%s' sent error message: %s", runtimeMessage.RunID, - errMessage.ToString(runtimeMessage.RunID)) + errorMessageStr := errMessage.ToString(runtimeMessage.RunID) + c.logger.Errorf("Step with run ID %q sent error message: %s", runtimeMessage.RunID, errorMessageStr) + resultMsg := fmt.Errorf("step with run ID %q sent error message: %s", runtimeMessage.RunID, errorMessageStr) if errMessage.ServerFatal { c.sendErrorToAll(resultMsg) return true // It's server fatal, so this is the last message from the server. @@ -456,7 +459,11 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { var runtimeMessage DecodedRuntimeMessage for { if err := cborReader.Decode(&runtimeMessage); err != nil { - c.logger.Errorf("ATP client for steps '%s' failed to read or decode runtime message: %v", c.getRunningStepIDs(), err) + c.logger.Errorf( + "ATP client for steps '%s' failed to read or decode runtime message: %v", + c.getRunningStepIDs(), + err, + ) // This is fatal since the entire structure of the runtime message is invalid. c.sendErrorToAll(fmt.Errorf("failed to read or decode runtime message (%w)", err)) return @@ -467,13 +474,15 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { case MessageTypeSignal: c.handleSignalMessage(runtimeMessage) case MessageTypeError: - fatal := c.handleErrorMessage(runtimeMessage) - if fatal { - return + if c.handleErrorMessage(runtimeMessage) { + return // Fatal } default: - c.logger.Warningf("Step with run ID '%s' sent unknown message type: %s", runtimeMessage.RunID, - runtimeMessage.MessageID) + c.logger.Warningf( + "Step with run ID '%s' sent unknown message type: %d", + runtimeMessage.RunID, + runtimeMessage.MessageID, + ) } // The non-error exit condition is having no more entries remaining. if !c.hasEntriesRemaining() { @@ -543,9 +552,7 @@ func (c *client) prepareResultChannels( } // getResultV2 communicates with the RuntimeMessage loop to get the ExecutionResult. -func (c *client) getResultV2( - stepData schema.Input, -) ExecutionResult { +func (c *client) getResultV2(stepData schema.Input) ExecutionResult { c.mutex.Lock() defer c.mutex.Unlock() resultEntry, found := c.runningStepResultEntries[stepData.RunID] @@ -565,8 +572,8 @@ func (c *client) getResultV2( stepData.RunID), ) } - // Deletion of the entry needs to be done in this function after waiting for - // the value to ensure the value's lifetime is long enough in the map. + // Now that we've received the result for this step, remove it from the list + // of running steps so that we won't see it as running anymore. // It cannot be removed on the sender's side, since that would cause a race. delete(c.runningStepResultEntries, stepData.RunID) return *resultEntry.result From 7baef7be2bdecaa1b3424c9275113f53df00742a Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Wed, 3 Jul 2024 16:33:51 -0400 Subject: [PATCH 08/12] Fixed linting error --- atp/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atp/client.go b/atp/client.go index 7560095..e68da4c 100644 --- a/atp/client.go +++ b/atp/client.go @@ -410,7 +410,7 @@ func (c *client) handleSignalMessage(runtimeMessage DecodedRuntimeMessage) { signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) } -// Returns true if the error is fatal +// Returns true if the error is fatal. func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { var errMessage ErrorMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { From 558c0b9f049cb09255be2e75a6ac62bb28f6047d Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Wed, 3 Jul 2024 18:10:25 -0400 Subject: [PATCH 09/12] Add timeout for message sending --- atp/server.go | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/atp/server.go b/atp/server.go index e723b54..0a76e5b 100644 --- a/atp/server.go +++ b/atp/server.go @@ -8,6 +8,7 @@ import ( "io" "os" "sync" + "time" ) // RunATPServer runs an ArcaflowTransportProtocol server with a given schema. @@ -83,12 +84,23 @@ func initializeATPServerSession( func (s *atpServerSession) sendRuntimeMessage(msgID uint32, runID string, message any) error { s.encoderMutex.Lock() - defer s.encoderMutex.Unlock() - return s.cborStdout.Encode(RuntimeMessage{ - MessageID: msgID, - RunID: runID, - MessageData: message, - }) + doneChannel := make(chan error, 1) + go func() { + defer close(doneChannel) + doneChannel <- s.cborStdout.Encode(RuntimeMessage{ + MessageID: msgID, + RunID: runID, + MessageData: message, + }) + }() + select { + case err := <-doneChannel: + s.encoderMutex.Unlock() + return err + case <-time.After(time.Second * 60): + s.encoderMutex.Unlock() + return fmt.Errorf("send timeout exceeded while sending message ID %q for run id %q", msgID, runID) + } } func (s *atpServerSession) handleClosure() []*ServerError { From 3b7f12aca97f011b184f3f0d4b8edde52ebb2e7e Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Fri, 5 Jul 2024 17:27:23 -0400 Subject: [PATCH 10/12] Fix missing lock, and address review comments --- atp/client.go | 8 ++++---- atp/server.go | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/atp/client.go b/atp/client.go index e68da4c..ffa2627 100644 --- a/atp/client.go +++ b/atp/client.go @@ -427,7 +427,9 @@ func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { if runtimeMessage.RunID == "" { c.sendErrorToAll(fmt.Errorf("step fatal error missing run id (%w)", resultMsg)) } else { + c.mutex.Lock() c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(resultMsg)) + c.mutex.Unlock() } } return false @@ -567,10 +569,8 @@ func (c *client) getResultV2(stepData schema.Input) ExecutionResult { resultEntry.condition.Wait() } if resultEntry.result == nil { - return NewErrorExecutionResult( - fmt.Errorf("did not receive result from results entry in ATP client for step with run ID '%s'", - stepData.RunID), - ) + panic(fmt.Errorf("did not receive result from results entry in ATP client for step with run ID '%s'", + stepData.RunID)) } // Now that we've received the result for this step, remove it from the list // of running steps so that we won't see it as running anymore. diff --git a/atp/server.go b/atp/server.go index 0a76e5b..9f3b15d 100644 --- a/atp/server.go +++ b/atp/server.go @@ -93,12 +93,11 @@ func (s *atpServerSession) sendRuntimeMessage(msgID uint32, runID string, messag MessageData: message, }) }() + defer s.encoderMutex.Unlock() select { case err := <-doneChannel: - s.encoderMutex.Unlock() return err case <-time.After(time.Second * 60): - s.encoderMutex.Unlock() return fmt.Errorf("send timeout exceeded while sending message ID %q for run id %q", msgID, runID) } } From 067cc3841c65cf8c82d2fa98113a63d8ab37edc2 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Mon, 8 Jul 2024 19:18:05 -0400 Subject: [PATCH 11/12] Addressed review comments --- atp/client.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/atp/client.go b/atp/client.go index ffa2627..12258ac 100644 --- a/atp/client.go +++ b/atp/client.go @@ -298,9 +298,9 @@ func (c *client) executeWriteLoop( c.mutex.Lock() if c.done { c.mutex.Unlock() - // Close() was called, so exit now to prevent the channel from being added to - // the map, since Close() uses that map to determine the exit condition. - // Adding to the map would cause it to never exit. + // Close() was called, so exit now. + // Failure to exit now may result in this receivedSignals channel not getting + // closed, resulting in this function hanging. c.logger.Warningf( "write called loop for run ID %q on done client; skipping receive loop", runID, @@ -361,8 +361,8 @@ func (c *client) sendExecutionResult(runID string, result ExecutionResult) { if !found { return } - close(signalChannel) delete(c.runningStepEmittedSignalChannels, runID) + close(signalChannel) } func (c *client) sendErrorToAll(err error) { @@ -396,8 +396,8 @@ func (c *client) handleSignalMessage(runtimeMessage DecodedRuntimeMessage) { return } c.mutex.Lock() + defer c.mutex.Unlock() // Hold lock until we send to the channel to prevent premature closing of the channel. signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] - c.mutex.Unlock() if !found { c.logger.Warningf( "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ @@ -418,8 +418,8 @@ func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { runtimeMessage.RunID, err) } errorMessageStr := errMessage.ToString(runtimeMessage.RunID) - c.logger.Errorf("Step with run ID %q sent error message: %s", runtimeMessage.RunID, errorMessageStr) resultMsg := fmt.Errorf("step with run ID %q sent error message: %s", runtimeMessage.RunID, errorMessageStr) + c.logger.Errorf(resultMsg.Error()) if errMessage.ServerFatal { c.sendErrorToAll(resultMsg) return true // It's server fatal, so this is the last message from the server. @@ -438,15 +438,15 @@ func (c *client) handleErrorMessage(runtimeMessage DecodedRuntimeMessage) bool { func (c *client) hasEntriesRemaining() bool { c.mutex.Lock() defer c.mutex.Unlock() - remainingSteps := 0 for _, resultEntry := range c.runningStepResultEntries { - // The result is the reliable way to determine if it's done. There is a fraction of - // time when the entry is still in the map, but it is done. + // If any result is nil then we're not done. + // Context: There is a fraction of time when the entry is still in the map + // following completion. It is set to a non-nil value when done. if resultEntry.result == nil { - remainingSteps++ + return true } } - return remainingSteps != 0 + return false } func (c *client) executeReadLoop(cborReader *cbor.Decoder) { From a676ab4053d9426115374f7b2b2d15638da0a5b2 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Mon, 8 Jul 2024 19:22:38 -0400 Subject: [PATCH 12/12] Update comment --- atp/client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/atp/client.go b/atp/client.go index 12258ac..0052ba6 100644 --- a/atp/client.go +++ b/atp/client.go @@ -572,9 +572,9 @@ func (c *client) getResultV2(stepData schema.Input) ExecutionResult { panic(fmt.Errorf("did not receive result from results entry in ATP client for step with run ID '%s'", stepData.RunID)) } - // Now that we've received the result for this step, remove it from the list - // of running steps so that we won't see it as running anymore. - // It cannot be removed on the sender's side, since that would cause a race. + // Now that we've received the result for this step, remove it from the list of running steps. + // We do this here because the sender cannot tell when the message has been received, and so + // it cannot tell when it is safe to remove the entry from the map. delete(c.runningStepResultEntries, stepData.RunID) return *resultEntry.result }