Skip to content

Commit

Permalink
Added work done message
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Sep 19, 2023
1 parent 8194670 commit 442d24f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
37 changes: 19 additions & 18 deletions atp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ func (c *client) Execute(

doneChannel := make(chan bool, 1) // Needs a buffer to not hang.
defer handleClientClosure(receivedSignals, doneChannel)
go func() {
c.executeWriteLoop(stepData, receivedSignals, doneChannel)
}()
if c.atpVersion > 1 {
go func() {
c.executeWriteLoop(stepData, receivedSignals, doneChannel)
}()
}
return c.executeReadLoop(stepData, receivedSignals)
}

Expand All @@ -151,27 +153,26 @@ func (c *client) executeWriteLoop(
) {
// Looped select that gets signals
for {
signal, ok := <-receivedSignals
isDone := false
var signal schema.Input
select {
case isDone = <-doneChannel:
default:
// Non-blocking because of the default.
}
if !ok {
if isDone {
// It's done, so the not ok is expected.
return
} else {
case <-doneChannel:
// Send the client done message
err := c.encoder.Encode(RuntimeMessage{
MessageTypeClientDone,
clientDoneMessage{},
})
if err != nil {
c.logger.Errorf("Step %s failed to write client done message with error: %w", stepData.ID, err)
}
return
case receivedSignal, ok := <-receivedSignals:
if !ok {
// It's not supposed to be not ok yet.
c.logger.Errorf("error in channel preparing to send signal (step %s, signal %s) over ATP",
stepData.ID, signal.ID)
return
}
}
if isDone {
c.logger.Errorf("signal received after step '%s' completed. Ignoring signal '%s'", stepData.ID, signal.ID)
return
signal = receivedSignal
}
c.logger.Debugf("Sending signal with ID '%s' to step with ID '%s'", signal.ID, stepData.ID)
if err := c.encoder.Encode(RuntimeMessage{
Expand Down
9 changes: 7 additions & 2 deletions atp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ type StartWorkMessage struct {

// All messages that can be contained in a RuntimeMessage struct.
const (
MessageTypeWorkDone uint32 = 1
MessageTypeSignal uint32 = 2
MessageTypeWorkDone uint32 = 1
MessageTypeSignal uint32 = 2
MessageTypeClientDone uint32 = 3
)

type RuntimeMessage struct {
Expand Down Expand Up @@ -48,3 +49,7 @@ type signalMessage struct {
func (s signalMessage) ToInput() schema.Input {
return schema.Input{ID: s.SignalID, InputData: s.Data}
}

type clientDoneMessage struct {
// Empty for now.
}
8 changes: 7 additions & 1 deletion atp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,31 @@ func (s *atpServerSession) runATPReadLoop() {
}
if !done {
s.workDone <- fmt.Errorf("failed to read or decode runtime message: %w", err)
}
} // If done, it didn't get the work done message, which is not ideal.
return
}
switch runtimeMessage.MessageID {
case MessageTypeSignal:
var signalMessage signalMessage
if err := cbor.Unmarshal(runtimeMessage.RawMessageData, &signalMessage); err != nil {
s.workDone <- fmt.Errorf("failed to decode signal message: %w", err)
return
}
if s.req.StepID != signalMessage.StepID {
s.workDone <- fmt.Errorf("signal sent with mismatched step ID, got %s, expected %s",
signalMessage.StepID, s.req.StepID)
return
}
if err := s.pluginSchema.CallSignal(s.ctx, signalMessage.StepID, signalMessage.SignalID, signalMessage.Data); err != nil {
s.workDone <- fmt.Errorf("failed while running signal ID %s: %w",
signalMessage.SignalID, err)
return
}
case MessageTypeClientDone:
return
default:
s.workDone <- fmt.Errorf("unknown message ID received: %d", runtimeMessage.MessageID)
return
}
}
}
Expand Down

0 comments on commit 442d24f

Please sign in to comment.