Skip to content

Commit

Permalink
AIP-72: Allow retrieving Connection from Task Context (apache#45043)
Browse files Browse the repository at this point in the history
part of apache#44481

- Added a minimal Connection user-facing object in Task SDK definition for use in the DAG file
- Added logic to get Connections in the context. Fixed some bugs in the way related to Connection parsing/serializing!


Now, we have following Connection related objects:
- `ConnectionResponse` is auto-generated and tightly coupled with the API schema.
- `ConnectionResult` is runtime-specific and meant for internal communication between Supervisor & Task Runner.
- `Connection` class here is where the public-facing, user-relevant aspects are exposed, hiding internal details.

**Next up**:

- Same for XCom & Variable
- Implementation of BaseHook.get_conn

Tested it with a DAG:

<img width="1711" alt="image" src="https://github.com/user-attachments/assets/14d28fb7-f6c5-4fbe-b226-46873af2d0f3" />

DAG:

```py
from __future__ import annotations

from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import dag


class CustomOperator(BaseOperator):
    def execute(self, context):
        import os
        os.environ["AIRFLOW_CONN_AIRFLOW_DB"] = "sqlite:///home/airflow/airflow.db"
        task_id = context["task_instance"].task_id
        print(f"Hello World {task_id}!")
        print(context)
        print(context["conn"].airflow_db)
        assert context["conn"].airflow_db.conn_id == "airflow_db"


@dag()
def super_basic_run():
    CustomOperator(task_id="hello")


super_basic_run()

```

For case where a **connection is not found**

<img width="1435" alt="image" src="https://github.com/user-attachments/assets/7c5e0cb4-6ed4-41aa-9a57-e5641adce954" />
  • Loading branch information
kaxil authored Dec 19, 2024
1 parent 088242a commit 4de24a1
Show file tree
Hide file tree
Showing 11 changed files with 373 additions and 9 deletions.
3 changes: 3 additions & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
"Label",
"TaskGroup",
"dag",
"Connection",
"__version__",
]

__version__ = "1.0.0.dev1"

if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.dag import DAG, dag
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.taskgroup import TaskGroup
Expand All @@ -43,6 +45,7 @@
"TaskGroup": ".definitions.taskgroup",
"EdgeModifier": ".definitions.edges",
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
}


Expand Down
17 changes: 15 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import sys
import uuid
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, TypeVar

import httpx
Expand All @@ -43,6 +44,8 @@
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser

Expand Down Expand Up @@ -161,9 +164,19 @@ class ConnectionOperations:
def __init__(self, client: Client):
self.client = client

def get(self, conn_id: str) -> ConnectionResponse:
def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
"""Get a connection from the API server."""
resp = self.client.get(f"connections/{conn_id}")
try:
resp = self.client.get(f"connections/{conn_id}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"Connection not found",
conn_id=conn_id,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id})
return ConnectionResponse.model_validate_json(resp.read())


Expand Down
52 changes: 52 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import attrs


@attrs.define
class Connection:
"""
A connection to an external data source.
:param conn_id: The connection ID.
:param conn_type: The connection type.
:param description: The connection description.
:param host: The host.
:param login: The login.
:param password: The password.
:param schema: The schema.
:param port: The port number.
:param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON
encoded object.
"""

conn_id: str
conn_type: str
description: str | None = None
host: str | None = None
schema: str | None = None
login: str | None = None
password: str | None = None
port: int | None = None
extra: str | None = None

def get_uri(self): ...

def get_hook(self): ...
21 changes: 21 additions & 0 deletions task_sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from airflow.sdk.execution_time.comms import ErrorResponse


class AirflowRuntimeError(Exception):
def __init__(self, error: ErrorResponse):
self.error = error
super().__init__(f"{error.error.value}: {error.detail}")


class ErrorType(enum.Enum):
CONNECTION_NOT_FOUND = "CONNECTION_NOT_FOUND"
VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
GENERIC_ERROR = "GENERIC_ERROR"
16 changes: 15 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType


class StartupDetails(BaseModel):
Expand All @@ -85,13 +86,26 @@ class XComResult(XComResponse):
class ConnectionResult(ConnectionResponse):
type: Literal["ConnectionResult"] = "ConnectionResult"

@classmethod
def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult:
# Exclude defaults to avoid sending unnecessary data
# Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True
# to avoid sending unset fields (which are defaults in our case).
return cls(**connection_response.model_dump(exclude_defaults=True), type="ConnectionResult")


class VariableResult(VariableResponse):
type: Literal["VariableResult"] = "VariableResult"


class ErrorResponse(BaseModel):
error: ErrorType = ErrorType.GENERIC_ERROR
detail: dict | None = None
type: Literal["ErrorResponse"] = "ErrorResponse"


ToTask = Annotated[
Union[StartupDetails, XComResult, ConnectionResult, VariableResult],
Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse],
Field(discriminator="type"),
]

Expand Down
78 changes: 78 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import structlog

from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.execution_time.comms import ConnectionResult


def _convert_connection_result_conn(conn_result: ConnectionResult):
from airflow.sdk.definitions.connection import Connection

# `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))


def _get_connection(conn_id: str) -> Connection:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.connection`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, ConnectionResult)
return _convert_connection_result_conn(msg)


class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

def __getattr__(self, conn_id: str) -> Any:
return _get_connection(conn_id)

def __repr__(self) -> str:
return "<ConnectionAccessor (dynamic access)>"

def __eq__(self, other):
if not isinstance(other, ConnectionAccessor):
return False
# All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor
return True

def get(self, conn_id: str, default_conn: Any = None) -> Any:
try:
return _get_connection(conn_id)
except AirflowRuntimeError as e:
if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
return default_conn
raise
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@

from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
IntermediateTIState,
TaskInstance,
TerminalTIState,
)
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -689,7 +692,11 @@ def _handle_request(self, msg, log):
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
Expand Down
13 changes: 11 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ToSupervisor,
ToTask,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand All @@ -53,6 +54,9 @@ class RuntimeTaskInstance(TaskInstance):
"""The Task Instance context from the API server, if any."""

def get_template_context(self):
# TODO: Move this to `airflow.sdk.execution_time.context`
# once we port the entire context logic from airflow/utils/context.py ?

# TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime()
context: dict[str, Any] = {
# From the Task Execution interface
Expand All @@ -63,6 +67,8 @@ def get_template_context(self):
"run_id": self.run_id,
"task": self.task,
"task_instance": self,
# TODO: Ensure that ti.log_url and such are available to use in context
# especially after removal of `conf` from Context.
"ti": self,
# "outlet_events": OutletEventAccessors(),
# "expanded_ti_count": expanded_ti_count,
Expand All @@ -73,14 +79,13 @@ def get_template_context(self):
# "prev_data_interval_end_success": get_prev_data_interval_end_success(),
# "prev_start_date_success": get_prev_start_date_success(),
# "prev_end_date_success": get_prev_end_date_success(),
# "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
# "var": {
# "json": VariableAccessor(deserialize_json=True),
# "value": VariableAccessor(deserialize_json=False),
# },
# "conn": ConnectionAccessor(),
"conn": ConnectionAccessor(),
}
if self._ti_context_from_server:
dag_run = self._ti_context_from_server.dag_run
Expand Down Expand Up @@ -108,6 +113,10 @@ def get_template_context(self):
context.update(context_from_server)
return context

def xcom_pull(self, *args, **kwargs): ...

def xcom_push(self, *args, **kwargs): ...


def parse(what: StartupDetails) -> RuntimeTaskInstance:
# TODO: Task-SDK:
Expand Down
Loading

0 comments on commit 4de24a1

Please sign in to comment.