Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor client, and add test #98

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions atp/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package atp

import (
"context"
"fmt"
"github.com/fxamacker/cbor/v2"
"go.arcalot.io/log/v2"
Expand Down Expand Up @@ -36,7 +37,13 @@
// ReadSchema reads the schema from the ATP server.
ReadSchema() (*schema.SchemaSchema, error)
// Execute executes a step with a given context and returns the resulting output. Assumes you called ReadSchema first.
Execute(input schema.Input, receivedSignals <-chan schema.Input, emittedSignals chan<- schema.Input) ExecutionResult
// Params:
// - input: The step input for the run.
// - signalsToStep: A channel to send signals from the client to the plugin.
// - signalsFromStep: A channel to receive signals from the plugin to the client.
// It is recommended to close the signalsToStep channel when either Execute is done or it is known that no more signals
// will be sent to the plugin.
Execute(input schema.Input, signalsToStep <-chan schema.Input, signalsFromStep chan<- schema.Input) ExecutionResult
Close() error
Encoder() *cbor.Encoder
Decoder() *cbor.Decoder
Expand Down Expand Up @@ -66,20 +73,22 @@
if logger == nil {
logger = log.NewLogger(log.LevelDebug, log.NewNOOPLogger())
}
ctx, cancel := context.WithCancel(context.Background())
return &client{
-1, // unknown
channel,
decMode,
logger,
decMode.NewDecoder(channel),
cbor.NewEncoder(channel),
make(chan bool, 5), // Buffer to prevent deadlocks
make([]schema.Input, 0),
make(map[string]*executionEntry),
make(map[string]chan<- schema.Input),
sync.Mutex{},
false,
false,
ctx,
cancel,
sync.WaitGroup{},
}
}
Expand All @@ -99,18 +108,19 @@

type client struct {
atpVersion int64
channel ClientChannel
rawAtpChannels ClientChannel
decMode cbor.DecMode
logger log.Logger
decoder *cbor.Decoder
encoder *cbor.Encoder
doneChannel chan bool
runningSteps []schema.Input
runningStepResultEntries map[string]*executionEntry // Run ID to results
runningStepEmittedSignalChannels map[string]chan<- schema.Input // Run ID to channel of signals emitted from steps
mutex sync.Mutex
readLoopRunning bool
readLoopRunning bool // To prevent duplicate loops across multiple step executions.
done bool
context context.Context
cancelFunc context.CancelFunc
wg sync.WaitGroup // For the read loop.
}

Expand Down Expand Up @@ -165,8 +175,8 @@

func (c *client) Execute(
stepData schema.Input,
receivedSignals <-chan schema.Input,
emittedSignals chan<- schema.Input,
signalsToStep <-chan schema.Input,
signalsFromStep chan<- schema.Input,
) ExecutionResult {
c.logger.Debugf("Executing plugin step %s/%s...", stepData.RunID, stepData.ID)
if len(stepData.RunID) == 0 {
Expand All @@ -177,20 +187,20 @@
StepID: stepData.ID,
Config: stepData.InputData,
}
cborReader := c.decMode.NewDecoder(c.channel)
cborReader := c.decMode.NewDecoder(c.rawChannels)

Check failure on line 190 in atp/client.go

View workflow job for this annotation

GitHub Actions / lint and test / go test

c.rawChannels undefined (type *client has no field or method rawChannels)

Check failure on line 190 in atp/client.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

c.rawChannels undefined (type *client has no field or method rawChannels) (typecheck)

Check failure on line 190 in atp/client.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

c.rawChannels undefined (type *client has no field or method rawChannels)) (typecheck)

Check failure on line 190 in atp/client.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

