Skip to content

Commit

Permalink
Remove deprecated SubDags (apache#41390)
Browse files Browse the repository at this point in the history
This PR removes SubDags in favor of TaskGroups fro Airflow 3.0

Subdags have been removed from the following locations:

- CLI
- API
- ``SubDagOperator``

This removal marks the end of Subdag support across all interfaces. Users
should transition to using TaskGroups as a more efficient and maintainable
alternative.

---------
Co-authored-by: Brent Bovenzi <[email protected]>
  • Loading branch information
kaxil authored Aug 13, 2024
1 parent 736ebfe commit 6570c6d
Show file tree
Hide file tree
Showing 88 changed files with 1,397 additions and 3,918 deletions.
26 changes: 3 additions & 23 deletions airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import logging
from typing import TYPE_CHECKING

from sqlalchemy import and_, delete, or_, select
from sqlalchemy import delete, select

from airflow import models
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DagModel, TaskFail
from airflow.models import DagModel
from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import get_sqla_model_classes
Expand Down Expand Up @@ -64,18 +64,6 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
if dag is None:
raise DagNotFound(f"Dag id {dag_id} not found")

# deleting a DAG should also delete all of its subdags
dags_to_delete_query = session.execute(
select(DagModel.dag_id).where(
or_(
DagModel.dag_id == dag_id,
and_(DagModel.dag_id.like(f"{dag_id}.%"), DagModel.is_subdag),
)
)
)

dags_to_delete = [dag_id for (dag_id,) in dags_to_delete_query]

# Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
# There may be a lag, so explicitly removes serialized DAG here.
if SerializedDagModel.has_dag(dag_id=dag_id, session=session):
Expand All @@ -86,15 +74,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session =
for model in get_sqla_model_classes():
if hasattr(model, "dag_id") and (not keep_records_in_log or model.__name__ != "Log"):
count += session.execute(
delete(model)
.where(model.dag_id.in_(dags_to_delete))
.execution_options(synchronize_session="fetch")
).rowcount
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += session.execute(
delete(model).where(model.dag_id == parent_dag_id, model.task_id == task_id)
delete(model).where(model.dag_id == dag_id).execution_options(synchronize_session="fetch")
).rowcount

# Delete entries in Import Errors table for a deleted DAG
Expand Down
106 changes: 3 additions & 103 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@

from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.subdag import SubDagOperator
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from datetime import datetime
Expand All @@ -40,6 +38,7 @@

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.utils.types import DagRunType


class _DagRunInfo(NamedTuple):
Expand Down Expand Up @@ -101,14 +100,14 @@ def set_state(
Can set state for future tasks (calculated from run_id) and retroactively
for past tasks. Will verify integrity of past dag runs in order to create
tasks that did not exist. It will not create dag runs that are missing
on the schedule (but it will, as for subdag, dag runs if needed).
on the schedule.
:param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
``task.dag`` needs to be set
:param run_id: the run_id of the dagrun to start looking from
:param execution_date: the execution date from which to start looking (deprecated)
:param upstream: Mark all parents (upstream tasks)
:param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
:param downstream: Mark all siblings (downstream tasks) of task_id
:param future: Mark all future tasks on the interval of the dag up until
last execution date.
:param past: Retroactively mark all tasks starting from start_date of the DAG
Expand Down Expand Up @@ -140,54 +139,20 @@ def set_state(

dag_run_ids = get_run_ids(dag, run_id, future, past, session=session)
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
task_ids = [task_id if isinstance(task_id, str) else task_id[0] for task_id in task_id_map_index_list]

confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids, session=session))
confirmed_dates = [info.logical_date for info in confirmed_infos]

sub_dag_run_ids = (
list(
_iter_subdag_run_ids(dag, session, DagRunState(state), task_ids, commit, confirmed_infos),
)
if not state == TaskInstanceState.SKIPPED
else []
)

# now look for the task instances that are affected

qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, dag_run_ids)

if commit:
tis_altered = session.scalars(qry_dag.with_for_update()).all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += session.scalars(qry_sub_dag.with_for_update()).all()
for task_instance in tis_altered:
task_instance.set_state(state, session=session)
session.flush()
else:
tis_altered = session.scalars(qry_dag).all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += session.scalars(qry_sub_dag).all()
return tis_altered


