diff --git a/atp/client.go b/atp/client.go index a60fdad..e0c6fba 100644 --- a/atp/client.go +++ b/atp/client.go @@ -3,7 +3,7 @@ package atp import ( "fmt" "github.com/fxamacker/cbor/v2" - log "go.arcalot.io/log/v2" + "go.arcalot.io/log/v2" "go.flow.arcalot.io/pluginsdk/schema" "io" "strings" @@ -36,13 +36,16 @@ type Client interface { // ReadSchema reads the schema from the ATP server. ReadSchema() (*schema.SchemaSchema, error) // Execute executes a step with a given context and returns the resulting output. Assumes you called ReadSchema first. - Execute(input schema.Input, receivedSignals chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult + Execute(input schema.Input, receivedSignals <-chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult Close() error Encoder() *cbor.Encoder Decoder() *cbor.Decoder } // NewClient creates a new ATP client (part of the engine code). +// Currently used only by tests in the Python- and Test-deployers. +// +//goland:noinspection GoUnusedExportedFunction func NewClient( channel ClientChannel, ) Client { @@ -72,7 +75,6 @@ func NewClientWithLogger( cbor.NewEncoder(channel), make(chan bool, 5), // Buffer to prevent deadlocks make([]schema.Input, 0), - make(map[string]chan schema.Input), make(map[string]*executionEntry), make(map[string]chan<- schema.Input), sync.Mutex{}, @@ -104,7 +106,6 @@ type client struct { encoder *cbor.Encoder doneChannel chan bool runningSteps []schema.Input - runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps runningStepResultEntries map[string]*executionEntry // Run ID to results runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps mutex sync.Mutex @@ -137,8 +138,9 @@ func (c *client) ReadSchema() (*schema.SchemaSchema, error) { err := c.validateVersion(hello.Version) if err != nil { - c.logger.Errorf("Unsupported plugin version. %w", err) - return nil, fmt.Errorf("unsupported plugin version: %w", err) + err = fmt.Errorf("unsupported plugin version: %w", err) + c.logger.Errorf(err.Error()) + return nil, err } c.atpVersion = hello.Version @@ -163,7 +165,7 @@ func (c *client) validateVersion(serverVersion int64) error { func (c *client) Execute( stepData schema.Input, - receivedSignals chan schema.Input, + receivedSignals <-chan schema.Input, emittedSignals chan<- schema.Input, ) ExecutionResult { c.logger.Debugf("Executing plugin step %s/%s...", stepData.RunID, stepData.ID) @@ -199,25 +201,9 @@ func (c *client) Execute( } c.logger.Debugf("Step '%s' started, waiting for response...", stepData.ID) - defer c.handleStepComplete(stepData.RunID, receivedSignals) return c.getResult(stepData, cborReader) } -// handleStepComplete is the deferred function that will handle closing of the received channel. -func (c *client) handleStepComplete(runID string, receivedSignals chan schema.Input) { - if receivedSignals != nil { - c.logger.Infof("Closing signal channel for finished step") - // Remove from the map to ensure that the client.Close() method doesn't double-close it - c.mutex.Lock() - // Validate that it exists, since Close() could have been called early. - _, exists := c.runningSignalReceiveLoops[runID] - if exists { - delete(c.runningSignalReceiveLoops, runID) - } - c.mutex.Unlock() - } -} - // Close Tells the client that it's done, and can stop listening for more requests. func (c *client) Close() error { c.mutex.Lock() @@ -285,7 +271,7 @@ func (c *client) getRunningStepIDs() string { // Listen for received signals, and send them over ATP if available. func (c *client) executeWriteLoop( runID string, - receivedSignals chan schema.Input, + receivedSignals <-chan schema.Input, ) { c.mutex.Lock() if c.done { @@ -299,14 +285,8 @@ func (c *client) executeWriteLoop( ) return } - // Add the channel to the client so that it can be kept track of - c.runningSignalReceiveLoops[runID] = receivedSignals c.mutex.Unlock() - defer func() { - c.mutex.Lock() - delete(c.runningSignalReceiveLoops, runID) - c.mutex.Unlock() - }() + // Looped select that gets signals for { signal, ok := <-receivedSignals @@ -326,8 +306,13 @@ func (c *client) executeWriteLoop( SignalID: signal.ID, Data: signal.InputData, }}); err != nil { - c.logger.Errorf("Client with steps '%s' failed to write signal (%s) with run id '&s' with error: %w", - c.getRunningStepIDs(), signal.ID, signal.RunID, err) + c.logger.Errorf( + "Client with steps '%s' failed to write signal (%s) with run id %q with error: %v", + c.getRunningStepIDs(), + signal.ID, + signal.RunID, + err, + ) return } c.logger.Debugf("Successfully sent signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID) @@ -505,9 +490,9 @@ func (c *client) getResultV1( ) ExecutionResult { var doneMessage WorkDoneMessage if err := cborReader.Decode(&doneMessage); err != nil { - c.logger.Errorf("Failed to read or decode work done message: (%w) for step %s", err, stepData.ID) - return NewErrorExecutionResult( - fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID)) + err = fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID) + c.logger.Errorf(err.Error()) + return NewErrorExecutionResult(err) } return c.processWorkDone(stepData.RunID, doneMessage) }