Skip to content

Commit

Permalink
Can update RenderedTaskInstanceFields over RPC (#38565)
Browse files Browse the repository at this point in the history
What I'm doing is, i'm separating out the "collecting" of the rendered fields from the "saving" of the rendered fields. This way they can be collected on the worker side, then sent over the wire where the rpc server just stores them. previously, collection and storage were all done on the RTIF object itself.
  • Loading branch information
dstandish authored Apr 2, 2024
1 parent c443971 commit 40dbe4b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 14 deletions.
3 changes: 2 additions & 1 deletion airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flask import Response

from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _get_template_context
from airflow.models.taskinstance import _get_template_context, _update_rtif
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session

Expand All @@ -50,6 +50,7 @@ def _initialize_map() -> dict[str, Callable]:

functions: list[Callable] = [
_get_template_context,
_update_rtif,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand Down
21 changes: 16 additions & 5 deletions airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,23 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import FromClause

from airflow.models import Operator
from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic


def get_serialized_template_fields(task: Operator):
"""
Get and serialize the template fields for a task.
Used in preparing to store them in RTIF table.
:param task: Operator instance with rendered template fields
:meta private:
"""
return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields}


class RenderedTaskInstanceFields(TaskInstanceDependencies):
"""Save Rendered Template Fields."""

Expand Down Expand Up @@ -101,7 +115,7 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies):

execution_date = association_proxy("dag_run", "execution_date")

def __init__(self, ti: TaskInstance, render_templates=True):
def __init__(self, ti: TaskInstance, render_templates=True, rendered_fields=None):
self.dag_id = ti.dag_id
self.task_id = ti.task_id
self.run_id = ti.run_id
Expand All @@ -120,10 +134,7 @@ def __init__(self, ti: TaskInstance, render_templates=True):
from airflow.providers.cncf.kubernetes.template_rendering import render_k8s_pod_yaml

self.k8s_pod_yaml = render_k8s_pod_yaml(ti)
self.rendered_fields = {
field: serialize_template_field(getattr(self.task, field), field)
for field in self.task.template_fields
}
self.rendered_fields = rendered_fields or get_serialized_template_fields(task=ti.task)

self._redact()

Expand Down
19 changes: 13 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskmap import TaskMap
Expand Down Expand Up @@ -1278,6 +1279,16 @@ def _get_previous_ti(
return dagrun.get_task_instance(task_instance.task_id, session=session)


@internal_api_call
@provide_session
def _update_rtif(ti, rendered_fields, session: Session | None = None):
from airflow.models.renderedtifields import RenderedTaskInstanceFields

rtif = RenderedTaskInstanceFields(ti=ti, render_templates=False, rendered_fields=rendered_fields)
RenderedTaskInstanceFields.write(rtif, session=session)
RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)


class TaskInstance(Base, LoggingMixin):
"""
Task instances store the state of a task instance.
Expand Down Expand Up @@ -2628,8 +2639,6 @@ def _register_dataset_changes(self, *, events: DatasetEventAccessors, session: S

def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
"""Prepare Task for Execution."""
from airflow.models.renderedtifields import RenderedTaskInstanceFields

if TYPE_CHECKING:
assert self.task

Expand Down Expand Up @@ -2670,10 +2679,8 @@ def signal_handler(signum, frame):
task_orig = self.render_templates(context=context, jinja_env=jinja_env)

if not test_mode:
rtif = RenderedTaskInstanceFields(ti=self, render_templates=False)
RenderedTaskInstanceFields.write(rtif)
RenderedTaskInstanceFields.delete_old_records(self.task_id, self.dag_id)

rendered_fields = get_serialized_template_fields(task=self.task)
_update_rtif(ti=self, rendered_fields=rendered_fields)
# Export context to make it available for operators to use.
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
os.environ.update(airflow_context_vars)
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,10 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
if type_ == DAT.TASK_CONTEXT:
d = {}
for k, v in var.items():
if k == "task": # todo: add TaskPydantic so we don't need this?
if k == "task": # todo: add `_encode` of Operator so we don't need this
continue
d[k] = cls.deserialize(v, use_pydantic_models=True)
d["task"] = d["task_instance"].task # todo: add TaskPydantic so we don't need this?
d["task"] = d["task_instance"].task # todo: add `_encode` of Operator so we don't need this
return Context(**d)
if type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
Expand Down

0 comments on commit 40dbe4b

Please sign in to comment.