Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use typed channels #97

Merged
merged 8 commits into from
Jul 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 21 additions & 36 deletions atp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package atp
import (
"fmt"
"github.com/fxamacker/cbor/v2"
log "go.arcalot.io/log/v2"
"go.arcalot.io/log/v2"
"go.flow.arcalot.io/pluginsdk/schema"
"io"
"strings"
Expand Down Expand Up @@ -36,13 +36,16 @@ type Client interface {
// ReadSchema reads the schema from the ATP server.
ReadSchema() (*schema.SchemaSchema, error)
// Execute executes a step with a given context and returns the resulting output. Assumes you called ReadSchema first.
Execute(input schema.Input, receivedSignals chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult
Execute(input schema.Input, receivedSignals <-chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult
Close() error
Encoder() *cbor.Encoder
Decoder() *cbor.Decoder
}

// NewClient creates a new ATP client (part of the engine code).
// Currently used only by tests in the Python- and Test-deployers.
//
//goland:noinspection GoUnusedExportedFunction
func NewClient(
webbnh marked this conversation as resolved.
Show resolved Hide resolved
channel ClientChannel,
) Client {
Expand Down Expand Up @@ -72,7 +75,6 @@ func NewClientWithLogger(
cbor.NewEncoder(channel),
make(chan bool, 5), // Buffer to prevent deadlocks
make([]schema.Input, 0),
make(map[string]chan schema.Input),
make(map[string]*executionEntry),
make(map[string]chan<- schema.Input),
sync.Mutex{},
Expand Down Expand Up @@ -104,7 +106,6 @@ type client struct {
encoder *cbor.Encoder
doneChannel chan bool
runningSteps []schema.Input
runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps
runningStepResultEntries map[string]*executionEntry // Run ID to results
runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps
mutex sync.Mutex
Expand Down Expand Up @@ -137,8 +138,9 @@ func (c *client) ReadSchema() (*schema.SchemaSchema, error) {
err := c.validateVersion(hello.Version)

if err != nil {
c.logger.Errorf("Unsupported plugin version. %w", err)
return nil, fmt.Errorf("unsupported plugin version: %w", err)
err = fmt.Errorf("unsupported plugin version: %w", err)
c.logger.Errorf(err.Error())
return nil, err
}
c.atpVersion = hello.Version

Expand All @@ -163,7 +165,7 @@ func (c *client) validateVersion(serverVersion int64) error {

func (c *client) Execute(
stepData schema.Input,
receivedSignals chan schema.Input,
receivedSignals <-chan schema.Input,
emittedSignals chan<- schema.Input,
) ExecutionResult {
c.logger.Debugf("Executing plugin step %s/%s...", stepData.RunID, stepData.ID)
Expand Down Expand Up @@ -199,25 +201,9 @@ func (c *client) Execute(
}
c.logger.Debugf("Step '%s' started, waiting for response...", stepData.ID)

defer c.handleStepComplete(stepData.RunID, receivedSignals)
return c.getResult(stepData, cborReader)
}

// handleStepComplete is the deferred function that will handle closing of the received channel.
func (c *client) handleStepComplete(runID string, receivedSignals chan schema.Input) {
if receivedSignals != nil {
c.logger.Infof("Closing signal channel for finished step")
// Remove from the map to ensure that the client.Close() method doesn't double-close it
c.mutex.Lock()
// Validate that it exists, since Close() could have been called early.
_, exists := c.runningSignalReceiveLoops[runID]
if exists {
delete(c.runningSignalReceiveLoops, runID)
}
c.mutex.Unlock()
}
}

// Close Tells the client that it's done, and can stop listening for more requests.
func (c *client) Close() error {
c.mutex.Lock()
Expand Down Expand Up @@ -285,7 +271,7 @@ func (c *client) getRunningStepIDs() string {
// Listen for received signals, and send them over ATP if available.
func (c *client) executeWriteLoop(
runID string,
receivedSignals chan schema.Input,
receivedSignals <-chan schema.Input,
) {
c.mutex.Lock()
if c.done {
Expand All @@ -299,14 +285,8 @@ func (c *client) executeWriteLoop(
)
return
}
// Add the channel to the client so that it can be kept track of
c.runningSignalReceiveLoops[runID] = receivedSignals
c.mutex.Unlock()
defer func() {
c.mutex.Lock()
delete(c.runningSignalReceiveLoops, runID)
c.mutex.Unlock()
}()

// Looped select that gets signals
for {
signal, ok := <-receivedSignals
Expand All @@ -326,8 +306,13 @@ func (c *client) executeWriteLoop(
SignalID: signal.ID,
Data: signal.InputData,
}}); err != nil {
c.logger.Errorf("Client with steps '%s' failed to write signal (%s) with run id '&s' with error: %w",
c.getRunningStepIDs(), signal.ID, signal.RunID, err)
c.logger.Errorf(
"Client with steps '%s' failed to write signal (%s) with run id %q with error: %v",
c.getRunningStepIDs(),
signal.ID,
signal.RunID,
err,
)
return
}
c.logger.Debugf("Successfully sent signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID)
Expand Down Expand Up @@ -505,9 +490,9 @@ func (c *client) getResultV1(
) ExecutionResult {
var doneMessage WorkDoneMessage
if err := cborReader.Decode(&doneMessage); err != nil {
c.logger.Errorf("Failed to read or decode work done message: (%w) for step %s", err, stepData.ID)
return NewErrorExecutionResult(
fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID))
err = fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID)
c.logger.Errorf(err.Error())
return NewErrorExecutionResult(err)
}
return c.processWorkDone(stepData.RunID, doneMessage)
}
Expand Down