Skip to content

Commit

Permalink
Detect invalid step return (#134)
Browse files Browse the repository at this point in the history
* Detect invalid step return

* Change import format

* Fix linting error

* Fix linting error

* Addressed review comments

* Addressed review comments

* Addressed review comments
  • Loading branch information
jaredoconnell authored Jun 25, 2024
1 parent 54cf20e commit 9a23deb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
11 changes: 9 additions & 2 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5989,10 +5989,17 @@ def __call__(
input.validate(params, tuple(["input"]))
# Run the step
result = self._handler(step_local_data.initialized_object, params)
if not isinstance(result, tuple):
raise BadArgumentException(
f"The implementation of step {run_id}/{self.id} returned"
f" type {type(result)}; expected a tuple with two"
" values: output ID string and a step-specific value."
f"\nValue returned: {result}"
)
if len(result) != 2:
raise BadArgumentException(
"The step returned {} results instead of 2. Did your step"
" return the correct results?".format(len(result))
f"The implementation of step {run_id}/{self.id} returned"
f"{len(result)} results instead of 2. Got {result}."
)
output_id, output_data = result
if output_id not in self.outputs:
Expand Down
35 changes: 28 additions & 7 deletions src/arcaflow_plugin_sdk/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@
import typing
import unittest

from arcaflow_plugin_sdk import plugin
from arcaflow_plugin_sdk import plugin, schema


@dataclasses.dataclass
class StdoutTestInput:
class EmptyTestInput:
pass


@dataclasses.dataclass
class StdoutTestOutput:
class EmptyTestOutput:
pass


@plugin.step(
"stdout-test",
"Stdout test",
"A test for writing to stdout.",
{"success": StdoutTestOutput},
{"success": EmptyTestOutput},
)
def stdout_test_step(
input: StdoutTestInput,
) -> typing.Tuple[str, StdoutTestOutput]:
_: EmptyTestInput,
) -> typing.Tuple[str, EmptyTestOutput]:
print("Hello world!")
return "success", StdoutTestOutput()
return "success", EmptyTestOutput()


class StdoutTest(unittest.TestCase):
Expand All @@ -52,5 +52,26 @@ def cleanup():
self.assertEqual("Hello world!\n", e.getvalue())


@plugin.step(
"incorrect-return",
"Incorrect Return",
"A step that returns a bad output which omits the output ID.",
{"success": EmptyTestOutput},
)
def incorrect_return_step(
_: EmptyTestInput,
) -> typing.Tuple[str, EmptyTestOutput]:
# noinspection PyTypeChecker
return EmptyTestOutput()


class CallStepTest(unittest.TestCase):
def test_incorrect_return_args_count(self):
s = plugin.build_schema(incorrect_return_step)

with self.assertRaises(schema.BadArgumentException):
s.call_step(self.id(), "incorrect-return", EmptyTestInput())


if __name__ == "__main__":
unittest.main()

0 comments on commit 9a23deb

Please sign in to comment.