Skip to content

Commit

Permalink
Detect invalid step return
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Jun 24, 2024
1 parent 54cf20e commit 74ec642
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
5 changes: 5 additions & 0 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5989,6 +5989,11 @@ def __call__(
input.validate(params, tuple(["input"]))
# Run the step
result = self._handler(step_local_data.initialized_object, params)
if not isinstance(result, tuple):
raise BadArgumentException(
"The step returned type {};".format(type(result))
+ " expected a tuple with two values: output ID string and a step-specific value."
)
if len(result) != 2:
raise BadArgumentException(
"The step returned {} results instead of 2. Did your step"
Expand Down
33 changes: 27 additions & 6 deletions src/arcaflow_plugin_sdk/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,30 @@
import unittest

from arcaflow_plugin_sdk import plugin
from arcaflow_plugin_sdk import schema


@dataclasses.dataclass
class StdoutTestInput:
class EmptyTestInput:
pass


@dataclasses.dataclass
class StdoutTestOutput:
class EmptyTestOutput:
pass


@plugin.step(
"stdout-test",
"Stdout test",
"A test for writing to stdout.",
{"success": StdoutTestOutput},
{"success": EmptyTestOutput},
)
def stdout_test_step(
input: StdoutTestInput,
) -> typing.Tuple[str, StdoutTestOutput]:
_: EmptyTestInput,
) -> typing.Tuple[str, EmptyTestOutput]:
print("Hello world!")
return "success", StdoutTestOutput()
return "success", EmptyTestOutput()


class StdoutTest(unittest.TestCase):
Expand All @@ -52,5 +53,25 @@ def cleanup():
self.assertEqual("Hello world!\n", e.getvalue())


@plugin.step(
"incorrect-return",
"Incorrect Return",
"A test that doesn't include the output ID.",
{"success": EmptyTestOutput},
)
def incorrect_return_step(
_: EmptyTestInput,
): # Skip return type, since we're purposefully not doing it right.
return EmptyTestOutput()


class CallStepTest(unittest.TestCase):
def test_incorrect_return_args_count(self):
s = plugin.build_schema(incorrect_return_step)

with self.assertRaises(schema.BadArgumentException):
s.call_step(self.id(), "incorrect-return", EmptyTestInput())


if __name__ == "__main__":
unittest.main()

0 comments on commit 74ec642

Please sign in to comment.