From cce45b7f08ac14d23abcd439d19c4f2019240f95 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Tue, 12 Sep 2023 16:33:33 -0400 Subject: [PATCH] Added client done message, and signal atp tests --- src/arcaflow_plugin_sdk/atp.py | 37 ++++++++++-- src/arcaflow_plugin_sdk/plugin.py | 9 ++- src/arcaflow_plugin_sdk/schema.py | 6 +- src/arcaflow_plugin_sdk/test_atp.py | 91 +++++++++++++++++++++++++++-- 4 files changed, 126 insertions(+), 17 deletions(-) diff --git a/src/arcaflow_plugin_sdk/atp.py b/src/arcaflow_plugin_sdk/atp.py index b49de47..cae21b8 100644 --- a/src/arcaflow_plugin_sdk/atp.py +++ b/src/arcaflow_plugin_sdk/atp.py @@ -28,8 +28,9 @@ class MessageType(Enum): - WORKDONE = 1 + WORK_DONE = 1 SIGNAL = 2 + CLIENT_DONE = 3 @dataclasses.dataclass @@ -133,7 +134,7 @@ def run_plugin( # Send WorkDoneMessage in a RuntimeMessage encoder.encode( { - "id": MessageType.WORKDONE.value, + "id": MessageType.WORK_DONE.value, "data": { "output_id": output_id, "output_data": plugin_schema.serialize_output( @@ -164,6 +165,7 @@ def run_server_read_loop( # Validate if msg_id is None: stderr.write("Runtime message is missing the 'id' field.") + return # Then take action if msg_id == MessageType.SIGNAL.value: signal_msg = runtime_msg["data"] @@ -179,11 +181,13 @@ def run_server_read_loop( signal_msg["data"] ) plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data) + elif msg_id == MessageType.CLIENT_DONE.value: + return else: stderr.write(f"Unknown kind of runtime message: {msg_id}") - except cbor2.CBORDecodeError: - stderr.write(f"Error while decoding CBOR: {msg_id}") + except cbor2.CBORDecodeError as err: + stderr.write(f"Error while decoding CBOR: {err}") class PluginClientStateException(Exception): @@ -241,6 +245,29 @@ def start_work(self, step_id: str, input_data: any): } ) + def send_signal(self, step_id: str, signal_id: str, input_data: any): + """ + This function sends any signals to the plugin. + """ + self.send_runtime_message(MessageType.SIGNAL, { + "step_id": step_id, + "signal_id": signal_id, + "data": input_data, + } + ) + + def send_client_done(self): + self.send_runtime_message(MessageType.CLIENT_DONE, {}) + + def send_runtime_message(self, message_type: MessageType, data: any): + self.encoder.encode( + { + "id": message_type.value, + "data": data, + } + ) + + def read_results(self) -> (str, any, str): """ This function reads the signals and results of an execution from the plugin. @@ -248,7 +275,7 @@ def read_results(self) -> (str, any, str): while True: runtime_msg = self.decoder.decode() msg_id = runtime_msg["id"] - if msg_id == MessageType.WORKDONE.value: + if msg_id == MessageType.WORK_DONE.value: signal_msg = runtime_msg["data"] if signal_msg["output_id"] is None: raise PluginClientStateException( diff --git a/src/arcaflow_plugin_sdk/plugin.py b/src/arcaflow_plugin_sdk/plugin.py index 4b441cc..28e8218 100644 --- a/src/arcaflow_plugin_sdk/plugin.py +++ b/src/arcaflow_plugin_sdk/plugin.py @@ -45,7 +45,7 @@ def signal_handler( :return: A schema for the signal. """ - def signal_decorator(func: _step_decorator_param) -> schema.StepType: + def signal_decorator(func: _step_decorator_param) -> schema.SignalHandlerType: if id == "": raise BadArgumentException("Signals cannot have empty IDs") if name == "": @@ -53,7 +53,7 @@ def signal_decorator(func: _step_decorator_param) -> schema.StepType: sig = inspect.signature(func) if len(sig.parameters) != 2: raise BadArgumentException( - "The '%s' (id: %s) signal must have exactly two parameters, including self. Currently has %d" % + "The '%s' (id: %s) signal must have exactly two parameters, including self. Currently has %v" % (name, id, sig.parameters) ) input_param = list(sig.parameters.values())[1] @@ -90,7 +90,7 @@ def step_with_signals( outputs: Dict[str, Type], signal_handler_method_names: List[str], signal_emitters: List[schema.SignalSchema], - step_object_constructor: schema._step_object_constructor_param, + step_object_constructor: schema.step_object_constructor_param, icon: typing.Optional[str] = None, ) -> Callable[[_step_object_decorator_param], schema.StepType]: """ @@ -101,6 +101,9 @@ def step_with_signals( :param name: The human-readable name for the step. :param description: The human-readable description for the step. :param outputs: A dict linking response IDs to response object types. + :param signal_handler_method_names: A list of methods for all signal handlers. + :param signal_emitters: A list of signal schemas for signal emitters. + :param step_object_constructor: A constructor lambda for the object with the step and signal methods. :param icon: SVG icon for this step. :return: A schema for the step. """ diff --git a/src/arcaflow_plugin_sdk/schema.py b/src/arcaflow_plugin_sdk/schema.py index 4ecfb55..487c57f 100644 --- a/src/arcaflow_plugin_sdk/schema.py +++ b/src/arcaflow_plugin_sdk/schema.py @@ -5583,7 +5583,7 @@ def __call__( self._handler(step_data, params) -_step_object_constructor_param = Callable[[], StepObjectT] +step_object_constructor_param = Callable[[], StepObjectT] class StepType(StepSchema): @@ -5593,7 +5593,7 @@ class StepType(StepSchema): """ _handler: Callable[[StepObjectT, StepInputT], typing.Tuple[str, StepOutputT]] - _step_object_constructor: _step_object_constructor_param + _step_object_constructor: step_object_constructor_param input: ScopeType outputs: Dict[ID_TYPE, StepOutputType] signal_handler_method_names: List[str] @@ -5605,7 +5605,7 @@ def __init__( self, id: str, handler: Callable[[StepObjectT, StepInputT], typing.Tuple[str, StepOutputT]], - step_object_constructor: _step_object_constructor_param, + step_object_constructor: step_object_constructor_param, input: ScopeType, outputs: Dict[ID_TYPE, StepOutputType], signal_handler_method_names: List[str], diff --git a/src/arcaflow_plugin_sdk/test_atp.py b/src/arcaflow_plugin_sdk/test_atp.py index 3c0a491..73564b6 100644 --- a/src/arcaflow_plugin_sdk/test_atp.py +++ b/src/arcaflow_plugin_sdk/test_atp.py @@ -3,9 +3,10 @@ import signal import time import unittest -from typing import TextIO, Tuple, Union +from threading import Event +from typing import TextIO, Tuple, Union, List -from arcaflow_plugin_sdk import atp, plugin, schema +from arcaflow_plugin_sdk import atp, plugin, schema, predefined_schemas @dataclasses.dataclass @@ -29,11 +30,56 @@ def hello_world(params: Input) -> Tuple[str, Union[Output]]: return "success", Output("Hello, {}!".format(params.name)) +@dataclasses.dataclass +class StepTestInput: + wait_time_seconds: float + + +@dataclasses.dataclass +class SignalTestInput: + final: bool # The last one will trigger the end of the step. + value: int + + +@dataclasses.dataclass +class SignalTestOutput: + signals_received: List[int] + + +class SignalTestStep: + signal_values: List[int] = [] + exit_event = Event() + + @plugin.step_with_signals( + id="signal_test_step", + name="signal_test_step", + description="waits for signal with timeout", + outputs={"success": SignalTestOutput}, + signal_handler_method_names=["signal_test_signal_handler"], + signal_emitters=[], + step_object_constructor=lambda: SignalTestStep(), + ) + def signal_test_step(self, params: StepTestInput) -> Tuple[str, Union[SignalTestOutput]]: + self.exit_event.wait(params.wait_time_seconds) + return "success", SignalTestOutput(self.signal_values) + + @plugin.signal_handler( + id="record_value", + name="record value", + description="Records the value, and optionally ends the step.", + ) + def signal_test_signal_handler(self, signal_input: SignalTestInput): + self.signal_values.append(signal_input.value) + if signal_input.final: + self.exit_event.set() + + test_schema = plugin.build_schema(hello_world) +test_signals_schema = plugin.build_schema(SignalTestStep.signal_test_step) class ATPTest(unittest.TestCase): - def _execute_plugin(self) -> Tuple[int, TextIO, TextIO]: + def _execute_plugin(self, schema) -> Tuple[int, TextIO, TextIO]: stdin_reader_fd, stdin_writer_fd = os.pipe() stdout_reader_fd, stdout_writer_fd = os.pipe() pid = os.fork() @@ -49,7 +95,7 @@ def _execute_plugin(self) -> Tuple[int, TextIO, TextIO]: stdout_writer.buffer.raw, stdout_writer.buffer.raw, ) - result = atp_server.run_plugin(test_schema) + result = atp_server.run_plugin(schema) os.close(stdin_reader_fd) os.close(stdout_writer_fd) if result != 0: @@ -77,8 +123,8 @@ def _cleanup(self, pid, stdin_writer, stdout_reader): if exit_status != 0: self.fail("Plugin exited with non-zero status: {}".format(exit_status)) - def test_full_workflow(self): - pid, stdin_writer, stdout_reader = self._execute_plugin() + def test_full_simple_workflow(self): + pid, stdin_writer, stdout_reader = self._execute_plugin(test_schema) try: client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) @@ -94,10 +140,43 @@ def test_full_workflow(self): client.start_work("hello-world", {"name": "Arca Lot"}) output_id, output_data, debug_logs = client.read_results() + client.send_client_done() + self.assertEqual(output_id, "success") self.assertEqual("Hello world!\n", debug_logs) finally: self._cleanup(pid, stdin_writer, stdout_reader) + def test_full_workflow_with_signals(self): + pid, stdin_writer, stdout_reader = self._execute_plugin(test_signals_schema) + + try: + client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) + client.start_output() + hello_message = client.read_hello() + self.assertEqual(2, hello_message.version) + + self.assertEqual( + schema.SCHEMA_SCHEMA.serialize(test_signals_schema), + schema.SCHEMA_SCHEMA.serialize(hello_message.schema), + ) + + client.start_work("signal_test_step", {"wait_time_seconds": "0.5"}) + client.send_signal("signal_test_step", "record_value", + {"final": "false", "value": "1"}, + ) + client.send_signal("signal_test_step", "record_value", + {"final": "false", "value": "2"}, + ) + client.send_signal("signal_test_step", "record_value", + {"final": "true", "value": "3"}, + ) + output_id, output_data, _ = client.read_results() + client.send_client_done() + self.assertEqual(output_id, "success") + self.assertListEqual(output_data["signals_received"], [1, 2, 3]) + finally: + self._cleanup(pid, stdin_writer, stdout_reader) + if __name__ == "__main__": unittest.main()