Skip to content

Commit

Permalink
Use StringConstraints, add update_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
omkar-foss committed Nov 21, 2024
1 parent f6c3f24 commit 3cb7311
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 47 deletions.
27 changes: 6 additions & 21 deletions airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Annotated, Any
from typing import Annotated

from pydantic import (
AliasPath,
Expand All @@ -27,8 +27,8 @@
ConfigDict,
Field,
NonNegativeInt,
StringConstraints,
field_validator,
model_validator,
)

from airflow.api_fastapi.core_api.datamodels.job import JobResponse
Expand Down Expand Up @@ -159,33 +159,18 @@ class PatchTaskInstanceBody(BaseModel):

dry_run: bool = True
new_state: str | None = None
note: str | None = None

@model_validator(mode="before")
@classmethod
def validate_model(cls, data: Any) -> Any:
if data.get("note") is None and data.get("new_state") is None:
raise ValueError("new_state is required.")
return data

@field_validator("note", mode="before")
@classmethod
def validate_note(cls, note: str | None) -> str | None:
"""Validate note."""
if note is None:
return None
if len(note) > 1000:
raise ValueError("Note length should not exceed 1000 characters.")
return note
note: Annotated[str, StringConstraints(max_length=1000)] | None = None

@field_validator("new_state", mode="before")
@classmethod
def validate_new_state(cls, ns: str) -> str:
def validate_new_state(cls, ns: str | None) -> str:
"""Validate new_state."""
valid_states = [
vs.name.lower()
for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED)
]
if ns is None:
raise ValueError("'new_state' should not be empty")
ns = ns.lower()
if ns not in valid_states:
raise ValueError(f"'{ns}' is not one of {valid_states}")
Expand Down
21 changes: 21 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3422,6 +3422,16 @@ paths:
type: integer
default: -1
title: Map Index
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
Expand Down Expand Up @@ -3886,6 +3896,16 @@ paths:
schema:
type: integer
title: Map Index
- name: update_mask
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
title: Update Mask
requestBody:
required: true
content:
Expand Down Expand Up @@ -6572,6 +6592,7 @@ components:
note:
anyOf:
- type: string
maxLength: 1000
- type: 'null'
title: Note
type: object
Expand Down
54 changes: 30 additions & 24 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import Annotated, Literal

from fastapi import Depends, HTTPException, Request, status
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
Expand Down Expand Up @@ -502,6 +502,7 @@ def patch_task_instance(
body: PatchTaskInstanceBody,
session: Annotated[Session, Depends(get_session)],
map_index: int = -1,
update_mask: list[str] | None = Query(None),
) -> TaskInstanceResponse:
"""Update the state of a task instance."""
dag = request.app.state.dag_bag.get_dag(dag_id)
Expand Down Expand Up @@ -534,28 +535,33 @@ def patch_task_instance(
if ti is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)

if not body.dry_run:
tis: list[TI] = dag.set_task_instance_state(
task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index],
state=body.new_state,
commit=True,
session=session,
)
if not ti:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
ti = tis[0] if isinstance(tis, list) else tis

# Set new note to the task instance if available in body.
if body.note is not None:
# @TODO: replace None passed for user_id with actual user id when
# permissions and auth is in place.
if ti.task_instance_note is None:
ti.note = (body.note, None)
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = None
session.commit()
fields_to_update = body.model_fields_set
if update_mask:
fields_to_update = fields_to_update.intersection(update_mask)

for field in fields_to_update:
if field == "new_state":
if not body.dry_run:
tis: list[TI] = dag.set_task_instance_state(
task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index],
state=body.new_state,
commit=True,
session=session,
)
if not ti:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
ti = tis[0] if isinstance(tis, list) else tis
elif field == "note":
if update_mask or body.note is not None:
# @TODO: replace None passed for user_id with actual user id when
# permissions and auth is in place.
if ti.task_instance_note is None:
ti.note = (body.note, None)
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = None
session.commit()

