diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0675a3d..d710074 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: python-version: [ '3.9', '3.10', 'pypy3.9' ] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Check out code uses: actions/checkout@v3 @@ -32,20 +32,36 @@ jobs: isort --profile black . black . flake8 . - - name: Install Python Poetry - uses: snok/install-poetry@v1.3.4 - - name: Install project dependencies - run: poetry install --no-interaction --with dev + - name: Install poetry + run: | + python -m pip install poetry==1.4.2 + - name: Configure poetry + run: | + python -m poetry config virtualenvs.in-project true + - name: Upload logs on failure + uses: actions/upload-artifact@v3 + if: failure() + with: + name: logs + path: "*.log" + - name: Cache the virtualenv + uses: actions/cache@v3 + with: + path: ./.venv + key: ${{ runner.os }}-venv-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies + run: | + python -m poetry install - name: Run tests with coverage run: | # Run the unit tests - poetry run python3 -m coverage run -a -m unittest discover -v src + python -m poetry run coverage run -a -m unittest discover -v src # Run the example plugin - poetry run python3 -m coverage run -a ./example_plugin.py -f example.yaml + python -m poetry run coverage run -a ./example_plugin.py -f example.yaml # Test the example plugin - poetry run python3 -m coverage run -a ./test_example_plugin.py + python -m poetry run coverage run -a ./test_example_plugin.py # Generate the coverage HTML report - poetry run python3 -m coverage html + python -m poetry run coverage html - name: Publish coverage report to job summary # publishing only once if: ${{ matrix.python-version == '3.9'}} diff --git a/src/arcaflow_plugin_sdk/atp.py b/src/arcaflow_plugin_sdk/atp.py index 40c9508..bcc0f09 100644 --- a/src/arcaflow_plugin_sdk/atp.py +++ b/src/arcaflow_plugin_sdk/atp.py @@ -20,29 +20,35 @@ import typing import threading import signal +import traceback import cbor2 -from enum import Enum +from enum import IntEnum from arcaflow_plugin_sdk import schema -class MessageType(Enum): +ATP_SERVER_VERSION = 3 + + +class MessageType(IntEnum): """ An integer ID that indicates the type of runtime message that is stored in the data field. The corresponding class can then be used to deserialize the inner data. Look at the go SDK for the reference data structure. """ - WORK_DONE = 1 - SIGNAL = 2 - CLIENT_DONE = 3 + WORK_START = 1 + WORK_DONE = 2 + SIGNAL = 3 + CLIENT_DONE = 4 + ERROR = 5 @dataclasses.dataclass class HelloMessage: """ - This message is the the initial greeting message a plugin sends to the output. + This message is the initial greeting message a plugin sends to the output. """ version: typing.Annotated[ @@ -64,24 +70,26 @@ class HelloMessage: _HELLO_MESSAGE_SCHEMA = schema.build_object_schema(HelloMessage) -def signal_handler(_sig, _frame): - pass # Do nothing - - class ATPServer: - stdin: io.FileIO - stdout: io.FileIO + input_pipe: io.FileIO + output_pipe: io.FileIO stderr: io.FileIO - step_object: typing.Any + step_ids: typing.Dict[str, str] = {} # Run ID to step IDs + encoder: cbor2.encoder + decoder: cbor2.decoder + user_out_buffer: io.StringIO + encoder_lock = threading.Lock() + plugin_schema: schema.SchemaType + running_threads: typing.List[threading.Thread] = [] def __init__( self, - stdin: io.FileIO, - stdout: io.FileIO, + input_pipe: io.FileIO, + output_pipe: io.FileIO, stderr: io.FileIO, ) -> None: - self.stdin = stdin - self.stdout = stdout + self.input_pipe = input_pipe + self.output_pipe = output_pipe self.stderr = stderr def run_plugin( @@ -91,121 +99,185 @@ def run_plugin( """ This function wraps running a plugin. """ - signal.signal(signal.SIGINT, signal_handler) # Ignore sigint. Only care about arcaflow signals. - if os.isatty(self.stdout.fileno()): + signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore sigint. Only care about arcaflow signals. + if os.isatty(self.output_pipe.fileno()): print("Cannot run plugin in ATP mode on an interactive terminal.") return 1 - try: - decoder = cbor2.decoder.CBORDecoder(self.stdin) - encoder = cbor2.encoder.CBOREncoder(self.stdout) - - # Decode empty "start output" message. - decoder.decode() - - # Serialize then send HelloMessage - start_hello_message = HelloMessage(2, plugin_schema) - serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start_hello_message) - encoder.encode(serialized_message) - self.stdout.flush() - - # Can fail here if only getting schema. - work_start_msg = decoder.decode() - except SystemExit: - return 0 - try: - if work_start_msg is None: - self.stderr.write("Work start message is None.") - return 1 - if work_start_msg["id"] is None: - self.stderr.write("Work start message is missing the 'id' field.") - return 1 - if work_start_msg["config"] is None: - self.stderr.write("Work start message is missing the 'config' field.") - return 1 - - # Run the step - original_stdout = sys.stdout - original_stderr = sys.stderr - out_buffer = io.StringIO() - sys.stdout = out_buffer - sys.stderr = out_buffer - # Run the read loop - read_thread = threading.Thread(target=self.run_server_read_loop, args=( - plugin_schema, # Plugin schema - work_start_msg["id"], # step ID - decoder, # Decoder - )) - read_thread.start() - output_id, output_data = plugin_schema.call_step( - work_start_msg["id"], - plugin_schema.unserialize_step_input(work_start_msg["id"], work_start_msg["config"]) - ) + self.decoder = cbor2.decoder.CBORDecoder(self.input_pipe) + self.encoder = cbor2.encoder.CBOREncoder(self.output_pipe) + self.plugin_schema = plugin_schema + self.handle_handshake() + + # First replace stdout so that prints are handled by us, instead of potentially interfering with the atp pipes. + original_stdout = sys.stdout + original_stderr = sys.stderr + self.user_out_buffer = io.StringIO() + sys.stdout = self.user_out_buffer + sys.stderr = self.user_out_buffer + + # Run the read loop + read_thread = threading.Thread(target=self.run_server_read_loop, args=()) + read_thread.start() + read_thread.join() # Wait for the read thread to finish. + # Wait for the step/signal threads to finish. If it gets stuck here then there is another thread blocked. + for thread in self.running_threads: + thread.join() + + # Don't reset stdout/stderr until after the read and step/signal threads are done. + sys.stdout = original_stdout + sys.stderr = original_stderr - # Send WorkDoneMessage in a RuntimeMessage - encoder.encode( - { - "id": MessageType.WORK_DONE.value, - "data": { - "output_id": output_id, - "output_data": plugin_schema.serialize_output( - work_start_msg["id"], output_id, output_data - ), - "debug_logs": out_buffer.getvalue(), - } - } - ) - self.stdout.flush() # Sends it to the ATP client immediately. Needed so it can realize it's done. - read_thread.join() # Wait for the read thread to finish. - # Don't reset stdout/stderr until after the read thread is done. - sys.stdout = original_stdout - sys.stderr = original_stderr - except SystemExit: - return 1 return 0 - def run_server_read_loop( - self, - plugin_schema: schema.SchemaType, - step_id: str, - decoder: cbor2.decoder.CBORDecoder, - ) -> None: + def handle_handshake(self): + # Decode empty "start output" message. + self.decoder.decode() + + # Serialize then send HelloMessage + start_hello_message = HelloMessage(ATP_SERVER_VERSION, self.plugin_schema) + serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start_hello_message) + self.send_message(serialized_message) + + def run_server_read_loop(self) -> None: try: while True: # Decode the message - runtime_msg = decoder.decode() + runtime_msg = self.decoder.decode() msg_id = runtime_msg["id"] + run_id = runtime_msg["run_id"] # Validate if msg_id is None: self.stderr.write("Runtime message is missing the 'id' field.") return # Then take action - if msg_id == MessageType.SIGNAL.value: + if msg_id == MessageType.WORK_START: + work_start_msg = runtime_msg["data"] + try: + self.handle_work_start(run_id, work_start_msg) + except Exception as e: + self.send_error_message(run_id, True, False, + f"Exception while handling work start: {e} {traceback.format_exc()}") + elif msg_id == MessageType.SIGNAL: signal_msg = runtime_msg["data"] - received_step_id = signal_msg["step_id"] - received_signal_id = signal_msg["signal_id"] - if received_step_id != step_id: # Ensure they match. - self.stderr.write(f"Received step ID in the signal message '{received_step_id}'" - f"does not match expected step ID '{step_id}'") - return - unserialized_data = plugin_schema.unserialize_signal_handler_input( - received_step_id, - received_signal_id, - signal_msg["data"] - ) - # The data is verified and unserialized. Now call the signal. - plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data) - elif msg_id == MessageType.CLIENT_DONE.value: + try: + self.handle_signal(run_id, signal_msg) + except Exception as e: + self.send_error_message(run_id, False, False, + f"Exception while handling signal: {e} {traceback.format_exc()}") + elif msg_id == MessageType.CLIENT_DONE: return else: + self.send_error_message(run_id, False, False, f"Unknown runtime message ID: {msg_id}") self.stderr.write(f"Unknown kind of runtime message: {msg_id}") except cbor2.CBORDecodeError as err: self.stderr.write(f"Error while decoding CBOR: {err}") + self.send_error_message("", False, True, + f"Error occurred while decoding CBOR: {err} {traceback.format_exc()}") + except Exception as e: + self.send_error_message("", False, True, + f"Exception occurred in ATP server read loop: {e} {traceback.format_exc()}") + + def handle_signal(self, run_id, signal_msg): + saved_step_id = self.step_ids[run_id] + received_signal_id = signal_msg["signal_id"] + + unserialized_data = self.plugin_schema.unserialize_signal_handler_input( + saved_step_id, + received_signal_id, + signal_msg["data"] + ) + # The data is verified and unserialized. Now call the signal in its own thread. + run_thread = threading.Thread(target=self.run_signal, + args=(run_id, saved_step_id, received_signal_id, unserialized_data)) + self.running_threads.append(run_thread) + run_thread.start() + + def run_signal(self, run_id: str, step_id: str, signal_id: str, unserialized_input_param: any): + try: + self.plugin_schema.call_step_signal(run_id, step_id, signal_id, unserialized_input_param) + except Exception as e: + self.send_error_message(run_id, False, False, + f"Error while calling signal for step with run ID {run_id}: {e} " + f"{traceback.format_exc()}" + ) + + def handle_work_start(self, run_id: str, work_start_msg: typing.Dict[str, any]): + if work_start_msg is None: + self.send_error_message(run_id, True, False, + "Work start message is None.") + return + if "id" not in work_start_msg: + self.send_error_message(run_id, True, False, + "Work start message is missing the 'id' field.") + return + if "config" not in work_start_msg: + self.send_error_message(run_id, True, False, + "Work start message is missing the 'config' field.") + return + # Save for later + self.step_ids[run_id] = work_start_msg["id"] + + # Now run the step, so start in a new thread + run_thread = threading.Thread(target=self.start_step, args=(run_id, work_start_msg)) + self.running_threads.append(run_thread) # Save so that we can join with it at the end. + run_thread.start() + + def start_step(self, run_id: str, work_start_msg: any): + try: + output_id, output_data = self.plugin_schema.call_step( + run_id, + work_start_msg["id"], + self.plugin_schema.unserialize_step_input(work_start_msg["id"], work_start_msg["config"]) + ) + + # Send WorkDoneMessage + self.send_runtime_message( + MessageType.WORK_DONE, + run_id, + { + "output_id": output_id, + "output_data": self.plugin_schema.serialize_output( + work_start_msg["id"], output_id, output_data + ), + "debug_logs": self.user_out_buffer.getvalue(), + }, + ) + except Exception as e: + self.send_error_message(run_id, True, False, + f"Error while calling step {run_id}/{work_start_msg.get('id', 'missing')}:" + f"{e} {traceback.format_exc()}") + return + + def send_message(self, data: any): + with self.encoder_lock: + self.encoder.encode(data) + self.output_pipe.flush() # Sends it to the ATP client immediately. + + def send_runtime_message(self, message_type: MessageType, run_id: str, data: any): + self.send_message( + { + "id": message_type, + "run_id": run_id, + "data": data, + } + ) + + def send_error_message(self, run_id: str, step_fatal: bool, server_fatal: bool, error_msg: str): + self.send_runtime_message( + MessageType.ERROR, + run_id, + { + "error": error_msg, + "step_fatal": step_fatal, + "server_fatal": server_fatal, + }, + ) class PluginClientStateException(Exception): """ - This + This exception is for client ATP client errors, like problems decoding """ msg: str @@ -223,19 +295,19 @@ class PluginClient: must be executed in order. """ - stdin: io.FileIO - stdout: io.FileIO + to_server_pipe: io.FileIO # Usually the stdin of the sub-process + to_client_pipe: io.FileIO # Usually the stdout of the sub-process decoder: cbor2.decoder.CBORDecoder def __init__( self, - stdin: io.FileIO, - stdout: io.FileIO, + to_server_pipe: io.FileIO, + to_client_pipe: io.FileIO, ): - self.stdin = stdin - self.stdout = stdout - self.decoder = cbor2.decoder.CBORDecoder(stdout) - self.encoder = cbor2.encoder.CBOREncoder(stdin) + self.to_server_pipe = to_server_pipe + self.to_client_pipe = to_client_pipe + self.decoder = cbor2.decoder.CBORDecoder(to_client_pipe) + self.encoder = cbor2.encoder.CBOREncoder(to_server_pipe) def start_output(self) -> None: self.encoder.encode(None) @@ -247,49 +319,53 @@ def read_hello(self) -> HelloMessage: message = self.decoder.decode() return _HELLO_MESSAGE_SCHEMA.unserialize(message) - def start_work(self, step_id: str, input_data: any): + def start_work(self, run_id: str, step_id: str, config: any): """ After the Hello message has been read, this function starts work in a plugin with the specified data. """ - self.encoder.encode( + self.send_runtime_message( + MessageType.WORK_START, + run_id, { "id": step_id, - "config": input_data, + "config": config, } ) - self.stdin.flush() - def send_signal(self, step_id: str, signal_id: str, input_data: any): + def send_signal(self, run_id: str, signal_id: str, input_data: any): """ This function sends any signals to the plugin. """ - self.send_runtime_message(MessageType.SIGNAL, { - "step_id": step_id, + self.send_runtime_message( + MessageType.SIGNAL, + run_id, + { "signal_id": signal_id, "data": input_data, } ) def send_client_done(self): - self.send_runtime_message(MessageType.CLIENT_DONE, {}) + self.send_runtime_message(MessageType.CLIENT_DONE, "", "") - def send_runtime_message(self, message_type: MessageType, data: any): + def send_runtime_message(self, message_type: MessageType, run_id: str, data: any): self.encoder.encode( { - "id": message_type.value, + "id": message_type, + "run_id": run_id, "data": data, } ) - self.stdin.flush() + self.to_server_pipe.flush() - def read_results(self) -> (str, any, str): + def read_single_result(self) -> (str, str, any, str): """ - This function reads the signals and results of an execution from the plugin. + This function reads the next signal or result of an execution from the plugin. """ while True: runtime_msg = self.decoder.decode() msg_id = runtime_msg["id"] - if msg_id == MessageType.WORK_DONE.value: + if msg_id == MessageType.WORK_DONE: signal_msg = runtime_msg["data"] if signal_msg["output_id"] is None: raise PluginClientStateException( @@ -303,10 +379,14 @@ def read_results(self) -> (str, any, str): raise PluginClientStateException( "Missing 'output_data' in CBOR message. Possibly wrong order of calls?" ) - return signal_msg["output_id"], signal_msg["output_data"], signal_msg["debug_logs"] - elif msg_id == MessageType.SIGNAL.value: + return runtime_msg["run_id"], signal_msg["output_id"], signal_msg["output_data"], signal_msg["debug_logs"] + elif msg_id == MessageType.SIGNAL: # Do nothing. Should change in the future. continue + elif msg_id == MessageType.ERROR: + raise PluginClientStateException( + "Error received from ATP Server (plugin): " + str(runtime_msg['data']).replace('\\n', '\n') + ) else: raise PluginClientStateException( f"Received unknown runtime message ID {msg_id}" diff --git a/src/arcaflow_plugin_sdk/plugin.py b/src/arcaflow_plugin_sdk/plugin.py index 51cd41c..fb43843 100644 --- a/src/arcaflow_plugin_sdk/plugin.py +++ b/src/arcaflow_plugin_sdk/plugin.py @@ -448,7 +448,7 @@ def _execute_file( sys.stdout = out_buffer sys.stderr = out_buffer try: - output_id, output_data = s(step_id, data) + output_id, output_data = s("file_run", step_id, data) output = { "output_id": output_id, "output_data": output_data, diff --git a/src/arcaflow_plugin_sdk/schema.py b/src/arcaflow_plugin_sdk/schema.py index fdc43a1..8403268 100644 --- a/src/arcaflow_plugin_sdk/schema.py +++ b/src/arcaflow_plugin_sdk/schema.py @@ -5573,19 +5573,32 @@ def __init__( def __call__( self, - step_data: StepObjectT, - params: SignalDataT, + step_object_data: StepObjectT, + signal_input: SignalDataT, ): """ + :param step_data: The instantiated object that stores step run-specific data. :param params: Input data parameter for the signal handler. """ input: ScopeType = self.data_schema - input.validate(params, tuple(["input"])) - self._handler(step_data, params) + input.validate(signal_input, tuple(["input"])) + self._handler(step_object_data, signal_input) step_object_constructor_param = Callable[[], StepObjectT] +class _StepLocalData: + """ + Data associated with a single step, including the constructed object, + and the data needed to synchronize and notify steps of the step + being ready. + """ + initialized_object: StepObjectT + step_running: bool = False # So signals to wait if sent before the step. + step_running_condition: threading.Condition = threading.Condition() + + def __init__(self, initialized_object: StepObjectT): + self.initialized_object = initialized_object class StepType(StepSchema): """ @@ -5600,9 +5613,8 @@ class StepType(StepSchema): signal_handler_method_names: List[str] signal_handlers: Dict[ID_TYPE, SignalHandlerType] signal_emitters: Dict[ID_TYPE, SignalSchema] - initialized_object_data: StepObjectT - object_data_ready: bool = False - object_cv: threading.Condition = threading.Condition() + initialized_object_data: Dict[str, _StepLocalData] = {} # Maps run_id to data + initialization_lock: threading.Lock = threading.Lock() def __init__( self, @@ -5620,8 +5632,20 @@ def __init__( self._step_object_constructor = step_object_constructor self.signal_handler_method_names = signal_handler_method_names + def setup_run_data(self, run_id: str): + with self.initialization_lock: + if run_id in self.initialized_object_data: + return self.initialized_object_data[run_id] + if self._step_object_constructor is not None: + new_run_data = _StepLocalData(self._step_object_constructor()) + else: + new_run_data = _StepLocalData(None) + self.initialized_object_data[run_id] = new_run_data + return new_run_data + def __call__( self, + run_id: str, params: StepInputT, skip_input_validation: bool = False, skip_output_validation: bool = False, @@ -5633,17 +5657,18 @@ def __call__( :return: The ID for the output datatype, and the output itself. """ # Initialize the step object - if self._step_object_constructor is not None: - self.initialized_object_data = self._step_object_constructor() - else: - self.initialized_object_data = None - self.object_data_ready = True - with self.object_cv: - self.object_cv.notify_all() + step_local_data: _StepLocalData = self.setup_run_data(run_id) + # Notify potentially waiting signals that the step is running + # Ideally, this would be done after, but just before is the only realistic option without more threads or sleep + step_local_data.step_running = True + with step_local_data.step_running_condition: + step_local_data.step_running_condition.notify_all() input: ScopeType = self.input + # Validate input if not skip_input_validation: input.validate(params, tuple(["input"])) - result = self._handler(self.initialized_object_data, params) + # Run the step + result = self._handler(step_local_data.initialized_object, params) if len(result) != 2: raise BadArgumentException( "The step returned {} results instead of 2. Did your step return the correct results?".format( @@ -5687,9 +5712,10 @@ class SchemaType(Schema): steps: Dict[str, StepType] def get_step(self, step_id: str): - if step_id not in self.steps: + found_step = self.steps.get(step_id) + if found_step is None: raise NoSuchStepException(step_id) - return self.steps[step_id] + return found_step def get_signal(self, step_id: str, signal_id: str): step = self.get_step(step_id) @@ -5733,43 +5759,49 @@ def _unserialize_signal_handler_input(signal: SignalHandlerType, data: Any) -> A except ConstraintException as e: raise InvalidInputException(e) from e - def call_step(self, step_id: str, input_param: Any) -> typing.Tuple[str, Any]: + def call_step(self, run_id: str, step_id: str, input_param: Any) -> typing.Tuple[str, Any]: """ This function calls a specific step with the input parameter that has already been unserialized. It expects the data to be already valid, use unserialize_step_input to produce a valid input. This function is automatically called by ``__call__`` after unserializing the input. + :param run_id: A unique ID for the run. :param step_id: The ID of the input step to run. :param input_param: The unserialized data structure the step expects. :return: The ID of the output, and the data structure returned from the step. """ - return self._call_step(self.get_step(step_id), input_param) + return self._call_step(self.get_step(step_id), run_id, input_param) - def call_step_signal(self, step_id: str, signal_id: str, unserialized_input_param: Any): + def call_step_signal(self, run_id: str, step_id: str, signal_id: str, unserialized_input_param: Any): """ This function calls a specific step's signal with the input parameter that has already been unserialized. It expects the data to be already valid, use unserialize_signal_input to produce a valid input. + :param run_id: A unique ID for the run, which must match signals associated with this step execution. :param step_id: The ID of the input step to run. + :param signal_id: The signal ID as defined by the plugin. :param unserialized_input_param: The unserialized data structure the step expects. :return: The ID of the output, and the data structure returned from the step. """ step = self.get_step(step_id) signal = self.get_signal(step_id, signal_id) - if not step.object_data_ready: - with step.object_cv: + local_step_data: _StepLocalData = step.setup_run_data(run_id) + with local_step_data.step_running_condition: + if not local_step_data.step_running: # wait to be notified of it being ready. Test this by adding a sleep before the step call. - step.object_cv.wait() - return signal(step.initialized_object_data, unserialized_input_param) + local_step_data.step_running_condition.wait() + return signal(local_step_data.initialized_object, unserialized_input_param) @staticmethod def _call_step( step: StepType, + run_id: str, unserialized_input_param: Any, skip_input_validation: bool = False, skip_output_validation: bool = False, ) -> typing.Tuple[str, Any]: return step( + run_id, unserialized_input_param, skip_input_validation=skip_input_validation, skip_output_validation=skip_output_validation, @@ -5795,7 +5827,7 @@ def _serialize_output(step, output_id: str, output_data: Any) -> Any: raise InvalidOutputException(e) from e def __call__( - self, step_id: str, data: Any, skip_serialization: bool = False + self, run_id: str, step_id: str, data: Any, skip_serialization: bool = False ) -> typing.Tuple[str, Any]: """ This function takes the input data, unserializes it for the specified step, calls the specified step, and, @@ -5810,6 +5842,7 @@ def __call__( input_param = self._unserialize_step_input(step, data) output_id, output_data = self._call_step( step, + run_id, input_param, # Skip duplicate verification skip_input_validation=True, diff --git a/src/arcaflow_plugin_sdk/test_atp.py b/src/arcaflow_plugin_sdk/test_atp.py index 9ebff2e..57a9843 100644 --- a/src/arcaflow_plugin_sdk/test_atp.py +++ b/src/arcaflow_plugin_sdk/test_atp.py @@ -30,6 +30,16 @@ def hello_world(params: Input) -> Tuple[str, Union[Output]]: return "success", Output("Hello, {}!".format(params.name)) +@plugin.step( + id="hello-world-broken", + name="Broken!", + description="Throws an exception with the text 'abcde'", + outputs={"success": Output}, +) +def hello_world_broken(_: Input) -> Tuple[str, Union[Output]]: + print("Hello world!") + raise Exception("abcde") + @dataclasses.dataclass class StepTestInput: wait_time_seconds: float @@ -47,8 +57,14 @@ class SignalTestOutput: class SignalTestStep: - signal_values: List[int] = [] - exit_event = Event() + signal_values: List[int] + exit_event: Event + + def __init__(self): + # Due to the way Python works, this MUST be done here, and not inlined above, or else it will be + # shared by all objects, resulting in a shared list and event, which would cause problems. + self.signal_values = [] + self.exit_event = Event() @plugin.step_with_signals( id="signal_test_step", @@ -66,15 +82,19 @@ def signal_test_step(self, params: StepTestInput) -> Tuple[str, Union[SignalTest @plugin.signal_handler( id="record_value", name="record value", - description="Records the value, and optionally ends the step.", + description="Records the value, and optionally ends the step. Throws error if it's less than 0, for testing.", ) def signal_test_signal_handler(self, signal_input: SignalTestInput): + if signal_input.value < 0: + self.exit_event.set() + raise Exception("Value below zero.") self.signal_values.append(signal_input.value) if signal_input.final: self.exit_event.set() test_schema = plugin.build_schema(hello_world) +test_broken_schema = plugin.build_schema(hello_world_broken) test_signals_schema = plugin.build_schema(SignalTestStep.signal_test_step) @@ -123,54 +143,56 @@ def _cleanup(self, pid, stdin_writer, stdout_reader): if exit_status != 0: self.fail("Plugin exited with non-zero status: {}".format(exit_status)) - def test_full_simple_workflow(self): + def test_step_simple(self): pid, stdin_writer, stdout_reader = self._execute_plugin(test_schema) try: client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) client.start_output() hello_message = client.read_hello() - self.assertEqual(2, hello_message.version) + self.assertEqual(3, hello_message.version) self.assertEqual( schema.SCHEMA_SCHEMA.serialize(test_schema), schema.SCHEMA_SCHEMA.serialize(hello_message.schema), ) - client.start_work("hello-world", {"name": "Arca Lot"}) + client.start_work(self.id(), "hello-world", {"name": "Arca Lot"}) - output_id, output_data, debug_logs = client.read_results() + run_id, output_id, output_data, debug_logs = client.read_single_result() + self.assertEqual(run_id, self.id()) client.send_client_done() self.assertEqual(output_id, "success") self.assertEqual("Hello world!\n", debug_logs) finally: self._cleanup(pid, stdin_writer, stdout_reader) - def test_full_workflow_with_signals(self): + def test_step_with_signals(self): pid, stdin_writer, stdout_reader = self._execute_plugin(test_signals_schema) try: client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) client.start_output() hello_message = client.read_hello() - self.assertEqual(2, hello_message.version) + self.assertEqual(3, hello_message.version) self.assertEqual( schema.SCHEMA_SCHEMA.serialize(test_signals_schema), schema.SCHEMA_SCHEMA.serialize(hello_message.schema), ) - client.start_work("signal_test_step", {"wait_time_seconds": "5"}) - client.send_signal("signal_test_step", "record_value", + client.start_work(self.id(), "signal_test_step", {"wait_time_seconds": "5"}) + client.send_signal(self.id(), "record_value", {"final": "false", "value": "1"}, ) - client.send_signal("signal_test_step", "record_value", + client.send_signal(self.id(), "record_value", {"final": "false", "value": "2"}, ) - client.send_signal("signal_test_step", "record_value", + client.send_signal(self.id(), "record_value", {"final": "true", "value": "3"}, ) - output_id, output_data, debug_logs = client.read_results() + run_id, output_id, output_data, debug_logs = client.read_single_result() + self.assertEqual(run_id, self.id()) client.send_client_done() self.assertEqual(debug_logs, "") self.assertEqual(output_id, "success") @@ -178,6 +200,148 @@ def test_full_workflow_with_signals(self): finally: self._cleanup(pid, stdin_writer, stdout_reader) + def test_multi_step_with_signals(self): + """ + Starts two steps simultaneously, sends them separate data from signals, then verifies + that each step got the dats intended for it. + """ + pid, stdin_writer, stdout_reader = self._execute_plugin(test_signals_schema) + + try: + client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) + client.start_output() + hello_message = client.read_hello() + self.assertEqual(3, hello_message.version) + + self.assertEqual( + schema.SCHEMA_SCHEMA.serialize(test_signals_schema), + schema.SCHEMA_SCHEMA.serialize(hello_message.schema), + ) + step_a_id = self.id() + "_a" + step_b_id = self.id() + "_b" + + client.start_work(step_a_id, "signal_test_step", {"wait_time_seconds": "5"}) + client.start_work(step_b_id, "signal_test_step", {"wait_time_seconds": "5"}) + client.send_signal(step_a_id, "record_value", + {"final": "false", "value": "1"}, + ) + client.send_signal(step_b_id, "record_value", + {"final": "true", "value": "2"}, + ) + b_run_id, b_output_id, b_output_data, b_debug_logs = client.read_single_result() + + client.send_signal(step_a_id, "record_value", + {"final": "true", "value": "3"}, + ) + a_run_id, a_output_id, a_output_data, a_debug_logs = client.read_single_result() + client.send_client_done() + self.assertEqual(a_run_id, step_a_id, "Expected 'a' run ID") + self.assertEqual(b_run_id, step_b_id, "Expected 'b' run ID") + self.assertEqual(b_debug_logs, "") + self.assertEqual(b_debug_logs, "") + self.assertEqual(a_output_id, "success") + self.assertEqual(b_output_id, "success") + self.assertListEqual(a_output_data["signals_received"], [1, 3]) + self.assertListEqual(b_output_data["signals_received"], [2]) + finally: + self._cleanup(pid, stdin_writer, stdout_reader) + + def test_broken_step(self): + """ + Runs a step that throws an exception, which is something that should be caught by the plugin, but + we need to test for it since the uncaught exceptions are the hardest to debug without proper handling. + """ + pid, stdin_writer, stdout_reader = self._execute_plugin(test_broken_schema) + + try: + client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) + client.start_output() + client.read_hello() + + client.start_work(self.id(), "hello-world-broken", {"name": "Arca Lot"}) + + with self.assertRaises(atp.PluginClientStateException) as context: + _, _, _, _ = client.read_single_result() + client.send_client_done() + self.assertIn("abcde", str(context.exception)) + finally: + self._cleanup(pid, stdin_writer, stdout_reader) + + def test_wrong_step(self): + """ + Tests the error reporting due to an invalid step being called. + """ + pid, stdin_writer, stdout_reader = self._execute_plugin(test_schema) + + try: + client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) + client.start_output() + client.read_hello() + + client.start_work(self.id(), "WRONG", {"name": "Arca Lot"}) + + with self.assertRaises(atp.PluginClientStateException) as context: + _, _, _, _ = client.read_single_result() + client.send_client_done() + self.assertIn("No such step: WRONG", str(context.exception)) + finally: + self._cleanup(pid, stdin_writer, stdout_reader) + + def test_invalid_runtime_message_id(self): + """ + Tests the error reporting due to an invalid step being called. + """ + pid, stdin_writer, stdout_reader = self._execute_plugin(test_schema) + + try: + client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) + client.start_output() + client.read_hello() + + client.send_runtime_message(1000, "", "") + + with self.assertRaises(atp.PluginClientStateException) as context: + _, _, _, _ = client.read_single_result() + client.send_client_done() + self.assertIn("Unknown runtime message ID: 1000", str(context.exception)) + finally: + self._cleanup(pid, stdin_writer, stdout_reader) + + def test_error_in_signal(self): + pid, stdin_writer, stdout_reader = self._execute_plugin(test_signals_schema) + + try: + client = atp.PluginClient(stdin_writer.buffer.raw, stdout_reader.buffer.raw) + client.start_output() + hello_message = client.read_hello() + self.assertEqual(3, hello_message.version) + + self.assertEqual( + schema.SCHEMA_SCHEMA.serialize(test_signals_schema), + schema.SCHEMA_SCHEMA.serialize(hello_message.schema), + ) + + client.start_work(self.id(), "signal_test_step", {"wait_time_seconds": "5"}) + client.send_signal(self.id(), "record_value", + {"final": "false", "value": "1"}, + ) + client.send_signal(self.id(), "record_value", + {"final": "false", "value": "-1"}, + ) + run_id, output_id, output_data, debug_logs = client.read_single_result() + self.assertEqual(run_id, self.id()) + self.assertEqual(debug_logs, "") + self.assertEqual(output_id, "success") + self.assertListEqual(output_data["signals_received"], [1]) + + # Note: The exception is raised after the step finishes in the test class + with self.assertRaises(atp.PluginClientStateException) as context: + _, _, _, _ = client.read_single_result() + client.send_client_done() + self.assertIn("Value below zero.", str(context.exception)) + + finally: + self._cleanup(pid, stdin_writer, stdout_reader) if __name__ == "__main__": unittest.main() diff --git a/test_example_plugin.py b/test_example_plugin.py index 82c45c5..e01ab63 100755 --- a/test_example_plugin.py +++ b/test_example_plugin.py @@ -21,9 +21,11 @@ def test_serialization(): ) def test_functional(self): - input = example_plugin.InputParams(name=example_plugin.FullName("Arca", "Lot")) + step_input = example_plugin.InputParams(name=example_plugin.FullName("Arca", "Lot")) - output_id, output_data = example_plugin.hello_world(input) + # Note: The call to hello_world is to the output of the decorator, not the function itself. + # So it's calling the StepType + output_id, output_data = example_plugin.hello_world(self.id(), step_input) # The example plugin always returns an error: self.assertEqual("success", output_id)