Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 Migrate Trigger Dag Run endpoint to FastAPI #43875

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@mark_fastapi_migration_done
@security.requires_access_dag("POST", DagAccessEntity.RUN)
@action_logging
@provide_session
Expand Down
39 changes: 38 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from datetime import datetime
from enum import Enum

from pydantic import Field
from fastapi import HTTPException, status
from pydantic import AwareDatetime, Field, computed_field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.models import DagRun
from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -73,3 +76,37 @@ class DAGRunCollectionResponse(BaseModel):

dag_runs: list[DAGRunResponse]
total_entries: int


class TriggerDAGRunPostBody(BaseModel):
"""Trigger DAG Run Serializer for POST body."""

dag_run_id: str | None = None
data_interval_start: AwareDatetime | None = None
data_interval_end: AwareDatetime | None = None

conf: dict = Field(default_factory=dict)
note: str | None = None

model_config = {"extra": "forbid"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it'll be great to keep this consistent across all APIs. Legacy API throws 400 bad request on extra properties passed, while since FastAPI and Pydantic ignore extra properties, new APIs don't throw an error yet.

We could either forbid extra properties for all models or for none of them, so the users have a consistent view of the APIs. In new APIs, I suppose currently forbidding is being done only in CreateAssetEventsBody here.

Copy link
Collaborator Author

@rawwar rawwar Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I created a draft to start a conversation on this: #44306

Forgot to link this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice,

Maybe remove it in this PR so there is no special case. #44306 will introduce it for all or none.


@model_validator(mode="after")
def check_data_intervals(cls, values):
if (values.data_interval_start is None) != (values.data_interval_end is None):
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY,
"Either both data_interval_start and data_interval_end must be provided or both must be None",
)
Comment on lines +96 to +99
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic is taking care of formatting correctly 422 errors. Those have a whole specific structure.

You can just raise ValueError pydantic catches those.

return values

@model_validator(mode="after")
def validate_dag_run_id(self):
if not self.dag_run_id:
self.dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.logical_date)
Copy link
Member

@pierrejeambrun pierrejeambrun Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uranusjr If we remove the logical_date, then we have None for logical date when creating the dagrun with create_dagrun, and unfortunately it needs type + logical_date to infer the run_id. This is why we need to manually fill the run_id here in case it's not there, but I am not a big fan of it. (is that ok ?)

I assume this will be updated later when the logical_date change will take place => create_dagrun will be able to generate an appropriate run_id without providing logical_date ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, Shall I just remove logical_date from request body and use current time by default to generate run_id? Once updated made to create_dagrun, this can also be updated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's remove it from the request body.

return self

# Mypy issue https://github.com/python/mypy/issues/1362
@computed_field # type: ignore[misc]
@property
def logical_date(self) -> datetime:
return timezone.utcnow()
92 changes: 92 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,67 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
post:
tags:
- DagRun
summary: Trigger Dag Run
description: Trigger a DAG.
operationId: trigger_dag_run
parameters:
- name: dag_id
in: path
required: true
schema:
title: Dag Id
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/TriggerDAGRunPostBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRunResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'409':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Conflict
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dagSources/{dag_id}:
get:
tags:
Expand Down Expand Up @@ -7881,6 +7942,37 @@ components:
- microseconds
title: TimeDelta
description: TimeDelta can be used to interact with datetime.timedelta objects.
TriggerDAGRunPostBody:
properties:
dag_run_id:
anyOf:
- type: string
- type: 'null'
title: Dag Run Id
data_interval_start:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Data Interval Start
data_interval_end:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Data Interval End
conf:
type: object
title: Conf
note:
anyOf:
- type: string
- type: 'null'
title: Note
additionalProperties: false
type: object
title: TriggerDAGRunPostBody
description: Trigger DAG Run Serializer for POST body.
TriggerResponse:
properties:
id:
Expand Down
76 changes: 75 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import Annotated, cast

import pendulum
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session
Expand All @@ -45,13 +46,18 @@
DAGRunPatchBody,
DAGRunPatchStates,
DAGRunResponse,
TriggerDAGRunPostBody,
)
from airflow.api_fastapi.core_api.datamodels.task_instances import (
TaskInstanceCollectionResponse,
TaskInstanceResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.models import DAG, DagRun
from airflow.models import DAG, DagModel, DagRun
from airflow.models.dag_version import DagVersion
from airflow.timetables.base import DataInterval
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

dag_run_router = AirflowRouter(tags=["DagRun"], prefix="/dags/{dag_id}/dagRuns")

Expand Down Expand Up @@ -296,3 +302,71 @@ def get_dag_runs(
dag_runs=dag_runs,
total_entries=total_entries,
)


@dag_run_router.post(
"",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
status.HTTP_404_NOT_FOUND,
status.HTTP_409_CONFLICT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 409 needs to stay though to be reflected in the documentation.

]
),
)
def trigger_dag_run(
dag_id, body: TriggerDAGRunPostBody, request: Request, session: Annotated[Session, Depends(get_session)]
) -> DAGRunResponse:
"""Trigger a DAG."""
dm = session.scalar(select(DagModel).where(DagModel.is_active, DagModel.dag_id == dag_id).limit(1))
if not dm:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: '{dag_id}' not found")

if dm.has_import_errors:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"DAG with dag_id: '{dag_id}' has import errors and cannot be triggered",
)

run_id = body.dag_run_id
logical_date = pendulum.instance(body.logical_date)
dagrun_instance = session.scalar(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id).limit(1)
)

if not dagrun_instance:
try:
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)

if body.data_interval_start and body.data_interval_end:
data_interval = DataInterval(
start=pendulum.instance(body.data_interval_start),
end=pendulum.instance(body.data_interval_end),
)
else:
data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date)
dag_version = DagVersion.get_latest_version(dag.dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
logical_date=logical_date,
data_interval=data_interval,
state=DagRunState.QUEUED,
conf=body.conf,
external_trigger=True,
dag_version=dag_version,
session=session,
triggered_by=DagRunTriggeredByType.REST_API,
)
dag_run_note = body.note
if dag_run_note:
current_user_id = None # refer to https://github.com/apache/airflow/issues/43534
dag_run.note = (dag_run_note, current_user_id)
return dag_run
except ValueError as e:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))

raise HTTPException(
status.HTTP_409_CONFLICT,
f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{body.dag_run_id}' already exists",
)
Comment on lines +369 to +372
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DB duplicate entry exceptions are already handled by the application.

You can just always try to execute this code, if the database crashes with duplicate entry, a nice 409 errors will automatically be returned.

3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,9 @@ export type ConnectionServiceTestConnectionMutationResult = Awaited<
export type DagRunServiceClearDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.clearDagRun>
>;
export type DagRunServiceTriggerDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.triggerDagRun>
>;
export type PoolServicePostPoolMutationResult = Awaited<
ReturnType<typeof PoolService.postPool>
>;
Expand Down
44 changes: 44 additions & 0 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
PoolPostBody,
PoolPostBulkBody,
TaskInstancesBatchBody,
TriggerDAGRunPostBody,
VariableBody,
} from "../requests/types.gen";
import * as Common from "./common";
Expand Down Expand Up @@ -2416,6 +2417,49 @@ export const useDagRunServiceClearDagRun = <
}) as unknown as Promise<TData>,
...options,
});
/**
* Trigger Dag Run
* Trigger a DAG.
* @param data The data for the request.
* @param data.dagId
* @param data.requestBody
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
export const useDagRunServiceTriggerDagRun = <
TData = Common.DagRunServiceTriggerDagRunMutationResult,
TError = unknown,
TContext = unknown,
>(
options?: Omit<
UseMutationOptions<
TData,
TError,
{
dagId: unknown;
requestBody: TriggerDAGRunPostBody;
},
TContext
>,
"mutationFn"
>,
) =>
useMutation<
TData,
TError,
{
dagId: unknown;
requestBody: TriggerDAGRunPostBody;
},
TContext
>({
mutationFn: ({ dagId, requestBody }) =>
DagRunService.triggerDagRun({
dagId,
requestBody,
}) as unknown as Promise<TData>,
...options,
});
/**
* Post Pool
* Create a Pool.
Expand Down
59 changes: 59 additions & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4602,6 +4602,65 @@ export const $TimeDelta = {
"TimeDelta can be used to interact with datetime.timedelta objects.",
} as const;

export const $TriggerDAGRunPostBody = {
properties: {
dag_run_id: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Dag Run Id",
},
data_interval_start: {
anyOf: [
{
type: "string",
format: "date-time",
},
{
type: "null",
},
],
title: "Data Interval Start",
},
data_interval_end: {
anyOf: [
{
type: "string",
format: "date-time",
},
{
type: "null",
},
],
title: "Data Interval End",
},
conf: {
type: "object",
title: "Conf",
},
note: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Note",
},
},
additionalProperties: false,
type: "object",
title: "TriggerDAGRunPostBody",
description: "Trigger DAG Run Serializer for POST body.",
} as const;

export const $TriggerResponse = {
properties: {
id: {
Expand Down
Loading