Skip to content

Commit

Permalink
Use a buffered stdout substitute
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Aug 14, 2024
1 parent 9ea3692 commit 75c504a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
12 changes: 7 additions & 5 deletions src/arcaflow_plugin_sdk/atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ATPServer:
step_ids: typing.Dict[str, str] # Run ID to step IDs
encoder: cbor2.CBOREncoder
decoder: cbor2.CBORDecoder
user_out_buffer: io.StringIO
user_out_wrapper: io.TextIOWrapper
encoder_lock: threading.Lock
plugin_schema: schema.SchemaType
running_threads: typing.List[threading.Thread]
Expand Down Expand Up @@ -114,9 +114,9 @@ def run_plugin(
# potentially interfering with the atp pipes.
original_stdout = sys.stdout
original_stderr = sys.stderr
self.user_out_buffer = io.StringIO()
sys.stdout = self.user_out_buffer
sys.stderr = self.user_out_buffer
self.user_out_wrapper = io.TextIOWrapper(io.BytesIO(), sys.stdout.encoding)
sys.stdout = self.user_out_wrapper
sys.stderr = self.user_out_wrapper

# Run the read loop. This blocks to wait for the loop to finish.
self.run_server_read_loop()
Expand Down Expand Up @@ -324,6 +324,8 @@ def start_step(self, run_id: str, step_id: str, config: typing.Any):
self.plugin_schema.unserialize_step_input(step_id, config),
)

self.user_out_wrapper.flush()
self.user_out_wrapper.seek(0) # go to start so that we can read stdout.
# Send WorkDoneMessage
self.send_runtime_message(
MessageType.WORK_DONE,
Expand All @@ -333,7 +335,7 @@ def start_step(self, run_id: str, step_id: str, config: typing.Any):
"output_data": self.plugin_schema.serialize_output(
step_id, output_id, output_data
),
"debug_logs": self.user_out_buffer.getvalue(),
"debug_logs": self.user_out_wrapper.read(),
},
)
except Exception as e:
Expand Down
7 changes: 5 additions & 2 deletions src/arcaflow_plugin_sdk/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def _execute_file(
data = serialization.load_from_file(filename)
original_stdout = sys.stdout
original_stderr = sys.stderr
out_buffer = io.StringIO()
buffer = io.BytesIO()
out_buffer = io.TextIOWrapper(buffer)
if options.debug:
# Redirect stdout to stderr for debug logging
sys.stdout = stderr
Expand All @@ -469,10 +470,12 @@ def _execute_file(
sys.stderr = out_buffer
try:
output_id, output_data = s("file_run", step_id, data)
out_buffer.flush()
out_buffer.seek(0) # go to start so that we can read stdout.
output = {
"output_id": output_id,
"output_data": output_data,
"debug_logs": out_buffer.getvalue(),
"debug_logs": out_buffer.read(),
}
stdout.write(yaml.dump(output, sort_keys=False))
return 0
Expand Down
4 changes: 2 additions & 2 deletions src/arcaflow_plugin_sdk/test_atp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Output:
outputs={"success": Output},
)
def hello_world(params: Input) -> Tuple[str, Output]:
print("Hello world!")
print("printed message")
return "success", Output("Hello, {}!".format(params.name))


Expand Down Expand Up @@ -182,7 +182,7 @@ def test_step_simple(self):
self.assertEqual(result.run_id, self.id())
client.send_client_done()
self.assertEqual(result.output_id, "success")
self.assertEqual("Hello world!\n", result.debug_logs)
self.assertEqual("printed message\n", result.debug_logs)
finally:
self._cleanup(pid, stdin_writer, stdout_reader)

Expand Down

0 comments on commit 75c504a

Please sign in to comment.