From 32dc9e837912dc1b56052b29bc24f38c57d32fb1 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:59:54 -0500 Subject: [PATCH] Do not use core Airflow Flask related resources in FAB provider (package `api_connexion`) (#45473) --- .../role_and_permission_endpoint.py | 8 +- .../api_endpoints/user_endpoint.py | 8 +- .../fab/auth_manager/schemas/user_schema.py | 2 +- .../fab/www/api_connexion/__init__.py | 17 ++ .../fab/www/api_connexion/exceptions.py | 197 ++++++++++++++++++ .../fab/www/api_connexion/parameters.py | 131 ++++++++++++ .../fab/www/api_connexion/security.py | 82 ++++++++ .../providers/fab/www/api_connexion/types.py | 30 +++ .../api_endpoints/test_asset_endpoint.py | 2 +- .../api_endpoints/test_dag_endpoint.py | 2 +- .../test_role_and_permission_endpoint.py | 2 +- .../test_task_instance_endpoint.py | 2 +- .../api_endpoints/test_user_endpoint.py | 2 +- 13 files changed, 471 insertions(+), 14 deletions(-) create mode 100644 providers/src/airflow/providers/fab/www/api_connexion/__init__.py create mode 100644 providers/src/airflow/providers/fab/www/api_connexion/exceptions.py create mode 100644 providers/src/airflow/providers/fab/www/api_connexion/parameters.py create mode 100644 providers/src/airflow/providers/fab/www/api_connexion/security.py create mode 100644 providers/src/airflow/providers/fab/www/api_connexion/types.py diff --git a/providers/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py b/providers/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py index aa68da0000424..e6eb18214b90c 100644 --- a/providers/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py +++ b/providers/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py @@ -24,9 +24,6 @@ from marshmallow import ValidationError from sqlalchemy import asc, desc, func, select -from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound -from airflow.api_connexion.parameters import check_limit, format_parameters -from airflow.api_connexion.security import requires_access_custom_view from airflow.api_fastapi.app import get_auth_manager from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager from airflow.providers.fab.auth_manager.models import Action, Role @@ -37,11 +34,14 @@ role_collection_schema, role_schema, ) +from airflow.providers.fab.www.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound +from airflow.providers.fab.www.api_connexion.parameters import check_limit, format_parameters +from airflow.providers.fab.www.api_connexion.security import requires_access_custom_view from airflow.security import permissions if TYPE_CHECKING: - from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride + from airflow.providers.fab.www.api_connexion.types import APIResponse, UpdateMask def _check_action_and_resource(sm: FabAirflowSecurityManagerOverride, perms: list[tuple[str, str]]) -> None: diff --git a/providers/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py b/providers/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py index 142918a27c6c7..187ddc3c6a686 100644 --- a/providers/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py +++ b/providers/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py @@ -25,9 +25,6 @@ from sqlalchemy import asc, desc, func, select from werkzeug.security import generate_password_hash -from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound, Unknown -from airflow.api_connexion.parameters import check_limit, format_parameters -from airflow.api_connexion.security import requires_access_custom_view from airflow.api_fastapi.app import get_auth_manager from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager from airflow.providers.fab.auth_manager.models import User @@ -37,11 +34,14 @@ user_collection_schema, user_schema, ) +from airflow.providers.fab.www.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound, Unknown +from airflow.providers.fab.www.api_connexion.parameters import check_limit, format_parameters +from airflow.providers.fab.www.api_connexion.security import requires_access_custom_view from airflow.security import permissions if TYPE_CHECKING: - from airflow.api_connexion.types import APIResponse, UpdateMask from airflow.providers.fab.auth_manager.models import Role + from airflow.providers.fab.www.api_connexion.types import APIResponse, UpdateMask @requires_access_custom_view("GET", permissions.RESOURCE_USER) diff --git a/providers/src/airflow/providers/fab/auth_manager/schemas/user_schema.py b/providers/src/airflow/providers/fab/auth_manager/schemas/user_schema.py index 4155667d56766..120698706ea16 100644 --- a/providers/src/airflow/providers/fab/auth_manager/schemas/user_schema.py +++ b/providers/src/airflow/providers/fab/auth_manager/schemas/user_schema.py @@ -21,9 +21,9 @@ from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field -from airflow.api_connexion.parameters import validate_istimezone from airflow.providers.fab.auth_manager.models import User from airflow.providers.fab.auth_manager.schemas.role_and_permission_schema import RoleSchema +from airflow.providers.fab.www.api_connexion.parameters import validate_istimezone class UserCollectionItemSchema(SQLAlchemySchema): diff --git a/providers/src/airflow/providers/fab/www/api_connexion/__init__.py b/providers/src/airflow/providers/fab/www/api_connexion/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/providers/src/airflow/providers/fab/www/api_connexion/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py b/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py new file mode 100644 index 0000000000000..ef2e2ab9b4bbc --- /dev/null +++ b/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from http import HTTPStatus +from typing import TYPE_CHECKING, Any + +import werkzeug +from connexion import FlaskApi, ProblemException, problem + +from airflow.utils.docs import get_docs_url + +if TYPE_CHECKING: + import flask + +doc_link = get_docs_url("stable-rest-api-ref.html") + +EXCEPTIONS_LINK_MAP = { + 400: f"{doc_link}#section/Errors/BadRequest", + 404: f"{doc_link}#section/Errors/NotFound", + 405: f"{doc_link}#section/Errors/MethodNotAllowed", + 401: f"{doc_link}#section/Errors/Unauthenticated", + 409: f"{doc_link}#section/Errors/AlreadyExists", + 403: f"{doc_link}#section/Errors/PermissionDenied", + 500: f"{doc_link}#section/Errors/Unknown", +} + + +def common_error_handler(exception: BaseException) -> flask.Response: + """Use to capture connexion exceptions and add link to the type field.""" + if isinstance(exception, ProblemException): + link = EXCEPTIONS_LINK_MAP.get(exception.status) + if link: + response = problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=link, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) + else: + response = problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=exception.type, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) + else: + if not isinstance(exception, werkzeug.exceptions.HTTPException): + exception = werkzeug.exceptions.InternalServerError() + + response = problem(title=exception.name, detail=exception.description, status=exception.code) + + return FlaskApi.get_response(response) + + +class NotFound(ProblemException): + """Raise when the object cannot be found.""" + + def __init__( + self, + title: str = "Not Found", + detail: str | None = None, + headers: dict | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + status=HTTPStatus.NOT_FOUND, + type=EXCEPTIONS_LINK_MAP[404], + title=title, + detail=detail, + headers=headers, + **kwargs, + ) + + +class BadRequest(ProblemException): + """Raise when the server processes a bad request.""" + + def __init__( + self, + title: str = "Bad Request", + detail: str | None = None, + headers: dict | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + status=HTTPStatus.BAD_REQUEST, + type=EXCEPTIONS_LINK_MAP[400], + title=title, + detail=detail, + headers=headers, + **kwargs, + ) + + +class Unauthenticated(ProblemException): + """Raise when the user is not authenticated.""" + + def __init__( + self, + title: str = "Unauthorized", + detail: str | None = None, + headers: dict | None = None, + **kwargs: Any, + ): + super().__init__( + status=HTTPStatus.UNAUTHORIZED, + type=EXCEPTIONS_LINK_MAP[401], + title=title, + detail=detail, + headers=headers, + **kwargs, + ) + + +class PermissionDenied(ProblemException): + """Raise when the user does not have the required permissions.""" + + def __init__( + self, + title: str = "Forbidden", + detail: str | None = None, + headers: dict | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + status=HTTPStatus.FORBIDDEN, + type=EXCEPTIONS_LINK_MAP[403], + title=title, + detail=detail, + headers=headers, + **kwargs, + ) + + +class Conflict(ProblemException): + """Raise when there is some conflict.""" + + def __init__( + self, + title="Conflict", + detail: str | None = None, + headers: dict | None = None, + **kwargs: Any, + ): + super().__init__( + status=HTTPStatus.CONFLICT, + type=EXCEPTIONS_LINK_MAP[409], + title=title, + detail=detail, + headers=headers, + **kwargs, + ) + + +class AlreadyExists(Conflict): + """Raise when the object already exists.""" + + +class Unknown(ProblemException): + """Returns a response body and status code for HTTP 500 exception.""" + + def __init__( + self, + title: str = "Internal Server Error", + detail: str | None = None, + headers: dict | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + status=HTTPStatus.INTERNAL_SERVER_ERROR, + type=EXCEPTIONS_LINK_MAP[500], + title=title, + detail=detail, + headers=headers, + **kwargs, + ) diff --git a/providers/src/airflow/providers/fab/www/api_connexion/parameters.py b/providers/src/airflow/providers/fab/www/api_connexion/parameters.py new file mode 100644 index 0000000000000..8edb8efa7bb79 --- /dev/null +++ b/providers/src/airflow/providers/fab/www/api_connexion/parameters.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from collections.abc import Container +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast + +from pendulum.parsing import ParserError +from sqlalchemy import text + +from airflow.configuration import conf +from airflow.providers.fab.www.api_connexion.exceptions import BadRequest +from airflow.utils import timezone + +if TYPE_CHECKING: + from datetime import datetime + + from sqlalchemy.sql import Select + +log = logging.getLogger(__name__) + + +def validate_istimezone(value: datetime) -> None: + """Validate that a datetime is not naive.""" + if not value.tzinfo: + raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed") + + +def format_datetime(value: str) -> datetime: + """ + Format datetime objects. + + Datetime format parser for args since connexion doesn't parse datetimes + https://github.com/zalando/connexion/issues/476 + + This should only be used within connection views because it raises 400 + """ + value = value.strip() + if value[-1] != "Z": + value = value.replace(" ", "+") + try: + return timezone.parse(value) + except (ParserError, TypeError) as err: + raise BadRequest("Incorrect datetime argument", detail=str(err)) + + +def check_limit(value: int) -> int: + """ + Check the limit does not exceed configured value. + + This checks the limit passed to view and raises BadRequest if + limit exceed user configured value + """ + max_val = conf.getint("api", "maximum_page_limit") # user configured max page limit + fallback = conf.getint("api", "fallback_page_limit") + + if value > max_val: + log.warning( + "The limit param value %s passed in API exceeds the configured maximum page limit %s", + value, + max_val, + ) + return max_val + if value == 0: + return fallback + if value < 0: + raise BadRequest("Page limit must be a positive integer") + return value + + +T = TypeVar("T", bound=Callable) + + +def format_parameters(params_formatters: dict[str, Callable[[Any], Any]]) -> Callable[[T], T]: + """ + Create a decorator to convert parameters using given formatters. + + Using it allows you to separate parameter formatting from endpoint logic. + + :param params_formatters: Map of key name and formatter function + """ + + def format_parameters_decorator(func: T) -> T: + @wraps(func) + def wrapped_function(*args, **kwargs): + for key, formatter in params_formatters.items(): + if key in kwargs: + kwargs[key] = formatter(kwargs[key]) + return func(*args, **kwargs) + + return cast(T, wrapped_function) + + return format_parameters_decorator + + +def apply_sorting( + query: Select, + order_by: str, + to_replace: dict[str, str] | None = None, + allowed_attrs: Container[str] | None = None, +) -> Select: + """Apply sorting to query.""" + lstriped_orderby = order_by.lstrip("-") + if allowed_attrs and lstriped_orderby not in allowed_attrs: + raise BadRequest( + detail=f"Ordering with '{lstriped_orderby}' is disallowed or " + f"the attribute does not exist on the model" + ) + if to_replace: + lstriped_orderby = to_replace.get(lstriped_orderby, lstriped_orderby) + if order_by[0] == "-": + order_by = f"{lstriped_orderby} desc" + else: + order_by = f"{lstriped_orderby} asc" + return query.order_by(text(order_by)) diff --git a/providers/src/airflow/providers/fab/www/api_connexion/security.py b/providers/src/airflow/providers/fab/www/api_connexion/security.py new file mode 100644 index 0000000000000..a130265b6de12 --- /dev/null +++ b/providers/src/airflow/providers/fab/www/api_connexion/security.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING, Callable, TypeVar, cast + +from flask import Response, current_app + +from airflow.api_fastapi.app import get_auth_manager +from airflow.providers.fab.www.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.utils.airflow_flask_app import AirflowApp + +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod + +T = TypeVar("T", bound=Callable) + + +def check_authentication() -> None: + """Check that the request has valid authorization information.""" + for auth in cast(AirflowApp, current_app).api_auth: + response = auth.requires_authentication(Response)() + if response.status_code == 200: + return + + # since this handler only checks authentication, not authorization, + # we should always return 401 + raise Unauthenticated(headers=response.headers) + + +def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callable, args, kwargs) -> bool: + """ + Define the behavior whether the user is authorized to access the resource. + + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access + the resource + :param func: the function to call if the user is authorized + :param args: the arguments of ``func`` + :param kwargs: the keyword arguments ``func`` + + :meta private: + """ + check_authentication() + if is_authorized_callback(): + return func(*args, **kwargs) + raise PermissionDenied() + + +def requires_access_custom_view( + method: ResourceMethod, + resource_name: str, +) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_custom_view( + method=method, resource_name=resource_name + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator diff --git a/providers/src/airflow/providers/fab/www/api_connexion/types.py b/providers/src/airflow/providers/fab/www/api_connexion/types.py new file mode 100644 index 0000000000000..f17f2a0d2712b --- /dev/null +++ b/providers/src/airflow/providers/fab/www/api_connexion/types.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union + +from flask import Response + +APIResponse = Union[ + Response, + tuple[object, int], # For '(NoContent, 201)'. + Mapping[str, Any], # JSON. +] + +UpdateMask = Optional[Sequence[str]] diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py index 79949b4cd8df0..b25fb1f68c8da 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_asset_endpoint.py @@ -21,7 +21,7 @@ import pytest import time_machine -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.security import permissions from airflow.utils import timezone diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py index 853ba3e643606..ac1e16ebfdb7f 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_dag_endpoint.py @@ -19,8 +19,8 @@ import pendulum import pytest -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DagModel +from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.security import permissions from airflow.utils.session import provide_session diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index b2e259fee28d1..e9373c6d56e88 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -18,7 +18,7 @@ import pytest -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.security import permissions from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py index 5f755109a0644..f146f8f4337d3 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py @@ -21,8 +21,8 @@ import pytest -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import DagRun, TaskInstance +from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import State diff --git a/providers/tests/fab/auth_manager/api_endpoints/test_user_endpoint.py b/providers/tests/fab/auth_manager/api_endpoints/test_user_endpoint.py index 7801fdf08111c..91a7e3749d95b 100644 --- a/providers/tests/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/providers/tests/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -21,7 +21,7 @@ import pytest from sqlalchemy.sql.functions import count -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session