Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use typed channels #97

Merged
merged 8 commits into from
Jul 29, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions atp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package atp
import (
"fmt"
"github.com/fxamacker/cbor/v2"
log "go.arcalot.io/log/v2"
"go.arcalot.io/log/v2"
"go.flow.arcalot.io/pluginsdk/schema"
"io"
"strings"
Expand Down Expand Up @@ -43,6 +43,9 @@ type Client interface {
}

// NewClient creates a new ATP client (part of the engine code).
// Currently used only by tests in the Python- and Test-deployers.
//
//goland:noinspection GoUnusedExportedFunction
func NewClient(
webbnh marked this conversation as resolved.
Show resolved Hide resolved
channel ClientChannel,
) Client {
Expand Down Expand Up @@ -72,7 +75,7 @@ func NewClientWithLogger(
cbor.NewEncoder(channel),
make(chan bool, 5), // Buffer to prevent deadlocks
make([]schema.Input, 0),
make(map[string]chan schema.Input),
make(map[string]chan<- schema.Input),
make(map[string]*executionEntry),
make(map[string]chan<- schema.Input),
sync.Mutex{},
Expand Down Expand Up @@ -104,7 +107,7 @@ type client struct {
encoder *cbor.Encoder
doneChannel chan bool
runningSteps []schema.Input
runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps
runningSignalReceiveLoops map[string]chan<- schema.Input // Run ID to channel of signals to steps
runningStepResultEntries map[string]*executionEntry // Run ID to results
runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps
mutex sync.Mutex
Expand Down Expand Up @@ -137,8 +140,9 @@ func (c *client) ReadSchema() (*schema.SchemaSchema, error) {
err := c.validateVersion(hello.Version)

if err != nil {
c.logger.Errorf("Unsupported plugin version. %w", err)
return nil, fmt.Errorf("unsupported plugin version: %w", err)
err = fmt.Errorf("unsupported plugin version: %w", err)
c.logger.Errorf(err.Error())
return nil, err
}
c.atpVersion = hello.Version

Expand Down Expand Up @@ -199,23 +203,20 @@ func (c *client) Execute(
}
c.logger.Debugf("Step '%s' started, waiting for response...", stepData.ID)

defer c.handleStepComplete(stepData.RunID, receivedSignals)
if receivedSignals != nil {
defer c.handleStepComplete(stepData.RunID)
}
dbutenhof marked this conversation as resolved.
Show resolved Hide resolved
return c.getResult(stepData, cborReader)
}

// handleStepComplete is the deferred function that will handle closing of the received channel.
func (c *client) handleStepComplete(runID string, receivedSignals chan schema.Input) {
if receivedSignals != nil {
c.logger.Infof("Closing signal channel for finished step")
// Remove from the map to ensure that the client.Close() method doesn't double-close it
c.mutex.Lock()
// Validate that it exists, since Close() could have been called early.
_, exists := c.runningSignalReceiveLoops[runID]
if exists {
delete(c.runningSignalReceiveLoops, runID)
}
c.mutex.Unlock()
}
func (c *client) handleStepComplete(runID string) {
c.logger.Infof("Closing signal channel for finished step")
// Remove from the map to ensure that the client.Close() method doesn't double-close it
c.mutex.Lock()
// Validate that it exists, since Close() could have been called early.
dbutenhof marked this conversation as resolved.
Show resolved Hide resolved
delete(c.runningSignalReceiveLoops, runID)
c.mutex.Unlock()
}

// Close Tells the client that it's done, and can stop listening for more requests.
Expand Down Expand Up @@ -326,8 +327,13 @@ func (c *client) executeWriteLoop(
SignalID: signal.ID,
Data: signal.InputData,
}}); err != nil {
c.logger.Errorf("Client with steps '%s' failed to write signal (%s) with run id '&s' with error: %w",
c.getRunningStepIDs(), signal.ID, signal.RunID, err)
c.logger.Errorf(
"Client with steps '%s' failed to write signal (%s) with run id %q with error: %v",
c.getRunningStepIDs(),
signal.ID,
signal.RunID,
err,
)
return
}
c.logger.Debugf("Successfully sent signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID)
Expand Down Expand Up @@ -505,9 +511,9 @@ func (c *client) getResultV1(
) ExecutionResult {
var doneMessage WorkDoneMessage
if err := cborReader.Decode(&doneMessage); err != nil {
c.logger.Errorf("Failed to read or decode work done message: (%w) for step %s", err, stepData.ID)
return NewErrorExecutionResult(
fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID))
err = fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID)
c.logger.Errorf(err.Error())
return NewErrorExecutionResult(err)
}
return c.processWorkDone(stepData.RunID, doneMessage)
}
Expand Down