-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Signals and ATP v2 #98
Changes from 31 commits
9d1d770
f952174
2e07076
82ab335
00d0a65
be01c85
1849d09
4b2561a
13cc3b3
317986f
c7f0155
b692b72
7b86c6b
91226f9
0edc6e5
f599ad3
a14a04b
ca012ed
21e59bf
9113031
28d95e2
a20b61e
7169c78
7ab7b20
68d0ecd
cce45b7
7cf6628
b68efc1
b13a14b
df69b67
cc2960e
c7d9658
c79cef8
3f467d4
ced8b2b
6a47aae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,29 @@ | |
import dataclasses | ||
import io | ||
import os | ||
import signal | ||
import sys | ||
import typing | ||
import threading | ||
import signal | ||
|
||
import cbor2 | ||
|
||
from enum import Enum | ||
|
||
from arcaflow_plugin_sdk import schema | ||
|
||
|
||
class MessageType(Enum): | ||
""" | ||
An integer ID that indicates the type of runtime message that is stored | ||
in the data field. The corresponding class can then be used to deserialize | ||
the inner data. Look at the go SDK for the reference data structure. | ||
""" | ||
WORK_DONE = 1 | ||
SIGNAL = 2 | ||
CLIENT_DONE = 3 | ||
Comment on lines
+37
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we need to sync these between the go and Python SDK, I think it makes sense to make these explicit. |
||
|
||
|
||
@dataclasses.dataclass | ||
class HelloMessage: | ||
""" | ||
|
@@ -50,73 +64,144 @@ class HelloMessage: | |
_HELLO_MESSAGE_SCHEMA = schema.build_object_schema(HelloMessage) | ||
|
||
|
||
def _handle_exit(_signo, _stack_frame): | ||
print("Exiting normally") | ||
sys.exit(0) | ||
def signal_handler(_sig, _frame): | ||
pass # Do nothing | ||
|
||
|
||
def run_plugin( | ||
s: schema.SchemaType, | ||
stdin: io.FileIO, | ||
stdout: io.FileIO, | ||
stderr: io.FileIO, | ||
) -> int: | ||
""" | ||
This function wraps running a plugin. | ||
""" | ||
if os.isatty(stdout.fileno()): | ||
print("Cannot run plugin in ATP mode on an interactive terminal.") | ||
return 1 | ||
class ATPServer: | ||
stdin: io.FileIO | ||
stdout: io.FileIO | ||
stderr: io.FileIO | ||
step_object: typing.Any | ||
|
||
signal.signal(signal.SIGTERM, _handle_exit) | ||
try: | ||
decoder = cbor2.decoder.CBORDecoder(stdin) | ||
encoder = cbor2.encoder.CBOREncoder(stdout) | ||
def __init__( | ||
self, | ||
stdin: io.FileIO, | ||
stdout: io.FileIO, | ||
stderr: io.FileIO, | ||
) -> None: | ||
self.stdin = stdin | ||
self.stdout = stdout | ||
self.stderr = stderr | ||
|
||
# Decode empty "start output" message. | ||
decoder.decode() | ||
def run_plugin( | ||
jaredoconnell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
plugin_schema: schema.SchemaType, | ||
) -> int: | ||
""" | ||
This function wraps running a plugin. | ||
""" | ||
signal.signal(signal.SIGINT, signal_handler) # Ignore sigint. Only care about arcaflow signals. | ||
Comment on lines
+93
to
+94
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use a custom signal handler (i.e., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I'll need to refactor this at some point. |
||
if os.isatty(self.stdout.fileno()): | ||
print("Cannot run plugin in ATP mode on an interactive terminal.") | ||
return 1 | ||
Comment on lines
+95
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is curious...why is this the case? Do/should you have similar restrictions on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The plugins communicate over stdout and stdin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood, but does that mean that there should never be a terminal connected there (not even for interactive testing)? (Is it even possible that this could happen?) That is, is there actually value to performing this check? If so, then shouldn't you be checking |
||
try: | ||
decoder = cbor2.decoder.CBORDecoder(self.stdin) | ||
encoder = cbor2.encoder.CBOREncoder(self.stdout) | ||
|
||
start = HelloMessage(1, s) | ||
serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start) | ||
encoder.encode(serialized_message) | ||
stdout.flush() | ||
# Decode empty "start output" message. | ||
decoder.decode() | ||
|
||
message = decoder.decode() | ||
except SystemExit: | ||
return 0 | ||
try: | ||
if message is None: | ||
stderr.write("Work start message is None.") | ||
return 1 | ||
if message["id"] is None: | ||
stderr.write("Work start message is missing the 'id' field.") | ||
return 1 | ||
if message["config"] is None: | ||
stderr.write("Work start message is missing the 'config' field.") | ||
# Serialize then send HelloMessage | ||
start_hello_message = HelloMessage(2, plugin_schema) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I should create a version constant. |
||
serialized_message = _HELLO_MESSAGE_SCHEMA.serialize(start_hello_message) | ||
encoder.encode(serialized_message) | ||
self.stdout.flush() | ||
|
||
# Can fail here if only getting schema. | ||
work_start_msg = decoder.decode() | ||
except SystemExit: | ||
return 0 | ||
Comment on lines
+113
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a somewhat strange thing to do...are you sure you want to be doing this, especially given that you've removed the code which used to call Ditto line (new) line 162. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is prior code that I didn't write, I just moved. It may be worth refactoring. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You've removed at least one source of a |
||
try: | ||
if work_start_msg is None: | ||
self.stderr.write("Work start message is None.") | ||
return 1 | ||
if work_start_msg["id"] is None: | ||
self.stderr.write("Work start message is missing the 'id' field.") | ||
return 1 | ||
if work_start_msg["config"] is None: | ||
self.stderr.write("Work start message is missing the 'config' field.") | ||
return 1 | ||
|
||
# Run the step | ||
original_stdout = sys.stdout | ||
original_stderr = sys.stderr | ||
Comment on lines
+127
to
+128
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At line 100 you reference BTW, you can potentially skip this step, since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're actually replacing the standard out so that the plugins can use print without interfering with the ATP protocol, which took over stdin/stdout. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood. But, you have at least two sources for |
||
out_buffer = io.StringIO() | ||
sys.stdout = out_buffer | ||
sys.stderr = out_buffer | ||
# Run the read loop | ||
read_thread = threading.Thread(target=self.run_server_read_loop, args=( | ||
plugin_schema, # Plugin schema | ||
work_start_msg["id"], # step ID | ||
decoder, # Decoder | ||
)) | ||
read_thread.start() | ||
|
||
output_id, output_data = plugin_schema.call_step( | ||
work_start_msg["id"], | ||
plugin_schema.unserialize_step_input(work_start_msg["id"], work_start_msg["config"]) | ||
) | ||
|
||
# Send WorkDoneMessage in a RuntimeMessage | ||
encoder.encode( | ||
{ | ||
"id": MessageType.WORK_DONE.value, | ||
Comment on lines
+145
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly pulling the value out of an enum kind of undermines the point of using an enum in the first place...won't the code still work if you remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may be worth refactoring. We're using the integers. |
||
"data": { | ||
"output_id": output_id, | ||
"output_data": plugin_schema.serialize_output( | ||
work_start_msg["id"], output_id, output_data | ||
), | ||
"debug_logs": out_buffer.getvalue(), | ||
} | ||
} | ||
) | ||
self.stdout.flush() # Sends it to the ATP client immediately. Needed so it can realize it's done. | ||
jaredoconnell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
read_thread.join() # Wait for the read thread to finish. | ||
# Don't reset stdout/stderr until after the read thread is done. | ||
sys.stdout = original_stdout | ||
sys.stderr = original_stderr | ||
except SystemExit: | ||
return 1 | ||
original_stdout = sys.stdout | ||
original_stderr = sys.stderr | ||
out_buffer = io.StringIO() | ||
sys.stdout = out_buffer | ||
sys.stderr = out_buffer | ||
output_id, output_data = s.call_step( | ||
message["id"], s.unserialize_input(message["id"], message["config"]) | ||
) | ||
sys.stdout = original_stdout | ||
sys.stderr = original_stderr | ||
encoder.encode( | ||
{ | ||
"output_id": output_id, | ||
"output_data": s.serialize_output( | ||
message["id"], output_id, output_data | ||
), | ||
"debug_logs": out_buffer.getvalue(), | ||
} | ||
) | ||
stdout.flush() | ||
except SystemExit: | ||
return 1 | ||
return 0 | ||
return 0 | ||
|
||
def run_server_read_loop( | ||
self, | ||
plugin_schema: schema.SchemaType, | ||
step_id: str, | ||
decoder: cbor2.decoder.CBORDecoder, | ||
) -> None: | ||
try: | ||
while True: | ||
jaredoconnell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Decode the message | ||
runtime_msg = decoder.decode() | ||
msg_id = runtime_msg["id"] | ||
# Validate | ||
if msg_id is None: | ||
self.stderr.write("Runtime message is missing the 'id' field.") | ||
return | ||
# Then take action | ||
if msg_id == MessageType.SIGNAL.value: | ||
signal_msg = runtime_msg["data"] | ||
received_step_id = signal_msg["step_id"] | ||
received_signal_id = signal_msg["signal_id"] | ||
if received_step_id != step_id: # Ensure they match. | ||
self.stderr.write(f"Received step ID in the signal message '{received_step_id}'" | ||
f"does not match expected step ID '{step_id}'") | ||
return | ||
unserialized_data = plugin_schema.unserialize_signal_handler_input( | ||
received_step_id, | ||
received_signal_id, | ||
signal_msg["data"] | ||
) | ||
# The data is verified and unserialized. Now call the signal. | ||
plugin_schema.call_step_signal(step_id, received_signal_id, unserialized_data) | ||
elif msg_id == MessageType.CLIENT_DONE.value: | ||
return | ||
else: | ||
self.stderr.write(f"Unknown kind of runtime message: {msg_id}") | ||
|
||
except cbor2.CBORDecodeError as err: | ||
self.stderr.write(f"Error while decoding CBOR: {err}") | ||
Comment on lines
+202
to
+203
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any attempts to catch this error from the other calls to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a new section I added, so I added it. I didn't check the old ones. It may be worth refactoring the others in the future. |
||
|
||
|
||
class PluginClientStateException(Exception): | ||
|
@@ -158,7 +243,7 @@ def start_output(self) -> None: | |
|
||
def read_hello(self) -> HelloMessage: | ||
""" | ||
This function reads the intial "Hello" message from the plugin. | ||
This function reads the initial "Hello" message from the plugin. | ||
""" | ||
message = self.decoder.decode() | ||
return _HELLO_MESSAGE_SCHEMA.unserialize(message) | ||
|
@@ -174,21 +259,54 @@ def start_work(self, step_id: str, input_data: any): | |
} | ||
) | ||
|
||
def send_signal(self, step_id: str, signal_id: str, input_data: any): | ||
""" | ||
This function sends any signals to the plugin. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function looks like it sends only the one signal to the plugin. 😉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there something wrong with that? This is mostly used for testing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was referring to the text of the docstring...it's not quite as clear as it might be. 🙂 |
||
""" | ||
self.send_runtime_message(MessageType.SIGNAL, { | ||
"step_id": step_id, | ||
"signal_id": signal_id, | ||
"data": input_data, | ||
} | ||
) | ||
|
||
def send_client_done(self): | ||
self.send_runtime_message(MessageType.CLIENT_DONE, {}) | ||
|
||
def send_runtime_message(self, message_type: MessageType, data: any): | ||
self.encoder.encode( | ||
{ | ||
"id": message_type.value, | ||
"data": data, | ||
} | ||
) | ||
|
||
def read_results(self) -> (str, any, str): | ||
""" | ||
This function reads the results of an execution from the plugin. | ||
This function reads the signals and results of an execution from the plugin. | ||
""" | ||
message = self.decoder.decode() | ||
if message["output_id"] is None: | ||
raise PluginClientStateException( | ||
"Missing 'output_id' in CBOR message. Possibly wrong order of calls?" | ||
) | ||
if message["output_data"] is None: | ||
raise PluginClientStateException( | ||
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?" | ||
) | ||
if message["debug_logs"] is None: | ||
raise PluginClientStateException( | ||
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?" | ||
) | ||
return message["output_id"], message["output_data"], message["debug_logs"] | ||
while True: | ||
runtime_msg = self.decoder.decode() | ||
msg_id = runtime_msg["id"] | ||
if msg_id == MessageType.WORK_DONE.value: | ||
signal_msg = runtime_msg["data"] | ||
if signal_msg["output_id"] is None: | ||
raise PluginClientStateException( | ||
"Missing 'output_id' in CBOR message. Possibly wrong order of calls?" | ||
) | ||
if signal_msg["output_data"] is None: | ||
raise PluginClientStateException( | ||
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?" | ||
) | ||
if signal_msg["debug_logs"] is None: | ||
raise PluginClientStateException( | ||
"Missing 'output_data' in CBOR message. Possibly wrong order of calls?" | ||
) | ||
return signal_msg["output_id"], signal_msg["output_data"], signal_msg["debug_logs"] | ||
elif msg_id == MessageType.SIGNAL.value: | ||
# Do nothing. Should change in the future. | ||
continue | ||
Comment on lines
+307
to
+309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check seems odd: shouldn't any signal messages have been fielded by the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is now the client side. Signals are not supply implemented in this client. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, the lingering question then is, if a signal were to show up here, should it be silently ignored or should it be announced with a klaxon. 🤣 |
||
else: | ||
raise PluginClientStateException( | ||
f"Received unknown runtime message ID {msg_id}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you keep your import lists ordered? (E.g., you might want to try out
isort
.)Should the import of
enum
be in this block?