From 75c504ae5a2ac444f16d9c6ddf841c08ec510e9b Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Wed, 14 Aug 2024 17:05:51 -0400 Subject: [PATCH] Use a buffered stdout substitute --- src/arcaflow_plugin_sdk/atp.py | 12 +++++++----- src/arcaflow_plugin_sdk/plugin.py | 7 +++++-- src/arcaflow_plugin_sdk/test_atp.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/arcaflow_plugin_sdk/atp.py b/src/arcaflow_plugin_sdk/atp.py index 5bdc81e..7e6d03f 100644 --- a/src/arcaflow_plugin_sdk/atp.py +++ b/src/arcaflow_plugin_sdk/atp.py @@ -76,7 +76,7 @@ class ATPServer: step_ids: typing.Dict[str, str] # Run ID to step IDs encoder: cbor2.CBOREncoder decoder: cbor2.CBORDecoder - user_out_buffer: io.StringIO + user_out_wrapper: io.TextIOWrapper encoder_lock: threading.Lock plugin_schema: schema.SchemaType running_threads: typing.List[threading.Thread] @@ -114,9 +114,9 @@ def run_plugin( # potentially interfering with the atp pipes. original_stdout = sys.stdout original_stderr = sys.stderr - self.user_out_buffer = io.StringIO() - sys.stdout = self.user_out_buffer - sys.stderr = self.user_out_buffer + self.user_out_wrapper = io.TextIOWrapper(io.BytesIO(), sys.stdout.encoding) + sys.stdout = self.user_out_wrapper + sys.stderr = self.user_out_wrapper # Run the read loop. This blocks to wait for the loop to finish. self.run_server_read_loop() @@ -324,6 +324,8 @@ def start_step(self, run_id: str, step_id: str, config: typing.Any): self.plugin_schema.unserialize_step_input(step_id, config), ) + self.user_out_wrapper.flush() + self.user_out_wrapper.seek(0) # go to start so that we can read stdout. # Send WorkDoneMessage self.send_runtime_message( MessageType.WORK_DONE, @@ -333,7 +335,7 @@ def start_step(self, run_id: str, step_id: str, config: typing.Any): "output_data": self.plugin_schema.serialize_output( step_id, output_id, output_data ), - "debug_logs": self.user_out_buffer.getvalue(), + "debug_logs": self.user_out_wrapper.read(), }, ) except Exception as e: diff --git a/src/arcaflow_plugin_sdk/plugin.py b/src/arcaflow_plugin_sdk/plugin.py index ec8c4bb..4a270be 100644 --- a/src/arcaflow_plugin_sdk/plugin.py +++ b/src/arcaflow_plugin_sdk/plugin.py @@ -459,7 +459,8 @@ def _execute_file( data = serialization.load_from_file(filename) original_stdout = sys.stdout original_stderr = sys.stderr - out_buffer = io.StringIO() + buffer = io.BytesIO() + out_buffer = io.TextIOWrapper(buffer) if options.debug: # Redirect stdout to stderr for debug logging sys.stdout = stderr @@ -469,10 +470,12 @@ def _execute_file( sys.stderr = out_buffer try: output_id, output_data = s("file_run", step_id, data) + out_buffer.flush() + out_buffer.seek(0) # go to start so that we can read stdout. output = { "output_id": output_id, "output_data": output_data, - "debug_logs": out_buffer.getvalue(), + "debug_logs": out_buffer.read(), } stdout.write(yaml.dump(output, sort_keys=False)) return 0 diff --git a/src/arcaflow_plugin_sdk/test_atp.py b/src/arcaflow_plugin_sdk/test_atp.py index fb6448c..e2584fa 100644 --- a/src/arcaflow_plugin_sdk/test_atp.py +++ b/src/arcaflow_plugin_sdk/test_atp.py @@ -26,7 +26,7 @@ class Output: outputs={"success": Output}, ) def hello_world(params: Input) -> Tuple[str, Output]: - print("Hello world!") + print("printed message") return "success", Output("Hello, {}!".format(params.name)) @@ -182,7 +182,7 @@ def test_step_simple(self): self.assertEqual(result.run_id, self.id()) client.send_client_done() self.assertEqual(result.output_id, "success") - self.assertEqual("Hello world!\n", result.debug_logs) + self.assertEqual("printed message\n", result.debug_logs) finally: self._cleanup(pid, stdin_writer, stdout_reader)