Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Signals and ATP v2 #98

Merged
merged 36 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9d1d770
Progress towards signals in SDK
jaredoconnell Aug 23, 2023
f952174
Added signals to ATP server
jaredoconnell Aug 25, 2023
2e07076
Merge branch 'main' into signals
jaredoconnell Aug 25, 2023
82ab335
Add ATP v2 support to client
jaredoconnell Aug 25, 2023
00d0a65
Fix linting errors
jaredoconnell Aug 28, 2023
be01c85
Fix linting error
jaredoconnell Aug 28, 2023
1849d09
Ignore linting issue
jaredoconnell Aug 28, 2023
4b2561a
Fix linting error
jaredoconnell Aug 28, 2023
13cc3b3
Fix errors introduced while trying to fix linting errors
jaredoconnell Aug 28, 2023
317986f
Fix missing params in test
jaredoconnell Aug 28, 2023
c7f0155
Update dependencies
jaredoconnell Aug 28, 2023
b692b72
Update pyyaml
jaredoconnell Aug 28, 2023
7b86c6b
Fixed typo
jaredoconnell Aug 28, 2023
91226f9
Fix predefined schema
jaredoconnell Aug 28, 2023
0edc6e5
Bypass class limitations by passing method names instead
jaredoconnell Aug 29, 2023
f599ad3
Delay signal retrieval
jaredoconnell Aug 30, 2023
a14a04b
Added missing field, and added extra None check
jaredoconnell Aug 30, 2023
ca012ed
Fix parameter passed into test case
jaredoconnell Aug 30, 2023
21e59bf
Fix missing input to server read loop
jaredoconnell Aug 30, 2023
9113031
Fix wrong signal ID
jaredoconnell Aug 30, 2023
28d95e2
Fix extra parameter passed into function
jaredoconnell Aug 30, 2023
a20b61e
Fix missing parameter passed into function
jaredoconnell Aug 30, 2023
7169c78
Fix missing deserialization step
jaredoconnell Aug 30, 2023
7ab7b20
Fix missed function rename, and fix linting err
jaredoconnell Aug 30, 2023
68d0ecd
Reduce redundancy, and fix wrong var passed
jaredoconnell Aug 30, 2023
cce45b7
Added client done message, and signal atp tests
jaredoconnell Sep 12, 2023
7cf6628
Fix linting errors
jaredoconnell Sep 12, 2023
b68efc1
Remove join and add flush
jaredoconnell Sep 15, 2023
b13a14b
Change when read thread is launched
jaredoconnell Sep 15, 2023
df69b67
Ignore sigint and manage stdout correctly
jaredoconnell Sep 15, 2023
cc2960e
Fix linting errors, and added comments
jaredoconnell Sep 18, 2023
c7d9658
Fix ordering problem with steps and signals
jaredoconnell Sep 19, 2023
c79cef8
Added coverage config file
jaredoconnell Sep 19, 2023
3f467d4
Remove unused import
jaredoconnell Sep 19, 2023
ced8b2b
Removed print statements, and added fail when debug logs aren't empty
jaredoconnell Sep 19, 2023
6a47aae
Remove coverage config
jaredoconnell Sep 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
588 changes: 329 additions & 259 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ homepage = "https://github.com/arcalot/arcaflow-plugin-sdk-python"
[tool.poetry.dependencies]
python = "^3.9"
cbor2 = "^5.4.3"
PyYAML = "^5.4"
PyYAML = "^6.0.1"

[tool.poetry.group.dev.dependencies]
coverage = "^6.5.0"
Expand Down
272 changes: 195 additions & 77 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,29 @@
import dataclasses
import io
import os
import signal
import sys
import typing
import threading
import signal
Comment on lines -19 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you keep your import lists ordered? (E.g., you might want to try out isort.)

Should the import of enum be in this block?


import cbor2

from enum import Enum

from arcaflow_plugin_sdk import schema


class MessageType(Enum):
"""
An integer ID that indicates the type of runtime message that is stored
in the data field. The corresponding class can then be used to deserialize
the inner data. Look at the go SDK for the reference data structure.
"""
WORK_DONE = 1
SIGNAL = 2
CLIENT_DONE = 3
Comment on lines +37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using enum.auto() instead of explicit values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to sync these between the go and Python SDK, I think it makes sense to make these explicit.



