Skip to content

Commit

Permalink
AIP-72: Push XCom on Task Return (apache#45245)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Dec 27, 2024
1 parent 60cd5ad commit af130c0
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 113 deletions.
48 changes: 45 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import os
import sys
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
Expand Down Expand Up @@ -197,7 +197,11 @@ def xcom_pull(

value = msg.value
if value is not None:
return value
from airflow.models.xcom import XCom

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
return XCom.deserialize_value(value)
return default

def xcom_push(self, key: str, value: Any):
Expand All @@ -207,6 +211,12 @@ def xcom_push(self, key: str, value: Any):
:param key: Key to store the value under.
:param value: Value to store. Only be JSON-serializable may be used otherwise.
"""
from airflow.models.xcom import XCom

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
value = XCom.serialize_value(value)

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
Expand Down Expand Up @@ -381,7 +391,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# - Update RTIF
# - Pre Execute
# etc
ti.task.execute(context) # type: ignore[attr-defined]
result = ti.task.execute(context) # type: ignore[attr-defined]
_push_xcom_if_needed(result, ti)

msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
Expand Down Expand Up @@ -436,6 +448,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
"""Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
if ti.task.do_xcom_push:
xcom_value = result
else:
xcom_value = None

# If the task returns a result, push an XCom containing it.
if xcom_value is None:
return

# If the task has multiple outputs, push each output as a separate XCom.
if ti.task.multiple_outputs:
if not isinstance(xcom_value, Mapping):
raise TypeError(
f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
)
for key in xcom_value.keys():
if not isinstance(key, str):
raise TypeError(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for k, v in result.items():
ti.xcom_push(k, v)

# TODO: Use constant for XCom return key & use serialize_value from Task SDK
ti.xcom_push("return_value", result)


def finalize(log: Logger): ...


Expand Down
9 changes: 9 additions & 0 deletions task_sdk/tests/execution_time/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import sys
from unittest import mock

import pytest

Expand All @@ -31,3 +32,11 @@ def disable_capturing():
sys.stderr = sys.__stderr__
yield
sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err


@pytest.fixture
def mock_supervisor_comms():
with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as supervisor_comms:
yield supervisor_comms
39 changes: 14 additions & 25 deletions task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

from __future__ import annotations

from unittest import mock

from airflow.sdk.definitions.connection import Connection
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
Expand Down Expand Up @@ -51,7 +49,7 @@ def test_convert_connection_result_conn():


class TestConnectionAccessor:
def test_getattr_connection(self):
def test_getattr_connection(self, mock_supervisor_comms):
"""
Test that the connection is fetched when accessed via __getattr__.
Expand All @@ -62,42 +60,33 @@ def test_getattr_connection(self):
# Conn from the supervisor / API Server
conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
assert conn == expected_conn
expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
assert conn == expected_conn

def test_get_method_valid_connection(self):
def test_get_method_valid_connection(self, mock_supervisor_comms):
"""Test that the get method returns the requested connection using `conn.get`."""
accessor = ConnectionAccessor()
conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result
mock_supervisor_comms.get_message.return_value = conn_result

conn = accessor.get("mysql_conn")
assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
conn = accessor.get("mysql_conn")
assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)

def test_get_method_with_default(self):
def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default connection when the requested connection is not found."""
accessor = ConnectionAccessor()
default_conn = {"conn_id": "default_conn", "conn_type": "sqlite"}
error_response = ErrorResponse(
error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"}
)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = error_response
mock_supervisor_comms.get_message.return_value = error_response

conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn
conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn
Loading

0 comments on commit af130c0

Please sign in to comment.