From a318a4750a200e42657abf29c6530141d88da4ee Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 5 Oct 2023 11:10:37 -0400 Subject: [PATCH 1/9] Refactored ATP v2 Made work start message a runtime message to allow graceful closure instead of error on schema-only executions Allowed multiple simultaneous step executions Improved error reporting Fixed deadlocks and race conditions --- atp/client.go | 415 +++++++++++++++++++++++++++++++----------- atp/protocol.go | 31 +++- atp/protocol_test.go | 211 ++++++++++++++++++--- atp/server.go | 292 +++++++++++++++++++++-------- schema/input.go | 1 + schema/schema.go | 6 +- schema/schema_test.go | 2 +- schema/step.go | 75 +++++--- schema/step_test.go | 2 +- 9 files changed, 792 insertions(+), 243 deletions(-) diff --git a/atp/client.go b/atp/client.go index ceb4b56..a2f1dd1 100644 --- a/atp/client.go +++ b/atp/client.go @@ -7,6 +7,7 @@ import ( "go.flow.arcalot.io/pluginsdk/schema" "io" "strings" + "sync" ) const MinSupportedATPVersion = 1 @@ -19,13 +20,24 @@ type ClientChannel interface { io.Closer } +type ExecutionResult struct { + OutputID string + OutputData any + Error error +} + +func NewErrorExecutionResult(err error) ExecutionResult { + return ExecutionResult{"", nil, err} +} + // Client is the way to read information from the ATP server and then send a task to it in the form of a step. // A step can only be sent once, but signals can be sent until the step is over. It is a single session. type Client interface { // ReadSchema reads the schema from the ATP server. ReadSchema() (*schema.SchemaSchema, error) - // Execute executes a step with a given context and returns the resulting output. - Execute(input schema.Input, receivedSignals chan schema.Input, emittedSignals chan<- schema.Input) (outputID string, outputData any, err error) + // Execute executes a step with a given context and returns the resulting output. Assumes you called ReadSchema first. + Execute(input schema.Input, receivedSignals chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult + Close() error Encoder() *cbor.Encoder Decoder() *cbor.Decoder } @@ -58,6 +70,15 @@ func NewClientWithLogger( logger, decMode.NewDecoder(channel), cbor.NewEncoder(channel), + make(chan bool, 5), // Buffer to prevent deadlocks + make([]schema.Input, 0), + make(map[string]chan schema.Input), + make(map[string]chan ExecutionResult), + make(map[string]chan<- schema.Input), + sync.Mutex{}, + false, + false, + sync.WaitGroup{}, } } @@ -70,18 +91,33 @@ func (c *client) Encoder() *cbor.Encoder { } type client struct { - atpVersion int64 - channel ClientChannel - decMode cbor.DecMode - logger log.Logger - decoder *cbor.Decoder - encoder *cbor.Encoder + atpVersion int64 + channel ClientChannel + decMode cbor.DecMode + logger log.Logger + decoder *cbor.Decoder + encoder *cbor.Encoder + doneChannel chan bool + runningSteps []schema.Input + runningSignalReceiveLoops map[string]chan schema.Input // Run ID to channel of signals to steps + runningStepResultChannels map[string]chan ExecutionResult // Run ID to channel of results + runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps + mutex sync.Mutex + readLoopRunning bool + done bool + wg sync.WaitGroup // For the read loop. +} + +func (c *client) sendCBOR(message any) error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.encoder.Encode(message) } func (c *client) ReadSchema() (*schema.SchemaSchema, error) { c.logger.Debugf("Reading plugin schema...") - if err := c.encoder.Encode(nil); err != nil { + if err := c.sendCBOR(nil); err != nil { c.logger.Errorf("Failed to encode ATP start output message: %v", err) return nil, fmt.Errorf("failed to encode start output message (%w)", err) } @@ -107,6 +143,7 @@ func (c *client) ReadSchema() (*schema.SchemaSchema, error) { return nil, fmt.Errorf("invalid schema (%w)", err) } c.logger.Debugf("Schema unserialization complete.") + return unserializedSchema, nil } @@ -114,168 +151,340 @@ func (c *client) Execute( stepData schema.Input, receivedSignals chan schema.Input, emittedSignals chan<- schema.Input, -) (outputID string, outputData any, err error) { - c.logger.Debugf("Executing plugin step %s...", stepData.ID) - if err := c.encoder.Encode(StartWorkMessage{ +) ExecutionResult { + c.logger.Debugf("Executing plugin step %s/%s...", stepData.RunID, stepData.ID) + if len(stepData.RunID) == 0 { + return NewErrorExecutionResult(fmt.Errorf("run ID is blank for step %s", stepData.ID)) + } + var workStartMsg any + workStartMsg = WorkStartMessage{ StepID: stepData.ID, Config: stepData.InputData, - }); err != nil { + } + cborReader := c.decMode.NewDecoder(c.channel) + if c.atpVersion > 1 { + // Wrap it in a runtime message. + workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg} + // Handle signals to the step + if receivedSignals != nil { + go func() { + c.executeWriteLoop(stepData.RunID, receivedSignals) + }() + } + // Setup channels for ATP v2 + err := c.prepareResultChannels(cborReader, stepData, emittedSignals) + if err != nil { + return NewErrorExecutionResult(err) + } + } + if err := c.sendCBOR(workStartMsg); err != nil { c.logger.Errorf("Step %s failed to write start work message: %v", stepData.ID, err) - return "", nil, fmt.Errorf("failed to write work start message (%w)", err) + return NewErrorExecutionResult(fmt.Errorf("failed to write work start message (%w)", err)) } c.logger.Debugf("Step %s started, waiting for response...", stepData.ID) - doneChannel := make(chan bool, 1) // Needs a buffer to not hang. - defer handleClientClosure(receivedSignals, doneChannel) - if c.atpVersion > 1 { - go func() { - c.executeWriteLoop(stepData, receivedSignals, doneChannel) - }() - } - return c.executeReadLoop(stepData, receivedSignals) + defer c.handleStepComplete(stepData.RunID, receivedSignals) + return c.getResult(stepData, cborReader) } -// handleClosure is the deferred function that will handle closing of the received channel, -// and notifying the code that it's closed. -// Note: The doneChannel should have a buffer. -func handleClientClosure(receivedSignals chan schema.Input, doneChannel chan bool) { - doneChannel <- true +// handleStepComplete is the deferred function that will handle closing of the received channel +func (c *client) handleStepComplete(runID string, receivedSignals chan schema.Input) { if receivedSignals != nil { + c.logger.Infof("Closing signal channel for finished step") + // Remove from the map to ensure that the client.Close() method doesn't double-close it + c.mutex.Lock() + delete(c.runningSignalReceiveLoops, runID) + c.mutex.Unlock() close(receivedSignals) } } +// Close Tells the client that it's done, and can stop listening for more requests. +func (c *client) Close() error { + c.mutex.Lock() + if c.done { + return nil + } + c.done = true + c.mutex.Unlock() + // First, close channels that could send signals to the clients + // This ends the loop + for runID, signalChannel := range c.runningSignalReceiveLoops { + // TODO: Test why commenting this out results in a deadlock instead of just the steps finishing when they're supposed to. + c.logger.Infof("Closing signal channel for run ID '%s'", runID) + close(signalChannel) + } + // Now tell the server we're done. + // Send the client done message + if c.atpVersion > 1 { + err := c.sendCBOR(RuntimeMessage{ + MessageTypeClientDone, + "", + clientDoneMessage{}, + }) + if err != nil { + return fmt.Errorf("client with steps '%s' failed to write client done message with error: %w", + c.getRunningStepIDs(), err) + } + } + c.wg.Wait() + return nil +} + +func (c *client) getRunningStepIDs() string { + if len(c.runningSteps) == 0 { + return "No running steps" + } + result := "" + for _, step := range c.runningSteps { + result += " " + step.RunID + "/" + step.ID + } + return result +} + // Listen for received signals, and send them over ATP if available. func (c *client) executeWriteLoop( - stepData schema.Input, + runID string, receivedSignals chan schema.Input, - doneChannel chan bool, ) { + // Add the channel to the client so that it can be kept track of + c.mutex.Lock() + c.runningSignalReceiveLoops[runID] = receivedSignals + c.mutex.Unlock() + defer func() { + c.mutex.Lock() + delete(c.runningSignalReceiveLoops, runID) + c.mutex.Unlock() + }() // Looped select that gets signals for { - var signal schema.Input - select { - case <-doneChannel: - // Send the client done message - err := c.encoder.Encode(RuntimeMessage{ - MessageTypeClientDone, - clientDoneMessage{}, - }) - if err != nil { - c.logger.Errorf("Step %s failed to write client done message with error: %w", stepData.ID, err) - } + signal, ok := <-receivedSignals + if !ok { + c.logger.Infof("ATP signal loop done") + return + } + c.logger.Debugf("Sending signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID) + if signal.ID == "" || signal.RunID == "" { + c.logger.Errorf("Invalid run ID (%s) or signal ID (%s)", signal.ID, signal.RunID) return - case receivedSignal, ok := <-receivedSignals: - if !ok { - // It's not supposed to be not ok yet. - c.logger.Errorf("error in channel preparing to send signal (step %s, signal %s) over ATP", - stepData.ID, signal.ID) - return - } - signal = receivedSignal } - c.logger.Debugf("Sending signal with ID '%s' to step with ID '%s'", signal.ID, stepData.ID) - if err := c.encoder.Encode(RuntimeMessage{ + if err := c.sendCBOR(RuntimeMessage{ MessageTypeSignal, + signal.RunID, signalMessage{ - StepID: stepData.ID, SignalID: signal.ID, Data: signal.InputData, }}); err != nil { - c.logger.Errorf("Step %s failed to write signal (%s) with error: %w", stepData.ID, signal.ID, err) + c.logger.Errorf("Client with steps '%s' failed to write signal (%s) with run id '&s' with error: %w", + c.getRunningStepIDs(), signal.ID, signal.RunID, err) return } - c.logger.Debugf("Successfully sent signal with ID '%s' to step with ID '%s'", signal.ID, stepData.ID) + c.logger.Debugf("Successfully sent signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID) } } -// executeReadLoop handles the reading of work done, signals, or any other outputs from the plugins. -// It branches off with different logic for ATP versions 1 and 2. -func (c *client) executeReadLoop( - stepData schema.Input, emittedSignals chan<- schema.Input, -) (outputID string, outputData any, err error) { - cborReader := c.decMode.NewDecoder(c.channel) - if c.atpVersion >= 2 { - return c.executeReadLoopV2(cborReader, stepData, emittedSignals) +// sendExecutionResult sends the results to the channel, and closes then removes the channels for the +// step results and the signals. +func (c *client) sendExecutionResult(runID string, result ExecutionResult) { + c.logger.Debugf("Providing input for run ID '%s'", runID) + c.mutex.Lock() + resultChannel, found := c.runningStepResultChannels[runID] + c.mutex.Unlock() + if found { + // Send the result + resultChannel <- result + // Close the channel and remove it to detect incorrectly duplicate results. + close(resultChannel) + c.mutex.Lock() + delete(c.runningStepResultChannels, runID) + c.mutex.Unlock() } else { - return c.executeReadLoopV1(cborReader, stepData) + c.logger.Errorf("Step result channel not found for run ID '%s'. This is either a bug in the ATP "+ + "client, or the plugin erroneously sent a second result.", runID) + } + // Now close the signal channel, since it's invalid to send a signal after the step is complete. + c.mutex.Lock() + defer c.mutex.Unlock() + signalChannel, found := c.runningStepEmittedSignalChannels[runID] + if !found { + c.logger.Debugf("Could not find signal output channel for run ID '%s'", runID) + return } + close(signalChannel) + delete(c.runningStepEmittedSignalChannels, runID) } -// executeReadLoopV1 is the legacy read loop function, that only waits for work done. -func (c *client) executeReadLoopV1( - cborReader *cbor.Decoder, - stepData schema.Input, -) (outputID string, outputData any, err error) { - var doneMessage workDoneMessage - if err := cborReader.Decode(&doneMessage); err != nil { - c.logger.Errorf("Failed to read or decode work done message: (%w) for step %s", err, stepData.ID) - return "", nil, - fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID) +func (c *client) sendErrorToAll(err error) { + result := NewErrorExecutionResult(err) + for runID, _ := range c.runningStepResultChannels { + c.sendExecutionResult(runID, result) } - return c.handleWorkDone(stepData, doneMessage) } -// executeReadLoopV2 is the new read loop function, that supports the RuntimeMessage loop. -func (c *client) executeReadLoopV2( - cborReader *cbor.Decoder, - stepData schema.Input, - emittedSignals chan<- schema.Input, -) (outputID string, outputData any, err error) { +func (c *client) executeReadLoop(cborReader *cbor.Decoder) { + defer func() { + c.mutex.Lock() + defer c.mutex.Unlock() + c.readLoopRunning = false + c.wg.Done() + }() // Loop and get all messages // The message is generic, so we must find the type and decode the full message next. var runtimeMessage DecodedRuntimeMessage for { if err := cborReader.Decode(&runtimeMessage); err != nil { - c.logger.Errorf("Step %s failed to read or decode runtime message: %v", stepData.ID, err) - return "", nil, - fmt.Errorf("failed to read or decode runtime message (%w)", err) + c.logger.Errorf("ATP client for steps %s failed to read or decode runtime message: %v", c.getRunningStepIDs(), err) + // This is fatal since the entire structure of the runtime message is invalid. + c.sendErrorToAll(fmt.Errorf("failed to read or decode runtime message (%w)", err)) + return } switch runtimeMessage.MessageID { case MessageTypeWorkDone: var doneMessage workDoneMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { - c.logger.Errorf("Failed to decode work done message (%v) for step ID %s ", err, stepData.ID) - return "", nil, - fmt.Errorf("failed to read work done message (%w)", err) + c.logger.Errorf("Failed to decode work done message (%v) for run ID %s ", err, runtimeMessage.RunID) + c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( + fmt.Errorf("failed to decode work done message (%w)", err))) } - return c.handleWorkDone(stepData, doneMessage) + c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) case MessageTypeSignal: var signalMessage signalMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { - c.logger.Errorf("Step %s failed to decode signal message: %v", stepData.ID, err) - } - if stepData.ID != signalMessage.StepID { - c.logger.Warningf("Step %s sent signal %s, but the step ID '%s' sent by the plugin does not match. Ignoring signal.", - stepData.ID, signalMessage.SignalID, signalMessage.StepID) - continue // Don't process the signal + c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", + runtimeMessage.RunID, err) } - if emittedSignals == nil { - c.logger.Warningf("Step '%s' sent signal '%s'. Ignoring; signal handling is not implemented (emittedSignals is nil).", - stepData.ID, signalMessage.SignalID) + signalChannel, found := c.runningStepEmittedSignalChannels[runtimeMessage.RunID] + if !found { + c.logger.Warningf( + "Step with run ID '%s' sent signal '%s'. Ignoring; signal handling is not implemented "+ + "(emittedSignals is nil).", + runtimeMessage.RunID, signalMessage.SignalID) } else { - c.logger.Debugf("Got signal from step '%s' with ID '%s'", stepData.ID, signalMessage.SignalID) - emittedSignals <- signalMessage.ToInput() + c.logger.Debugf("Got signal from step with run ID '%s' with ID '%s'", runtimeMessage.RunID, + signalMessage.SignalID) + signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) + } + case MessageTypeError: + var errMessage errorMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { + c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", + runtimeMessage.RunID, err) + } + c.logger.Errorf("Step with run ID %s sent error message: %v", runtimeMessage.RunID, errMessage) + resultMsg := fmt.Errorf("step %s sent error message: %s", runtimeMessage.RunID, + errMessage.ToString(runtimeMessage.RunID)) + if errMessage.ServerFatal { + c.sendErrorToAll(resultMsg) + return // It's server fatal, so this is the last message from the server. + } else if errMessage.StepFatal { + c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult(resultMsg)) } default: - c.logger.Warningf("Step %s sent unknown message type: %s", stepData.ID, runtimeMessage.MessageID) + c.logger.Warningf("Step with run ID %s sent unknown message type: %s", runtimeMessage.RunID, + runtimeMessage.MessageID) + } + c.mutex.Lock() + if len(c.runningStepResultChannels) == 0 { + c.mutex.Unlock() + return // Done } + c.mutex.Unlock() + } +} + +// executeStep handles the reading of work done, signals, or any other outputs from the plugins. +// It branches off with different logic for ATP versions 1 and 2. +func (c *client) getResult( + stepData schema.Input, + cborReader *cbor.Decoder, +) ExecutionResult { + if c.atpVersion >= 2 { + return c.getResultV2(stepData) + } else { + return c.getResultV1(cborReader, stepData) + } +} + +// getResultV1 is the legacy function that only waits for work done. +func (c *client) getResultV1( + cborReader *cbor.Decoder, + stepData schema.Input, +) ExecutionResult { + var doneMessage workDoneMessage + if err := cborReader.Decode(&doneMessage); err != nil { + c.logger.Errorf("Failed to read or decode work done message: (%w) for step %s", err, stepData.ID) + return NewErrorExecutionResult( + fmt.Errorf("failed to read or decode work done message (%w) for step %s", err, stepData.ID)) + } + return c.processWorkDone(stepData.RunID, doneMessage) +} + +func (c *client) prepareResultChannels( + cborReader *cbor.Decoder, + stepData schema.Input, + emittedSignals chan<- schema.Input, +) error { + c.mutex.Lock() + defer c.mutex.Unlock() + _, existing := c.runningStepResultChannels[stepData.RunID] + if existing { + return fmt.Errorf("duplicate run ID given '%s'", stepData.RunID) + } + // Set up the signal and step results channels + resultChannel := make(chan ExecutionResult) + c.runningStepResultChannels[stepData.RunID] = resultChannel + if emittedSignals != nil { + c.runningStepEmittedSignalChannels[stepData.RunID] = emittedSignals + } + // Run the loop if it isn't running. + if !c.readLoopRunning { + // Only a single read loop should be running + c.wg.Add(1) // Add here, so that it's before the goroutine to prevent race conditions. + c.readLoopRunning = true + go func() { + c.executeReadLoop(cborReader) + }() } + return nil } -func (c *client) handleWorkDone( +// getResultV2 works with the channels that communicate with the RuntimeMessage loop. +func (c *client) getResultV2( stepData schema.Input, +) ExecutionResult { + c.mutex.Lock() + resultChannel, found := c.runningStepResultChannels[stepData.RunID] + c.mutex.Unlock() + if !found { + return NewErrorExecutionResult( + fmt.Errorf("could not find result channel for step with run ID '%s'", + stepData.RunID), + ) + } + // Wait for the result + result, received := <-resultChannel + if !received { + return NewErrorExecutionResult( + fmt.Errorf("did not receive result from results channel in ATP client for step with run ID '%s'", + stepData.RunID), + ) + } + return result +} + +func (c *client) processWorkDone( + runID string, doneMessage workDoneMessage, -) (outputID string, outputData any, err error) { - c.logger.Debugf("Step %s completed with output ID '%s'.", stepData.ID, doneMessage.OutputID) +) ExecutionResult { + c.logger.Debugf("Step with run ID '%s' completed with output ID '%s'.", runID, doneMessage.OutputID) // Print debug logs from the step as debug. debugLogs := strings.Split(doneMessage.DebugLogs, "\n") for _, line := range debugLogs { if strings.TrimSpace(line) != "" { - c.logger.Debugf("Step %s debug: %s", stepData.ID, line) + c.logger.Debugf("Step %s debug: %s", runID, line) } } - return doneMessage.OutputID, doneMessage.OutputData, nil + return ExecutionResult{doneMessage.OutputID, doneMessage.OutputData, nil} } diff --git a/atp/protocol.go b/atp/protocol.go index 82eb29c..c181553 100644 --- a/atp/protocol.go +++ b/atp/protocol.go @@ -1,6 +1,7 @@ package atp import ( + "fmt" "github.com/fxamacker/cbor/v2" "go.flow.arcalot.io/pluginsdk/schema" ) @@ -12,44 +13,58 @@ type HelloMessage struct { Schema any `cbor:"schema"` } -type StartWorkMessage struct { - StepID string `cbor:"id"` +type WorkStartMessage struct { + StepID string `cbor:"step_id"` Config any `cbor:"config"` } // All messages that can be contained in a RuntimeMessage struct. const ( - MessageTypeWorkDone uint32 = 1 - MessageTypeSignal uint32 = 2 - MessageTypeClientDone uint32 = 3 + MessageTypeWorkStart uint32 = 1 + MessageTypeWorkDone uint32 = 2 + MessageTypeSignal uint32 = 3 + MessageTypeClientDone uint32 = 4 + MessageTypeError uint32 = 5 ) type RuntimeMessage struct { MessageID uint32 `cbor:"id"` + RunID string `cbor:"run_id"` MessageData any `cbor:"data"` } type DecodedRuntimeMessage struct { MessageID uint32 `cbor:"id"` + RunID string `cbor:"run_id"` RawMessageData cbor.RawMessage `cbor:"data"` } type workDoneMessage struct { + StepID string `cbor:"step_id"` OutputID string `cbor:"output_id"` OutputData any `cbor:"output_data"` DebugLogs string `cbor:"debug_logs"` } type signalMessage struct { - StepID string `cbor:"step_id"` SignalID string `cbor:"signal_id"` Data any `cbor:"data"` } -func (s signalMessage) ToInput() schema.Input { - return schema.Input{ID: s.SignalID, InputData: s.Data} +func (s signalMessage) ToInput(runID string) schema.Input { + return schema.Input{RunID: runID, ID: s.SignalID, InputData: s.Data} } type clientDoneMessage struct { // Empty for now. } + +type errorMessage struct { + Error string `cbor:"error"` + StepFatal bool `cbor:"step_fatal"` + ServerFatal bool `cbor:"server_fatal"` +} + +func (e errorMessage) ToString(runID string) string { + return fmt.Sprintf("RunID: %s, err: %s, step fatal: %t, server fatal: %t", runID, e.Error, e.StepFatal, e.ServerFatal) +} diff --git a/atp/protocol_test.go b/atp/protocol_test.go index b1b4ca9..456bd65 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -11,6 +11,7 @@ import ( "io" "sync" "testing" + "time" ) type helloWorldInput struct { @@ -117,7 +118,7 @@ func TestProtocol_Client_Execute(t *testing.T) { go func() { defer wg.Done() - assert.NoError(t, atp.RunATPServer( + assert.Nil(t, atp.RunATPServer( ctx, stdinReader, stdoutWriter, @@ -137,14 +138,130 @@ func TestProtocol_Client_Execute(t *testing.T) { _, err := cli.ReadSchema() assert.NoError(t, err) - outputID, outputData, err := cli.Execute( + result := cli.Execute( schema.Input{ + RunID: t.Name(), ID: "hello-world", InputData: map[string]any{"name": "Arca Lot"}, }, nil, nil) + assert.NoError(t, cli.Close()) + assert.NoError(t, result.Error) + assert.Equals(t, result.OutputID, "success") + assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, Arca Lot!") + }() + + wg.Wait() +} + +func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) { + // Runs several steps on one client instance at the same time + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(2) + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + + go func() { + defer wg.Done() + assert.Nil(t, atp.RunATPServer( + ctx, + stdinReader, + stdoutWriter, + helloWorldSchema, + )) + }() + + go func() { + defer wg.Done() + cli := atp.NewClientWithLogger(channel{ + Reader: stdoutReader, + Writer: stdinWriter, + Context: nil, + cancel: cancel, + }, log.NewTestLogger(t)) + + _, err := cli.ReadSchema() assert.NoError(t, err) - assert.Equals(t, outputID, "success") - assert.Equals(t, outputData.(map[any]any)["message"].(string), "Hello, Arca Lot!") + + names := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"} + stepWg := &sync.WaitGroup{} + for _, name := range names { + stepName := name + stepWg.Add(1) + go func() { + defer stepWg.Done() + result := cli.Execute( + schema.Input{ + RunID: t.Name() + "_" + stepName, // Must be unique + ID: "hello-world", + InputData: map[string]any{"name": stepName}, + }, nil, nil) + assert.NoError(t, result.Error) + assert.Equals(t, result.OutputID, "success") + assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, "+stepName+"!") + }() + } + stepWg.Wait() + assert.NoError(t, cli.Close()) + }() + + wg.Wait() +} +func TestProtocol_Client_Execute_Multi_Step_Serial(t *testing.T) { + // Runs several steps in one client, but with a long enough delay for each one to finish up + // before the next one runs + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(2) + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + + go func() { + defer wg.Done() + assert.Nil(t, atp.RunATPServer( + ctx, + stdinReader, + stdoutWriter, + helloWorldSchema, + )) + }() + + go func() { + defer wg.Done() + cli := atp.NewClientWithLogger(channel{ + Reader: stdoutReader, + Writer: stdinWriter, + Context: nil, + cancel: cancel, + }, log.NewTestLogger(t)) + + _, err := cli.ReadSchema() + assert.NoError(t, err) + + names := []string{"a", "b", "c"} + waitTime := 0 + stepWg := &sync.WaitGroup{} + for _, name := range names { + stepName := name + stepWg.Add(1) + stepWaitTime := waitTime + waitTime += 5 + go func() { + defer stepWg.Done() + time.Sleep(time.Duration(stepWaitTime) * time.Millisecond) + result := cli.Execute( + schema.Input{ + RunID: t.Name() + "_" + stepName, // Must be unique + ID: "hello-world", + InputData: map[string]any{"name": stepName}, + }, nil, nil) + assert.NoError(t, result.Error) + assert.Equals(t, result.OutputID, "success") + assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, "+stepName+"!") + }() + } + stepWg.Wait() + assert.NoError(t, cli.Close()) }() wg.Wait() @@ -161,12 +278,15 @@ func TestProtocol_Client_ReadSchema(t *testing.T) { go func() { defer wg.Done() - assert.NoError(t, atp.RunATPServer( + t.Logf("Starting ATP server") + assert.Nil(t, atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, )) + t.Logf("ATP server exited without error") + }() go func() { @@ -176,7 +296,7 @@ func TestProtocol_Client_ReadSchema(t *testing.T) { cancel() wg.Done() }() - + t.Logf("Starting client.") cli := atp.NewClientWithLogger(channel{ Reader: stdoutReader, Writer: stdinWriter, @@ -184,7 +304,10 @@ func TestProtocol_Client_ReadSchema(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) _, err := cli.ReadSchema() + err2 := cli.Close() assert.NoError(t, err) + assert.NoError(t, err2) + t.Logf("Client exited without error") }() wg.Wait() @@ -217,6 +340,9 @@ func TestProtocol_Error_Client_StartOutput(t *testing.T) { }() wg.Wait() + + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Server_StartOutput(t *testing.T) { @@ -229,6 +355,7 @@ func TestProtocol_Error_Server_StartOutput(t *testing.T) { // close the server's cbor decoder's io pipe assert.NoError(t, stdinReader.Close()) + assert.NoError(t, stdoutWriter.Close()) // Close this, because it's unbuffered, so it would deadlock. err := atp.RunATPServer( ctx, @@ -237,7 +364,11 @@ func TestProtocol_Error_Server_StartOutput(t *testing.T) { helloWorldSchema, ) - assert.Error(t, err) + assert.NotNil(t, err) + assert.Equals(t, err.ServerFatal, true) + + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Client_Hello(t *testing.T) { @@ -278,6 +409,9 @@ func TestProtocol_Error_Client_Hello(t *testing.T) { assert.NoError(t, srvr.decoder.Decode(&empty)) wg.Wait() + + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Server_Hello(t *testing.T) { @@ -299,7 +433,7 @@ func TestProtocol_Error_Server_Hello(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var test_error error + var test_error *atp.ServerError go func() { defer wg.Done() @@ -325,7 +459,9 @@ func TestProtocol_Error_Server_Hello(t *testing.T) { wgcli.Wait() wg.Wait() - assert.Error(t, test_error) + assert.NotNil(t, test_error) + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Server_WorkStart(t *testing.T) { @@ -346,7 +482,7 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var test_error error + var testError *atp.ServerError go func() { defer wg.Done() err := atp.RunATPServer( @@ -357,7 +493,7 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { ) if err != nil { - test_error = err + testError = err } }() @@ -369,10 +505,17 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { // close the server's cbor decoder's io pipe assert.NoError(t, stdinReader.Close()) + // Now close the client's stdoutWriter, since it would otherwise deadlock. + assert.NoError(t, stdoutWriter.Close()) }() wg.Wait() - assert.Error(t, test_error) + assert.NotNil(t, testError) + // This may make the test more fragile, but checking the error is the only way + // to know that the error is from where we're testing. + assert.Contains(t, testError.Err.Error(), "failed to read or decode runtime message") + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Client_WorkStart(t *testing.T) { @@ -396,7 +539,7 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var srvr_error error + var srvr_error *atp.ServerError var cli_error error go func() { defer wg.Done() @@ -416,23 +559,29 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) { _, err := cli.ReadSchema() assert.NoError(t, err) - // close client's cbor encoder's io pipe + // close client's cbor encoder's io pipe. This is intentionally done incorrectly to cause an error. assert.NoError(t, stdinWriter.Close()) - _, _, err = cli.Execute( + result := cli.Execute( schema.Input{ + RunID: t.Name(), ID: "hello-world", InputData: map[string]any{"name": "Arca Lot"}, }, nil, nil) - if err != nil { - cli_error = err + assert.Error(t, cli.Close()) + if result.Error != nil { + cli_error = result.Error } + // Close the other pipe after to unblock the server + assert.NoError(t, stdoutWriter.Close()) }() wgcli.Wait() wg.Wait() - assert.Error(t, srvr_error) + assert.NotNil(t, srvr_error) assert.Error(t, cli_error) + // We don't lock on error to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Client_WorkDone(t *testing.T) { @@ -447,7 +596,7 @@ func TestProtocol_Error_Client_WorkDone(t *testing.T) { stdoutReader, stdoutWriter := io.Pipe() defer cancel() - srvr := newATPServer(channel{ + atpServer := newATPServer(channel{ Reader: stdinReader, Writer: stdoutWriter, Context: ctx, @@ -467,8 +616,8 @@ func TestProtocol_Error_Client_WorkDone(t *testing.T) { go func() { defer wg.Done() - req := atp.StartWorkMessage{} - err := srvr.decoder.Decode(&req) + req := atp.WorkStartMessage{} + err := atpServer.decoder.Decode(&req) if err != nil { srvr_error = err } @@ -479,19 +628,22 @@ func TestProtocol_Error_Client_WorkDone(t *testing.T) { go func() { defer wg.Done() - _, _, err := cli.Execute( + result := cli.Execute( schema.Input{ + RunID: t.Name(), ID: "hello-world", InputData: map[string]any{"name": "Arca Lot"}, }, nil, nil) - if err != nil { - cli_error = err + if result.Error != nil { + cli_error = result.Error } + assert.NoError(t, cli.Close()) }() - wg.Wait() assert.NoError(t, srvr_error) assert.Error(t, cli_error) + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } func TestProtocol_Error_Server_WorkDone(t *testing.T) { @@ -512,7 +664,7 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var srvr_error error + var srvr_error *atp.ServerError var cli_error error go func() { @@ -536,7 +688,7 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { // close server's cbor encoder's io pipe assert.NoError(t, stdoutWriter.Close()) - err = cli.Encoder().Encode(atp.StartWorkMessage{ + err = cli.Encoder().Encode(atp.WorkStartMessage{ StepID: "hello-world", Config: map[string]any{"name": "Arca Lot"}, }) @@ -547,7 +699,10 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { wg.Wait() assert.NoError(t, cli_error) - assert.Error(t, srvr_error) + assert.NotNil(t, srvr_error) + + // We don't wait on error, to prevent deadlocks, so just sleep + time.Sleep(time.Millisecond * 2) } // serverChannel holds the methods to talking to an ATP server (plugin). diff --git a/atp/server.go b/atp/server.go index 95e524e..9b12099 100644 --- a/atp/server.go +++ b/atp/server.go @@ -18,32 +18,45 @@ func RunATPServer( stdin io.ReadCloser, stdout io.WriteCloser, pluginSchema *schema.CallableSchema, -) error { +) *ServerError { session := initializeATPServerSession(ctx, stdin, stdout, pluginSchema) - wg := &sync.WaitGroup{} - wg.Add(1) + session.wg.Add(1) // Run needs to be run in its own goroutine to allow for the closure handling to happen simultaneously. go func() { - session.run(wg) + session.run() }() - workError := session.handleClosure(stdin) + workError := session.handleClosure() // Ensure that the session is done. - wg.Wait() + session.wg.Wait() return workError } type atpServerSession struct { - ctx context.Context - cancel *context.CancelFunc - req StartWorkMessage - cborStdin *cbor.Decoder - cborStdout *cbor.Encoder - workDone chan error - doneChannel chan bool - pluginSchema *schema.CallableSchema + ctx context.Context + cancel *context.CancelFunc + wg *sync.WaitGroup + stdinCloser io.ReadCloser + cborStdin *cbor.Decoder + cborStdout *cbor.Encoder + runningSteps map[string]string // Maps run ID to step ID + workDone chan ServerError + runDoneChannel chan bool + pluginSchema *schema.CallableSchema + encoderMutex sync.Mutex +} + +type ServerError struct { + RunID string + Err error + StepFatal bool + ServerFatal bool +} + +func (e ServerError) String() string { + return fmt.Sprintf("RunID: %s, err: %s, step fatal: %t, server fatal: %t", e.RunID, e.Err, e.StepFatal, e.ServerFatal) } func initializeATPServerSession( @@ -53,11 +66,11 @@ func initializeATPServerSession( pluginSchema *schema.CallableSchema, ) *atpServerSession { subCtx, cancel := context.WithCancel(ctx) - workDone := make(chan error, 1) + workDone := make(chan ServerError, 3) // The ATP protocol uses CBOR. cborStdin := cbor.NewDecoder(stdin) cborStdout := cbor.NewEncoder(stdout) - doneChannel := make(chan bool, 1) // Buffer to prevent it from hanging if something unexpected happens. + runDoneChannel := make(chan bool, 3) // Buffer to prevent it from hanging if something unexpected happens. // Cancel the sub context on sigint or sigterm. sigs := make(chan os.Signal, 1) @@ -73,27 +86,74 @@ func initializeATPServerSession( }() return &atpServerSession{ - ctx: subCtx, - cancel: &cancel, - req: StartWorkMessage{}, - cborStdin: cborStdin, - cborStdout: cborStdout, - workDone: workDone, - doneChannel: doneChannel, - pluginSchema: pluginSchema, + ctx: subCtx, + cancel: &cancel, + cborStdin: cborStdin, + stdinCloser: stdin, + cborStdout: cborStdout, + workDone: workDone, + runDoneChannel: runDoneChannel, + pluginSchema: pluginSchema, + wg: &sync.WaitGroup{}, + runningSteps: make(map[string]string), } } -func (s *atpServerSession) handleClosure(stdin io.ReadCloser) error { +func (s *atpServerSession) sendRuntimeMessage(msgID uint32, runID string, message any) error { + s.encoderMutex.Lock() + defer s.encoderMutex.Unlock() + return s.cborStdout.Encode(RuntimeMessage{ + MessageID: msgID, + RunID: runID, + MessageData: message, + }) +} + +func (s *atpServerSession) handleClosure() *ServerError { // Wait for work done or context complete. - var workError error - select { - case workError = <-s.workDone: - case <-s.ctx.Done(): - // Likely got sigterm. Just close. Ideally gracefully. + var workError *ServerError +closeLoop: + for { + select { + case errorSent, wasError := <-s.workDone: + if wasError { + workError = &errorSent + err := s.sendRuntimeMessage( + MessageTypeError, + errorSent.RunID, + errorMessage{ + Error: errorSent.Err.Error(), + StepFatal: errorSent.StepFatal, + ServerFatal: errorSent.ServerFatal, + }, + ) + // If that didn't send, just send to stderr now. + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error while sending error message: %s\n", err) + } + // If either the error report sending failed, or the error was server fatal, stop here. + if err != nil || errorSent.ServerFatal { + err = s.stdinCloser.Close() + if err != nil { + return &ServerError{ + RunID: workError.RunID, + Err: fmt.Errorf("error closing stdin (%s) after workDone error (%v)", err, workError), + StepFatal: true, + ServerFatal: true, + } + } else { + break closeLoop + } + } + } else { + break closeLoop + } + case <-s.ctx.Done(): + // Likely got sigterm. Just close. Ideally gracefully. + break closeLoop + } } // Now close the pipe that it gets input from. - _ = stdin.Close() return workError } @@ -102,95 +162,177 @@ func (s *atpServerSession) runATPReadLoop() { var runtimeMessage DecodedRuntimeMessage for { // First, decode the message + // Note: This blocks. To abort early, close stdin. if err := s.cborStdin.Decode(&runtimeMessage); err != nil { // Failed to decode. If it's done, that's okay. If not, there's a problem. done := false select { - case done = <-s.doneChannel: + case done = <-s.runDoneChannel: default: // Prevents it from blocking } if !done { - s.workDone <- fmt.Errorf("failed to read or decode runtime message: %w", err) + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("failed to read or decode runtime message: %w", err), + StepFatal: true, + ServerFatal: true, + } } // If done, it didn't get the work done message, which is not ideal. return } + runID := runtimeMessage.RunID switch runtimeMessage.MessageID { + case MessageTypeWorkStart: + var workStartMsg WorkStartMessage + if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &workStartMsg); err != nil { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("failed to decode work start message: %w", err), + StepFatal: true, + ServerFatal: false, + } + continue + } + if runID == "" || workStartMsg.StepID == "" { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("missing runID (%s) or stepID in work start message (%s)", + runID, workStartMsg.StepID), + StepFatal: true, + ServerFatal: false, + } + continue + } + s.runningSteps[runID] = workStartMsg.StepID + s.wg.Add(1) // Wait until the step is done + go func() { + s.runStep(runID, workStartMsg) + s.wg.Done() + }() case MessageTypeSignal: var signalMessage signalMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { - s.workDone <- fmt.Errorf("failed to decode signal message: %w", err) - return + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("failed to decode signal message: %w", err), + StepFatal: false, + ServerFatal: false, + } + continue } - if s.req.StepID != signalMessage.StepID { - s.workDone <- fmt.Errorf("signal sent with mismatched step ID, got %s, expected %s", - signalMessage.StepID, s.req.StepID) - return + if runID == "" { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("RunID missing for signal '%s' in signal message", signalMessage.SignalID), + StepFatal: false, + ServerFatal: false, + } + continue } - if err := s.pluginSchema.CallSignal(s.ctx, signalMessage.StepID, signalMessage.SignalID, signalMessage.Data); err != nil { - s.workDone <- fmt.Errorf("failed while running signal ID %s: %w", - signalMessage.SignalID, err) - return + stepID, found := s.runningSteps[runID] + if !found { + s.workDone <- ServerError{ + RunID: runID, + Err: fmt.Errorf("unknown step with run ID '%s' in signal mesage", runID), + StepFatal: false, + ServerFatal: false, + } + continue } + s.wg.Add(1) // Wait until the signal handler is done + go func() { + if err := s.pluginSchema.CallSignal( + s.ctx, + runID, + stepID, + signalMessage.SignalID, + signalMessage.Data, + ); err != nil { + s.workDone <- ServerError{ + RunID: runID, + Err: fmt.Errorf("failed while running signal ID %s: %w", + signalMessage.SignalID, err), + StepFatal: false, + ServerFatal: false, + } + } + s.wg.Done() + }() case MessageTypeClientDone: + // It's now safe to close the channel + err := s.stdinCloser.Close() + if err != nil { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("error while closing stdin on client done: %s", err), + StepFatal: true, + ServerFatal: true, + } + } return default: - s.workDone <- fmt.Errorf("unknown message ID received: %d", runtimeMessage.MessageID) - return + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("unknown message ID received: %d. This is a sign of incompatible server and client versions", + runtimeMessage.MessageID), + StepFatal: false, + ServerFatal: false, + } + continue } } } -func (s *atpServerSession) run(wg *sync.WaitGroup) { +func (s *atpServerSession) run() { defer func() { - s.doneChannel <- true + s.runDoneChannel <- true close(s.workDone) - wg.Done() + s.wg.Done() }() err := s.sendInitialMessagesToClient() if err != nil { - s.workDone <- err - return - } - - // Now, get the work message that dictates which step to run and the config info. - err = s.cborStdin.Decode(&s.req) - if err != nil { - s.workDone <- fmt.Errorf("failed to CBOR-decode start work message (%w)", err) + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("error while sending initial messages to client (%s)", err), + StepFatal: true, + ServerFatal: true, + } return } // Now, loop through stdin inputs until the step ends. - go func() { // Listen for signals in another thread - s.runATPReadLoop() - }() + s.runATPReadLoop() +} +func (s *atpServerSession) runStep(runID string, req WorkStartMessage) { // Call the step in the provided callable schema. - outputID, outputData, err := s.pluginSchema.CallStep(s.ctx, s.req.StepID, s.req.Config) + outputID, outputData, err := s.pluginSchema.CallStep(s.ctx, runID, req.StepID, req.Config) if err != nil { - s.workDone <- err + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("error calling step (%s)", err), + StepFatal: true, + ServerFatal: false, + } return } - // Lastly, send the work done message. - err = s.cborStdout.Encode( - RuntimeMessage{ - MessageTypeWorkDone, - workDoneMessage{ - outputID, - outputData, - "", - }, + err = s.sendRuntimeMessage( + MessageTypeWorkDone, + runID, + workDoneMessage{ + req.StepID, + outputID, + outputData, + "", }, ) if err != nil { - s.workDone <- fmt.Errorf("failed to encode CBOR response (%w)", err) - return + // At this point, the work done message failed to send, so it's likely that sending an errorMessage would fail. + _, err = fmt.Fprintf(os.Stderr, "error while sending work done message: %s\n", err) } - - // finished with no error! - s.workDone <- nil } func (s *atpServerSession) sendInitialMessagesToClient() error { diff --git a/schema/input.go b/schema/input.go index 9898c17..208c13b 100644 --- a/schema/input.go +++ b/schema/input.go @@ -1,6 +1,7 @@ package schema type Input struct { + RunID string // id identifies the step, signal, or any other case where data is being input ID string // The data being input into the step/signal/other diff --git a/schema/schema.go b/schema/schema.go index dfee22e..b71848a 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -74,6 +74,7 @@ type CallableSchema struct { func (s CallableSchema) CallStep( ctx context.Context, + runID string, stepID string, serializedInputData any, ) ( @@ -91,7 +92,7 @@ func (s CallableSchema) CallStep( if err != nil { return "", nil, InvalidInputError{err} } - outputID, unserializedOutput, err := step.Call(ctx, unserializedInputData) + outputID, unserializedOutput, err := step.Call(ctx, runID, unserializedInputData) if err != nil { return outputID, nil, err } @@ -105,6 +106,7 @@ func (s CallableSchema) CallStep( func (s CallableSchema) CallSignal( ctx context.Context, + runID string, stepID string, signalID string, serializedInputData any, @@ -123,7 +125,7 @@ func (s CallableSchema) CallSignal( return InvalidInputError{err} } - err = step.CallSignal(ctx, signalID, unserializedInputData) + err = step.CallSignal(ctx, runID, signalID, unserializedInputData) if err != nil { return err } diff --git a/schema/schema_test.go b/schema/schema_test.go index 8d59320..4fad74a 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -18,7 +18,7 @@ func TestSchemaCall(t *testing.T) { } ctx := context.Background() - outputID, outputData, err := schemaTestSchema.CallStep(ctx, "hello", data) + outputID, outputData, err := schemaTestSchema.CallStep(ctx, t.Name(), "hello", data) assert.NoError(t, err) assert.Equals(t, outputID, "success") typedData := outputData.(map[string]any) diff --git a/schema/step.go b/schema/step.go index b9cb343..820f0a8 100644 --- a/schema/step.go +++ b/schema/step.go @@ -20,8 +20,8 @@ type Step interface { type CallableStep interface { Step ToStepSchema() *StepSchema - Call(ctx context.Context, data any) (outputID string, outputData any, err error) - CallSignal(ctx context.Context, signalID string, data any) (err error) + Call(ctx context.Context, runID string, data any) (outputID string, outputData any, err error) + CallSignal(ctx context.Context, runID string, signalID string, data any) (err error) } // NewStepSchema defines a new step. @@ -97,6 +97,7 @@ func NewCallableStep[StepInputType any]( SignalEmittersValue: nil, DisplayValue: display, initializer: nil, + stepData: make(map[string]*runningStepData[any]), handler: updatedHandler, } } @@ -125,11 +126,18 @@ func NewCallableStepWithSignals[StepData any, StepInputType any]( SignalEmittersValue: signalEmitters, DisplayValue: display, initializer: initializer, - initializerWG: &wg, handler: handler, + stepData: make(map[string]*runningStepData[StepData]), } } +type runningStepData[StepData any] struct { + runID string + initializedData StepData + startedWG *sync.WaitGroup // For waiting until the step is started. + done bool +} + // CallableStepSchema is a step that can be directly called and is typed to a specific input type. type CallableStepSchema[StepData any, InputType any] struct { IDValue string `json:"id"` @@ -139,9 +147,8 @@ type CallableStepSchema[StepData any, InputType any] struct { OutputsValue map[string]*StepOutputSchema `json:"outputs"` DisplayValue Display `json:"display"` initializer func() StepData - initializerWG *sync.WaitGroup initializerMutex sync.Mutex - initializedData *StepData + stepData map[string]*runningStepData[StepData] // Maps run ID to step data handler func(context.Context, StepData, InputType) (string, any) } @@ -188,23 +195,40 @@ func (s *CallableStepSchema[StepData, InputType]) ToStepSchema() *StepSchema { } } -func (s *CallableStepSchema[StepData, InputType]) Call(ctx context.Context, input any) (string, any, error) { - if err := s.InputValue.Validate(input); err != nil { - return "", nil, InvalidInputError{err} - } - +// Set up the runningStepData struct. This results in a waitgroup available for the signals to wait on, and +// it setups the data shared between the step and signals. +func (s *CallableStepSchema[StepData, InputType]) setupStepData(runID string) *runningStepData[StepData] { s.initializerMutex.Lock() - if s.initializedData == nil && s.initializer != nil { - newInitializedData := s.initializer() - s.initializedData = &newInitializedData - s.initializerWG.Done() + defer s.initializerMutex.Unlock() + + // This will be called by both the signal and step handlers, so it's important to check to ensure this + // isn't getting re-done on the second call. + existingRunningStepData, found := s.stepData[runID] + if found { + return existingRunningStepData // Already done } - s.initializerMutex.Unlock() var stepData StepData if s.initializer != nil { - stepData = *s.initializedData + newInitializedData := s.initializer() + stepData = newInitializedData } - outputID, outputData := s.handler(ctx, stepData, input.(InputType)) + runningStepData := runningStepData[StepData]{ + runID: runID, + initializedData: stepData, + startedWG: &sync.WaitGroup{}, + done: false, + } + s.stepData[runID] = &runningStepData + return &runningStepData +} + +func (s *CallableStepSchema[StepData, InputType]) Call(ctx context.Context, runID string, input any) (string, any, error) { + if err := s.InputValue.Validate(input); err != nil { + return "", nil, InvalidInputError{err} + } + + runningStepData := s.setupStepData(runID) + outputID, outputData := s.handler(ctx, runningStepData.initializedData, input.(InputType)) output, ok := s.OutputsValue[outputID] if !ok { return "", nil, InvalidOutputError{ @@ -214,12 +238,13 @@ func (s *CallableStepSchema[StepData, InputType]) Call(ctx context.Context, inpu return outputID, outputData, output.Validate(outputData) } -func (s *CallableStepSchema[StepData, InputType]) CallSignal(ctx context.Context, signalID string, input any) error { - s.initializerWG.Wait() - if s.initializedData == nil { - return IllegalStateError{ - fmt.Errorf("signal ID '%s' called before step initialization", signalID), - } - } - return s.SignalHandlersValue[signalID].Call(ctx, *s.initializedData, input) +func (s *CallableStepSchema[StepData, InputType]) CallSignal( + ctx context.Context, + runID string, + signalID string, + input any, +) error { + runningStepData := s.setupStepData(runID) + runningStepData.startedWG.Wait() // Wait for the step to start + return s.SignalHandlersValue[signalID].Call(ctx, runningStepData.initializedData, input) } diff --git a/schema/step_test.go b/schema/step_test.go index 39e35e6..fa24907 100644 --- a/schema/step_test.go +++ b/schema/step_test.go @@ -96,7 +96,7 @@ func stepTestHandler(_ context.Context, input stepTestInputData) (string, any) { func TestStepExecution(t *testing.T) { ctx := context.Background() - outputID, outputData, err := testStepSchema.Call(ctx, stepTestInputData{Name: "Arca Lot"}) + outputID, outputData, err := testStepSchema.Call(ctx, t.Name(), stepTestInputData{Name: "Arca Lot"}) assert.NoError(t, err) assert.Equals(t, outputID, "success") assert.Equals(t, outputData.(stepTestSuccessOutput).Message, "Hello, Arca Lot!") From 1de0e27fb2e5affba8025415f274775c5493a6ec Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 5 Oct 2023 12:47:33 -0400 Subject: [PATCH 2/9] Handle panics, and improve error reporting in ATP server --- atp/protocol_test.go | 165 +++++++++++++++++++++++++++++++++---------- atp/server.go | 29 +++++--- 2 files changed, 149 insertions(+), 45 deletions(-) diff --git a/atp/protocol_test.go b/atp/protocol_test.go index 456bd65..ed2cb7e 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -46,6 +46,10 @@ func helloWorldStepHandler(_ context.Context, _ any, input helloWorldInput) (str } } +func panickingHelloWorldStepHandler(_ context.Context, _ any, input helloWorldInput) (string, any) { + panic("abcde") +} + func helloWorldSignalHandler(_ context.Context, test any, input helloWorldInput) { // Does nothing at the moment } @@ -96,6 +100,45 @@ var helloWorldSchema = schema.NewCallableSchema( ), ) +var panickingHelloWorldSchema = schema.NewCallableSchema( + schema.NewCallableStepWithSignals[any, helloWorldInput]( + /* id */ "hello-world", + /* input */ helloWorldInputSchema, + /* outputs */ map[string]*schema.StepOutputSchema{ + "success": schema.NewStepOutputSchema( + schema.NewScopeSchema( + schema.NewStructMappedObjectSchema[helloWorldOutput]( + "Output", + map[string]*schema.PropertySchema{ + "message": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + }, + ), + ), + nil, + false, + ), + }, + /* signal handlers */ map[string]schema.CallableSignal{ + "hello-world-signal": helloWorldCallableSignal, + }, + /* signal emitters */ map[string]*schema.SignalSchema{ + "hello-world-signal": helloWorldCallableSignal.ToSignalSchema(), + }, + /* Display */ nil, + /* Initializer */ nil, + /* step handler */ panickingHelloWorldStepHandler, + ), +) + type channel struct { io.Reader io.Writer @@ -118,12 +161,13 @@ func TestProtocol_Client_Execute(t *testing.T) { go func() { defer wg.Done() - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) }() go func() { @@ -153,6 +197,54 @@ func TestProtocol_Client_Execute(t *testing.T) { wg.Wait() } +func TestProtocol_Client_Execute_Panicking(t *testing.T) { + // Client ReadSchema and Execute happy path. + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(2) + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + + go func() { + defer wg.Done() + errors := atp.RunATPServer( + ctx, + stdinReader, + stdoutWriter, + panickingHelloWorldSchema, + ) + assert.Equals(t, len(errors), 2) + }() + + go func() { + defer wg.Done() + cli := atp.NewClientWithLogger(channel{ + Reader: stdoutReader, + Writer: stdinWriter, + Context: nil, + cancel: cancel, + }, log.NewTestLogger(t)) + + _, err := cli.ReadSchema() + assert.NoError(t, err) + + for _, testID := range []string{"a", "b"} { + result := cli.Execute( + schema.Input{ + RunID: t.Name() + "_" + testID, + ID: "hello-world", + InputData: map[string]any{"name": "Arca Lot"}, + }, nil, nil) + assert.Error(t, result.Error) + assert.Contains(t, result.Error.Error(), "abcde") + assert.Equals(t, result.OutputID, "") + } + assert.NoError(t, cli.Close()) + }() + + wg.Wait() +} + func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) { // Runs several steps on one client instance at the same time ctx, cancel := context.WithCancel(context.Background()) @@ -163,12 +255,13 @@ func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) { go func() { defer wg.Done() - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) }() go func() { @@ -218,12 +311,13 @@ func TestProtocol_Client_Execute_Multi_Step_Serial(t *testing.T) { go func() { defer wg.Done() - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) }() go func() { @@ -279,12 +373,13 @@ func TestProtocol_Client_ReadSchema(t *testing.T) { go func() { defer wg.Done() t.Logf("Starting ATP server") - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) t.Logf("ATP server exited without error") }() @@ -357,15 +452,16 @@ func TestProtocol_Error_Server_StartOutput(t *testing.T) { assert.NoError(t, stdinReader.Close()) assert.NoError(t, stdoutWriter.Close()) // Close this, because it's unbuffered, so it would deadlock. - err := atp.RunATPServer( + serverErrors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - assert.NotNil(t, err) - assert.Equals(t, err.ServerFatal, true) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) + assert.Equals(t, serverErrors[0].ServerFatal, true) // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) @@ -433,20 +529,18 @@ func TestProtocol_Error_Server_Hello(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var test_error *atp.ServerError + var serverErrors []*atp.ServerError go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - test_error = err - } + serverErrors = errors }() go func() { @@ -459,7 +553,8 @@ func TestProtocol_Error_Server_Hello(t *testing.T) { wgcli.Wait() wg.Wait() - assert.NotNil(t, test_error) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) } @@ -482,19 +577,17 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var testError *atp.ServerError + var serverErrors []*atp.ServerError go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - testError = err - } + serverErrors = errors }() go func() { @@ -510,10 +603,11 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { }() wg.Wait() - assert.NotNil(t, testError) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) // This may make the test more fragile, but checking the error is the only way // to know that the error is from where we're testing. - assert.Contains(t, testError.Err.Error(), "failed to read or decode runtime message") + assert.Contains(t, serverErrors[0].Err.Error(), "failed to read or decode runtime message") // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) } @@ -539,19 +633,18 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var srvr_error *atp.ServerError + var serverErrors []*atp.ServerError var cli_error error go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - srvr_error = err - } + + serverErrors = errors }() go func() { @@ -578,7 +671,8 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) { wgcli.Wait() wg.Wait() - assert.NotNil(t, srvr_error) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) assert.Error(t, cli_error) // We don't lock on error to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) @@ -664,20 +758,18 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var srvr_error *atp.ServerError + var serverErrors []*atp.ServerError var cli_error error go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - srvr_error = err - } + serverErrors = errors }() go func() { @@ -699,7 +791,8 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { wg.Wait() assert.NoError(t, cli_error) - assert.NotNil(t, srvr_error) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) diff --git a/atp/server.go b/atp/server.go index 9b12099..7321a73 100644 --- a/atp/server.go +++ b/atp/server.go @@ -18,7 +18,7 @@ func RunATPServer( stdin io.ReadCloser, stdout io.WriteCloser, pluginSchema *schema.CallableSchema, -) *ServerError { +) []*ServerError { session := initializeATPServerSession(ctx, stdin, stdout, pluginSchema) session.wg.Add(1) @@ -109,15 +109,15 @@ func (s *atpServerSession) sendRuntimeMessage(msgID uint32, runID string, messag }) } -func (s *atpServerSession) handleClosure() *ServerError { +func (s *atpServerSession) handleClosure() []*ServerError { // Wait for work done or context complete. - var workError *ServerError + var errors []*ServerError closeLoop: for { select { case errorSent, wasError := <-s.workDone: if wasError { - workError = &errorSent + errors = append(errors, &errorSent) err := s.sendRuntimeMessage( MessageTypeError, errorSent.RunID, @@ -135,12 +135,12 @@ closeLoop: if err != nil || errorSent.ServerFatal { err = s.stdinCloser.Close() if err != nil { - return &ServerError{ - RunID: workError.RunID, - Err: fmt.Errorf("error closing stdin (%s) after workDone error (%v)", err, workError), + return append(errors, &ServerError{ + RunID: errorSent.RunID, + Err: fmt.Errorf("error closing stdin (%s) after workDone error (%v)", err, errorSent), StepFatal: true, ServerFatal: true, - } + }) } else { break closeLoop } @@ -154,7 +154,7 @@ closeLoop: } } // Now close the pipe that it gets input from. - return workError + return errors } func (s *atpServerSession) runATPReadLoop() { @@ -308,6 +308,17 @@ func (s *atpServerSession) run() { func (s *atpServerSession) runStep(runID string, req WorkStartMessage) { // Call the step in the provided callable schema. + defer func() { + // Handle and properly report panics + if r := recover(); r != nil { + s.workDone <- ServerError{ + RunID: runID, + Err: fmt.Errorf("panic while running step with Run ID '%s': (%v)", runID, r), + StepFatal: true, + ServerFatal: false, + } + } + }() outputID, outputData, err := s.pluginSchema.CallStep(s.ctx, runID, req.StepID, req.Config) if err != nil { s.workDone <- ServerError{ From 9a9a2d018881365202ac578213b876aa500eee1a Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 5 Oct 2023 13:11:42 -0400 Subject: [PATCH 3/9] Remove fixed TODO --- atp/client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/atp/client.go b/atp/client.go index a2f1dd1..57973db 100644 --- a/atp/client.go +++ b/atp/client.go @@ -210,7 +210,6 @@ func (c *client) Close() error { // First, close channels that could send signals to the clients // This ends the loop for runID, signalChannel := range c.runningSignalReceiveLoops { - // TODO: Test why commenting this out results in a deadlock instead of just the steps finishing when they're supposed to. c.logger.Infof("Closing signal channel for run ID '%s'", runID) close(signalChannel) } From 3d9cc6379fa73647dc2b4cdf215ca363a6e43cbf Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 5 Oct 2023 16:14:43 -0400 Subject: [PATCH 4/9] Fix most linting errors --- atp/client.go | 17 +-- atp/protocol.go | 10 +- atp/protocol_test.go | 1 + atp/server.go | 273 +++++++++++++++++++++++-------------------- 4 files changed, 161 insertions(+), 140 deletions(-) diff --git a/atp/client.go b/atp/client.go index 57973db..23a353e 100644 --- a/atp/client.go +++ b/atp/client.go @@ -187,7 +187,7 @@ func (c *client) Execute( return c.getResult(stepData, cborReader) } -// handleStepComplete is the deferred function that will handle closing of the received channel +// handleStepComplete is the deferred function that will handle closing of the received channel. func (c *client) handleStepComplete(runID string, receivedSignals chan schema.Input) { if receivedSignals != nil { c.logger.Infof("Closing signal channel for finished step") @@ -270,7 +270,7 @@ func (c *client) executeWriteLoop( if err := c.sendCBOR(RuntimeMessage{ MessageTypeSignal, signal.RunID, - signalMessage{ + SignalMessage{ SignalID: signal.ID, Data: signal.InputData, }}); err != nil { @@ -315,11 +315,12 @@ func (c *client) sendExecutionResult(runID string, result ExecutionResult) { func (c *client) sendErrorToAll(err error) { result := NewErrorExecutionResult(err) - for runID, _ := range c.runningStepResultChannels { + for runID := range c.runningStepResultChannels { c.sendExecutionResult(runID, result) } } +//nolint:funlen func (c *client) executeReadLoop(cborReader *cbor.Decoder) { defer func() { c.mutex.Lock() @@ -339,7 +340,7 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { } switch runtimeMessage.MessageID { case MessageTypeWorkDone: - var doneMessage workDoneMessage + var doneMessage WorkDoneMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &doneMessage); err != nil { c.logger.Errorf("Failed to decode work done message (%v) for run ID %s ", err, runtimeMessage.RunID) c.sendExecutionResult(runtimeMessage.RunID, NewErrorExecutionResult( @@ -347,7 +348,7 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { } c.sendExecutionResult(runtimeMessage.RunID, c.processWorkDone(runtimeMessage.RunID, doneMessage)) case MessageTypeSignal: - var signalMessage signalMessage + var signalMessage SignalMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { c.logger.Errorf("ATP client for run ID '%s' failed to decode signal message: %v", runtimeMessage.RunID, err) @@ -364,7 +365,7 @@ func (c *client) executeReadLoop(cborReader *cbor.Decoder) { signalChannel <- signalMessage.ToInput(runtimeMessage.RunID) } case MessageTypeError: - var errMessage errorMessage + var errMessage ErrorMessage if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &errMessage); err != nil { c.logger.Errorf("Step with run ID '%s' failed to decode error message: %v", runtimeMessage.RunID, err) @@ -409,7 +410,7 @@ func (c *client) getResultV1( cborReader *cbor.Decoder, stepData schema.Input, ) ExecutionResult { - var doneMessage workDoneMessage + var doneMessage WorkDoneMessage if err := cborReader.Decode(&doneMessage); err != nil { c.logger.Errorf("Failed to read or decode work done message: (%w) for step %s", err, stepData.ID) return NewErrorExecutionResult( @@ -473,7 +474,7 @@ func (c *client) getResultV2( func (c *client) processWorkDone( runID string, - doneMessage workDoneMessage, + doneMessage WorkDoneMessage, ) ExecutionResult { c.logger.Debugf("Step with run ID '%s' completed with output ID '%s'.", runID, doneMessage.OutputID) diff --git a/atp/protocol.go b/atp/protocol.go index c181553..b52b877 100644 --- a/atp/protocol.go +++ b/atp/protocol.go @@ -39,19 +39,19 @@ type DecodedRuntimeMessage struct { RawMessageData cbor.RawMessage `cbor:"data"` } -type workDoneMessage struct { +type WorkDoneMessage struct { StepID string `cbor:"step_id"` OutputID string `cbor:"output_id"` OutputData any `cbor:"output_data"` DebugLogs string `cbor:"debug_logs"` } -type signalMessage struct { +type SignalMessage struct { SignalID string `cbor:"signal_id"` Data any `cbor:"data"` } -func (s signalMessage) ToInput(runID string) schema.Input { +func (s SignalMessage) ToInput(runID string) schema.Input { return schema.Input{RunID: runID, ID: s.SignalID, InputData: s.Data} } @@ -59,12 +59,12 @@ type clientDoneMessage struct { // Empty for now. } -type errorMessage struct { +type ErrorMessage struct { Error string `cbor:"error"` StepFatal bool `cbor:"step_fatal"` ServerFatal bool `cbor:"server_fatal"` } -func (e errorMessage) ToString(runID string) string { +func (e ErrorMessage) ToString(runID string) string { return fmt.Sprintf("RunID: %s, err: %s, step fatal: %t, server fatal: %t", runID, e.Error, e.StepFatal, e.ServerFatal) } diff --git a/atp/protocol_test.go b/atp/protocol_test.go index ed2cb7e..23770c6 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -612,6 +612,7 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { time.Sleep(time.Millisecond * 2) } +//nolint:funlen func TestProtocol_Error_Client_WorkStart(t *testing.T) { // Induce error on client's (and server incidentally) // start work message by closing the client's cbor diff --git a/atp/server.go b/atp/server.go index 7321a73..8dd2396 100644 --- a/atp/server.go +++ b/atp/server.go @@ -116,37 +116,36 @@ closeLoop: for { select { case errorSent, wasError := <-s.workDone: - if wasError { - errors = append(errors, &errorSent) - err := s.sendRuntimeMessage( - MessageTypeError, - errorSent.RunID, - errorMessage{ - Error: errorSent.Err.Error(), - StepFatal: errorSent.StepFatal, - ServerFatal: errorSent.ServerFatal, - }, - ) - // If that didn't send, just send to stderr now. + if !wasError { + break closeLoop + } + errors = append(errors, &errorSent) + err := s.sendRuntimeMessage( + MessageTypeError, + errorSent.RunID, + ErrorMessage{ + Error: errorSent.Err.Error(), + StepFatal: errorSent.StepFatal, + ServerFatal: errorSent.ServerFatal, + }, + ) + // If that didn't send, just send to stderr now. + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error while sending error message: %s\n", err) + } + // If either the error report sending failed, or the error was server fatal, stop here. + if err != nil || errorSent.ServerFatal { + err = s.stdinCloser.Close() if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "error while sending error message: %s\n", err) - } - // If either the error report sending failed, or the error was server fatal, stop here. - if err != nil || errorSent.ServerFatal { - err = s.stdinCloser.Close() - if err != nil { - return append(errors, &ServerError{ - RunID: errorSent.RunID, - Err: fmt.Errorf("error closing stdin (%s) after workDone error (%v)", err, errorSent), - StepFatal: true, - ServerFatal: true, - }) - } else { - break closeLoop - } + return append(errors, &ServerError{ + RunID: errorSent.RunID, + Err: fmt.Errorf("error closing stdin (%w) after workDone error (%v)", err, errorSent), + StepFatal: true, + ServerFatal: true, + }) + } else { + break closeLoop } - } else { - break closeLoop } case <-s.ctx.Done(): // Likely got sigterm. Just close. Ideally gracefully. @@ -181,107 +180,127 @@ func (s *atpServerSession) runATPReadLoop() { } // If done, it didn't get the work done message, which is not ideal. return } - runID := runtimeMessage.RunID - switch runtimeMessage.MessageID { - case MessageTypeWorkStart: - var workStartMsg WorkStartMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &workStartMsg); err != nil { - s.workDone <- ServerError{ - RunID: "", - Err: fmt.Errorf("failed to decode work start message: %w", err), - StepFatal: true, - ServerFatal: false, - } - continue - } - if runID == "" || workStartMsg.StepID == "" { - s.workDone <- ServerError{ - RunID: "", - Err: fmt.Errorf("missing runID (%s) or stepID in work start message (%s)", - runID, workStartMsg.StepID), - StepFatal: true, - ServerFatal: false, - } - continue - } - s.runningSteps[runID] = workStartMsg.StepID - s.wg.Add(1) // Wait until the step is done - go func() { - s.runStep(runID, workStartMsg) - s.wg.Done() - }() - case MessageTypeSignal: - var signalMessage signalMessage - if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil { - s.workDone <- ServerError{ - RunID: "", - Err: fmt.Errorf("failed to decode signal message: %w", err), - StepFatal: false, - ServerFatal: false, - } - continue - } - if runID == "" { - s.workDone <- ServerError{ - RunID: "", - Err: fmt.Errorf("RunID missing for signal '%s' in signal message", signalMessage.SignalID), - StepFatal: false, - ServerFatal: false, - } - continue + done := s.onRuntimeMessageReceived(&runtimeMessage) + if done { + return + } + } +} + +// onRuntimeMessageReceived handles the runtime message by determining what type it is, and executing the proper path. +// Returns true if termination should be terminated, which should correspond to only client done or fatal server errors. +func (s *atpServerSession) onRuntimeMessageReceived(message *DecodedRuntimeMessage) bool { + runID := message.RunID + switch message.MessageID { + case MessageTypeWorkStart: + var workStartMsg WorkStartMessage + if err := cbor.Unmarshal(message.RawMessageData, &workStartMsg); err != nil { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("failed to decode work start message: %w", err), + StepFatal: true, + ServerFatal: false, } - stepID, found := s.runningSteps[runID] - if !found { - s.workDone <- ServerError{ - RunID: runID, - Err: fmt.Errorf("unknown step with run ID '%s' in signal mesage", runID), - StepFatal: false, - ServerFatal: false, - } - continue + return false + } + s.handleWorkStartMessage(runID, workStartMsg) + return false + case MessageTypeSignal: + var signalMessage SignalMessage + if err := cbor.Unmarshal(message.RawMessageData, &signalMessage); err != nil { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("failed to decode signal message: %w", err), + StepFatal: false, + ServerFatal: false, } - s.wg.Add(1) // Wait until the signal handler is done - go func() { - if err := s.pluginSchema.CallSignal( - s.ctx, - runID, - stepID, - signalMessage.SignalID, - signalMessage.Data, - ); err != nil { - s.workDone <- ServerError{ - RunID: runID, - Err: fmt.Errorf("failed while running signal ID %s: %w", - signalMessage.SignalID, err), - StepFatal: false, - ServerFatal: false, - } - } - s.wg.Done() - }() - case MessageTypeClientDone: - // It's now safe to close the channel - err := s.stdinCloser.Close() - if err != nil { - s.workDone <- ServerError{ - RunID: "", - Err: fmt.Errorf("error while closing stdin on client done: %s", err), - StepFatal: true, - ServerFatal: true, - } + return false + } + s.handleSignalMessage(runID, signalMessage) + + return false + case MessageTypeClientDone: + // It's now safe to close the channel + err := s.stdinCloser.Close() + if err != nil { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("error while closing stdin on client done: %w", err), + StepFatal: true, + ServerFatal: true, } - return - default: + } + return true // Client done, so terminate loop + default: + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("unknown message ID received: %d. This is a sign of incompatible server and client versions", + message.MessageID), + StepFatal: false, + ServerFatal: false, + } + return false + } +} + +func (s *atpServerSession) handleWorkStartMessage(runID string, workStartMsg WorkStartMessage) { + if runID == "" || workStartMsg.StepID == "" { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("missing runID (%s) or stepID in work start message (%s)", + runID, workStartMsg.StepID), + StepFatal: true, + ServerFatal: false, + } + return + } + s.runningSteps[runID] = workStartMsg.StepID + s.wg.Add(1) // Wait until the step is done + go func() { + s.runStep(runID, workStartMsg) + s.wg.Done() + }() +} + +func (s *atpServerSession) handleSignalMessage(runID string, signalMessage SignalMessage) { + if runID == "" { + s.workDone <- ServerError{ + RunID: "", + Err: fmt.Errorf("RunID missing for signal '%s' in signal message", signalMessage.SignalID), + StepFatal: false, + ServerFatal: false, + } + return + } + stepID, found := s.runningSteps[runID] + if !found { + s.workDone <- ServerError{ + RunID: runID, + Err: fmt.Errorf("unknown step with run ID '%s' in signal mesage", runID), + StepFatal: false, + ServerFatal: false, + } + return + } + s.wg.Add(1) // Wait until the signal handler is done + go func() { + if err := s.pluginSchema.CallSignal( + s.ctx, + runID, + stepID, + signalMessage.SignalID, + signalMessage.Data, + ); err != nil { s.workDone <- ServerError{ - RunID: "", - Err: fmt.Errorf("unknown message ID received: %d. This is a sign of incompatible server and client versions", - runtimeMessage.MessageID), + RunID: runID, + Err: fmt.Errorf("failed while running signal ID %s: %w", + signalMessage.SignalID, err), StepFatal: false, ServerFatal: false, } - continue } - } + s.wg.Done() + }() } func (s *atpServerSession) run() { @@ -295,7 +314,7 @@ func (s *atpServerSession) run() { if err != nil { s.workDone <- ServerError{ RunID: "", - Err: fmt.Errorf("error while sending initial messages to client (%s)", err), + Err: fmt.Errorf("error while sending initial messages to client (%w)", err), StepFatal: true, ServerFatal: true, } @@ -323,7 +342,7 @@ func (s *atpServerSession) runStep(runID string, req WorkStartMessage) { if err != nil { s.workDone <- ServerError{ RunID: "", - Err: fmt.Errorf("error calling step (%s)", err), + Err: fmt.Errorf("error calling step (%w)", err), StepFatal: true, ServerFatal: false, } @@ -333,7 +352,7 @@ func (s *atpServerSession) runStep(runID string, req WorkStartMessage) { err = s.sendRuntimeMessage( MessageTypeWorkDone, runID, - workDoneMessage{ + WorkDoneMessage{ req.StepID, outputID, outputData, @@ -341,8 +360,8 @@ func (s *atpServerSession) runStep(runID string, req WorkStartMessage) { }, ) if err != nil { - // At this point, the work done message failed to send, so it's likely that sending an errorMessage would fail. - _, err = fmt.Fprintf(os.Stderr, "error while sending work done message: %s\n", err) + // At this point, the work done message failed to send, so it's likely that sending an ErrorMessage would fail. + _, _ = fmt.Fprintf(os.Stderr, "error while sending work done message: %s\n", err) } } From 000885913b2ce0731f5e55e92b42a821a3795106 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Mon, 9 Oct 2023 16:12:17 -0400 Subject: [PATCH 5/9] Revert change to retain compatibility with ATP v1 --- atp/protocol.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atp/protocol.go b/atp/protocol.go index b52b877..34db463 100644 --- a/atp/protocol.go +++ b/atp/protocol.go @@ -14,7 +14,7 @@ type HelloMessage struct { } type WorkStartMessage struct { - StepID string `cbor:"step_id"` + StepID string `cbor:"id"` Config any `cbor:"config"` } From e6d647d62d025a0248d3229b27417486d86f3861 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 12 Oct 2023 18:40:07 -0400 Subject: [PATCH 6/9] Update version, and update version handling to have specific versions instead of a range --- atp/client.go | 22 +++++++++++++++------- atp/protocol.go | 2 +- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/atp/client.go b/atp/client.go index 23a353e..b77ea98 100644 --- a/atp/client.go +++ b/atp/client.go @@ -10,8 +10,7 @@ import ( "sync" ) -const MinSupportedATPVersion = 1 -const MaxSupportedATPVersion = 2 +var supportedServerVersions = []int64{1, 3} // ClientChannel holds the methods to talking to an ATP server (plugin). type ClientChannel interface { @@ -129,11 +128,11 @@ func (c *client) ReadSchema() (*schema.SchemaSchema, error) { } c.logger.Debugf("Hello message read, ATP version %d.", hello.Version) - if hello.Version < MinSupportedATPVersion || hello.Version > MaxSupportedATPVersion { - c.logger.Errorf("Incompatible plugin ATP version: %d; expected between %d and %d.", hello.Version, - MinSupportedATPVersion, MaxSupportedATPVersion) - return nil, fmt.Errorf("incompatible plugin ATP version: %d; expected between %d and %d", hello.Version, - MinSupportedATPVersion, MaxSupportedATPVersion) + err := c.validateVersion(hello.Version) + + if err != nil { + c.logger.Errorf("Unsupported plugin version. %w", err) + return nil, fmt.Errorf("unsupported plugin version: %w", err) } c.atpVersion = hello.Version @@ -147,6 +146,15 @@ func (c *client) ReadSchema() (*schema.SchemaSchema, error) { return unserializedSchema, nil } +func (c *client) validateVersion(serverVersion int64) error { + for _, v := range supportedServerVersions { + if serverVersion == v { + return nil + } + } + return fmt.Errorf("unsupported atp version '%d', supported versions: %v", serverVersion, supportedServerVersions) +} + func (c *client) Execute( stepData schema.Input, receivedSignals chan schema.Input, diff --git a/atp/protocol.go b/atp/protocol.go index 34db463..bb15ffe 100644 --- a/atp/protocol.go +++ b/atp/protocol.go @@ -6,7 +6,7 @@ import ( "go.flow.arcalot.io/pluginsdk/schema" ) -const ProtocolVersion int64 = 2 +const ProtocolVersion int64 = 3 type HelloMessage struct { Version int64 `cbor:"version"` From 4dea23e0d910a10129f2c91629b54ef3ab8e1f05 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 19 Oct 2023 16:44:54 -0400 Subject: [PATCH 7/9] Added test to validate backwards compatibility --- atp/protocol_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/atp/protocol_test.go b/atp/protocol_test.go index 23770c6..da2d35e 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -197,6 +197,74 @@ func TestProtocol_Client_Execute(t *testing.T) { wg.Wait() } +func TestProtocol_Client_ATP_v1(t *testing.T) { + // Client ReadSchema and Execute atp v1 happy path. + // This is not a fragile test because the ATP v1 is not changing. It is the legacy supported version. + wg := &sync.WaitGroup{} + wg.Add(2) + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + step := "hello-world" + stepInput := map[string]any{"name": "Arca Lot"} + + go func() { + defer wg.Done() + fromClient := cbor.NewDecoder(stdinReader) + toClient := cbor.NewEncoder(stdoutWriter) + // 1: read start output message + var empty any + assert.NoError(t, fromClient.Decode(&empty)) + // 2: Send hello message with version set to 1 and the hello-world schema. + helloMessage := atp.HelloMessage{ + Version: 1, + Schema: assert.NoErrorR[any](t)(helloWorldSchema.SelfSerialize()), + } + assert.NoError(t, toClient.Encode(&helloMessage)) + // 3: Read work start message + var workStartMsg atp.WorkStartMessage + assert.NoError(t, fromClient.Decode(&workStartMsg)) + assert.Equals(t, workStartMsg.StepID, step) + unserializedInput := assert.NoErrorR[any](t)(helloWorldInputSchema.Unserialize(workStartMsg.Config)) + assert.Equals(t, unserializedInput.(helloWorldInput), helloWorldInput{Name: "Arca Lot"}) + + // 4: Send work done message + workDoneMessage := atp.WorkDoneMessage{ + StepID: step, + OutputID: "success", + OutputData: map[string]string{"message": "Hello, Arca Lot!"}, + DebugLogs: "", + } + assert.NoError(t, toClient.Encode(&workDoneMessage)) + + }() + + go func() { + defer wg.Done() + cli := atp.NewClientWithLogger(channel{ + Reader: stdoutReader, + Writer: stdinWriter, + Context: nil, + cancel: nil, + }, log.NewTestLogger(t)) + + _, err := cli.ReadSchema() + assert.NoError(t, err) + + result := cli.Execute( + schema.Input{ + RunID: t.Name(), + ID: step, + InputData: stepInput, + }, nil, nil) + assert.NoError(t, cli.Close()) + assert.NoError(t, result.Error) + assert.Equals(t, result.OutputID, "success") + assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, Arca Lot!") + }() + + wg.Wait() +} + func TestProtocol_Client_Execute_Panicking(t *testing.T) { // Client ReadSchema and Execute happy path. ctx, cancel := context.WithCancel(context.Background()) From 92efac16a5ea9a62efa38b6d70771bd1bd80d9b2 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 19 Oct 2023 16:48:39 -0400 Subject: [PATCH 8/9] Fix linting --- atp/protocol_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/atp/protocol_test.go b/atp/protocol_test.go index da2d35e..6ecbde1 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -197,6 +197,7 @@ func TestProtocol_Client_Execute(t *testing.T) { wg.Wait() } +//nolint:funlen func TestProtocol_Client_ATP_v1(t *testing.T) { // Client ReadSchema and Execute atp v1 happy path. // This is not a fragile test because the ATP v1 is not changing. It is the legacy supported version. From 1d3cf8b54ed759071ba236f07a1f78f7d80d9c91 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Fri, 20 Oct 2023 10:57:09 -0400 Subject: [PATCH 9/9] Remove unused use case, and fixed close ordering --- atp/client.go | 11 ++++++++--- atp/server.go | 20 +------------------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/atp/client.go b/atp/client.go index b77ea98..0d0aa7d 100644 --- a/atp/client.go +++ b/atp/client.go @@ -201,9 +201,13 @@ func (c *client) handleStepComplete(runID string, receivedSignals chan schema.In c.logger.Infof("Closing signal channel for finished step") // Remove from the map to ensure that the client.Close() method doesn't double-close it c.mutex.Lock() - delete(c.runningSignalReceiveLoops, runID) + // Validate that it exists, since Close() could have been called early. + _, exists := c.runningSignalReceiveLoops[runID] + if exists { + delete(c.runningSignalReceiveLoops, runID) + close(receivedSignals) + } c.mutex.Unlock() - close(receivedSignals) } } @@ -214,13 +218,14 @@ func (c *client) Close() error { return nil } c.done = true - c.mutex.Unlock() // First, close channels that could send signals to the clients // This ends the loop for runID, signalChannel := range c.runningSignalReceiveLoops { c.logger.Infof("Closing signal channel for run ID '%s'", runID) + delete(c.runningSignalReceiveLoops, runID) close(signalChannel) } + c.mutex.Unlock() // Now tell the server we're done. // Send the client done message if c.atpVersion > 1 { diff --git a/atp/server.go b/atp/server.go index 8dd2396..2ac6c3d 100644 --- a/atp/server.go +++ b/atp/server.go @@ -7,9 +7,7 @@ import ( "go.flow.arcalot.io/pluginsdk/schema" "io" "os" - "os/signal" "sync" - "syscall" ) // RunATPServer runs an ArcaflowTransportProtocol server with a given schema. @@ -36,7 +34,6 @@ func RunATPServer( type atpServerSession struct { ctx context.Context - cancel *context.CancelFunc wg *sync.WaitGroup stdinCloser io.ReadCloser cborStdin *cbor.Decoder @@ -65,29 +62,14 @@ func initializeATPServerSession( stdout io.WriteCloser, pluginSchema *schema.CallableSchema, ) *atpServerSession { - subCtx, cancel := context.WithCancel(ctx) workDone := make(chan ServerError, 3) // The ATP protocol uses CBOR. cborStdin := cbor.NewDecoder(stdin) cborStdout := cbor.NewEncoder(stdout) runDoneChannel := make(chan bool, 3) // Buffer to prevent it from hanging if something unexpected happens. - // Cancel the sub context on sigint or sigterm. - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - go func() { - select { - case <-sigs: - // Got sigterm. So cancel context. - cancel() - case <-subCtx.Done(): - // Done. No sigterm. - } - }() - return &atpServerSession{ - ctx: subCtx, - cancel: &cancel, + ctx: ctx, cborStdin: cborStdin, stdinCloser: stdin, cborStdout: cborStdout,