@dataclasses.dataclass
class HelloMessage:
"""
Expand All @@ -50,73 +64,144 @@ class HelloMessage:
_HELLO_MESSAGE_SCHEMA = schema.build_object_schema(HelloMessage)


def _handle_exit(_signo, _stack_frame):
print("Exiting normally")
sys.exit(0)
def signal_handler(_sig, _frame):
pass # Do nothing


def run_plugin(
s: schema.SchemaType,
stdin: io.FileIO,
stdout: io.FileIO,
stderr: io.FileIO,
) -> int:
"""
This function wraps running a plugin.
"""
if os.isatty(stdout.fileno()):
print("Cannot run plugin in ATP mode on an interactive terminal.")
return 1
class ATPServer:
stdin: io.FileIO
stdout: io.FileIO
stderr: io.FileIO
step_object: typing.Any

signal.signal(signal.SIGTERM, _handle_exit)
try:
decoder = cbor2.decoder.CBORDecoder(stdin)
encoder = cbor2.encoder.CBOREncoder(stdout)
def __init__(
self,
stdin: io.FileIO,
stdout: io.FileIO,
stderr: io.FileIO,
) -> None:
self.stdin = stdin
self.stdout = stdout
self.stderr = stderr

# Decode empty "start output" message.
decoder.decode()
def run_plugin(
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
self,
plugin_schema: schema.SchemaType,
) -> int:
"""
This function wraps running a plugin.
"""
signal.signal(signal.SIGINT, signal_handler) # Ignore sigint. Only care about arcaflow signals.
Comment on lines +93 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a custom signal handler (i.e., signal_handler) instead of using the standard signal.SIG_IGN to ignore the signal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll need to refactor this at some point.

if os.isatty(self.stdout.fileno()):
print("Cannot run plugin in ATP mode on an interactive terminal.")
return 1
Comment on lines +95 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is curious...why is this the case? Do/should you have similar restrictions on stderr or stdin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plugins communicate over stdout and stdin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, but does that mean that there should never be a terminal connected there (not even for interactive testing)? (Is it even possible that this could happen?) That is, is there actually value to performing this check?

If so, then shouldn't you be checking stdin as well?

try:
decoder = cbor2.decoder.CBORDecoder(self.stdin)
encoder = cbor2.encoder.CBOREncoder(self.stdout)

start = HelloMessage(1, s)
serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start)
encoder.encode(serialized_message)
stdout.flush()
# Decode empty "start output" message.
decoder.decode()

message = decoder.decode()
except SystemExit:
return 0
try:
if message is None:
stderr.write("Work start message is None.")
return 1
if message["id"] is None:
stderr.write("Work start message is missing the 'id' field.")
return 1
if message["config"] is None:
stderr.write("Work start message is missing the 'config' field.")
# Serialize then send HelloMessage
start_hello_message = HelloMessage(2, plugin_schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2 here looks like a "magic number"...shouldn't that be a symbolic reference to an externally-defined constant? Or, perhaps, the whole HelloMessage(2, plugin_schema) expression should be encapsulated in a reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should create a version constant.

serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start_hello_message)
encoder.encode(serialized_message)
self.stdout.flush()

# Can fail here if only getting schema.
work_start_msg = decoder.decode()
except SystemExit:
return 0
Comment on lines +113 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a somewhat strange thing to do...are you sure you want to be doing this, especially given that you've removed the code which used to call sys.exit() from old line 55? (Removing this would allow you to remove the try, as well.)

Ditto line (new) line 162.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is prior code that I didn't write, I just moved. It may be worth refactoring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've removed at least one source of a SystemExit exception; if you've in fact removed all of the explicit ones, then you should almost certainly remove these two try/except blocks.

try:
if work_start_msg is None:
self.stderr.write("Work start message is None.")
return 1
if work_start_msg["id"] is None:
self.stderr.write("Work start message is missing the 'id' field.")
return 1
if work_start_msg["config"] is None:
self.stderr.write("Work start message is missing the 'config' field.")
return 1

# Run the step
original_stdout = sys.stdout
original_stderr = sys.stderr
Comment on lines +127 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At line 100 you reference self.stdout; here you reference sys.stdout...can they be different from each other? Are you sure that you're referencing the right one in each case? (Shouldn't you be using the same one in both cases?)

BTW, you can potentially skip this step, since sys already does something very like it for you -- see sys.__stdout__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're actually replacing the standard out so that the plugins can use print without interfering with the ATP protocol, which took over stdin/stdout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. But, you have at least two sources for stdin and stdout, so you'll need to take care that you don't get them confused (and, the Reviewer needs to keep this in mind, as well 🙂).

out_buffer = io.StringIO()
sys.stdout = out_buffer
sys.stderr = out_buffer
# Run the read loop
read_thread = threading.Thread(target=self.run_server_read_loop, args=(
plugin_schema, # Plugin schema
work_start_msg["id"], # step ID
decoder, # Decoder
))
read_thread.start()

output_id, output_data = plugin_schema.call_step(
work_start_msg["id"],
plugin_schema.unserialize_step_input(work_start_msg["id"], work_start_msg["config"])
)

# Send WorkDoneMessage in a RuntimeMessage
encoder.encode(
{
"id": MessageType.WORK_DONE.value,
Comment on lines +145 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly pulling the value out of an enum kind of undermines the point of using an enum in the first place...won't the code still work if you remove the .value from here and line 292? If not, consider switching to the IntEnum base class (or switching to simple defined constants).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be worth refactoring. We're using the integers.

"data": {
"output_id": output_id,
"output_data": plugin_schema.serialize_output(
work_start_msg["id"], output_id, output_data
),
"debug_logs": out_buffer.getvalue(),
}
}
)
self.stdout.flush() # Sends it to the ATP client immediately. Needed so it can realize it's done.
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
read_thread.join() # Wait for the read thread to finish.
# Don't reset stdout/stderr until after the read thread is done.
sys.stdout = original_stdout
sys.stderr = original_stderr
except SystemExit:
return 1
original_stdout = sys.stdout
original_stderr = sys.stderr
out_buffer = io.StringIO()
sys.stdout = out_buffer
sys.stderr = out_buffer
output_id, output_data = s.call_step(
message["id"], s.unserialize_input(message["id"], message["config"])
)
sys.stdout = original_stdout
sys.stderr = original_stderr
encoder.encode(
{
"output_id": output_id,
"output_data": s.serialize_output(
message["id"], output_id, output_data
),
"debug_logs": out_buffer.getvalue(),
}
)
stdout.flush()
except SystemExit:
return 1
return 0
return 0

def run_server_read_loop(
self,
plugin_schema: schema.SchemaType,
step_id: str,
decoder: cbor2.decoder.CBORDecoder,
) -> None:
try:
while True:
jaredoconnell marked this conversation as resolved.
Show resolved Hide resolved
# Decode the message
runtime_msg = decoder.decode()
msg_id = runtime_msg["id"]
# Validate
if msg_id is None:
self.stderr.write("Runtime message is missing the 'id' field.")
return
# Then take action
if msg_id == MessageType.SIGNAL.value:
signal_msg = runtime_msg["data"]
received_step_id = signal_msg["step_id"]
received_signal_id = signal_msg["signal_id"]
if received_step_id != step_id: # Ensure they match.
self.stderr.write(f"Received step ID in the signal message '{received_step_id}'"
f"does not match expected step ID '{step_id}'")
return
unserialized_data = plugin_schema.unserialize_signal_handler_input(
received_step_id,
received_signal_id,
signal_msg["data"]
)
# The data is verified and unserialized. Now call the signal.
plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data)
elif msg_id == MessageType.CLIENT_DONE.value:
return
else:
self.stderr.write(f"Unknown kind of runtime message: {msg_id}")

except cbor2.CBORDecodeError as err:
self.stderr.write(f"Error while decoding CBOR: {err}")
Comment on lines +202 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any attempts to catch this error from the other calls to decoder.decode(). Is this one superfluous, should the others be added, or do we somehow know that the others cannot fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new section I added, so I added it. I didn't check the old ones. It may be worth refactoring the others in the future.



class PluginClientStateException(Exception):
Expand Down Expand Up @@ -158,7 +243,7 @@ def start_output(self) -> None:

def read_hello(self) -> HelloMessage:
"""
This function reads the intial "Hello" message from the plugin.
This function reads the initial "Hello" message from the plugin.
"""
message = self.decoder.decode()
return _HELLO_MESSAGE_SCHEMA.unserialize(message)
Expand All @@ -174,21 +259,54 @@ def start_work(self, step_id: str, input_data: any):
}
)

def send_signal(self, step_id: str, signal_id: str, input_data: any):
"""
This function sends any signals to the plugin.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks like it sends only the one signal to the plugin. 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there something wrong with that? This is mostly used for testing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was referring to the text of the docstring...it's not quite as clear as it might be. 🙂

