Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atp concurrency improvements #92

Merged
merged 14 commits into from
Jul 9, 2024
135 changes: 98 additions & 37 deletions atp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"io"
"strings"
"sync"
"time"
)

var supportedServerVersions = []int64{1, 3}
Expand Down Expand Up @@ -72,7 +73,7 @@
make(chan bool, 5), // Buffer to prevent deadlocks
make([]schema.Input, 0),
make(map[string]chan schema.Input),
make(map[string]chan ExecutionResult),
make(map[string]*executionEntry),
make(map[string]chan<- schema.Input),
sync.Mutex{},
false,
Expand All @@ -89,6 +90,11 @@
return c.encoder
}

type executionEntry struct {
result *ExecutionResult
condition sync.Cond
}

type client struct {
atpVersion int64
channel ClientChannel
Expand All @@ -98,9 +104,9 @@
encoder *cbor.Encoder
doneChannel chan bool
runningSteps []schema.Input
runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps
runningStepResultChannels map[string]chan ExecutionResult // Run ID to channel of results
runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps
runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps
runningStepResultEntries map[string]*executionEntry // Run ID to results
runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps
mutex sync.Mutex
readLoopRunning bool
done bool
Expand Down Expand Up @@ -175,7 +181,9 @@
workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg}
// Handle signals to the step
if receivedSignals != nil {
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.executeWriteLoop(stepData.RunID, receivedSignals)
}()
}
Expand Down Expand Up @@ -215,6 +223,7 @@
func (c *client) Close() error {
c.mutex.Lock()
if c.done {
c.mutex.Unlock()
return nil
}
c.done = true
Expand All @@ -235,14 +244,41 @@
clientDoneMessage{},
})
if err != nil {
return fmt.Errorf("client with steps '%s' failed to write client done message with error: %w",
c.getRunningStepIDs(), err)
// add a timeout to the wait to prevent it from causing a deadlock.
// 5 seconds is arbitrary, but gives it enough time to exit.
waitedGracefully := waitWithTimeout(time.Second*5, &c.wg)
if waitedGracefully {
return fmt.Errorf("client with step '%s' failed to write client done message with error: %w",
c.getRunningStepIDs(), err)
} else {
panic(fmt.Errorf("potential deadlock after client with step '%s' failed to write client done message with error: %w",
c.getRunningStepIDs(), err))
}
}
}
c.wg.Wait()
return nil
}
webbnh marked this conversation as resolved.
Show resolved Hide resolved

// Waits for the WaitGroup to finish, but with a timeout to
// prevent a deadlock.
webbnh marked this conversation as resolved.
Show resolved Hide resolved
// Returns true if the WaitGroup finished, and false if
// it reached the end of the timeout.
func waitWithTimeout(duration time.Duration, wg *sync.WaitGroup) bool {
// Run a goroutine to do the waiting
doneChannel := make(chan bool, 1)
go func() {
defer close(doneChannel)
wg.Wait()
}()
select {
case <-doneChannel:
return true
case <-time.After(duration):
return false
}
}

