Skip to content

Commit

Permalink
Handle panics, and improve error reporting in ATP server
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Oct 5, 2023
1 parent 799160e commit 1de0e27
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 45 deletions.
165 changes: 129 additions & 36 deletions atp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func helloWorldStepHandler(_ context.Context, _ any, input helloWorldInput) (str
}
}

func panickingHelloWorldStepHandler(_ context.Context, _ any, input helloWorldInput) (string, any) {
panic("abcde")
}

func helloWorldSignalHandler(_ context.Context, test any, input helloWorldInput) {
// Does nothing at the moment
}
Expand Down Expand Up @@ -96,6 +100,45 @@ var helloWorldSchema = schema.NewCallableSchema(
),
)

var panickingHelloWorldSchema = schema.NewCallableSchema(
schema.NewCallableStepWithSignals[any, helloWorldInput](
/* id */ "hello-world",
/* input */ helloWorldInputSchema,
/* outputs */ map[string]*schema.StepOutputSchema{
"success": schema.NewStepOutputSchema(
schema.NewScopeSchema(
schema.NewStructMappedObjectSchema[helloWorldOutput](
"Output",
map[string]*schema.PropertySchema{
"message": schema.NewPropertySchema(
schema.NewStringSchema(nil, nil, nil),
nil,
true,
nil,
nil,
nil,
nil,
nil,
),
},
),
),
nil,
false,
),
},
/* signal handlers */ map[string]schema.CallableSignal{
"hello-world-signal": helloWorldCallableSignal,
},
/* signal emitters */ map[string]*schema.SignalSchema{
"hello-world-signal": helloWorldCallableSignal.ToSignalSchema(),
},
/* Display */ nil,
/* Initializer */ nil,
/* step handler */ panickingHelloWorldStepHandler,
),
)

type channel struct {
io.Reader
io.Writer
Expand All @@ -118,12 +161,13 @@ func TestProtocol_Client_Execute(t *testing.T) {

go func() {
defer wg.Done()
assert.Nil(t, atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
))
)
assert.Equals(t, len(errors), 0)
}()

go func() {
Expand Down Expand Up @@ -153,6 +197,54 @@ func TestProtocol_Client_Execute(t *testing.T) {
wg.Wait()
}

func TestProtocol_Client_Execute_Panicking(t *testing.T) {
// Client ReadSchema and Execute happy path.
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(2)
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()

go func() {
defer wg.Done()
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
panickingHelloWorldSchema,
)
assert.Equals(t, len(errors), 2)
}()

go func() {
defer wg.Done()
cli := atp.NewClientWithLogger(channel{
Reader: stdoutReader,
Writer: stdinWriter,
Context: nil,
cancel: cancel,
}, log.NewTestLogger(t))

_, err := cli.ReadSchema()
assert.NoError(t, err)

for _, testID := range []string{"a", "b"} {
result := cli.Execute(
schema.Input{
RunID: t.Name() + "_" + testID,
ID: "hello-world",
InputData: map[string]any{"name": "Arca Lot"},
}, nil, nil)
assert.Error(t, result.Error)
assert.Contains(t, result.Error.Error(), "abcde")
assert.Equals(t, result.OutputID, "")
}
assert.NoError(t, cli.Close())
}()

wg.Wait()
}

func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) {
// Runs several steps on one client instance at the same time
ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -163,12 +255,13 @@ func TestProtocol_Client_Execute_Multi_Step_Parallel(t *testing.T) {

go func() {
defer wg.Done()
assert.Nil(t, atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
))
)
assert.Equals(t, len(errors), 0)
}()

go func() {
Expand Down Expand Up @@ -218,12 +311,13 @@ func TestProtocol_Client_Execute_Multi_Step_Serial(t *testing.T) {

go func() {
defer wg.Done()
assert.Nil(t, atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
))
)
assert.Equals(t, len(errors), 0)
}()