c.rawChannels undefined (type *client has no field or method rawChannels)) (typecheck)
if c.atpVersion > 1 {
// Wrap it in a runtime message.
workStartMsg = RuntimeMessage{RunID: stepData.RunID, MessageID: MessageTypeWorkStart, MessageData: workStartMsg}
// Handle signals to the step
if receivedSignals != nil {
if signalsToStep != nil {
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.executeWriteLoop(stepData.RunID, receivedSignals)
c.executeWriteLoop(stepData.RunID, signalsToStep)
}()
}
// Setup channels for ATP v2
err := c.prepareResultChannels(cborReader, stepData, emittedSignals)
err := c.prepareResultChannels(cborReader, stepData, signalsFromStep)
if err != nil {
return NewErrorExecutionResult(err)
}
Expand All @@ -206,6 +216,7 @@

// Close Tells the client that it's done, and can stop listening for more requests.
func (c *client) Close() error {
c.cancelFunc()
c.mutex.Lock()
if c.done {
c.mutex.Unlock()
Expand Down Expand Up @@ -271,14 +282,12 @@
// Listen for received signals, and send them over ATP if available.
func (c *client) executeWriteLoop(
runID string,
receivedSignals <-chan schema.Input,
signalsToStep <-chan schema.Input,
) {
c.mutex.Lock()
if c.done {
c.mutex.Unlock()
// Close() was called, so exit now.
// Failure to exit now may result in this receivedSignals channel not getting
// closed, resulting in this function hanging.
c.logger.Warningf(
"write called loop for run ID %q on done client; skipping receive loop",
runID,
Expand All @@ -289,9 +298,16 @@

// Looped select that gets signals
for {
signal, ok := <-receivedSignals
if !ok {
c.logger.Infof("ATP signal loop done")
var signal schema.Input
var ok bool
select {
case signal, ok = <-signalsToStep:
if !ok {
c.logger.Debugf("ATP signal loop done; channel closed")
return
}
case <-c.context.Done():
c.logger.Debugf("ATP signal loop exited; context closed")
return
}
c.logger.Debugf("Sending signal with ID '%s' to step with run ID '%s'", signal.ID, signal.RunID)
Expand Down
60 changes: 60 additions & 0 deletions atp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"github.com/fxamacker/cbor/v2"
"go.arcalot.io/assert"
"go.arcalot.io/log/v2"
"go.flow.arcalot.io/pluginsdk/atp"

Check failure on line 9 in atp/protocol_test.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

could not import go.flow.arcalot.io/pluginsdk/atp (-: # go.flow.arcalot.io/pluginsdk/atp
"go.flow.arcalot.io/pluginsdk/schema"
"io"
"sync"
Expand Down Expand Up @@ -197,6 +197,66 @@
wg.Wait()
}

func TestProtocol_Client_Execute_With_Signals(t *testing.T) {
testExecuteWithChannels(true, t)
}

func TestProtocol_Client_Execute_With_Signals_Unclosed(t *testing.T) {
testExecuteWithChannels(false, t)
}

func testExecuteWithChannels(closeChannel bool, t *testing.T) {
// Client ReadSchema and Execute happy path with signal handlers passed
// into the Execute call.
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(2)
stdinReader, stdinWriter := io.Pipe()
stdoutReader, stdoutWriter := io.Pipe()

go func() {
defer wg.Done()
errors := atp.RunATPServer(
ctx,
stdinReader,
stdoutWriter,
helloWorldSchema,
)
assert.Equals(t, len(errors), 0)
}()

go func() {
defer wg.Done()
cli := atp.NewClientWithLogger(channel{
Reader: stdoutReader,
Writer: stdinWriter,
Context: nil,
cancel: cancel,
}, log.NewTestLogger(t))

_, err := cli.ReadSchema()
assert.NoError(t, err)
toStepChan := make(chan schema.Input)
fromStepChan := make(chan schema.Input)

result := cli.Execute(
schema.Input{
RunID: t.Name(),
ID: "hello-world",
InputData: map[string]any{"name": "Arca Lot"},
}, toStepChan, fromStepChan)
if closeChannel {
close(toStepChan)
}
assert.NoError(t, cli.Close())
assert.NoError(t, result.Error)
assert.Equals(t, result.OutputID, "success")
assert.Equals(t, result.OutputData.(map[any]any)["message"].(string), "Hello, Arca Lot!")
}()

wg.Wait()
}

//nolint:funlen
func TestProtocol_Client_ATP_v1(t *testing.T) {
// Client ReadSchema and Execute atp v1 happy path.
Expand Down Expand Up @@ -305,7 +365,7 @@
InputData: map[string]any{"name": "Arca Lot"},
}, nil, nil)
assert.Error(t, result.Error)
assert.Contains(t, result.Error.Error(), "abcde")

Check failure on line 368 in atp/protocol_test.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

cannot infer T (/home/runner/go/pkg/mod/go.arcalot.io/[email protected]/equality.go:10:1) (typecheck)
assert.Equals(t, result.OutputID, "")
}
assert.NoError(t, cli.Close())
Expand Down Expand Up @@ -676,7 +736,7 @@
assert.Equals(t, len(serverErrors), 1)
// This may make the test more fragile, but checking the error is the only way
// to know that the error is from where we're testing.
assert.Contains(t, serverErrors[0].Err.Error(), "failed to read or decode runtime message")

Check failure on line 739 in atp/protocol_test.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

cannot infer T (/home/runner/go/pkg/mod/go.arcalot.io/[email protected]/equality.go:10:1) (typecheck)
// We don't wait on error, to prevent deadlocks, so just sleep
time.Sleep(time.Millisecond * 2)
}
Expand Down
Loading