diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index e37b2ee864644..5f528a0b33773 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -558,6 +558,7 @@ def set_mapped_task_instance_note( return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) +@mark_fastapi_migration_done @security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @action_logging @provide_session diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 6e2cc376dcd0d..8736f3522df7e 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -26,7 +26,9 @@ ConfigDict, Field, NonNegativeInt, + StringConstraints, ValidationError, + field_validator, model_validator, ) @@ -193,6 +195,33 @@ def validate_model(cls, data: Any) -> Any: return data +class PatchTaskInstanceBody(BaseModel): + """Request body for Clear Task Instances endpoint.""" + + dry_run: bool = True + new_state: str | None = None + note: Annotated[str, StringConstraints(max_length=1000)] | None = None + include_upstream: bool = False + include_downstream: bool = False + include_future: bool = False + include_past: bool = False + + @field_validator("new_state", mode="before") + @classmethod + def validate_new_state(cls, ns: str | None) -> str: + """Validate new_state.""" + valid_states = [ + vs.name.lower() + for vs in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.SKIPPED) + ] + if ns is None: + raise ValueError("'new_state' should not be empty") + ns = ns.lower() + if ns not in valid_states: + raise ValueError(f"'{ns}' is not one of {valid_states}") + return ns + + class TaskInstanceReferenceResponse(BaseModel): """Task Instance Reference serializer for responses.""" diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index e0ce72573bb76..f1cc5d1c521dd 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3656,6 +3656,91 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - Task Instance + summary: Patch Task Instance + description: Update the state of a task instance. + operationId: patch_task_instance + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: query + required: false + schema: + type: integer + default: -1 + title: Map Index + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PatchTaskInstanceBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped: get: tags: @@ -4103,6 +4188,90 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + patch: + tags: + - Task Instance + summary: Patch Task Instance + description: Update the state of a task instance. + operationId: patch_task_instance + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: path + required: true + schema: + type: integer + title: Map Index + - name: update_mask + in: query + required: false + schema: + anyOf: + - type: array + items: + type: string + - type: 'null' + title: Update Mask + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PatchTaskInstanceBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances: get: tags: @@ -7053,6 +7222,42 @@ components: - unixname title: JobResponse description: Job serializer for responses. + PatchTaskInstanceBody: + properties: + dry_run: + type: boolean + title: Dry Run + default: true + new_state: + anyOf: + - type: string + - type: 'null' + title: New State + note: + anyOf: + - type: string + maxLength: 1000 + - type: 'null' + title: Note + include_upstream: + type: boolean + title: Include Upstream + default: false + include_downstream: + type: boolean + title: Include Downstream + default: false + include_future: + type: boolean + title: Include Future + default: false + include_past: + type: boolean + title: Include Past + default: false + type: object + title: PatchTaskInstanceBody + description: Request body for Clear Task Instances endpoint. PluginCollectionResponse: properties: plugins: diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 8d54e372c4b40..a866a7d09c9f8 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -19,9 +19,10 @@ from typing import Annotated, Literal, cast -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, Query, Request, status +from sqlalchemy import or_, select +from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import Session, joinedload -from sqlalchemy.sql import or_, select from sqlalchemy.sql.selectable import Select from airflow.api_fastapi.common.db.common import get_session, paginated_select @@ -50,6 +51,7 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.task_instances import ( ClearTaskInstancesBody, + PatchTaskInstanceBody, TaskDependencyCollectionResponse, TaskInstanceCollectionResponse, TaskInstanceHistoryCollectionResponse, @@ -600,3 +602,88 @@ def post_clear_task_instances( ], total_entries=len(task_instances), ) + + +@task_instances_router.patch( + task_instances_prefix + "/{task_id}", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]), +) +@task_instances_router.patch( + task_instances_prefix + "/{task_id}/{map_index}", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]), +) +def patch_task_instance( + dag_id: str, + dag_run_id: str, + task_id: str, + request: Request, + body: PatchTaskInstanceBody, + session: Annotated[Session, Depends(get_session)], + map_index: int = -1, + update_mask: list[str] | None = Query(None), +) -> TaskInstanceResponse: + """Update the state of a task instance.""" + dag = request.app.state.dag_bag.get_dag(dag_id) + if not dag: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found") + + if not dag.has_task(task_id): + raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found in DAG '{dag_id}'") + + query = ( + select(TI) + .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id) + .join(TI.dag_run) + .options(joinedload(TI.rendered_task_instance_fields)) + ) + if map_index == -1: + query = query.where(or_(TI.map_index == -1, TI.map_index is None)) + else: + query = query.where(TI.map_index == map_index) + + try: + ti = session.scalar(query) + except MultipleResultsFound: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + "Multiple task instances found. As the TI is mapped, add the map_index value to the URL", + ) + + err_msg_404 = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}" + if ti is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404) + + fields_to_update = body.model_fields_set + if update_mask: + fields_to_update = fields_to_update.intersection(update_mask) + + for field in fields_to_update: + if field == "new_state": + if not body.dry_run: + tis: list[TI] = dag.set_task_instance_state( + task_id=task_id, + run_id=dag_run_id, + map_indexes=[map_index], + state=body.new_state, + upstream=body.include_upstream, + downstream=body.include_downstream, + future=body.include_future, + past=body.include_past, + commit=True, + session=session, + ) + if not ti: + raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404) + ti = tis[0] if isinstance(tis, list) else tis + elif field == "note": + if update_mask or body.note is not None: + # @TODO: replace None passed for user_id with actual user id when + # permissions and auth is in place. + if ti.task_instance_note is None: + ti.note = (body.note, None) + else: + ti.task_instance_note.content = body.note + ti.task_instance_note.user_id = None + session.commit() + + return TaskInstanceResponse.model_validate(ti, from_attributes=True) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 86c9bb9819c05..a09a1ededb085 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1549,6 +1549,12 @@ export type DagServicePatchDagsMutationResult = Awaited< export type DagServicePatchDagMutationResult = Awaited< ReturnType >; +export type TaskInstanceServicePatchTaskInstanceMutationResult = Awaited< + ReturnType +>; +export type TaskInstanceServicePatchTaskInstance1MutationResult = Awaited< + ReturnType +>; export type PoolServicePatchPoolMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 0dcbf340df550..e4d79b3ab4cfd 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -42,6 +42,7 @@ import { DAGRunPatchBody, DagRunState, DagWarningType, + PatchTaskInstanceBody, PoolPatchBody, PoolPostBody, PoolPostBulkBody, @@ -3167,6 +3168,138 @@ export const useDagServicePatchDag = < }) as unknown as Promise, ...options, }); +/** + * Patch Task Instance + * Update the state of a task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.requestBody + * @param data.mapIndex + * @param data.updateMask + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServicePatchTaskInstance = < + TData = Common.TaskInstanceServicePatchTaskInstanceMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex?: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex?: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ + dagId, + dagRunId, + mapIndex, + requestBody, + taskId, + updateMask, + }) => + TaskInstanceService.patchTaskInstance({ + dagId, + dagRunId, + mapIndex, + requestBody, + taskId, + updateMask, + }) as unknown as Promise, + ...options, + }); +/** + * Patch Task Instance + * Update the state of a task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @param data.requestBody + * @param data.updateMask + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServicePatchTaskInstance1 = < + TData = Common.TaskInstanceServicePatchTaskInstance1MutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: string[]; + }, + TContext + >({ + mutationFn: ({ + dagId, + dagRunId, + mapIndex, + requestBody, + taskId, + updateMask, + }) => + TaskInstanceService.patchTaskInstance1({ + dagId, + dagRunId, + mapIndex, + requestBody, + taskId, + updateMask, + }) as unknown as Promise, + ...options, + }); /** * Patch Pool * Update a Pool. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 660dab68b8c4b..f9a745d83673d 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -2895,6 +2895,62 @@ export const $JobResponse = { description: "Job serializer for responses.", } as const; +export const $PatchTaskInstanceBody = { + properties: { + dry_run: { + type: "boolean", + title: "Dry Run", + default: true, + }, + new_state: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "New State", + }, + note: { + anyOf: [ + { + type: "string", + maxLength: 1000, + }, + { + type: "null", + }, + ], + title: "Note", + }, + include_upstream: { + type: "boolean", + title: "Include Upstream", + default: false, + }, + include_downstream: { + type: "boolean", + title: "Include Downstream", + default: false, + }, + include_future: { + type: "boolean", + title: "Include Future", + default: false, + }, + include_past: { + type: "boolean", + title: "Include Past", + default: false, + }, + }, + type: "object", + title: "PatchTaskInstanceBody", + description: "Request body for Clear Task Instances endpoint.", +} as const; + export const $PluginCollectionResponse = { properties: { plugins: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index cbb74b4395ca9..88737c9e3a4cf 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -98,6 +98,8 @@ import type { GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, + PatchTaskInstanceData, + PatchTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesData, @@ -108,6 +110,8 @@ import type { GetTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, + PatchTaskInstance1Data, + PatchTaskInstance1Response, GetTaskInstancesData, GetTaskInstancesResponse, GetTaskInstancesBatchData, @@ -1692,6 +1696,46 @@ export class TaskInstanceService { }); } + /** + * Patch Task Instance + * Update the state of a task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.requestBody + * @param data.mapIndex + * @param data.updateMask + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ + public static patchTaskInstance( + data: PatchTaskInstanceData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}", + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + }, + query: { + map_index: data.mapIndex, + update_mask: data.updateMask, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } + /** * Get Mapped Task Instances * Get list of mapped task instances. @@ -1886,6 +1930,46 @@ export class TaskInstanceService { }); } + /** + * Patch Task Instance + * Update the state of a task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @param data.requestBody + * @param data.updateMask + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ + public static patchTaskInstance1( + data: PatchTaskInstance1Data, + ): CancelablePromise { + return __request(OpenAPI, { + method: "PATCH", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}", + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + map_index: data.mapIndex, + }, + query: { + update_mask: data.updateMask, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } + /** * Get Task Instances * Get list of task instances. diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index a434871181ea4..960525a1d68cc 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -719,6 +719,19 @@ export type JobResponse = { unixname: string | null; }; +/** + * Request body for Clear Task Instances endpoint. + */ +export type PatchTaskInstanceBody = { + dry_run?: boolean; + new_state?: string | null; + note?: string | null; + include_upstream?: boolean; + include_downstream?: boolean; + include_future?: boolean; + include_past?: boolean; +}; + /** * Plugin Collection serializer. */ @@ -1574,6 +1587,17 @@ export type GetTaskInstanceData = { export type GetTaskInstanceResponse = TaskInstanceResponse; +export type PatchTaskInstanceData = { + dagId: string; + dagRunId: string; + mapIndex?: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: Array | null; +}; + +export type PatchTaskInstanceResponse = TaskInstanceResponse; + export type GetMappedTaskInstancesData = { dagId: string; dagRunId: string; @@ -1637,6 +1661,17 @@ export type GetMappedTaskInstanceData = { export type GetMappedTaskInstanceResponse = TaskInstanceResponse; +export type PatchTaskInstance1Data = { + dagId: string; + dagRunId: string; + mapIndex: number; + requestBody: PatchTaskInstanceBody; + taskId: string; + updateMask?: Array | null; +}; + +export type PatchTaskInstance1Response = TaskInstanceResponse; + export type GetTaskInstancesData = { dagId: string; dagRunId: string; @@ -3139,6 +3174,35 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchTaskInstanceData; + res: { + /** + * Successful Response + */ + 200: TaskInstanceResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped": { get: { @@ -3274,6 +3338,35 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + patch: { + req: PatchTaskInstance1Data; + res: { + /** + * Successful Response + */ + 200: TaskInstanceResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances": { get: { diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 50409fa2783bf..2cae2869b3a72 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -25,6 +25,7 @@ import pendulum import pytest import sqlalchemy +from sqlalchemy import select from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner @@ -2441,3 +2442,563 @@ def test_raises_404_for_nonexistent_task_instance(self, test_client, session): assert response.json() == { "detail": "The Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `non_existent_task` and map_index: `-1` was not found" } + + +class TestPatchTaskInstance(TestTaskInstanceEndpoint): + ENDPOINT_URL = ( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + NEW_STATE = "failed" + DAG_ID = "example_python_operator" + TASK_ID = "print_the_context" + RUN_ID = "TEST_DAG_RUN_ID" + + @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + def test_should_call_mocked_api(self, mock_set_ti_state, test_client, session): + self.create_task_instances(session) + + mock_set_ti_state.return_value = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).one_or_none() + + response = test_client.patch( + self.ENDPOINT_URL, + json={ + "dry_run": False, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + assert response.json() == { + "dag_id": self.DAG_ID, + "dag_run_id": self.RUN_ID, + "logical_date": "2020-01-01T00:00:00Z", + "task_id": self.TASK_ID, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + mock_set_ti_state.assert_called_once() + # We check args individually instead of direct matching using + # assert_called_once_with(), because the session objects don't match + # and can't be skipped using mock.ANY. + mock_set_ti_state.assert_called_once() + args, kwargs = mock_set_ti_state.call_args + assert len(args) == 0 + assert len(kwargs) == 10 + # 1st keyword argument + assert kwargs["task_id"] == self.TASK_ID + # 2nd keyword argument + assert kwargs["run_id"] == self.RUN_ID + # 3rd keyword argument + assert kwargs["map_indexes"] == [-1] + # 4th keyword argument + assert kwargs["state"] == self.NEW_STATE + # 5th keyword argument + assert kwargs["upstream"] is False + # 6th keyword argument + assert kwargs["downstream"] is False + # 7th keyword argument + assert kwargs["future"] is False + # 8th keyword argument + assert kwargs["past"] is False + # 9th keyword argument + assert kwargs["commit"] is True + # 10th keyword argument + assert isinstance(kwargs["session"], sqlalchemy.orm.session.Session) + + @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_state, test_client, session): + self.create_task_instances(session) + + mock_set_task_instance_state.return_value = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).one_or_none() + + response = test_client.patch( + self.ENDPOINT_URL, + json={ + "dry_run": True, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + assert response.json() == { + "dag_id": self.DAG_ID, + "dag_run_id": self.RUN_ID, + "logical_date": "2020-01-01T00:00:00Z", + "task_id": self.TASK_ID, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + mock_set_task_instance_state.assert_not_called() + + def test_should_update_task_instance_state(self, test_client, session): + self.create_task_instances(session) + + test_client.patch( + self.ENDPOINT_URL, + json={ + "dry_run": False, + "new_state": self.NEW_STATE, + }, + ) + + response2 = test_client.get(self.ENDPOINT_URL) + assert response2.status_code == 200 + assert response2.json()["state"] == self.NEW_STATE + + def test_should_update_task_instance_state_default_dry_run_to_true(self, test_client, session): + self.create_task_instances(session) + + test_client.patch( + self.ENDPOINT_URL, + json={ + "new_state": self.NEW_STATE, + }, + ) + + response2 = test_client.get(self.ENDPOINT_URL) + assert response2.status_code == 200 + assert response2.json()["state"] == "running" # no change in state + + def test_should_update_mapped_task_instance_state(self, test_client, session): + map_index = 1 + tis = self.create_task_instances(session) + ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + session.add(ti) + session.commit() + + response = test_client.patch( + f"{self.ENDPOINT_URL}/{map_index}", + json={ + "dry_run": False, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + + response2 = test_client.get(f"{self.ENDPOINT_URL}/{map_index}") + assert response2.status_code == 200 + assert response2.json()["state"] == self.NEW_STATE + + @pytest.mark.parametrize( + "error, code, payload", + [ + [ + ( + "Task Instance not found for dag_id=example_python_operator" + ", run_id=TEST_DAG_RUN_ID, task_id=print_the_context" + ), + 404, + { + "dry_run": True, + "new_state": "failed", + }, + ] + ], + ) + def test_should_handle_errors(self, error, code, payload, test_client, session): + response = test_client.patch( + self.ENDPOINT_URL, + json=payload, + ) + assert response.status_code == code + assert response.json()["detail"] == error + + def test_should_200_for_unknown_fields(self, test_client, session): + self.create_task_instances(session) + response = test_client.patch( + self.ENDPOINT_URL, + json={ + "dryrun": True, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 200 + + def test_should_raise_404_for_non_existent_dag(self, test_client): + response = test_client.patch( + "/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + json={ + "dry_run": False, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + assert response.json() == {"detail": "DAG non-existent-dag not found"} + + def test_should_raise_404_for_non_existent_task_in_dag(self, test_client): + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task", + json={ + "dry_run": False, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + assert response.json() == { + "detail": "Task 'non_existent_task' not found in DAG 'example_python_operator'" + } + + def test_should_raise_404_not_found_dag(self, test_client): + response = test_client.patch( + self.ENDPOINT_URL, + json={ + "dry_run": True, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + + def test_should_raise_404_not_found_task(self, test_client): + response = test_client.patch( + self.ENDPOINT_URL, + json={ + "dry_run": True, + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 404 + + @pytest.mark.parametrize( + "payload, expected", + [ + ( + { + "dry_run": True, + "new_state": "failede", + }, + f"'failede' is not one of ['{State.SUCCESS}', '{State.FAILED}', '{State.SKIPPED}']", + ), + ( + { + "dry_run": True, + "new_state": "queued", + }, + f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}', '{State.SKIPPED}']", + ), + ], + ) + def test_should_raise_422_for_invalid_task_instance_state(self, payload, expected, test_client, session): + self.create_task_instances(session) + response = test_client.patch( + self.ENDPOINT_URL, + json=payload, + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "value_error", + "loc": ["body", "new_state"], + "msg": f"Value error, {expected}", + "input": payload["new_state"], + "ctx": {"error": {}}, + } + ] + } + + @pytest.mark.parametrize( + "new_state,expected_status_code,expected_json,set_ti_state_call_count", + [ + ( + "failed", + 200, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID", + "logical_date": "2020-01-01T00:00:00Z", + "task_id": "print_the_context", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "id": mock.ANY, + "map_index": -1, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + }, + 1, + ), + ( + None, + 422, + { + "detail": [ + { + "type": "value_error", + "loc": ["body", "new_state"], + "msg": "Value error, 'new_state' should not be empty", + "input": None, + "ctx": {"error": {}}, + } + ] + }, + 0, + ), + ], + ) + @mock.patch("airflow.models.dag.DAG.set_task_instance_state") + def test_update_mask_should_call_mocked_api( + self, + mock_set_ti_state, + test_client, + session, + new_state, + expected_status_code, + expected_json, + set_ti_state_call_count, + ): + self.create_task_instances(session) + + mock_set_ti_state.return_value = session.scalars( + select(TaskInstance).where( + TaskInstance.dag_id == self.DAG_ID, + TaskInstance.task_id == self.TASK_ID, + TaskInstance.run_id == self.RUN_ID, + TaskInstance.map_index == -1, + ) + ).one_or_none() + + response = test_client.patch( + self.ENDPOINT_URL, + params={"update_mask": "new_state"}, + json={ + "dry_run": False, + "new_state": new_state, + }, + ) + assert response.status_code == expected_status_code + assert response.json() == expected_json + assert mock_set_ti_state.call_count == set_ti_state_call_count + + @pytest.mark.parametrize( + "new_note_value", + [ + "My super cool TaskInstance note.", + None, + ], + ) + def test_update_mask_set_note_should_respond_200(self, test_client, session, new_note_value): + self.create_task_instances(session) + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + params={"update_mask": "note"}, + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + assert response.json() == { + "dag_id": self.DAG_ID, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "id": mock.ANY, + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": -1, + "max_tries": 0, + "note": new_note_value, + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": self.TASK_ID, + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "dag_run_id": self.RUN_ID, + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + def test_set_note_should_respond_200(self, test_client, session): + self.create_task_instances(session) + new_note_value = "My super cool TaskInstance note." + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + assert response.json() == { + "dag_id": self.DAG_ID, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "id": mock.ANY, + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": -1, + "max_tries": 0, + "note": new_note_value, + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": self.TASK_ID, + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "dag_run_id": self.RUN_ID, + "rendered_fields": {}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_client, session): + """Verify we don't duplicate rows through join to RTIF""" + tis = self.create_task_instances(session) + old_ti = tis[0] + for idx in (1, 2): + ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: + setattr(ti, attr, getattr(old_ti, attr)) + session.add(ti) + session.commit() + + # in each loop, we should get the right mapped TI back + for map_index in (1, 2): + new_note_value = f"My super cool TaskInstance note {map_index}" + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" + f"print_the_context/{map_index}", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + + assert response.json() == { + "dag_id": self.DAG_ID, + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "id": mock.ANY, + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": map_index, + "max_tries": 0, + "note": new_note_value, + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": self.TASK_ID, + "task_display_name": self.TASK_ID, + "try_number": 0, + "unixname": getuser(), + "dag_run_id": self.RUN_ID, + "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + def test_set_note_should_respond_200_when_note_is_empty(self, test_client, session): + tis = self.create_task_instances(session) + for ti in tis: + ti.task_instance_note = None + session.add(ti) + session.commit() + new_note_value = "My super cool TaskInstance note." + response = test_client.patch( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + json={"note": new_note_value}, + ) + assert response.status_code == 200, response.text + assert response.json()["note"] == new_note_value