diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index f3789808071ff..31f762c0a0c0b 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -32,7 +32,7 @@ from contextlib import suppress from datetime import datetime, timezone from socket import socket, socketpair -from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, TextIO, cast, overload from uuid import UUID import attrs @@ -131,17 +131,17 @@ def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, sock, mode, close in ( - ("stdin", child_stdin, "r", True), - ("stdout", child_stdout, "w", True), - ("stderr", child_stderr, "w", False), + for handle_name, sock, mode in ( + ("stdin", child_stdin, "r"), + ("stdout", child_stdout, "w"), + ("stderr", child_stderr, "w"), ): handle = getattr(sys, handle_name) try: fd = handle.fileno() os.dup2(sock.fileno(), fd) - if close: - handle.close() + # dup2 creates another open copy of the fd, we can close the "socket" copy of it. + sock.close() except io.UnsupportedOperation: if "PYTEST_CURRENT_TEST" in os.environ: # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need @@ -154,6 +154,17 @@ def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): setattr(sys, handle_name, handle) +def _get_last_chance_stderr() -> TextIO: + stream = sys.__stderr__ or sys.stderr + + try: + # We want to open another copy of the underlying filedescriptor if we can, to ensure it stays open! + return os.fdopen(os.dup(stream.fileno()), "w", buffering=1) + except Exception: + # If that didn't work, do the best we can + return stream + + def _fork_main( child_stdin: socket, child_stdout: socket, @@ -180,7 +191,7 @@ def _fork_main( # TODO: Make this process a session leader # Store original stderr for last-chance exception handling - last_chance_stderr = sys.__stderr__ or sys.stderr + last_chance_stderr = _get_last_chance_stderr() _reset_signals() if log_fd: @@ -244,8 +255,7 @@ class WatchedSubprocess: pid: int stdin: BinaryIO - stdout: socket - stderr: socket + """The handle connected to stdin of the child process""" client: Client @@ -291,8 +301,6 @@ def start( ti_id=ti.id, pid=pid, stdin=feed_stdin, - stdout=read_stdout, - stderr=read_stderr, process=psutil.Process(pid), client=client, ) @@ -308,17 +316,19 @@ def start( proc.kill(signal.SIGKILL) raise - proc._register_pipes(read_msgs, read_logs) + proc._register_pipe_readers( + stdout=read_stdout, stderr=read_stderr, requests=read_msgs, logs=read_logs + ) # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - proc._close_unused_sockets(child_stdout, child_stdin, child_comms, child_logs) + proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) # Tell the task process what it needs to do! proc._send_startup_message(ti, path, child_comms) return proc - def _register_pipes(self, read_msgs, read_logs): + def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socket, logs: socket): """Register handlers for subprocess communication channels.""" # self.selector is a way of registering a handler/callback to be called when the given IO channel has # activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better @@ -328,21 +338,19 @@ def _register_pipes(self, read_msgs, read_logs): # TODO: Use logging providers to handle the chunked upload for us logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() + self.selector.register(stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout")) self.selector.register( - self.stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout") - ) - self.selector.register( - self.stderr, + stderr, selectors.EVENT_READ, self._create_socket_handler(logger, "stderr", log_level=logging.ERROR), ) self.selector.register( - read_logs, + logs, selectors.EVENT_READ, make_buffered_socket_reader(process_log_messages_from_subprocess(logger)), ) self.selector.register( - read_msgs, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log)) + requests, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log)) ) @staticmethod @@ -468,13 +476,15 @@ def final_state(self): return TerminalTIState.FAILED def __rich_repr__(self): + yield "ti_id", self.ti_id yield "pid", self.pid + # only include this if it's not the default (third argument) yield "exit_code", self._exit_code, None __rich_repr__.angular = True # type: ignore[attr-defined] def __repr__(self) -> str: - rep = f"" diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 8d1117acd0e95..b0c8074b9f52d 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -163,6 +163,36 @@ def subprocess_main(): assert rc == -9 + def test_last_chance_exception_handling(self, capfd): + # Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards + structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO)) + + def subprocess_main(): + # The real main() in task_runner catches exceptions! This is what would happen if we had a syntax + # or import error for instance - a very early exception + raise RuntimeError("Fake syntax error") + + proc = WatchedSubprocess.start( + path=os.devnull, + ti=TaskInstance( + id=uuid7(), + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=MagicMock(spec=sdk_client.Client), + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == 126 + + captured = capfd.readouterr() + assert "Last chance exception handler" in captured.err + assert "RuntimeError: Fake syntax error" in captured.err + def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch): """Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency""" import airflow.sdk.execution_time.supervisor @@ -237,8 +267,6 @@ def watched_subprocess(self, mocker): ti_id=uuid7(), pid=12345, stdin=BytesIO(), - stdout=mocker.Mock(), # Not used in these tests - stderr=mocker.Mock(), # Not used in these tests client=mocker.Mock(), process=mocker.Mock(), )