"""
self.send_runtime_message(MessageType.SIGNAL, {
"step_id": step_id,
"signal_id": signal_id,
"data": input_data,
}
)

def send_client_done(self):
self.send_runtime_message(MessageType.CLIENT_DONE, {})

def send_runtime_message(self, message_type: MessageType, data: any):
self.encoder.encode(
{
"id": message_type.value,
"data": data,
}
)

def read_results(self) -> (str, any, str):
"""
This function reads the results of an execution from the plugin.
This function reads the signals and results of an execution from the plugin.
"""
message = self.decoder.decode()
if message["output_id"] is None:
raise PluginClientStateException(
"Missing 'output_id' in CBOR message. Possibly wrong order of calls?"
)
if message["output_data"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
if message["debug_logs"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
return message["output_id"], message["output_data"], message["debug_logs"]
while True:
runtime_msg = self.decoder.decode()
msg_id = runtime_msg["id"]
if msg_id == MessageType.WORK_DONE.value:
signal_msg = runtime_msg["data"]
if signal_msg["output_id"] is None:
raise PluginClientStateException(
"Missing 'output_id' in CBOR message. Possibly wrong order of calls?"
)
if signal_msg["output_data"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
if signal_msg["debug_logs"] is None:
raise PluginClientStateException(
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?"
)
return signal_msg["output_id"], signal_msg["output_data"], signal_msg["debug_logs"]
elif msg_id == MessageType.SIGNAL.value:
# Do nothing. Should change in the future.
continue
Comment on lines +307 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check seems odd: shouldn't any signal messages have been fielded by the run_server_read_loop thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now the client side. Signals are not supply implemented in this client.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the lingering question then is, if a signal were to show up here, should it be silently ignored or should it be announced with a klaxon. 🤣

else:
raise PluginClientStateException(
f"Received unknown runtime message ID {msg_id}"
)
Loading