go func() {
Expand Down Expand Up @@ -279,12 +373,13 @@ func TestProtocol_Client_ReadSchema(t *testing.T) {
go func() {
defer wg.Done()
t.Logf("Starting ATP server")
assert.Nil(t, atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
))
)
assert.Equals(t, len(errors), 0)
t.Logf("ATP server exited without error")

}()
Expand Down Expand Up @@ -357,15 +452,16 @@ func TestProtocol_Error_Server_StartOutput(t *testing.T) {
assert.NoError(t, stdinReader.Close())
assert.NoError(t, stdoutWriter.Close()) // Close this, because it's unbuffered, so it would deadlock.

err := atp.RunATPServer(
serverErrors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)

assert.NotNil(t, err)
assert.Equals(t, err.ServerFatal, true)
assert.NotNil(t, serverErrors)
assert.Equals(t, len(serverErrors), 1)
assert.Equals(t, serverErrors[0].ServerFatal, true)

// We don't wait on error, to prevent deadlocks, so just sleep
time.Sleep(time.Millisecond * 2)
Expand Down Expand Up @@ -433,20 +529,18 @@ func TestProtocol_Error_Server_Hello(t *testing.T) {
cancel: cancel,
}, log.NewTestLogger(t))

var test_error *atp.ServerError
var serverErrors []*atp.ServerError

go func() {
defer wg.Done()
err := atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)

if err != nil {
test_error = err
}
serverErrors = errors
}()

go func() {
Expand All @@ -459,7 +553,8 @@ func TestProtocol_Error_Server_Hello(t *testing.T) {
wgcli.Wait()

wg.Wait()
assert.NotNil(t, test_error)
assert.NotNil(t, serverErrors)
assert.Equals(t, len(serverErrors), 1)
// We don't wait on error, to prevent deadlocks, so just sleep
time.Sleep(time.Millisecond * 2)
}
Expand All @@ -482,19 +577,17 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) {
cancel: cancel,
}, log.NewTestLogger(t))

var testError *atp.ServerError
var serverErrors []*atp.ServerError
go func() {
defer wg.Done()
err := atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)

if err != nil {
testError = err
}
serverErrors = errors
}()

go func() {
Expand All @@ -510,10 +603,11 @@ func TestProtocol_Error_Server_WorkStart(t *testing.T) {
}()

wg.Wait()
assert.NotNil(t, testError)
assert.NotNil(t, serverErrors)
assert.Equals(t, len(serverErrors), 1)
// This may make the test more fragile, but checking the error is the only way
// to know that the error is from where we're testing.
assert.Contains(t, testError.Err.Error(), "failed to read or decode runtime message")
assert.Contains(t, serverErrors[0].Err.Error(), "failed to read or decode runtime message")
// We don't wait on error, to prevent deadlocks, so just sleep
time.Sleep(time.Millisecond * 2)
}
Expand All @@ -539,19 +633,18 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) {
cancel: cancel,
}, log.NewTestLogger(t))

var srvr_error *atp.ServerError
var serverErrors []*atp.ServerError
var cli_error error
go func() {
defer wg.Done()
err := atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)
if err != nil {
srvr_error = err
}

serverErrors = errors
}()

go func() {
Expand All @@ -578,7 +671,8 @@ func TestProtocol_Error_Client_WorkStart(t *testing.T) {
wgcli.Wait()

wg.Wait()
assert.NotNil(t, srvr_error)
assert.NotNil(t, serverErrors)
assert.Equals(t, len(serverErrors), 1)
assert.Error(t, cli_error)
// We don't lock on error to prevent deadlocks, so just sleep
time.Sleep(time.Millisecond * 2)
Expand Down Expand Up @@ -664,20 +758,18 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) {
cancel: cancel,
}, log.NewTestLogger(t))

var srvr_error *atp.ServerError
var serverErrors []*atp.ServerError
var cli_error error

go func() {
defer wg.Done()
err := atp.RunATPServer(
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)
if err != nil {
srvr_error = err
}
serverErrors = errors
}()

go func() {
Expand All @@ -699,7 +791,8 @@ func TestProtocol_Error_Server_WorkDone(t *testing.T) {

wg.Wait()
assert.NoError(t, cli_error)
assert.NotNil(t, srvr_error)
assert.NotNil(t, serverErrors)
assert.Equals(t, len(serverErrors), 1)

// We don't wait on error, to prevent deadlocks, so just sleep
time.Sleep(time.Millisecond * 2)
Expand Down
Loading

0 comments on commit 1de0e27

Please sign in to comment.