From 1de0e27fb2e5affba8025415f274775c5493a6ec Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Thu, 5 Oct 2023 12:47:33 -0400 Subject: [PATCH] Handle panics, and improve error reporting in ATP server --- atp/protocol_test.go | 165 +++++++++++++++++++++++++++++++++---------- atp/server.go | 29 +++++--- 2 files changed, 149 insertions(+), 45 deletions(-) diff --git a/atp/protocol_test.go b/atp/protocol_test.go index 456bd65..ed2cb7e 100644 --- a/atp/protocol_test.go +++ b/atp/protocol_test.go @@ -46,6 +46,10 @@ func helloWorldStepHandler(_ context.Context, _ any, input helloWorldInput) (str } } +func panickingHelloWorldStepHandler(_ context.Context, _ any, input helloWorldInput) (string, any) { + panic("abcde") +} + func helloWorldSignalHandler(_ context.Context, test any, input helloWorldInput) { // Does nothing at the moment } @@ -96,6 +100,45 @@ var helloWorldSchema = schema.NewCallableSchema( ), ) +var panickingHelloWorldSchema = schema.NewCallableSchema( + schema.NewCallableStepWithSignals[any, helloWorldInput]( + /* id */ "hello-world", + /* input */ helloWorldInputSchema, + /* outputs */ map[string]*schema.StepOutputSchema{ + "success": schema.NewStepOutputSchema( + schema.NewScopeSchema( + schema.NewStructMappedObjectSchema[helloWorldOutput]( + "Output", + map[string]*schema.PropertySchema{ + "message": schema.NewPropertySchema( + schema.NewStringSchema(nil, nil, nil), + nil, + true, + nil, + nil, + nil, + nil, + nil, + ), + }, + ), + ), + nil, + false, + ), + }, + /* signal handlers */ map[string]schema.CallableSignal{ + "hello-world-signal": helloWorldCallableSignal, + }, + /* signal emitters */ map[string]*schema.SignalSchema{ + "hello-world-signal": helloWorldCallableSignal.ToSignalSchema(), + }, + /* Display */ nil, + /* Initializer */ nil, + /* step handler */ panickingHelloWorldStepHandler, + ), +) + type channel struct { io.Reader io.Writer @@ -118,12 +161,13 @@ func TestProtocol_Client_Execute(t *testing.T) { go func() { defer wg.Done() - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) }() go func() { @@ -153,6 +197,54 @@ func TestProtocol_Client_Execute(t *testing.T) { wg.Wait() } +func TestProtocol_Client_Execute_Panicking(t *testing.T) { + // Client ReadSchema and Execute happy path. + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(2) + stdinReader, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + + go func() { + defer wg.Done() + errors := atp.RunATPServer( + ctx, + stdinReader, + stdoutWriter, + panickingHelloWorldSchema, + ) + assert.Equals(t, len(errors), 2) + }() + + go func() { + defer wg.Done() + cli := atp.NewClientWithLogger(channel{ + Reader: stdoutReader, + Writer: stdinWriter, + Context: nil, + cancel: cancel, + }, log.NewTestLogger(t)) + + _, err := cli.ReadSchema() + assert.NoError(t, err) + + for _, testID := range []string{"a", "b"} { + result := cli.Execute( + schema.Input{ + RunID: t.Name() + "_" + testID, + ID: "hello-world", + InputData: map[string]any{"name": "Arca Lot"}, + }, nil, nil) + assert.Error(t, result.Error) + assert.Contains(t, result.Error.Error(), "abcde") + assert.Equals(t, result.OutputID, "") + } + assert.NoError(t, cli.Close()) + }() + + wg.Wait() +} + func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) { // Runs several steps on one client instance at the same time ctx, cancel := context.WithCancel(context.Background()) @@ -163,12 +255,13 @@ func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) { go func() { defer wg.Done() - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) }() go func() { @@ -218,12 +311,13 @@ func TestProtocol_Client_Execute_Multi_Step_Serial(t *testing.T) { go func() { defer wg.Done() - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) }() go func() { @@ -279,12 +373,13 @@ func TestProtocol_Client_ReadSchema(t *testing.T) { go func() { defer wg.Done() t.Logf("Starting ATP server") - assert.Nil(t, atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, - )) + ) + assert.Equals(t, len(errors), 0) t.Logf("ATP server exited without error") }() @@ -357,15 +452,16 @@ func TestProtocol_Error_Server_StartOutput(t *testing.T) { assert.NoError(t, stdinReader.Close()) assert.NoError(t, stdoutWriter.Close()) // Close this, because it's unbuffered, so it would deadlock. - err := atp.RunATPServer( + serverErrors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - assert.NotNil(t, err) - assert.Equals(t, err.ServerFatal, true) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) + assert.Equals(t, serverErrors[0].ServerFatal, true) // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) @@ -433,20 +529,18 @@ func TestProtocol_Error_Server_Hello(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var test_error *atp.ServerError + var serverErrors []*atp.ServerError go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - test_error = err - } + serverErrors = errors }() go func() { @@ -459,7 +553,8 @@ func TestProtocol_Error_Server_Hello(t *testing.T) { wgcli.Wait() wg.Wait() - assert.NotNil(t, test_error) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) } @@ -482,19 +577,17 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var testError *atp.ServerError + var serverErrors []*atp.ServerError go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - testError = err - } + serverErrors = errors }() go func() { @@ -510,10 +603,11 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) { }() wg.Wait() - assert.NotNil(t, testError) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) // This may make the test more fragile, but checking the error is the only way // to know that the error is from where we're testing. - assert.Contains(t, testError.Err.Error(), "failed to read or decode runtime message") + assert.Contains(t, serverErrors[0].Err.Error(), "failed to read or decode runtime message") // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) } @@ -539,19 +633,18 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var srvr_error *atp.ServerError + var serverErrors []*atp.ServerError var cli_error error go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - srvr_error = err - } + + serverErrors = errors }() go func() { @@ -578,7 +671,8 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) { wgcli.Wait() wg.Wait() - assert.NotNil(t, srvr_error) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) assert.Error(t, cli_error) // We don't lock on error to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) @@ -664,20 +758,18 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { cancel: cancel, }, log.NewTestLogger(t)) - var srvr_error *atp.ServerError + var serverErrors []*atp.ServerError var cli_error error go func() { defer wg.Done() - err := atp.RunATPServer( + errors := atp.RunATPServer( ctx, stdinReader, stdoutWriter, helloWorldSchema, ) - if err != nil { - srvr_error = err - } + serverErrors = errors }() go func() { @@ -699,7 +791,8 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) { wg.Wait() assert.NoError(t, cli_error) - assert.NotNil(t, srvr_error) + assert.NotNil(t, serverErrors) + assert.Equals(t, len(serverErrors), 1) // We don't wait on error, to prevent deadlocks, so just sleep time.Sleep(time.Millisecond * 2) diff --git a/atp/server.go b/atp/server.go index 9b12099..7321a73 100644 --- a/atp/server.go +++ b/atp/server.go @@ -18,7 +18,7 @@ func RunATPServer( stdin io.ReadCloser, stdout io.WriteCloser, pluginSchema *schema.CallableSchema, -) *ServerError { +) []*ServerError { session := initializeATPServerSession(ctx, stdin, stdout, pluginSchema) session.wg.Add(1) @@ -109,15 +109,15 @@ func (s *atpServerSession) sendRuntimeMessage(msgID uint32, runID string, messag }) } -func (s *atpServerSession) handleClosure() *ServerError { +func (s *atpServerSession) handleClosure() []*ServerError { // Wait for work done or context complete. - var workError *ServerError + var errors []*ServerError closeLoop: for { select { case errorSent, wasError := <-s.workDone: if wasError { - workError = &errorSent + errors = append(errors, &errorSent) err := s.sendRuntimeMessage( MessageTypeError, errorSent.RunID, @@ -135,12 +135,12 @@ closeLoop: if err != nil || errorSent.ServerFatal { err = s.stdinCloser.Close() if err != nil { - return &ServerError{ - RunID: workError.RunID, - Err: fmt.Errorf("error closing stdin (%s) after workDone error (%v)", err, workError), + return append(errors, &ServerError{ + RunID: errorSent.RunID, + Err: fmt.Errorf("error closing stdin (%s) after workDone error (%v)", err, errorSent), StepFatal: true, ServerFatal: true, - } + }) } else { break closeLoop } @@ -154,7 +154,7 @@ closeLoop: } } // Now close the pipe that it gets input from. - return workError + return errors } func (s *atpServerSession) runATPReadLoop() { @@ -308,6 +308,17 @@ func (s *atpServerSession) run() { func (s *atpServerSession) runStep(runID string, req WorkStartMessage) { // Call the step in the provided callable schema. + defer func() { + // Handle and properly report panics + if r := recover(); r != nil { + s.workDone <- ServerError{ + RunID: runID, + Err: fmt.Errorf("panic while running step with Run ID '%s': (%v)", runID, r), + StepFatal: true, + ServerFatal: false, + } + } + }() outputID, outputData, err := s.pluginSchema.CallStep(s.ctx, runID, req.StepID, req.Config) if err != nil { s.workDone <- ServerError{