def all_subdag_tasks_query(
sub_dag_run_ids: list[str],
session: SASession,
state: TaskInstanceState,
confirmed_dates: Iterable[datetime],
):
"""Get *all* tasks of the sub dags."""
qry_sub_dag = (
select(TaskInstance)
.where(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.where(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_sub_dag


def get_all_dag_task_query(
dag: DAG,
session: SASession,
Expand All @@ -208,71 +173,6 @@ def get_all_dag_task_query(
return qry_dag


def _iter_subdag_run_ids(
dag: DAG,
session: SASession,
state: DagRunState,
task_ids: list[str],
commit: bool,
confirmed_infos: Iterable[_DagRunInfo],
) -> Iterator[str]:
"""
Go through subdag operators and create dag runs.
We only work within the scope of the subdag. A subdag does not propagate to
its parent DAG, but parent propagates to subdags.
"""
dags = [dag]
while dags:
current_dag = dags.pop()
for task_id in task_ids:
if not current_dag.has_task(task_id):
continue

current_task = current_dag.get_task(task_id)
if isinstance(current_task, SubDagOperator) or current_task.task_type == "SubDagOperator":
# this works as a kind of integrity check
# it creates missing dag runs for subdag operators,
# maybe this should be moved to dagrun.verify_integrity
if TYPE_CHECKING:
assert current_task.subdag
dag_runs = _create_dagruns(
current_task.subdag,
infos=confirmed_infos,
state=DagRunState.RUNNING,
run_type=DagRunType.BACKFILL_JOB,
)

verify_dagruns(dag_runs, commit, state, session, current_task)

dags.append(current_task.subdag)
yield current_task.subdag.dag_id


def verify_dagruns(
dag_runs: Iterable[DagRun],
commit: bool,
state: DagRunState,
session: SASession,
current_task: Operator,
):
"""
Verify integrity of dag_runs.
:param dag_runs: dag runs to verify
:param commit: whether dag runs state should be updated
:param state: state of the dag_run to set if commit is True
:param session: session to use
:param current_task: current task
"""
for dag_run in dag_runs:
dag_run.dag = current_task.subdag
dag_run.verify_integrity()
if commit:
dag_run.state = state
session.merge(dag_run)


def _iter_existing_dag_run_infos(dag: DAG, run_ids: list[str], session: SASession) -> Iterator[_DagRunInfo]:
for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, session=session):
dag_run.dag = dag
Expand Down
30 changes: 13 additions & 17 deletions airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _trigger_dag(
conf: dict | str | None = None,
execution_date: datetime | None = None,
replace_microseconds: bool = True,
) -> list[DagRun | None]:
) -> DagRun | None:
"""
Triggers DAG run.
Expand Down Expand Up @@ -90,21 +90,17 @@ def _trigger_dag(
if conf:
run_conf = conf if isinstance(conf, dict) else json.loads(conf)

dag_runs = []
dags_to_run = [dag, *dag.subdags]
for _dag in dags_to_run:
dag_run = _dag.create_dagrun(
run_id=run_id,
execution_date=execution_date,
state=DagRunState.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=dag_bag.dags_hash.get(dag_id),
data_interval=data_interval,
)
dag_runs.append(dag_run)
dag_run = dag.create_dagrun(
run_id=run_id,
execution_date=execution_date,
state=DagRunState.QUEUED,
conf=run_conf,
external_trigger=True,
dag_hash=dag_bag.dags_hash.get(dag_id),
data_interval=data_interval,
)

return dag_runs
return dag_run


@internal_api_call
Expand Down Expand Up @@ -133,7 +129,7 @@ def trigger_dag(
raise DagNotFound(f"Dag id {dag_id} not found in DagModel")

dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
triggers = _trigger_dag(
dr = _trigger_dag(
dag_id=dag_id,
dag_bag=dagbag,
run_id=run_id,
Expand All @@ -142,4 +138,4 @@ def trigger_dag(
replace_microseconds=replace_microseconds,
)

return triggers[0] if triggers else None
return dr if dr else None
7 changes: 3 additions & 4 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_dags(
) -> APIResponse:
"""Get all DAGs."""
allowed_attrs = ["dag_id"]
dags_query = select(DagModel).where(~DagModel.is_subdag)
dags_query = select(DagModel)
if only_active:
dags_query = dags_query.where(DagModel.is_active)
if paused is not None:
Expand Down Expand Up @@ -179,10 +179,9 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
update_mask = update_mask[0]
patch_body_[update_mask] = patch_body[update_mask]
patch_body = patch_body_
dags_query = select(DagModel)
if only_active:
dags_query = select(DagModel).where(~DagModel.is_subdag, DagModel.is_active)
else:
dags_query = select(DagModel).where(~DagModel.is_subdag)
dags_query = dags_query.where(DagModel.is_active)

if dag_id_pattern == "~":
dag_id_pattern = "%"
Expand Down
4 changes: 0 additions & 4 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,6 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
start_date=start_date,
end_date=end_date,
task_ids=None,
include_subdags=True,
include_parentdag=True,
only_failed=False,
dry_run=True,
)
Expand All @@ -438,8 +436,6 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
start_date=start_date,
end_date=end_date,
task_ids=None,
include_subdags=True,
include_parentdag=True,
only_failed=False,
)
dag_run = session.execute(select(DagRun).where(DagRun.id == dag_run.id)).scalar_one()
Expand Down
17 changes: 0 additions & 17 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3106,11 +3106,6 @@ components:
Human centric display text for the DAG.
*New in version 2.9.0*
root_dag_id:
type: string
readOnly: true
nullable: true
description: If the DAG is SubDAG then it is the top level DAG identifier. Otherwise, null.
is_paused:
type: boolean
nullable: true
Expand All @@ -3125,10 +3120,6 @@ components:
nullable: true
readOnly: true
type: boolean
is_subdag:
description: Whether the DAG is SubDAG.
type: boolean
readOnly: true
last_parsed_time:
type: string
format: date-time
Expand Down Expand Up @@ -4903,14 +4894,6 @@ components:
type: boolean
default: false

include_subdags:
description: Clear tasks in subdags and clear external tasks indicated by ExternalTaskMarker.
type: boolean

include_parentdag:
description: Clear tasks in the parent dag of the subdag.
type: boolean

reset_dag_runs:
description: Set state of DAG runs to RUNNING.
type: boolean
Expand Down
2 changes: 0 additions & 2 deletions airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ class Meta:

dag_id = auto_field(dump_only=True)
dag_display_name = fields.String(attribute="dag_display_name", dump_only=True)
root_dag_id = auto_field(dump_only=True)
is_paused = auto_field()
is_active = auto_field(dump_only=True)
is_subdag = auto_field(dump_only=True)
last_parsed_time = auto_field(dump_only=True)
last_pickled = auto_field(dump_only=True)
last_expired = auto_field(dump_only=True)
Expand Down
2 changes: 0 additions & 2 deletions airflow/api_connexion/schemas/task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ class ClearTaskInstanceFormSchema(Schema):
end_date = fields.DateTime(load_default=None, validate=validate_istimezone)
only_failed = fields.Boolean(load_default=True)
only_running = fields.Boolean(load_default=False)
include_subdags = fields.Boolean(load_default=False)
include_parentdag = fields.Boolean(load_default=False)
reset_dag_runs = fields.Boolean(load_default=False)
task_ids = fields.List(fields.String(), validate=validate.Length(min=1))
dag_run_id = fields.Str(load_default=None)
Expand Down
2 changes: 0 additions & 2 deletions airflow/api_connexion/schemas/task_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TimeDeltaSchema,
WeightRuleField,
)
from airflow.api_connexion.schemas.dag_schema import DAGSchema
from airflow.models.mappedoperator import MappedOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +60,6 @@ class TaskSchema(Schema):
ui_color = ColorField(dump_only=True)
ui_fgcolor = ColorField(dump_only=True)
template_fields = fields.List(fields.String(), dump_only=True)
sub_dag = fields.Nested(DAGSchema, dump_only=True)
downstream_task_ids = fields.List(fields.String(), dump_only=True)
params = fields.Method("_get_params", dump_only=True)
is_mapped = fields.Method("_get_is_mapped", dump_only=True)
Expand Down
Loading

0 comments on commit 6570c6d

Please sign in to comment.