Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Signals and ATP v2 #98

Merged
merged 36 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9d1d770
Progress towards signals in SDK
jaredoconnell Aug 23, 2023
f952174
Added signals to ATP server
jaredoconnell Aug 25, 2023
2e07076
Merge branch 'main' into signals
jaredoconnell Aug 25, 2023
82ab335
Add ATP v2 support to client
jaredoconnell Aug 25, 2023
00d0a65
Fix linting errors
jaredoconnell Aug 28, 2023
be01c85
Fix linting error
jaredoconnell Aug 28, 2023
1849d09
Ignore linting issue
jaredoconnell Aug 28, 2023
4b2561a
Fix linting error
jaredoconnell Aug 28, 2023
13cc3b3
Fix errors introduced while trying to fix linting errors
jaredoconnell Aug 28, 2023
317986f
Fix missing params in test
jaredoconnell Aug 28, 2023
c7f0155
Update dependencies
jaredoconnell Aug 28, 2023
b692b72
Update pyyaml
jaredoconnell Aug 28, 2023
7b86c6b
Fixed typo
jaredoconnell Aug 28, 2023
91226f9
Fix predefined schema
jaredoconnell Aug 28, 2023
0edc6e5
Bypass class limitations by passing method names instead
jaredoconnell Aug 29, 2023
f599ad3
Delay signal retrieval
jaredoconnell Aug 30, 2023
a14a04b
Added missing field, and added extra None check
jaredoconnell Aug 30, 2023
ca012ed
Fix parameter passed into test case
jaredoconnell Aug 30, 2023
21e59bf
Fix missing input to server read loop
jaredoconnell Aug 30, 2023
9113031
Fix wrong signal ID
jaredoconnell Aug 30, 2023
28d95e2
Fix extra parameter passed into function
jaredoconnell Aug 30, 2023
a20b61e
Fix missing parameter passed into function
jaredoconnell Aug 30, 2023
7169c78
Fix missing deserialization step
jaredoconnell Aug 30, 2023
7ab7b20
Fix missed function rename, and fix linting err
jaredoconnell Aug 30, 2023
68d0ecd
Reduce redundancy, and fix wrong var passed
jaredoconnell Aug 30, 2023
cce45b7
Added client done message, and signal atp tests
jaredoconnell Sep 12, 2023
7cf6628
Fix linting errors
jaredoconnell Sep 12, 2023
b68efc1
Remove join and add flush
jaredoconnell Sep 15, 2023
b13a14b
Change when read thread is launched
jaredoconnell Sep 15, 2023
df69b67
Ignore sigint and manage stdout correctly
jaredoconnell Sep 15, 2023
cc2960e
Fix linting errors, and added comments
jaredoconnell Sep 18, 2023
c7d9658
Fix ordering problem with steps and signals
jaredoconnell Sep 19, 2023
c79cef8
Added coverage config file
jaredoconnell Sep 19, 2023
3f467d4
Remove unused import
jaredoconnell Sep 19, 2023
ced8b2b
Removed print statements, and added fail when debug logs aren't empty
jaredoconnell Sep 19, 2023
6a47aae
Remove coverage config
jaredoconnell Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix missing deserialization step
  • Loading branch information
