Skip to content

Commit

Permalink
Remove join and add flush
Browse files Browse the repository at this point in the history
Also did a mild refactor of function
  • Loading branch information
jaredoconnell committed Sep 15, 2023
1 parent 7cf6628 commit b68efc1
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def run_plugin(
plugin_schema, # Plugin schema
work_start_msg["id"], # step ID
decoder, # Decoder
self.stderr, # Stderr
))
read_thread.start()
# Run the step
Expand Down Expand Up @@ -145,7 +144,8 @@ def run_plugin(
}
)
self.stdout.flush()
read_thread.join()
self.stdin.flush()
#read_thread.join()
except SystemExit:
return 1
return 0
Expand All @@ -155,7 +155,6 @@ def run_server_read_loop(
plugin_schema: schema.SchemaType,
step_id: str,
decoder: cbor2.decoder.CBORDecoder,
stderr: io.FileIO
) -> None:
try:
while True:
Expand All @@ -164,15 +163,15 @@ def run_server_read_loop(
msg_id = runtime_msg["id"]
# Validate
if msg_id is None:
stderr.write("Runtime message is missing the 'id' field.")
self.stderr.write("Runtime message is missing the 'id' field.")
return
# Then take action
if msg_id == MessageType.SIGNAL.value:
signal_msg = runtime_msg["data"]
received_step_id = signal_msg["step_id"]
received_signal_id = signal_msg["signal_id"]
if received_step_id != step_id:
stderr.write(f"Received step ID in the signal message '{received_step_id}'"
self.stderr.write(f"Received step ID in the signal message '{received_step_id}'"
f"does not match expected step ID '{step_id}'")
return
unserialized_data = plugin_schema.unserialize_signal_handler_input(
Expand All @@ -184,10 +183,10 @@ def run_server_read_loop(
elif msg_id == MessageType.CLIENT_DONE.value:
return
else:
stderr.write(f"Unknown kind of runtime message: {msg_id}")
self.stderr.write(f"Unknown kind of runtime message: {msg_id}")

except cbor2.CBORDecodeError as err:
stderr.write(f"Error while decoding CBOR: {err}")
self.stderr.write(f"Error while decoding CBOR: {err}")


class PluginClientStateException(Exception):
Expand Down Expand Up @@ -229,7 +228,7 @@ def start_output(self) -> None:

def read_hello(self) -> HelloMessage:
"""
This function reads the intial "Hello" message from the plugin.
This function reads the initial "Hello" message from the plugin.
"""
message = self.decoder.decode()
return _HELLO_MESSAGE_SCHEMA.unserialize(message)
Expand Down

0 comments on commit b68efc1

Please sign in to comment.