Skip to content

Commit

Permalink
Added client done message, and signal atp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Sep 12, 2023
1 parent 68d0ecd commit cce45b7
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 17 deletions.
37 changes: 32 additions & 5 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@


class MessageType(Enum):
WORKDONE = 1
WORK_DONE = 1
SIGNAL = 2
CLIENT_DONE = 3


@dataclasses.dataclass
Expand Down Expand Up @@ -133,7 +134,7 @@ def run_plugin(
# Send WorkDoneMessage in a RuntimeMessage
encoder.encode(
{
"id": MessageType.WORKDONE.value,
"id": MessageType.WORK_DONE.value,
"data": {
"output_id": output_id,
"output_data": plugin_schema.serialize_output(
Expand Down Expand Up @@ -164,6 +165,7 @@ def run_server_read_loop(
# Validate
if msg_id is None:
stderr.write("Runtime message is missing the 'id' field.")
return
# Then take action
if msg_id == MessageType.SIGNAL.value:
signal_msg = runtime_msg["data"]
Expand All @@ -179,11 +181,13 @@ def run_server_read_loop(
signal_msg["data"]
)
plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data)
elif msg_id == MessageType.CLIENT_DONE.value:
return
else:
stderr.write(f"Unknown kind of runtime message: {msg_id}")

except cbor2.CBORDecodeError:
stderr.write(f"Error while decoding CBOR: {msg_id}")
except cbor2.CBORDecodeError as err:
stderr.write(f"Error while decoding CBOR: {err}")


class PluginClientStateException(Exception):
Expand Down Expand Up @@ -241,14 +245,37 @@ def start_work(self, step_id: str, input_data: any):
}
)

def send_signal(self, step_id: str, signal_id: str, input_data: any):
"""
This function sends any signals to the plugin.
"""
self.send_runtime_message(MessageType.SIGNAL, {
"step_id": step_id,
"signal_id": signal_id,
"data": input_data,
}
)

def send_client_done(self):
self.send_runtime_message(MessageType.CLIENT_DONE, {})

def send_runtime_message(self, message_type: MessageType, data: any):
self.encoder.encode(
{
"id": message_type.value,
"data": data,
}
)


def read_results(self) -> (str, any, str):
"""
This function reads the signals and results of an execution from the plugin.
"""
while True:
runtime_msg = self.decoder.decode()
msg_id = runtime_msg["id"]
if msg_id == MessageType.WORKDONE.value:
if msg_id == MessageType.WORK_DONE.value:
signal_msg = runtime_msg["data"]
if signal_msg["output_id"] is None:
raise PluginClientStateException(
Expand Down
9 changes: 6 additions & 3 deletions src/arcaflow_plugin_sdk/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def signal_handler(
:return: A schema for the signal.
"""

def signal_decorator(func: _step_decorator_param) -> schema.StepType:
def signal_decorator(func: _step_decorator_param) -> schema.SignalHandlerType:
if id == "":
raise BadArgumentException("Signals cannot have empty IDs")
if name == "":
raise BadArgumentException("Signals cannot have empty names")
sig = inspect.signature(func)
if len(sig.parameters) != 2:
raise BadArgumentException(
"The '%s' (id: %s) signal must have exactly two parameters, including self. Currently has %d" %
"The '%s' (id: %s) signal must have exactly two parameters, including self. Currently has %v" %
(name, id, sig.parameters)
)
input_param = list(sig.parameters.values())[1]
Expand Down Expand Up @@ -90,7 +90,7 @@ def step_with_signals(
outputs: Dict[str, Type],
signal_handler_method_names: List[str],
signal_emitters: List[schema.SignalSchema],
step_object_constructor: schema._step_object_constructor_param,
step_object_constructor: schema.step_object_constructor_param,
icon: typing.Optional[str] = None,
) -> Callable[[_step_object_decorator_param], schema.StepType]:
"""
Expand All @@ -101,6 +101,9 @@ def step_with_signals(
:param name: The human-readable name for the step.
:param description: The human-readable description for the step.
:param outputs: A dict linking response IDs to response object types.
:param signal_handler_method_names: A list of methods for all signal handlers.
:param signal_emitters: A list of signal schemas for signal emitters.
:param step_object_constructor: A constructor lambda for the object with the step and signal methods.
:param icon: SVG icon for this step.
:return: A schema for the step.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5583,7 +5583,7 @@ def __call__(
self._handler(step_data, params)


_step_object_constructor_param = Callable[[], StepObjectT]
step_object_constructor_param = Callable[[], StepObjectT]


class StepType(StepSchema):
Expand All @@ -5593,7 +5593,7 @@ class StepType(StepSchema):
"""

_handler: Callable[[StepObjectT, StepInputT], typing.Tuple[str, StepOutputT]]
_step_object_constructor: _step_object_constructor_param
_step_object_constructor: step_object_constructor_param
input: ScopeType
outputs: Dict[ID_TYPE, StepOutputType]
signal_handler_method_names: List[str]
Expand All @@ -5605,7 +5605,7 @@ def __init__(
self,
id: str,
handler: Callable[[StepObjectT, StepInputT], typing.Tuple[str, StepOutputT]],
step_object_constructor: _step_object_constructor_param,
step_object_constructor: step_object_constructor_param,
input: ScopeType,
outputs: Dict[ID_TYPE, StepOutputType],
signal_handler_method_names: List[str],
Expand Down
91 changes: 85 additions & 6 deletions src/arcaflow_plugin_sdk/test_atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import signal
import time
import unittest
from typing import TextIO, Tuple, Union
from threading import Event
from typing import TextIO, Tuple, Union, List

from arcaflow_plugin_sdk import atp, plugin, schema
from arcaflow_plugin_sdk import atp, plugin, schema, predefined_schemas


@dataclasses.dataclass
Expand All @@ -29,11 +30,56 @@ def hello_world(params: Input) -> Tuple[str, Union[Output]]:
return "success", Output("Hello, {}!".format(params.name))


@dataclasses.dataclass
class StepTestInput:
wait_time_seconds: float


@dataclasses.dataclass
class SignalTestInput:
final: bool # The last one will trigger the end of the step.
value: int


@dataclasses.dataclass
class SignalTestOutput:
signals_received: List[int]


class SignalTestStep:
signal_values: List[int] = []
exit_event = Event()

@plugin.step_with_signals(
id="signal_test_step",
name="signal_test_step",
description="waits for signal with timeout",
outputs={"success": SignalTestOutput},
signal_handler_method_names=["signal_test_signal_handler"],
signal_emitters=[],
step_object_constructor=lambda: SignalTestStep(),
)
def signal_test_step(self, params: StepTestInput) -> Tuple[str, Union[SignalTestOutput]]:
self.exit_event.wait(params.wait_time_seconds)
return "success", SignalTestOutput(self.signal_values)

@plugin.signal_handler(
id="record_value",
name="record value",
description="Records the value, and optionally ends the step.",
)
def signal_test_signal_handler(self, signal_input: SignalTestInput):
self.signal_values.append(signal_input.value)
if signal_input.final:
self.exit_event.set()


test_schema = plugin.build_schema(hello_world)
test_signals_schema = plugin.build_schema(SignalTestStep.signal_test_step)


class ATPTest(unittest.TestCase):
def _execute_plugin(self) -> Tuple[int, TextIO, TextIO]:
def _execute_plugin(self, schema) -> Tuple[int, TextIO, TextIO]:
stdin_reader_fd, stdin_writer_fd = os.pipe()
stdout_reader_fd, stdout_writer_fd = os.pipe()
pid = os.fork()
Expand All @@ -49,7 +95,7 @@ def _execute_plugin(self) -> Tuple[int, TextIO, TextIO]:
stdout_writer.buffer.raw,
stdout_writer.buffer.raw,
)
result = atp_server.run_plugin(test_schema)
result = atp_server.run_plugin(schema)
os.close(stdin_reader_fd)
os.close(stdout_writer_fd)
if result != 0:
Expand Down Expand Up @@ -77,8 +123,8 @@ def _cleanup(self, pid, stdin_writer, stdout_reader):
if exit_status != 0:
self.fail("Plugin exited with non-zero status: {}".format(exit_status))

def test_full_workflow(self):
pid, stdin_writer, stdout_reader = self._execute_plugin()
def test_full_simple_workflow(self):
pid, stdin_writer, stdout_reader = self._execute_plugin(test_schema)

try:
client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw)
Expand All @@ -94,10 +140,43 @@ def test_full_workflow(self):
client.start_work("hello-world", {"name": "Arca Lot"})

output_id, output_data, debug_logs = client.read_results()
client.send_client_done()
self.assertEqual(output_id, "success")
self.assertEqual("Hello world!\n", debug_logs)
finally:
self._cleanup(pid, stdin_writer, stdout_reader)

def test_full_workflow_with_signals(self):
pid, stdin_writer, stdout_reader = self._execute_plugin(test_signals_schema)

try:
client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw)
client.start_output()
hello_message = client.read_hello()
self.assertEqual(2, hello_message.version)

self.assertEqual(
schema.SCHEMA_SCHEMA.serialize(test_signals_schema),
schema.SCHEMA_SCHEMA.serialize(hello_message.schema),
)

client.start_work("signal_test_step", {"wait_time_seconds": "0.5"})
client.send_signal("signal_test_step", "record_value",
{"final": "false", "value": "1"},
)
client.send_signal("signal_test_step", "record_value",
{"final": "false", "value": "2"},
)
client.send_signal("signal_test_step", "record_value",
{"final": "true", "value": "3"},
)
output_id, output_data, _ = client.read_results()
client.send_client_done()
self.assertEqual(output_id, "success")
self.assertListEqual(output_data["signals_received"], [1, 2, 3])
finally:
self._cleanup(pid, stdin_writer, stdout_reader)


if __name__ == "__main__":
unittest.main()

0 comments on commit cce45b7

Please sign in to comment.