Skip to content

Commit

Permalink
Reduce redundancy, and fix wrong var passed
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Aug 30, 2023
1 parent 7ab7b20 commit 68d0ecd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def run_server_read_loop(
f"does not match expected step ID '{step_id}'")
return
unserialized_data = plugin_schema.unserialize_signal_handler_input(
step_id,
received_step_id,
received_signal_id,
signal_msg["data"]
)
Expand Down
48 changes: 19 additions & 29 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5681,6 +5681,17 @@ class SchemaType(Schema):

steps: Dict[str, StepType]

def get_step(self, step_id: str):
if step_id not in self.steps:
raise NoSuchStepException(step_id)
return self.steps[step_id]

def get_signal(self, step_id: str, signal_id: str):
step = self.get_step(step_id)
if signal_id not in step.signal_handlers:
raise NoSuchSignalException(step_id, signal_id)
return step.signal_handlers[signal_id]

def unserialize_step_input(self, step_id: str, serialized_data: Any) -> Any:
"""
This function unserializes the input from a raw data to data structures, such as dataclasses. This function is
Expand All @@ -5690,10 +5701,7 @@ def unserialize_step_input(self, step_id: str, serialized_data: Any) -> Any:
:param serialized_data: The raw data to unserialize.
:return: The unserialized data in the structure the step expects it.
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
return self._unserialize_step_input(step, serialized_data)
return self._unserialize_step_input(self.get_step(step_id), serialized_data)

@staticmethod
def _unserialize_step_input(step: StepType, serialized_data: Any) -> Any:
Expand All @@ -5711,12 +5719,7 @@ def unserialize_signal_handler_input(self, step_id: str, signal_id: str, seriali
:param serialized_data: The raw data to unserialize.
:return: The unserialized data in the structure the step expects it.
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
if signal_id not in step.signal_handlers:
raise NoSuchSignalException(step_id, signal_id)
return self._unserialize_signal_handler_input(step, serialized_data)
return self._unserialize_signal_handler_input(self.get_signal(step_id, signal_id), serialized_data)

@staticmethod
def _unserialize_signal_handler_input(signal: SignalHandlerType, data: Any) -> Any:
Expand All @@ -5735,10 +5738,7 @@ def call_step(self, step_id: str, input_param: Any) -> typing.Tuple[str, Any]:
:param input_param: The unserialized data structure the step expects.
:return: The ID of the output, and the data structure returned from the step.
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
return self._call_step(step, input_param)
return self._call_step(self.get_step(step_id), input_param)

def call_step_signal(self, step_id: str, signal_id: str, unserialized_input_param: Any):
"""
Expand All @@ -5749,13 +5749,8 @@ def call_step_signal(self, step_id: str, signal_id: str, unserialized_input_para
:param unserialized_input_param: The unserialized data structure the step expects.
:return: The ID of the output, and the data structure returned from the step.
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]

if signal_id not in step.signal_handlers:
raise NoSuchSignalException(step_id, signal_id)
signal = step.signal_handlers[signal_id]
step = self.get_step(step_id)
signal = self.get_signal(step_id, signal_id)
return signal(step.initializedObjectData, unserialized_input_param)

@staticmethod
Expand All @@ -5781,10 +5776,7 @@ def serialize_output(self, step_id: str, output_id: str, output_data: Any) -> An
:param output_data: The data structure returned from the step.
:return:
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
return self._serialize_output(step, output_id, output_data)
return self._serialize_output(self.get_step(step_id), output_id, output_data)

@staticmethod
def _serialize_output(step, output_id: str, output_data: Any) -> Any:
Expand All @@ -5805,10 +5797,8 @@ def __call__(
:param skip_serialization: skip result serialization to basic types
:return: the result ID, and the resulting data in the structure matching the result ID
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
input_param = self.unserialize_step_input(step, data)
step = self.get_step(step_id)
input_param = self._unserialize_step_input(step, data)
output_id, output_data = self._call_step(
step,
input_param,
Expand Down

0 comments on commit 68d0ecd

Please sign in to comment.