diff --git a/atp/client.go b/atp/client.go index e0c6fba..d633929 100644 --- a/atp/client.go +++ b/atp/client.go @@ -1,6 +1,7 @@ package atp import ( + "context" "fmt" "github.com/fxamacker/cbor/v2" "go.arcalot.io/log/v2" @@ -36,7 +37,13 @@ type Client interface { // ReadSchema reads the schema from the ATP server. ReadSchema() (*schema.SchemaSchema, error) // Execute executes a step with a given context and returns the resulting output. Assumes you called ReadSchema first. - Execute(input schema.Input, receivedSignals <-chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult + // Params: + // - input: The step input for the run. + // - signalsToStep: A channel to send signals from the client to the plugin. + // - signalsFromStep: A channel to receive signals from the plugin to the client. + // It is recommended to close the signalsToStep channel when either Execute is done or it is known that no more signals + // will be sent to the plugin. + Execute(input schema.Input, signalsToStep <-chan schema.Input, signalsFromStep chan<- schema.Input) ExecutionResult Close() error Encoder() *cbor.Encoder Decoder() *cbor.Decoder @@ -66,6 +73,7 @@ func NewClientWithLogger( if logger == nil { logger = log.NewLogger(log.LevelDebug, log.NewNOOPLogger()) } + ctx, cancel := context.WithCancel(context.Background()) return &client{ -1, // unknown channel, @@ -73,13 +81,14 @@ func NewClientWithLogger( logger, decMode.NewDecoder(channel), cbor.NewEncoder(channel), - make(chan bool, 5), // Buffer to prevent deadlocks make([]schema.Input, 0), make(map[string]*executionEntry), make(map[string]chan<- schema.Input), sync.Mutex{}, false, false, + ctx, + cancel, sync.WaitGroup{}, } } @@ -99,18 +108,19 @@ type executionEntry struct { type client struct { atpVersion int64 - channel ClientChannel + rawAtpChannels ClientChannel decMode cbor.DecMode logger log.Logger decoder *cbor.Decoder encoder *cbor.Encoder - doneChannel chan bool runningSteps []schema.Input runningStepResultEntries map[string]*executionEntry // Run ID to results runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps mutex sync.Mutex - readLoopRunning bool + readLoopRunning bool // To prevent duplicate loops across multiple step executions. done bool + context context.Context + cancelFunc context.CancelFunc wg sync.WaitGroup // For the read loop. } @@ -165,8 +175,8 @@ func (c *client) validateVersion(serverVersion int64) error { func (c *client) Execute( stepData schema.Input, - receivedSignals <-chan schema.Input, - emittedSignals chan<- schema.Input, + signalsToStep <-chan schema.Input, + signalsFromStep chan<- schema.Input, ) ExecutionResult { c.logger.Debugf("Executing plugin step %s/%s...", stepData.RunID, stepData.ID) if len(stepData.RunID) == 0 { @@ -177,20 +187,20 @@ func (c *client) Execute( StepID: stepData.ID, Config: stepData.InputData, } - cborReader := c.decMode.NewDecoder(c.channel) + cborReader := c.decMode.NewDecoder(c.rawChannels) if c.atpVersion > 1 { // Wrap it in a runtime message. workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg} // Handle signals to the step - if receivedSignals != nil { + if signalsToStep != nil { c.wg.Add(1) go func() { defer c.wg.Done() - c.executeWriteLoop(stepData.RunID, receivedSignals) + c.executeWriteLoop(stepData.RunID, signalsToStep) }() } // Setup channels for ATP v2 - err := c.prepareResultChannels(cborReader, stepData, emittedSignals) + err := c.prepareResultChannels(cborReader, stepData, signalsFromStep) if err != nil { return NewErrorExecutionResult(err) } @@ -206,6 +216,7 @@ func (c *client) Execute( // Close Tells the client that it's done, and can stop listening for more requests. func (c *client) Close() error { + c.cancelFunc() c.mutex.Lock() if c.done { c.mutex.Unlock() @@ -271,14 +282,12 @@ func (c *client) getRunningStepIDs() string { // Listen for received signals, and send them over ATP if available. func (c *client) executeWriteLoop( runID string, - receivedSignals <-chan schema.Input, + signalsToStep <-chan schema.Input, ) { c.mutex.Lock() if c.done { c.mutex.Unlock() // Close() was called, so exit now. - // Failure to exit now may result in this receivedSignals channel not getting - // closed, resulting in this function hanging. c.logger.Warningf( "write called loop for run ID %q on done client; skipping receive loop", runID, @@ -289,9 +298,16 @@ func (c *client) executeWriteLoop( // Looped select that gets signals for { - signal, ok := <-receivedSignals - if !ok { - c.logger.Infof("ATP signal loop done") + var signal schema.Input + var ok bool + select { + case signal, ok = <-signalsToStep: + if !ok { + c.logger.Debugf("ATP signal loop done; channel closed") + return + } + case <-c.context.Done(): + c.logger.Debugf("ATP signal loop exited; context closed") return } c.logger.Debugf("Sending signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID) diff --git a/atp/protocol_test.go b/atp/protocol_test.go index 6ecbde1..e5a24d6 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -197,6 +197,66 @@ func TestProtocol_Client_Execute(t *testing.T) { wg.Wait() } +func TestProtocol_Client_Execute_With_Signals(t *testing.T) { + testExecuteWithChannels(true, t) +} + +func TestProtocol_Client_Execute_With_Signals_Unclosed(t *testing.T) { + testExecuteWithChannels(false, t) +} + +func testExecuteWithChannels(closeChannel bool, t *testing.T) { + // Client ReadSchema and Execute happy path with signal handlers passed + // into the Execute call. + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(2) + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + + go func() { + defer wg.Done() + errors := atp.RunATPServer( + ctx, + stdinReader, + stdoutWriter, + helloWorldSchema, + ) + assert.Equals(t, len(errors), 0) + }() + + go func() { + defer wg.Done() + cli := atp.NewClientWithLogger(channel{ + Reader: stdoutReader, + Writer: stdinWriter, + Context: nil, + cancel: cancel, + }, log.NewTestLogger(t)) + + _, err := cli.ReadSchema() + assert.NoError(t, err) + toStepChan := make(chan schema.Input) + fromStepChan := make(chan schema.Input) + + result := cli.Execute( + schema.Input{ + RunID: t.Name(), + ID: "hello-world", + InputData: map[string]any{"name": "Arca Lot"}, + }, toStepChan, fromStepChan) + if closeChannel { + close(toStepChan) + } + assert.NoError(t, cli.Close()) + assert.NoError(t, result.Error) + assert.Equals(t, result.OutputID, "success") + assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, Arca Lot!") + }() + + wg.Wait() +} + //nolint:funlen func TestProtocol_Client_ATP_v1(t *testing.T) { // Client ReadSchema and Execute atp v1 happy path.