jaredoconnell committed Aug 30, 2023
commit 7169c78962b050b58879bb2b6f861c4c70a2a8a6
21 changes: 13 additions & 8 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(

def run_plugin(
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
self,
step_schema: schema.SchemaType,
plugin_schema: schema.SchemaType,
) -> int:
"""
This function wraps running a plugin.
Expand All @@ -91,7 +91,7 @@ def run_plugin(
decoder.decode()

# Serialize then send HelloMessage
start_hello_message = HelloMessage(2, step_schema)
start_hello_message = HelloMessage(2, plugin_schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2 here looks like a "magic number"...shouldn't that be a symbolic reference to an externally-defined constant? Or, perhaps, the whole HelloMessage(2, plugin_schema) expression should be encapsulated in a reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should create a version constant.

serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start_hello_message)
encoder.encode(serialized_message)
self.stdout.flush()
Expand All @@ -112,7 +112,7 @@ def run_plugin(
return 1
# Run the read loop
read_thread = threading.Thread(target=self.run_server_read_loop, args=(
step_schema, #
plugin_schema, # Plugin schema
work_start_msg["id"], # step ID
decoder, # Decoder
self.stderr, # Stderr
Expand All @@ -124,8 +124,8 @@ def run_plugin(
out_buffer = io.StringIO()
sys.stdout = out_buffer
sys.stderr = out_buffer
output_id, output_data = step_schema.call_step(
work_start_msg["id"], step_schema.unserialize_input(work_start_msg["id"], work_start_msg["config"])
output_id, output_data = plugin_schema.call_step(
work_start_msg["id"], plugin_schema.unserialize_input(work_start_msg["id"], work_start_msg["config"])
)
sys.stdout = original_stdout
sys.stderr = original_stderr
Expand All @@ -136,7 +136,7 @@ def run_plugin(
"id": MessageType.WORKDONE.value,
"data": {
"output_id": output_id,
"output_data": step_schema.serialize_output(
"output_data": plugin_schema.serialize_output(
work_start_msg["id"], output_id, output_data
),
"debug_logs": out_buffer.getvalue(),
Expand All @@ -151,7 +151,7 @@ def run_plugin(

def run_server_read_loop(
self,
step: schema.SchemaType,
plugin_schema: schema.SchemaType,
step_id: str,
decoder: cbor2.decoder.CBORDecoder,
stderr: io.FileIO
Expand All @@ -173,7 +173,12 @@ def run_server_read_loop(
stderr.write(f"Received step ID in the signal message '{received_step_id}'"
f"does not match expected step ID '{step_id}'")
return
step.call_step_signal(step_id, received_signal_id, signal_msg["data"])
unserialized_data = plugin_schema.unserialize_signal_handler_input(
step_id,
received_signal_id,
signal_msg["data"]
)
plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data)
else:
stderr.write(f"Unknown kind of runtime message: {msg_id}")

Expand Down
51 changes: 37 additions & 14 deletions src/arcaflow_plugin_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5681,32 +5681,55 @@ class SchemaType(Schema):

steps: Dict[str, StepType]

def unserialize_input(self, step_id: str, data: Any) -> Any:
def unserialize_step_input(self, step_id: str, serialized_data: Any) -> Any:
"""
This function unserializes the input from a raw data to data structures, such as dataclasses. This function is
automatically called by ``__call__`` before running the step with the unserialized input.

:param step_id: The step ID to use to look up the schema for unserialization.
:param data: The raw data to unserialize.
:param serialized_data: The raw data to unserialize.
:return: The unserialized data in the structure the step expects it.
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
return self._unserialize_input(step, data)
return self._unserialize_step_input(step, serialized_data)

@staticmethod
def _unserialize_input(step: StepType, data: Any) -> Any:
def _unserialize_step_input(step: StepType, serialized_data: Any) -> Any:
try:
return step.input.unserialize(data)
return step.input.unserialize(serialized_data)
except ConstraintException as e:
raise InvalidInputException(e) from e

def unserialize_signal_handler_input(self, step_id: str, signal_id: str, serialized_data: Any) -> Any:
"""
This function unserializes the input from a raw data to data structures, such as dataclasses. This function is
automatically called by ``__call__`` before running the step with the unserialized input.

:param step_id: The step ID to use to look up the schema for unserialization.
:param serialized_data: The raw data to unserialize.
:return: The unserialized data in the structure the step expects it.
"""
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
if signal_id not in step.signal_handlers:
raise NoSuchSignalException(step_id, signal_id)
return self._unserialize_signal_handler_input(step, serialized_data)

@staticmethod
def _unserialize_signal_handler_input(signal: SignalHandlerType, data: Any) -> Any:
try:
return signal.data_schema.unserialize(data)
except ConstraintException as e:
raise InvalidInputException(e) from e

def call_step(self, step_id: str, input_param: Any) -> typing.Tuple[str, Any]:
"""
This function calls a specific step with the input parameter that has already been unserialized. It expects the
data to be already valid, use unserialize_input to produce a valid input. This function is automatically called
by ``__call__`` after unserializing the input.
data to be already valid, use unserialize_step_input to produce a valid input. This function is automatically
called by ``__call__`` after unserializing the input.

:param step_id: The ID of the input step to run.
:param input_param: The unserialized data structure the step expects.
Expand All @@ -5717,13 +5740,13 @@ def call_step(self, step_id: str, input_param: Any) -> typing.Tuple[str, Any]:
step = self.steps[step_id]
return self._call_step(step, input_param)

def call_step_signal(self, step_id: str, signal_id: str, input_param: Any):
def call_step_signal(self, step_id: str, signal_id: str, unserialized_input_param: Any):
"""
This function calls a specific step's signal with the input parameter that has already been unserialized. It expects the
data to be already valid, use unserialize_input to produce a valid input.
data to be already valid, use unserialize_signal_input to produce a valid input.

:param step_id: The ID of the input step to run.
:param input_param: The unserialized data structure the step expects.
:param unserialized_input_param: The unserialized data structure the step expects.
:return: The ID of the output, and the data structure returned from the step.
"""
if step_id not in self.steps:
Expand All @@ -5733,17 +5756,17 @@ def call_step_signal(self, step_id: str, signal_id: str, input_param: Any):
if signal_id not in step.signal_handlers:
raise NoSuchSignalException(step_id, signal_id)
signal = step.signal_handlers[signal_id]
return signal(step.initializedObjectData, input_param)
return signal(step.initializedObjectData, unserialized_input_param)

@staticmethod
def _call_step(
step: StepType,
input_param: Any,
unserialized_input_param: Any,
skip_input_validation: bool = False,
skip_output_validation: bool = False,
) -> typing.Tuple[str, Any]:
return step(
input_param,
unserialized_input_param,
skip_input_validation=skip_input_validation,
skip_output_validation=skip_output_validation,
)
Expand Down Expand Up @@ -5785,7 +5808,7 @@ def __call__(
if step_id not in self.steps:
raise NoSuchStepException(step_id)
step = self.steps[step_id]
input_param = self._unserialize_input(step, data)
input_param = self.unserialize_step_input(step, data)
output_id, output_data = self._call_step(
step,
input_param,
Expand Down