From 68d0ecda1b150abec3269cc011cf00c12bd9a6df Mon Sep 17 00:00:00 2001 From: jaredoconnell Date: Wed, 30 Aug 2023 19:49:36 -0400 Subject: [PATCH] Reduce redundancy, and fix wrong var passed --- src/arcaflow_plugin_sdk/atp.py | 2 +- src/arcaflow_plugin_sdk/schema.py | 48 ++++++++++++------------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/arcaflow_plugin_sdk/atp.py b/src/arcaflow_plugin_sdk/atp.py index 55ab70e..b49de47 100644 --- a/src/arcaflow_plugin_sdk/atp.py +++ b/src/arcaflow_plugin_sdk/atp.py @@ -174,7 +174,7 @@ def run_server_read_loop( f"does not match expected step ID '{step_id}'") return unserialized_data = plugin_schema.unserialize_signal_handler_input( - step_id, + received_step_id, received_signal_id, signal_msg["data"] ) diff --git a/src/arcaflow_plugin_sdk/schema.py b/src/arcaflow_plugin_sdk/schema.py index 8b89fb5..4ecfb55 100644 --- a/src/arcaflow_plugin_sdk/schema.py +++ b/src/arcaflow_plugin_sdk/schema.py @@ -5681,6 +5681,17 @@ class SchemaType(Schema): steps: Dict[str, StepType] + def get_step(self, step_id: str): + if step_id not in self.steps: + raise NoSuchStepException(step_id) + return self.steps[step_id] + + def get_signal(self, step_id: str, signal_id: str): + step = self.get_step(step_id) + if signal_id not in step.signal_handlers: + raise NoSuchSignalException(step_id, signal_id) + return step.signal_handlers[signal_id] + def unserialize_step_input(self, step_id: str, serialized_data: Any) -> Any: """ This function unserializes the input from a raw data to data structures, such as dataclasses. This function is @@ -5690,10 +5701,7 @@ def unserialize_step_input(self, step_id: str, serialized_data: Any) -> Any: :param serialized_data: The raw data to unserialize. :return: The unserialized data in the structure the step expects it. """ - if step_id not in self.steps: - raise NoSuchStepException(step_id) - step = self.steps[step_id] - return self._unserialize_step_input(step, serialized_data) + return self._unserialize_step_input(self.get_step(step_id), serialized_data) @staticmethod def _unserialize_step_input(step: StepType, serialized_data: Any) -> Any: @@ -5711,12 +5719,7 @@ def unserialize_signal_handler_input(self, step_id: str, signal_id: str, seriali :param serialized_data: The raw data to unserialize. :return: The unserialized data in the structure the step expects it. """ - if step_id not in self.steps: - raise NoSuchStepException(step_id) - step = self.steps[step_id] - if signal_id not in step.signal_handlers: - raise NoSuchSignalException(step_id, signal_id) - return self._unserialize_signal_handler_input(step, serialized_data) + return self._unserialize_signal_handler_input(self.get_signal(step_id, signal_id), serialized_data) @staticmethod def _unserialize_signal_handler_input(signal: SignalHandlerType, data: Any) -> Any: @@ -5735,10 +5738,7 @@ def call_step(self, step_id: str, input_param: Any) -> typing.Tuple[str, Any]: :param input_param: The unserialized data structure the step expects. :return: The ID of the output, and the data structure returned from the step. """ - if step_id not in self.steps: - raise NoSuchStepException(step_id) - step = self.steps[step_id] - return self._call_step(step, input_param) + return self._call_step(self.get_step(step_id), input_param) def call_step_signal(self, step_id: str, signal_id: str, unserialized_input_param: Any): """ @@ -5749,13 +5749,8 @@ def call_step_signal(self, step_id: str, signal_id: str, unserialized_input_para :param unserialized_input_param: The unserialized data structure the step expects. :return: The ID of the output, and the data structure returned from the step. """ - if step_id not in self.steps: - raise NoSuchStepException(step_id) - step = self.steps[step_id] - - if signal_id not in step.signal_handlers: - raise NoSuchSignalException(step_id, signal_id) - signal = step.signal_handlers[signal_id] + step = self.get_step(step_id) + signal = self.get_signal(step_id, signal_id) return signal(step.initializedObjectData, unserialized_input_param) @staticmethod @@ -5781,10 +5776,7 @@ def serialize_output(self, step_id: str, output_id: str, output_data: Any) -> An :param output_data: The data structure returned from the step. :return: """ - if step_id not in self.steps: - raise NoSuchStepException(step_id) - step = self.steps[step_id] - return self._serialize_output(step, output_id, output_data) + return self._serialize_output(self.get_step(step_id), output_id, output_data) @staticmethod def _serialize_output(step, output_id: str, output_data: Any) -> Any: @@ -5805,10 +5797,8 @@ def __call__( :param skip_serialization: skip result serialization to basic types :return: the result ID, and the resulting data in the structure matching the result ID """ - if step_id not in self.steps: - raise NoSuchStepException(step_id) - step = self.steps[step_id] - input_param = self.unserialize_step_input(step, data) + step = self.get_step(step_id) + input_param = self._unserialize_step_input(step, data) output_id, output_data = self._call_step( step, input_param,