From 9a23deb7272de6e7b901b183bec57c83764bd27b Mon Sep 17 00:00:00 2001 From: Jared O'Connell <46976761+jaredoconnell@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:26:05 -0400 Subject: [PATCH] Detect invalid step return (#134) * Detect invalid step return * Change import format * Fix linting error * Fix linting error * Addressed review comments * Addressed review comments * Addressed review comments --- src/arcaflow_plugin_sdk/schema.py | 11 ++++++-- src/arcaflow_plugin_sdk/test_plugin.py | 35 ++++++++++++++++++++------ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/arcaflow_plugin_sdk/schema.py b/src/arcaflow_plugin_sdk/schema.py index e6e2831..c12382b 100644 --- a/src/arcaflow_plugin_sdk/schema.py +++ b/src/arcaflow_plugin_sdk/schema.py @@ -5989,10 +5989,17 @@ def __call__( input.validate(params, tuple(["input"])) # Run the step result = self._handler(step_local_data.initialized_object, params) + if not isinstance(result, tuple): + raise BadArgumentException( + f"The implementation of step {run_id}/{self.id} returned" + f" type {type(result)}; expected a tuple with two" + " values: output ID string and a step-specific value." + f"\nValue returned: {result}" + ) if len(result) != 2: raise BadArgumentException( - "The step returned {} results instead of 2. Did your step" - " return the correct results?".format(len(result)) + f"The implementation of step {run_id}/{self.id} returned" + f"{len(result)} results instead of 2. Got {result}." ) output_id, output_data = result if output_id not in self.outputs: diff --git a/src/arcaflow_plugin_sdk/test_plugin.py b/src/arcaflow_plugin_sdk/test_plugin.py index 3ca1896..bb6f7d2 100644 --- a/src/arcaflow_plugin_sdk/test_plugin.py +++ b/src/arcaflow_plugin_sdk/test_plugin.py @@ -4,16 +4,16 @@ import typing import unittest -from arcaflow_plugin_sdk import plugin +from arcaflow_plugin_sdk import plugin, schema @dataclasses.dataclass -class StdoutTestInput: +class EmptyTestInput: pass @dataclasses.dataclass -class StdoutTestOutput: +class EmptyTestOutput: pass @@ -21,13 +21,13 @@ class StdoutTestOutput: "stdout-test", "Stdout test", "A test for writing to stdout.", - {"success": StdoutTestOutput}, + {"success": EmptyTestOutput}, ) def stdout_test_step( - input: StdoutTestInput, -) -> typing.Tuple[str, StdoutTestOutput]: + _: EmptyTestInput, +) -> typing.Tuple[str, EmptyTestOutput]: print("Hello world!") - return "success", StdoutTestOutput() + return "success", EmptyTestOutput() class StdoutTest(unittest.TestCase): @@ -52,5 +52,26 @@ def cleanup(): self.assertEqual("Hello world!\n", e.getvalue()) +@plugin.step( + "incorrect-return", + "Incorrect Return", + "A step that returns a bad output which omits the output ID.", + {"success": EmptyTestOutput}, +) +def incorrect_return_step( + _: EmptyTestInput, +) -> typing.Tuple[str, EmptyTestOutput]: + # noinspection PyTypeChecker + return EmptyTestOutput() + + +class CallStepTest(unittest.TestCase): + def test_incorrect_return_args_count(self): + s = plugin.build_schema(incorrect_return_step) + + with self.assertRaises(schema.BadArgumentException): + s.call_step(self.id(), "incorrect-return", EmptyTestInput()) + + if __name__ == "__main__": unittest.main()