Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate public endpoint Patch Task Instance to FastAPI #44223

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def set_mapped_task_instance_note(
return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index)


@mark_fastapi_migration_done
@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE)
@action_logging
@provide_session
Expand Down
29 changes: 29 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
ConfigDict,
Field,
NonNegativeInt,
StringConstraints,
ValidationError,
field_validator,
model_validator,
)

Expand Down Expand Up @@ -193,6 +195,33 @@ def validate_model(cls, data: Any) -> Any:
return data


class PatchTaskInstanceBody(BaseModel):
"""Request body for Clear Task Instances endpoint."""

dry_run: bool = True
new_state: str | None = None
note: Annotated[str, StringConstraints(max_length=1000)] | None = None
include_upstream: bool = False
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved
include_downstream: bool = False
include_future: bool = False
include_past: bool = False

@field_validator("new_state", mode="before")
@classmethod
def validate_new_state(cls, ns: str | None) -> str:
"""Validate new_state."""
valid_states = [
vs.name.lower()
for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED)
]
if ns is None:
raise ValueError("'new_state' should not be empty")
ns = ns.lower()
if ns not in valid_states:
raise ValueError(f"'{ns}' is not one of {valid_states}")
return ns


class TaskInstanceReferenceResponse(BaseModel):
"""Task Instance Reference serializer for responses."""

Expand Down
205 changes: 205 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3656,6 +3656,91 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
patch:
tags:
- Task Instance
summary: Patch Task Instance
description: Update the state of a task instance.
operationId: patch_task_instance
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: dag_run_id
in: path
required: true
schema:
type: string
title: Dag Run Id
- name: task_id
in: path
required: true
schema:
type: string
title: Task Id
- name: map_index
in: query
required: false
schema:
type: integer
default: -1
title: Map Index
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PatchTaskInstanceBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped:
get:
tags:
Expand Down Expand Up @@ -4103,6 +4188,90 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
patch:
tags:
- Task Instance
summary: Patch Task Instance
description: Update the state of a task instance.
operationId: patch_task_instance
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: dag_run_id
in: path
required: true
schema:
type: string
title: Dag Run Id
- name: task_id
in: path
required: true
schema:
type: string
title: Task Id
- name: map_index
in: path
required: true
schema:
type: integer
title: Map Index
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PatchTaskInstanceBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/TaskInstanceResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances:
get:
tags:
Expand Down Expand Up @@ -7053,6 +7222,42 @@ components:
- unixname
title: JobResponse
description: Job serializer for responses.
PatchTaskInstanceBody:
properties:
dry_run:
type: boolean
title: Dry Run
default: true
new_state:
anyOf:
- type: string
- type: 'null'
title: New State
note:
anyOf:
- type: string
maxLength: 1000
- type: 'null'
title: Note
include_upstream:
type: boolean
title: Include Upstream
default: false
include_downstream:
type: boolean
title: Include Downstream
default: false
include_future:
type: boolean
title: Include Future
default: false
include_past:
type: boolean
title: Include Past
default: false
type: object
title: PatchTaskInstanceBody
description: Request body for Clear Task Instances endpoint.
PluginCollectionResponse:
properties:
plugins:
Expand Down
91 changes: 89 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

from typing import Annotated, Literal, cast

from fastapi import Depends, HTTPException, Request, status
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import or_, select
from sqlalchemy.sql.selectable import Select

from airflow.api_fastapi.common.db.common import get_session, paginated_select
Expand Down Expand Up @@ -50,6 +51,7 @@
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.task_instances import (
ClearTaskInstancesBody,
PatchTaskInstanceBody,
TaskDependencyCollectionResponse,
TaskInstanceCollectionResponse,
TaskInstanceHistoryCollectionResponse,
Expand Down Expand Up @@ -600,3 +602,88 @@ def post_clear_task_instances(
],
total_entries=len(task_instances),
)


@task_instances_router.patch(
task_instances_prefix + "/{task_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
@task_instances_router.patch(
task_instances_prefix + "/{task_id}/{map_index}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]),
)
def patch_task_instance(
dag_id: str,
dag_run_id: str,
task_id: str,
request: Request,
body: PatchTaskInstanceBody,
session: Annotated[Session, Depends(get_session)],
map_index: int = -1,
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved
update_mask: list[str] | None = Query(None),
) -> TaskInstanceResponse:
"""Update the state of a task instance."""
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")

if not dag.has_task(task_id):
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'")

query = (
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
.join(TI.dag_run)
.options(joinedload(TI.rendered_task_instance_fields))
)
if map_index == -1:
query = query.where(or_(TI.map_index == -1, TI.map_index is None))
else:
query = query.where(TI.map_index == map_index)

try:
ti = session.scalar(query)
except MultipleResultsFound:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
)

err_msg_404 = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}"
if ti is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)

fields_to_update = body.model_fields_set
if update_mask:
fields_to_update = fields_to_update.intersection(update_mask)

for field in fields_to_update:
if field == "new_state":
if not body.dry_run:
tis: list[TI] = dag.set_task_instance_state(
omkar-foss marked this conversation as resolved.
Show resolved Hide resolved
task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index],
state=body.new_state,
upstream=body.include_upstream,
downstream=body.include_downstream,
future=body.include_future,
past=body.include_past,
commit=True,
session=session,
)
if not ti:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
ti = tis[0] if isinstance(tis, list) else tis
elif field == "note":
if update_mask or body.note is not None:
# @TODO: replace None passed for user_id with actual user id when
# permissions and auth is in place.
if ti.task_instance_note is None:
ti.note = (body.note, None)
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = None
session.commit()

return TaskInstanceResponse.model_validate(ti, from_attributes=True)
Loading