Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that Task SDK supervisor closes all its subprocess handles correctly. #44263

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from contextlib import suppress
from datetime import datetime, timezone
from socket import socket, socketpair
from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload
from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, TextIO, cast, overload
from uuid import UUID

import attrs
Expand Down Expand Up @@ -131,17 +131,17 @@ def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr):
# Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the
# pipes from the supervisor

for handle_name, sock, mode, close in (
("stdin", child_stdin, "r", True),
("stdout", child_stdout, "w", True),
("stderr", child_stderr, "w", False),
for handle_name, sock, mode in (
("stdin", child_stdin, "r"),
("stdout", child_stdout, "w"),
("stderr", child_stderr, "w"),
):
handle = getattr(sys, handle_name)
try:
fd = handle.fileno()
os.dup2(sock.fileno(), fd)
if close:
handle.close()
# dup2 creates another open copy of the fd, we can close the "socket" copy of it.
sock.close()
except io.UnsupportedOperation:
if "PYTEST_CURRENT_TEST" in os.environ:
# When we're running under pytest, the stdin is not a real filehandle with an fd, so we need
Expand All @@ -154,6 +154,17 @@ def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr):
setattr(sys, handle_name, handle)


def _get_last_chance_stderr() -> TextIO:
stream = sys.__stderr__ or sys.stderr

try:
# We want to open another copy of the underlying filedescriptor if we can, to ensure it stays open!
return os.fdopen(os.dup(stream.fileno()), "w", buffering=1)
except Exception:
# If that didn't work, do the best we can
return stream


def _fork_main(
child_stdin: socket,
child_stdout: socket,
Expand All @@ -180,7 +191,7 @@ def _fork_main(
# TODO: Make this process a session leader

# Store original stderr for last-chance exception handling
last_chance_stderr = sys.__stderr__ or sys.stderr
last_chance_stderr = _get_last_chance_stderr()

_reset_signals()
if log_fd:
Expand Down Expand Up @@ -244,8 +255,7 @@ class WatchedSubprocess:
pid: int

stdin: BinaryIO
stdout: socket
stderr: socket
"""The handle connected to stdin of the child process"""

client: Client

Expand Down Expand Up @@ -291,8 +301,6 @@ def start(
ti_id=ti.id,
pid=pid,
stdin=feed_stdin,
stdout=read_stdout,
stderr=read_stderr,
process=psutil.Process(pid),
client=client,
)
Expand All @@ -308,17 +316,19 @@ def start(
proc.kill(signal.SIGKILL)
raise

proc._register_pipes(read_msgs, read_logs)
proc._register_pipe_readers(
stdout=read_stdout, stderr=read_stderr, requests=read_msgs, logs=read_logs
)

# Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the
# other end of the pair open
proc._close_unused_sockets(child_stdout, child_stdin, child_comms, child_logs)
proc._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs)

# Tell the task process what it needs to do!
proc._send_startup_message(ti, path, child_comms)
return proc

def _register_pipes(self, read_msgs, read_logs):
def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socket, logs: socket):
"""Register handlers for subprocess communication channels."""
# self.selector is a way of registering a handler/callback to be called when the given IO channel has
# activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better
Expand All @@ -328,21 +338,19 @@ def _register_pipes(self, read_msgs, read_logs):
# TODO: Use logging providers to handle the chunked upload for us
logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind()

self.selector.register(stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout"))
self.selector.register(
self.stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout")
)
self.selector.register(
self.stderr,
stderr,
selectors.EVENT_READ,
self._create_socket_handler(logger, "stderr", log_level=logging.ERROR),
)
self.selector.register(
read_logs,
logs,
selectors.EVENT_READ,
make_buffered_socket_reader(process_log_messages_from_subprocess(logger)),
)
self.selector.register(
read_msgs, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log))
requests, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log))
)

@staticmethod
Expand Down Expand Up @@ -468,13 +476,15 @@ def final_state(self):
return TerminalTIState.FAILED

def __rich_repr__(self):
yield "ti_id", self.ti_id
yield "pid", self.pid
# only include this if it's not the default (third argument)
yield "exit_code", self._exit_code, None

__rich_repr__.angular = True # type: ignore[attr-defined]

def __repr__(self) -> str:
rep = f"<WatchedSubprocess pid={self.pid}"
rep = f"<WatchedSubprocess ti_id={self.ti_id} pid={self.pid}"
if self._exit_code is not None:
rep += f" exit_code={self._exit_code}"
return rep + " >"
Expand Down
32 changes: 30 additions & 2 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,36 @@ def subprocess_main():

assert rc == -9

def test_last_chance_exception_handling(self, capfd):
# Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))

def subprocess_main():
# The real main() in task_runner catches exceptions! This is what would happen if we had a syntax
# or import error for instance - a very early exception
raise RuntimeError("Fake syntax error")

proc = WatchedSubprocess.start(
path=os.devnull,
ti=TaskInstance(
id=uuid7(),
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
),
client=MagicMock(spec=sdk_client.Client),
target=subprocess_main,
)

rc = proc.wait()

assert rc == 126

captured = capfd.readouterr()
assert "Last chance exception handler" in captured.err
assert "RuntimeError: Fake syntax error" in captured.err

def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch):
"""Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency"""
import airflow.sdk.execution_time.supervisor
Expand Down Expand Up @@ -237,8 +267,6 @@ def watched_subprocess(self, mocker):
ti_id=uuid7(),
pid=12345,
stdin=BytesIO(),
stdout=mocker.Mock(), # Not used in these tests
stderr=mocker.Mock(), # Not used in these tests
client=mocker.Mock(),
process=mocker.Mock(),
)
Expand Down