return TaskInstanceResponse.model_validate(ti, from_attributes=True)
26 changes: 24 additions & 2 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,7 @@ export const usePoolServicePatchPool = <
* @param data.taskId
* @param data.requestBody
* @param data.mapIndex
* @param data.updateMask
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
Expand All @@ -2869,6 +2870,7 @@ export const useTaskInstanceServicePatchTaskInstance = <
mapIndex?: number;
requestBody: PatchTaskInstanceBody;
taskId: string;
updateMask?: string[];
},
TContext
>,
Expand All @@ -2884,16 +2886,25 @@ export const useTaskInstanceServicePatchTaskInstance = <
mapIndex?: number;
requestBody: PatchTaskInstanceBody;
taskId: string;
updateMask?: string[];
},
TContext
>({
mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId }) =>
mutationFn: ({
dagId,
dagRunId,
mapIndex,
requestBody,
taskId,
updateMask,
}) =>
TaskInstanceService.patchTaskInstance({
dagId,
dagRunId,
mapIndex,
requestBody,
taskId,
updateMask,
}) as unknown as Promise<TData>,
...options,
});
Expand All @@ -2906,6 +2917,7 @@ export const useTaskInstanceServicePatchTaskInstance = <
* @param data.taskId
* @param data.mapIndex
* @param data.requestBody
* @param data.updateMask
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
Expand All @@ -2924,6 +2936,7 @@ export const useTaskInstanceServicePatchTaskInstance1 = <
mapIndex: number;
requestBody: PatchTaskInstanceBody;
taskId: string;
updateMask?: string[];
},
TContext
>,
Expand All @@ -2939,16 +2952,25 @@ export const useTaskInstanceServicePatchTaskInstance1 = <
mapIndex: number;
requestBody: PatchTaskInstanceBody;
taskId: string;
updateMask?: string[];
},
TContext
>({
mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId }) =>
mutationFn: ({
dagId,
dagRunId,
mapIndex,
requestBody,
taskId,
updateMask,
}) =>
TaskInstanceService.patchTaskInstance1({
dagId,
dagRunId,
mapIndex,
requestBody,
taskId,
updateMask,
}) as unknown as Promise<TData>,
...options,
});
Expand Down
1 change: 1 addition & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2674,6 +2674,7 @@ export const $PatchTaskInstanceBody = {
anyOf: [
{
type: "string",
maxLength: 1000,
},
{
type: "null",
Expand Down
6 changes: 6 additions & 0 deletions airflow/ui/openapi-gen/requests/services.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,7 @@ export class TaskInstanceService {
* @param data.taskId
* @param data.requestBody
* @param data.mapIndex
* @param data.updateMask
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
Expand All @@ -1885,6 +1886,7 @@ export class TaskInstanceService {
},
query: {
map_index: data.mapIndex,
update_mask: data.updateMask,
},
body: data.requestBody,
mediaType: "application/json",
Expand Down Expand Up @@ -2071,6 +2073,7 @@ export class TaskInstanceService {
* @param data.taskId
* @param data.mapIndex
* @param data.requestBody
* @param data.updateMask
* @returns TaskInstanceResponse Successful Response
* @throws ApiError
*/
Expand All @@ -2086,6 +2089,9 @@ export class TaskInstanceService {
task_id: data.taskId,
map_index: data.mapIndex,
},
query: {
update_mask: data.updateMask,
},
body: data.requestBody,
mediaType: "application/json",
errors: {
Expand Down
2 changes: 2 additions & 0 deletions airflow/ui/openapi-gen/requests/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,7 @@ export type PatchTaskInstanceData = {
mapIndex?: number;
requestBody: PatchTaskInstanceBody;
taskId: string;
updateMask?: Array<string> | null;
};

export type PatchTaskInstanceResponse = TaskInstanceResponse;
Expand Down Expand Up @@ -1604,6 +1605,7 @@ export type PatchTaskInstance1Data = {
mapIndex: number;
requestBody: PatchTaskInstanceBody;
taskId: string;
updateMask?: Array<string> | null;
};

export type PatchTaskInstance1Response = TaskInstanceResponse;
Expand Down
Loading

0 comments on commit 3cb7311

Please sign in to comment.