Skip to content

Commit

Permalink
Make _get_template_context an RPC call (#38567)
Browse files Browse the repository at this point in the history
Provide way of serializing the template context over RPC
  • Loading branch information
dstandish authored Apr 2, 2024
1 parent ab5aabe commit 0010bf1
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 2 deletions.
2 changes: 2 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flask import Response

from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _get_template_context
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session

Expand All @@ -48,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.utils.log.file_task_handler import FileTaskHandler

functions: list[Callable] = [
_get_template_context,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand Down
23 changes: 22 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti
task_instance.next_kwargs = None


@internal_api_call
def _get_template_context(
*,
task_instance: TaskInstance | TaskInstancePydantic,
Expand All @@ -623,10 +624,30 @@ def _get_template_context(

task = task_instance.task
if TYPE_CHECKING:
assert task_instance.task
assert task
assert task.dag
dag: DAG = task.dag
try:
dag: DAG = task.dag
except AirflowException:
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(task_instance, TaskInstancePydantic):
ti = session.scalar(
select(TaskInstance).where(
TaskInstance.task_id == task_instance.task_id,
TaskInstance.dag_id == task_instance.dag_id,
TaskInstance.run_id == task_instance.run_id,
TaskInstance.map_index == task_instance.map_index,
)
)
dag = ti.dag_model.serialized_dag.dag
if hasattr(task_instance.task, "_dag"): # BaseOperator
task_instance.task._dag = dag
else: # MappedOperator
task_instance.task.dag = dag
else:
raise
dag_run = task_instance.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)

Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ class DagAttributeTypes(str, Enum):
DATA_SET = "data_set"
LOG_TEMPLATE = "log_template"
CONNECTION = "connection"
TASK_CONTEXT = "task_context"
ARG_NOT_SET = "arg_not_set"
16 changes: 15 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
airflow_priority_weight_strategies_classes,
)
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import Context
from airflow.utils.docs import get_docs_url
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
Expand Down Expand Up @@ -602,6 +603,12 @@ def serialize(
)
elif isinstance(var, Connection):
return cls._encode(var.to_dict(validate=True), type_=DAT.CONNECTION)
elif var.__class__ == Context:
d = {}
for k, v in var._context.items():
obj = cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models)
d[str(k)] = obj
return cls._encode(d, type_=DAT.TASK_CONTEXT)
elif use_pydantic_models and _ENABLE_AIP_44:

def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -648,7 +655,14 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}")
var = encoded_var[Encoding.VAR]
type_ = encoded_var[Encoding.TYPE]

if type_ == DAT.TASK_CONTEXT:
d = {}
for k, v in var.items():
if k == "task": # todo: add TaskPydantic so we don't need this?
continue
d[k] = cls.deserialize(v, use_pydantic_models=True)
d["task"] = d["task_instance"].task # todo: add TaskPydantic so we don't need this?
return Context(**d)
if type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.DAG:
Expand Down

0 comments on commit 0010bf1

Please sign in to comment.