func (c *client) getRunningStepIDs() string {
if len(c.runningSteps) == 0 {
return "No running steps"
Expand All @@ -261,6 +297,13 @@
) {
// Add the channel to the client so that it can be kept track of
c.mutex.Lock()
if c.done {
c.mutex.Unlock()
// You need to abort to allow proper closure, since the channel would otherwise
// be left open.
c.logger.Warningf("aborting write loop for run ID %q due to done client", runID)
return
webbnh marked this conversation as resolved.
Show resolved Hide resolved
}
c.runningSignalReceiveLoops[runID] = receivedSignals
webbnh marked this conversation as resolved.
Show resolved Hide resolved
c.mutex.Unlock()
defer func() {
Expand Down Expand Up @@ -295,31 +338,23 @@
}
}

// sendExecutionResult sends the results to the channel, and closes then removes the channels for the
// step results and the signals.
// sendExecutionResult finalizes the result entry for processing by the client's caller, and
// closes then removes the channels for the signals.
// The caller should have the mutex locked while calling this function.
func (c *client) sendExecutionResult(runID string, result ExecutionResult) {
webbnh marked this conversation as resolved.
Show resolved Hide resolved
c.logger.Debugf("Providing input for run ID '%s'", runID)
c.mutex.Lock()
resultChannel, found := c.runningStepResultChannels[runID]
c.mutex.Unlock()
c.logger.Debugf("Sending results for run ID '%s'", runID)
resultEntry, found := c.runningStepResultEntries[runID]
if found {
// Send the result
resultChannel <- result
// Close the channel and remove it to detect incorrectly duplicate results.
close(resultChannel)
c.mutex.Lock()
delete(c.runningStepResultChannels, runID)
c.mutex.Unlock()
resultEntry.result = &result
webbnh marked this conversation as resolved.
Show resolved Hide resolved
resultEntry.condition.Signal()
} else {
c.logger.Errorf("Step result channel not found for run ID '%s'. This is either a bug in the ATP "+
c.logger.Errorf("Step result entry not found for run ID '%s'. This is either a bug in the ATP "+
"client, or the plugin erroneously sent a second result.", runID)
}
// Now close the signal channel, since it's invalid to send a signal after the step is complete.
c.mutex.Lock()
defer c.mutex.Unlock()
signalChannel, found := c.runningStepEmittedSignalChannels[runID]
if !found {
c.logger.Debugf("Could not find signal output channel for run ID '%s'", runID)
return
}
close(signalChannel)
Expand All @@ -328,13 +363,15 @@

func (c *client) sendErrorToAll(err error) {
result := NewErrorExecutionResult(err)
for runID := range c.runningStepResultChannels {
c.mutex.Lock()
for runID := range c.runningStepResultEntries {
c.sendExecutionResult(runID, result)
}
c.mutex.Unlock()
}

//nolint:funlen
func (c *client) executeReadLoop(cborReader *cbor.Decoder) {

Check failure on line 374 in atp/client.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

cognitive complexity 33 of func `(*client).executeReadLoop` is high (> 30) (gocognit)
defer func() {
c.mutex.Lock()
defer c.mutex.Unlock()
Expand All @@ -356,17 +393,23 @@
var doneMessage WorkDoneMessage
if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil {
c.logger.Errorf("Failed to decode work done message (%v) for run ID '%s' ", err, runtimeMessage.RunID)
c.mutex.Lock()
c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(
fmt.Errorf("failed to decode work done message (%w)", err)))
c.mutex.Unlock()
}
c.mutex.Lock()
c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage))
c.mutex.Unlock()
webbnh marked this conversation as resolved.
Show resolved Hide resolved
case MessageTypeSignal:
var signalMessage SignalMessage
if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil {
c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v",
runtimeMessage.RunID, err)
}
c.mutex.Lock()
signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID]
c.mutex.Unlock()
webbnh marked this conversation as resolved.
Show resolved Hide resolved
if !found {
c.logger.Warningf(
"Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+
Expand Down Expand Up @@ -401,7 +444,15 @@
runtimeMessage.MessageID)
}
c.mutex.Lock()
if len(c.runningStepResultChannels) == 0 {
remainingSteps := 0
for _, resultEntry := range c.runningStepResultEntries {
// The result is the reliable way to determine if it's done. There is a fraction of
// time when the entry is still in the map, but it's done.
if resultEntry.result == nil {
remainingSteps++
}
}
if remainingSteps == 0 {
c.mutex.Unlock()
return // Done
}
webbnh marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -441,15 +492,19 @@
stepData schema.Input,
emittedSignals chan<- schema.Input,
) error {
c.logger.Debugf("Preparing result channels for step with run ID %q", stepData.RunID)
c.mutex.Lock()
defer c.mutex.Unlock()
_, existing := c.runningStepResultChannels[stepData.RunID]
_, existing := c.runningStepResultEntries[stepData.RunID]
if existing {
return fmt.Errorf("duplicate run ID given '%s'", stepData.RunID)
}
// Set up the signal and step results channels
resultChannel := make(chan ExecutionResult)
c.runningStepResultChannels[stepData.RunID] = resultChannel
resultEntry := executionEntry{
result: nil,
condition: sync.Cond{L: &c.mutex},
}
c.runningStepResultEntries[stepData.RunID] = &resultEntry
if emittedSignals != nil {
c.runningStepEmittedSignalChannels[stepData.RunID] = emittedSignals
}
Expand All @@ -465,28 +520,34 @@
return nil
}

// getResultV2 works with the channels that communicate with the RuntimeMessage loop.
// getResultV2 communicates with the RuntimeMessage loop to get the .
webbnh marked this conversation as resolved.
Show resolved Hide resolved
func (c *client) getResultV2(
stepData schema.Input,
) ExecutionResult {
webbnh marked this conversation as resolved.
Show resolved Hide resolved
c.mutex.Lock()
resultChannel, found := c.runningStepResultChannels[stepData.RunID]
c.mutex.Unlock()
resultEntry, found := c.runningStepResultEntries[stepData.RunID]
c.logger.Debugf("Got result entry for run ID %q", stepData.RunID)
webbnh marked this conversation as resolved.
Show resolved Hide resolved
if !found {
return NewErrorExecutionResult(
fmt.Errorf("could not find result channel for step with run ID '%s'",
stepData.RunID),
fmt.Errorf("could not find result entry for step with run ID '%s'. Existing entries: %v",
stepData.RunID, c.runningStepResultEntries),
)
}
// Wait for the result
result, received := <-resultChannel
if !received {
if resultEntry.result == nil {
// Wait for the result
resultEntry.condition.Wait()
}
if resultEntry.result == nil {
return NewErrorExecutionResult(
fmt.Errorf("did not receive result from results channel in ATP client for step with run ID '%s'",
fmt.Errorf("did not receive result from results entry in ATP client for step with run ID '%s'",
stepData.RunID),
)
}
webbnh marked this conversation as resolved.
Show resolved Hide resolved
return result
defer c.mutex.Unlock()
// This needs to be done after receiving the value, or else the sender will
// not be able to get the entry.
delete(c.runningStepResultEntries, stepData.RunID)
webbnh marked this conversation as resolved.
Show resolved Hide resolved
return *resultEntry.result
webbnh marked this conversation as resolved.
Show resolved Hide resolved
}

func (c *client) processWorkDone(
Expand Down
Loading