diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 5074504b8d7fb..243fcfa2847d1 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -25,6 +25,7 @@ from flask import Response from airflow.jobs.job import Job, most_recent_job +from airflow.models.taskinstance import _get_template_context from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.session import create_session @@ -48,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]: from airflow.utils.log.file_task_handler import FileTaskHandler functions: list[Callable] = [ + _get_template_context, DagFileProcessor.update_import_errors, DagFileProcessor.manage_slas, DagFileProcessorManager.deactivate_stale_dags, diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 14fc0fc8f7ffb..e7fdc5bec179c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -597,6 +597,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti task_instance.next_kwargs = None +@internal_api_call def _get_template_context( *, task_instance: TaskInstance | TaskInstancePydantic, @@ -623,10 +624,30 @@ def _get_template_context( task = task_instance.task if TYPE_CHECKING: + assert task_instance.task assert task assert task.dag - dag: DAG = task.dag + try: + dag: DAG = task.dag + except AirflowException: + from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic + if isinstance(task_instance, TaskInstancePydantic): + ti = session.scalar( + select(TaskInstance).where( + TaskInstance.task_id == task_instance.task_id, + TaskInstance.dag_id == task_instance.dag_id, + TaskInstance.run_id == task_instance.run_id, + TaskInstance.map_index == task_instance.map_index, + ) + ) + dag = ti.dag_model.serialized_dag.dag + if hasattr(task_instance.task, "_dag"): # BaseOperator + task_instance.task._dag = dag + else: # MappedOperator + task_instance.task.dag = dag + else: + raise dag_run = task_instance.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 2a4387eeb4809..9b7cdbcc738ad 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -61,4 +61,5 @@ class DagAttributeTypes(str, Enum): DATA_SET = "data_set" LOG_TEMPLATE = "log_template" CONNECTION = "connection" + TASK_CONTEXT = "task_context" ARG_NOT_SET = "arg_not_set" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 16a5c9e481d1d..98d3d3a654214 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -67,6 +67,7 @@ airflow_priority_weight_strategies_classes, ) from airflow.utils.code_utils import get_python_source +from airflow.utils.context import Context from airflow.utils.docs import get_docs_url from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources @@ -602,6 +603,12 @@ def serialize( ) elif isinstance(var, Connection): return cls._encode(var.to_dict(validate=True), type_=DAT.CONNECTION) + elif var.__class__ == Context: + d = {} + for k, v in var._context.items(): + obj = cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) + d[str(k)] = obj + return cls._encode(d, type_=DAT.TASK_CONTEXT) elif use_pydantic_models and _ENABLE_AIP_44: def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]: @@ -648,7 +655,14 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}") var = encoded_var[Encoding.VAR] type_ = encoded_var[Encoding.TYPE] - + if type_ == DAT.TASK_CONTEXT: + d = {} + for k, v in var.items(): + if k == "task": # todo: add TaskPydantic so we don't need this? + continue + d[k] = cls.deserialize(v, use_pydantic_models=True) + d["task"] = d["task_instance"].task # todo: add TaskPydantic so we don't need this? + return Context(**d) if type_ == DAT.DICT: return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()} elif type_ == DAT.DAG: