Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 Migrate Trigger Dag Run endpoint to FastAPI #43875

Merged
merged 21 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@mark_fastapi_migration_done
@security.requires_access_dag("POST", DagAccessEntity.RUN)
@action_logging
@provide_session
Expand Down
39 changes: 38 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from datetime import datetime
from enum import Enum

from pydantic import Field
from fastapi import HTTPException, status
from pydantic import AwareDatetime, Field, computed_field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.models import DagRun
from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -73,3 +76,37 @@ class DAGRunCollectionResponse(BaseModel):

dag_runs: list[DAGRunResponse]
total_entries: int


class TriggerDAGRunPostBody(BaseModel):
"""Trigger DAG Run Serializer for POST body."""

dag_run_id: str | None = None
data_interval_start: AwareDatetime | None = None
data_interval_end: AwareDatetime | None = None

conf: dict = Field(default_factory=dict)
note: str | None = None

model_config = {"extra": "forbid"}
rawwar marked this conversation as resolved.
Show resolved Hide resolved

@model_validator(mode="after")
def check_data_intervals(cls, values):
if (values.data_interval_start is None) != (values.data_interval_end is None):
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY,
"Either both data_interval_start and data_interval_end must be provided or both must be None",
)
rawwar marked this conversation as resolved.
Show resolved Hide resolved
return values

@model_validator(mode="after")
def validate_dag_run_id(self):
if not self.dag_run_id:
self.dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.logical_date)
rawwar marked this conversation as resolved.
Show resolved Hide resolved
return self

# Mypy issue https://github.com/python/mypy/issues/1362
@computed_field # type: ignore[misc]
@property
def logical_date(self) -> datetime:
return timezone.utcnow()
92 changes: 92 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,67 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
post:
tags:
- DagRun
summary: Trigger Dag Run
description: Trigger a DAG.
operationId: trigger_dag_run
parameters:
- name: dag_id
in: path
required: true
schema:
title: Dag Id
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/TriggerDAGRunPostBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRunResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'409':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Conflict
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dagSources/{dag_id}:
get:
tags:
Expand Down Expand Up @@ -7881,6 +7942,37 @@ components:
- microseconds
title: TimeDelta
description: TimeDelta can be used to interact with datetime.timedelta objects.
TriggerDAGRunPostBody:
properties:
dag_run_id:
anyOf:
- type: string
- type: 'null'
title: Dag Run Id
data_interval_start:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Data Interval Start
data_interval_end:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Data Interval End
conf:
type: object
title: Conf
note:
anyOf:
- type: string
- type: 'null'
title: Note
additionalProperties: false
type: object
title: TriggerDAGRunPostBody
description: Trigger DAG Run Serializer for POST body.
TriggerResponse:
properties:
id:
Expand Down
76 changes: 75 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import Annotated, cast

import pendulum
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session
Expand All @@ -45,13 +46,18 @@
DAGRunPatchBody,
DAGRunPatchStates,
DAGRunResponse,
TriggerDAGRunPostBody,
)
from airflow.api_fastapi.core_api.datamodels.task_instances import (
TaskInstanceCollectionResponse,
TaskInstanceResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.models import DAG, DagRun
from airflow.models import DAG, DagModel, DagRun
from airflow.models.dag_version import DagVersion
from airflow.timetables.base import DataInterval
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

dag_run_router = AirflowRouter(tags=["DagRun"], prefix="/dags/{dag_id}/dagRuns")

Expand Down Expand Up @@ -296,3 +302,71 @@ def get_dag_runs(
dag_runs=dag_runs,
total_entries=total_entries,
)


@dag_run_router.post(
"",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
status.HTTP_404_NOT_FOUND,
status.HTTP_409_CONFLICT,
rawwar marked this conversation as resolved.
Show resolved Hide resolved
]
),
)
def trigger_dag_run(
dag_id, body: TriggerDAGRunPostBody, request: Request, session: Annotated[Session, Depends(get_session)]
) -> DAGRunResponse:
"""Trigger a DAG."""
dm = session.scalar(select(DagModel).where(DagModel.is_active, DagModel.dag_id == dag_id).limit(1))
if not dm:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: '{dag_id}' not found")

if dm.has_import_errors:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"DAG with dag_id: '{dag_id}' has import errors and cannot be triggered",
)

run_id = body.dag_run_id
logical_date = pendulum.instance(body.logical_date)
dagrun_instance = session.scalar(
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id).limit(1)
)

if not dagrun_instance:
try:
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)

if body.data_interval_start and body.data_interval_end:
data_interval = DataInterval(
start=pendulum.instance(body.data_interval_start),
end=pendulum.instance(body.data_interval_end),
)
else:
data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date)
dag_version = DagVersion.get_latest_version(dag.dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
logical_date=logical_date,
data_interval=data_interval,
state=DagRunState.QUEUED,
conf=body.conf,
external_trigger=True,
dag_version=dag_version,
session=session,
triggered_by=DagRunTriggeredByType.REST_API,
)
dag_run_note = body.note
if dag_run_note:
current_user_id = None # refer to https://github.com/apache/airflow/issues/43534
dag_run.note = (dag_run_note, current_user_id)
return dag_run
except ValueError as e:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))

raise HTTPException(
status.HTTP_409_CONFLICT,
f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{body.dag_run_id}' already exists",
)
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,9 @@ export type ConnectionServiceTestConnectionMutationResult = Awaited<
export type DagRunServiceClearDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.clearDagRun>
>;
export type DagRunServiceTriggerDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.triggerDagRun>
>;
export type PoolServicePostPoolMutationResult = Awaited<
ReturnType<typeof PoolService.postPool>
>;
Expand Down
44 changes: 44 additions & 0 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
PoolPostBody,
PoolPostBulkBody,
TaskInstancesBatchBody,
TriggerDAGRunPostBody,
VariableBody,
} from "../requests/types.gen";
import * as Common from "./common";
Expand Down Expand Up @@ -2416,6 +2417,49 @@ export const useDagRunServiceClearDagRun = <
}) as unknown as Promise<TData>,
...options,
});
/**
* Trigger Dag Run
* Trigger a DAG.
* @param data The data for the request.
* @param data.dagId
* @param data.requestBody
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
export const useDagRunServiceTriggerDagRun = <
TData = Common.DagRunServiceTriggerDagRunMutationResult,
TError = unknown,
TContext = unknown,
>(
options?: Omit<
UseMutationOptions<
TData,
TError,
{
dagId: unknown;
requestBody: TriggerDAGRunPostBody;
},
TContext
>,
"mutationFn"
>,
) =>
useMutation<
TData,
TError,
{
dagId: unknown;
requestBody: TriggerDAGRunPostBody;
},
TContext
>({
mutationFn: ({ dagId, requestBody }) =>
DagRunService.triggerDagRun({
dagId,
requestBody,
}) as unknown as Promise<TData>,
...options,
});
/**
* Post Pool
* Create a Pool.
Expand Down
59 changes: 59 additions & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4602,6 +4602,65 @@ export const $TimeDelta = {
"TimeDelta can be used to interact with datetime.timedelta objects.",
} as const;

export const $TriggerDAGRunPostBody = {
properties: {
dag_run_id: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Dag Run Id",
},
data_interval_start: {
anyOf: [
{
type: "string",
format: "date-time",
},
{
type: "null",
},
],
title: "Data Interval Start",
},
data_interval_end: {
anyOf: [
{
type: "string",
format: "date-time",
},
{
type: "null",
},
],
title: "Data Interval End",
},
conf: {
type: "object",
title: "Conf",
},
note: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Note",
},
},
additionalProperties: false,
type: "object",
title: "TriggerDAGRunPostBody",
description: "Trigger DAG Run Serializer for POST body.",
} as const;

export const $TriggerResponse = {
properties: {
id: {
Expand Down
Loading