From c84719ee350100286ab2725bd3dcebd3728bc6d3 Mon Sep 17 00:00:00 2001 From: Murray Christopherson Date: Mon, 15 May 2023 04:53:05 -0700 Subject: [PATCH] Add ability to specify custom actions (#486) Co-authored-by: Amin Alaee --- docs/api_reference/application.md | 3 + docs/configurations.md | 35 +++ sqladmin/__init__.py | 3 +- sqladmin/application.py | 131 +++++++++- sqladmin/helpers.py | 10 + sqladmin/models.py | 21 +- sqladmin/statics/js/main.js | 11 + sqladmin/templates/details.html | 22 ++ sqladmin/templates/list.html | 29 ++- .../modals/details_action_confirmation.html | 28 +++ .../modals/list_action_confirmation.html | 28 +++ tests/test_helpers.py | 16 +- tests/test_models_action.py | 226 ++++++++++++++++++ 13 files changed, 543 insertions(+), 20 deletions(-) create mode 100644 sqladmin/templates/modals/details_action_confirmation.html create mode 100644 sqladmin/templates/modals/list_action_confirmation.html create mode 100644 tests/test_models_action.py diff --git a/docs/api_reference/application.md b/docs/api_reference/application.md index 78fac527..bb595f82 100644 --- a/docs/api_reference/application.md +++ b/docs/api_reference/application.md @@ -12,3 +12,6 @@ - add_view - add_model_view - add_base_view + +::: sqladmin.application.action + handler: python diff --git a/docs/configurations.md b/docs/configurations.md index 2697dec2..c3c10bbe 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -279,3 +279,38 @@ By default these methods do nothing. # Perform some other action ... ``` + +## Custom Action + +To add custom action on models to the Admin interface, you can use the `action` annotation. + +For example: + +!!! example + + ```python + from sqladmin import BaseView, action + + class UserAdmin(ModelView, model=User): + @action( + name="approve_users", + label="Approve", + confirmation_message="Are you sure?", + add_in_detail=True, + add_in_list=True, + ) + async def approve_users(self, request: Request): + pks = request.query_params.get("pks", "").split(",") + if pks: + for pk in pks: + model: User = await self.get_object_for_edit(pk) + ... + + referer = request.headers.get("Referer") + if referer: + return RedirectResponse(referer) + else: + return RedirectResponse(request.url_for("admin:list", identity=self.identity)) + + admin.add_view(UserAdmin) + ``` diff --git a/sqladmin/__init__.py b/sqladmin/__init__.py index 0ce4d170..a6aba11a 100644 --- a/sqladmin/__init__.py +++ b/sqladmin/__init__.py @@ -1,4 +1,4 @@ -from sqladmin.application import Admin, expose +from sqladmin.application import Admin, action, expose from sqladmin.models import BaseView, ModelAdmin, ModelView __version__ = "0.10.3" @@ -6,6 +6,7 @@ __all__ = [ "Admin", "expose", + "action", "BaseView", "ModelAdmin", "ModelView", diff --git a/sqladmin/application.py b/sqladmin/application.py index 70da2093..482aae5f 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -1,6 +1,7 @@ import inspect import io import logging +from types import MethodType from typing import ( Any, Callable, @@ -10,6 +11,7 @@ Tuple, Type, Union, + cast, no_type_check, ) @@ -30,11 +32,13 @@ from sqladmin._types import ENGINE_TYPE from sqladmin.ajax import QueryAjaxModelLoader from sqladmin.authentication import AuthenticationBackend, login_required +from sqladmin.helpers import slugify_action_name from sqladmin.models import BaseView, ModelView __all__ = [ "Admin", "expose", + "action", ] logger = logging.getLogger(__name__) @@ -117,6 +121,67 @@ def add_view(self, view: Union[Type[ModelView], Type[BaseView]]) -> None: else: self.add_base_view(view) + def _find_decorated_funcs( + self, + view: Type[Union[BaseView, ModelView]], + view_instance: Union[BaseView, ModelView], + handle_fn: Callable[ + [MethodType, Type[Union[BaseView, ModelView]], Union[BaseView, ModelView]], + None, + ], + ) -> None: + funcs = inspect.getmembers(view_instance, predicate=inspect.ismethod) + + for _, func in funcs[::-1]: + handle_fn(func, view, view_instance) + + def _handle_action_decorated_func( + self, + func: MethodType, + view: Type[Union[BaseView, ModelView]], + view_instance: Union[BaseView, ModelView], + ) -> None: + if hasattr(func, "_action"): + view_instance = cast(ModelView, view_instance) + self.admin.add_route( + route=func, + path="/{identity}/action/" + getattr(func, "_slug"), + methods=["GET"], + name=f"{view_instance.identity}-{getattr(func, '_slug')}", + include_in_schema=getattr(func, "_include_in_schema"), + ) + + if getattr(func, "_add_in_list"): + view_instance._custom_actions_in_list[getattr(func, "_slug")] = getattr( + func, "_label" + ) + if getattr(func, "_add_in_detail"): + view_instance._custom_actions_in_detail[ + getattr(func, "_slug") + ] = getattr(func, "_label") + + if getattr(func, "_confirmation_message"): + view_instance._custom_actions_confirmation[ + getattr(func, "_slug") + ] = getattr(func, "_confirmation_message") + + def _handle_expose_decorated_func( + self, + func: MethodType, + view: Type[Union[BaseView, ModelView]], + view_instance: Union[BaseView, ModelView], + ) -> None: + if hasattr(func, "_exposed"): + self.admin.add_route( + route=func, + path=getattr(func, "_path"), + methods=getattr(func, "_methods"), + name=getattr(func, "_identity"), + include_in_schema=getattr(func, "_include_in_schema"), + ) + + view.identity = getattr(func, "_identity") + def add_model_view(self, view: Type[ModelView]) -> None: """Add ModelView to the Admin. @@ -152,7 +217,14 @@ class UserAdmin(ModelView, model=User): ) view.async_engine = True - self._views.append((view())) + view_instance = view() + + self._find_decorated_funcs( + view, view_instance, self._handle_action_decorated_func + ) + + view.templates = self.templates + self._views.append((view_instance)) def add_base_view(self, view: Type[BaseView]) -> None: """Add BaseView to the Admin. @@ -177,19 +249,10 @@ def test_page(self, request: Request): """ view_instance = view() - funcs = inspect.getmembers(view_instance, predicate=inspect.ismethod) - - for _, func in funcs[::-1]: - if hasattr(func, "_exposed"): - self.admin.add_route( - route=func, - path=func._path, - methods=func._methods, - name=func._identity, - include_in_schema=func._include_in_schema, - ) - view.identity = func._identity + self._find_decorated_funcs( + view, view_instance, self._handle_expose_decorated_func + ) view.templates = self.templates self._views.append(view_instance) @@ -638,3 +701,45 @@ def wrap(func): return func return wrap + + +def action( + name: str, + label: Optional[str] = None, + confirmation_message: Optional[str] = None, + *, + include_in_schema: bool = True, + add_in_detail: bool = True, + add_in_list: bool = True, +) -> Callable[..., Any]: + """Decorate a [`ModelView`][sqladmin.models.ModelView] function + with this to: + + * expose it as a custom "action" route + * add a button to the admin panel to invoke the action + + When invoked from the admin panel, the following query parameter(s) are passed: + + * `pks`: the comma-separated list of selected object PKs - can be empty + + Args: + name: Unique name for the action - must match `^[A-Za-z0-9 \-_]+$` regex + label: Human-readable text describing action + confirmation_message: Message to show before confirming action + include_in_schema: Should the endpoint be included in the schema? + add_in_detail: Should action be invocable from the "Detail" view? + add_in_list: Should action be invocable from the "List" view? + """ + + @no_type_check + def wrap(func): + func._action = True + func._slug = slugify_action_name(name) + func._label = label if label is not None else name + func._confirmation_message = confirmation_message + func._include_in_schema = include_in_schema + func._add_in_detail = add_in_detail + func._add_in_list = add_in_list + return func + + return wrap diff --git a/sqladmin/helpers.py b/sqladmin/helpers.py index 5abc3378..574c02b7 100644 --- a/sqladmin/helpers.py +++ b/sqladmin/helpers.py @@ -93,6 +93,16 @@ def slugify_class_name(name: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1-\2", dashed).lower() +def slugify_action_name(name: str) -> str: + if not re.search(r"^[A-Za-z0-9 \-_]+$", name): + raise ValueError( + "name must be non-empty and contain only allowed characters" + " - use `label` for more expressive names" + ) + + return re.sub(r"[_ ]", "-", name).lower() + + def secure_filename(filename: str) -> str: """Ported from Werkzeug. diff --git a/sqladmin/models.py b/sqladmin/models.py index bed5ee96..76062280 100644 --- a/sqladmin/models.py +++ b/sqladmin/models.py @@ -724,6 +724,10 @@ def __init__(self) -> None: model_admin=self, name=name, options=options ) + self._custom_actions_in_list: Dict[str, str] = {} + self._custom_actions_in_detail: Dict[str, str] = {} + self._custom_actions_confirmation: Dict[str, str] = {} + def _run_query_sync(self, stmt: ClauseElement) -> Any: with self.sessionmaker(expire_on_commit=False) as session: result = session.execute(stmt) @@ -738,7 +742,7 @@ async def _run_query(self, stmt: ClauseElement) -> Any: return await anyio.to_thread.run_sync(self._run_query_sync, stmt) def _url_for_details(self, request: Request, obj: Any) -> Union[str, URL]: - pk = getattr(obj, get_primary_key(obj).name) + pk = self._get_pk(obj) return request.url_for( "admin:details", identity=slugify_class_name(obj.__class__.__name__), @@ -746,7 +750,7 @@ def _url_for_details(self, request: Request, obj: Any) -> Union[str, URL]: ) def _url_for_edit(self, request: Request, obj: Any) -> Union[str, URL]: - pk = getattr(obj, get_primary_key(obj).name) + pk = self._get_pk(obj) return request.url_for( "admin:edit", identity=slugify_class_name(obj.__class__.__name__), @@ -754,7 +758,7 @@ def _url_for_edit(self, request: Request, obj: Any) -> Union[str, URL]: ) def _url_for_delete(self, request: Request, obj: Any) -> str: - pk = getattr(obj, get_primary_key(obj).name) + pk = self._get_pk(obj) query_params = urlencode({"pks": pk}) url = request.url_for( "admin:delete", identity=slugify_class_name(obj.__class__.__name__) @@ -775,6 +779,17 @@ def _url_for_details_with_prop( pk=pk, ) + def _url_for_action(self, request: Request, action_name: str) -> str: + return str( + request.url_for( + f"admin:{self.identity}-{action_name}", + identity=self.identity, + ) + ) + + def _get_pk(self, obj: Any) -> Any: + return getattr(obj, get_primary_key(obj).name) + def _get_default_sort(self) -> List[Tuple[str, bool]]: if self.column_default_sort: if isinstance(self.column_default_sort, list): diff --git a/sqladmin/statics/js/main.js b/sqladmin/statics/js/main.js index de98b647..c3428904 100644 --- a/sqladmin/statics/js/main.js +++ b/sqladmin/statics/js/main.js @@ -134,6 +134,17 @@ $("#action-delete").click(function () { }); }); +$("[id^='action-custom-']").click(function () { + var pks = []; + $('.select-box').each(function () { + if ($(this).is(':checked')) { + pks.push($(this).siblings().get(0).value); + } + }); + + window.location.href = $(this).attr('data-url') + '?pks=' + pks.join(","); +}); + // Select2 Tags $(':input[data-role="select2-tags"]').each(function () { $(this).select2({ diff --git a/sqladmin/templates/details.html b/sqladmin/templates/details.html index d1cd00aa..b3b05efd 100644 --- a/sqladmin/templates/details.html +++ b/sqladmin/templates/details.html @@ -58,6 +58,19 @@

{{ model_view.pk_column.name }}: {{ model_view.get_prop_v {% endif %} + {% for custom_action,label in model_view._custom_actions_in_detail.items() %} +
+ {% if custom_action in model_view._custom_actions_confirmation %} + + {{ label }} + + {% else %} + + {{ label }} + + {% endif %} +
+ {% endfor %} @@ -66,4 +79,13 @@

{{ model_view.pk_column.name }}: {{ model_view.get_prop_v {% if model_view.can_delete %} {% include 'modals/delete.html' %} {% endif %} + +{% for custom_action in model_view._custom_actions_in_detail %} +{% if custom_action in model_view._custom_actions_confirmation %} +{% with confirmation_message = model_view._custom_actions_confirmation[custom_action], custom_action=custom_action, url=model_view._url_for_action(request, custom_action) + '?pks=' + (model_view._get_pk(model) | string) %} +{% include 'modals/details_action_confirmation.html' %} +{% endwith %} +{% endif %} +{% endfor %} + {% endblock %} diff --git a/sqladmin/templates/list.html b/sqladmin/templates/list.html index 2d507f48..52567fb0 100644 --- a/sqladmin/templates/list.html +++ b/sqladmin/templates/list.html @@ -37,12 +37,27 @@

{{ model_view.name_plural }}

@@ -167,6 +182,16 @@

{{ model_view.name_plural }}

+ {% if model_view.can_delete %} {% include 'modals/delete.html' %} + {% endif %} + + {% for custom_action in model_view._custom_actions_in_list %} + {% if custom_action in model_view._custom_actions_confirmation %} + {% with confirmation_message = model_view._custom_actions_confirmation[custom_action], custom_action=custom_action, url=model_view._url_for_action(request, custom_action) %} + {% include 'modals/list_action_confirmation.html' %} + {% endwith %} + {% endif %} + {% endfor %} {% endblock %} diff --git a/sqladmin/templates/modals/details_action_confirmation.html b/sqladmin/templates/modals/details_action_confirmation.html new file mode 100644 index 00000000..a27e07e6 --- /dev/null +++ b/sqladmin/templates/modals/details_action_confirmation.html @@ -0,0 +1,28 @@ + diff --git a/sqladmin/templates/modals/list_action_confirmation.html b/sqladmin/templates/modals/list_action_confirmation.html new file mode 100644 index 00000000..db78563f --- /dev/null +++ b/sqladmin/templates/modals/list_action_confirmation.html @@ -0,0 +1,28 @@ + diff --git a/tests/test_helpers.py b/tests/test_helpers.py index d6526d8e..5147edbf 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,13 @@ from datetime import timedelta -from sqladmin.helpers import is_falsy_value, parse_interval, secure_filename +import pytest + +from sqladmin.helpers import ( + is_falsy_value, + parse_interval, + secure_filename, + slugify_action_name, +) def test_secure_filename(monkeypatch): @@ -26,3 +33,10 @@ def test_is_falsy_values(): assert is_falsy_value("") is True assert is_falsy_value(0) is False assert is_falsy_value("example") is False + + +def test_slugify_action_name(): + assert slugify_action_name("custom action") == "custom-action" + + with pytest.raises(ValueError): + slugify_action_name("custom action !@#$%") diff --git a/tests/test_models_action.py b/tests/test_models_action.py new file mode 100644 index 00000000..e72afcd4 --- /dev/null +++ b/tests/test_models_action.py @@ -0,0 +1,226 @@ +from typing import Any, Generator, List +from unittest.mock import Mock + +import pytest +from sqlalchemy import Column, Integer +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response +from starlette.testclient import TestClient + +from sqladmin import Admin, ModelView +from sqladmin.application import action +from tests.common import sync_engine as engine + +Base: Any = declarative_base() + +Session = sessionmaker(bind=engine) + +app = Starlette() +admin = Admin(app=app, engine=engine) + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + + def __repr__(self) -> str: + return f"User: {self.id}" + + +class UserAdmin(ModelView, model=User): + async def _action_stub(self, request: Request) -> Response: + pks = request.query_params.get("pks", "") + + obj_strs: List[str] = [] + for pk in pks.split(","): + obj = await self.get_object_for_edit(pk) + + obj_strs.append(repr(obj)) + + response = RedirectResponse( + request.url_for("admin:list", identity=self.identity) + ) + response.headers["X-Objs"] = ",".join(obj_strs) + return response + + @action(name="details", add_in_detail=True, add_in_list=False) + async def action_details(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action(name="list", add_in_detail=False, add_in_list=True) + async def action_list(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action(name="details_list", add_in_detail=True, add_in_list=True) + async def action_details_list(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="details_confirm", + confirmation_message="!Details Confirm?!", + add_in_detail=True, + add_in_list=False, + ) + async def action_details_confirm(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="list_confirm", + confirmation_message="!List Confirm?!", + add_in_detail=False, + add_in_list=True, + ) + async def action_list_confirm(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="details_list_confirm", + confirmation_message="!Details List Confirm?!", + add_in_detail=True, + add_in_list=True, + ) + async def action_details_list_confirm(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="label_details", + label="Label Details", + add_in_detail=True, + add_in_list=False, + ) + async def action_label_details(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="label_list", label="Label List", add_in_detail=False, add_in_list=True + ) + async def action_label_list(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="label_details_list", + label="Label Details List", + add_in_detail=True, + add_in_list=True, + ) + async def action_label_details_list(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="label_details_confirm", + label="Label Details Confirm", + confirmation_message="!Label Details Confirm?!", + add_in_detail=True, + add_in_list=False, + ) + async def action_label_details_confirm(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="label_list_confirm", + label="Label List Confirm", + confirmation_message="!Label List Confirm?!", + add_in_detail=False, + add_in_list=True, + ) + async def action_label_list_confirm(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + @action( + name="label_details_list_confirm", + label="Label Details List Confirm", + confirmation_message="!Label Details List Confirm?!", + add_in_detail=True, + add_in_list=True, + ) + async def action_label_details_list_confirm(self, request: Request) -> Response: + return await self._action_stub(request) # pragma: no cover + + +@pytest.fixture(autouse=True) +def prepare_database() -> Generator[None, None, None]: + Base.metadata.create_all(engine) + yield + Base.metadata.drop_all(engine) + + +@pytest.fixture +def client() -> Generator[TestClient, None, None]: + with TestClient(app=app, base_url="http://testserver") as c: + yield c + + +def test_model_action(client: TestClient) -> None: + admin.add_view(UserAdmin) + + assert admin.views[0]._custom_actions_in_list == { + "list": "list", + "details-list": "details_list", + "label-list": "Label List", + "label-details-list": "Label Details List", + "list-confirm": "list_confirm", + "details-list-confirm": "details_list_confirm", + "label-list-confirm": "Label List Confirm", + "label-details-list-confirm": "Label Details List Confirm", + } + + assert admin.views[0]._custom_actions_in_detail == { + "details": "details", + "details-confirm": "details_confirm", + "details-list": "details_list", + "details-list-confirm": "details_list_confirm", + "label-details": "Label Details", + "label-details-confirm": "Label Details Confirm", + "label-details-list": "Label Details List", + "label-details-list-confirm": "Label Details List Confirm", + } + + assert admin.views[0]._custom_actions_confirmation == { + "details-confirm": "!Details Confirm?!", + "details-list-confirm": "!Details List Confirm?!", + "label-details-confirm": "!Label Details Confirm?!", + "label-details-list-confirm": "!Label Details List Confirm?!", + "label-list-confirm": "!Label List Confirm?!", + "list-confirm": "!List Confirm?!", + } + + request = Mock(Request) + request.url_for = Mock() + + admin.views[0]._url_for_action(request, "test") + request.url_for.assert_called_with("admin:user-test", identity="user") + + with Session() as session: + user1 = User() + user2 = User() + session.add(user1) + session.add(user2) + session.commit() + + response = client.get( + f"/admin/user/action/details?pks={user1.id},{user2.id}", + follow_redirects=False, + ) + assert response.status_code == 307 + assert f"User: {user1.id}" in response.headers["X-Objs"] + assert f"User: {user2.id}" in response.headers["X-Objs"] + + response = client.get("/admin/user/list") + assert response.text.count("!Details Confirm?!") == 0 + assert response.text.count("!List Confirm?!") == 1 + assert response.text.count("!Details List Confirm?!") == 1 + assert response.text.count("!Label Details Confirm?!") == 0 + assert response.text.count("!Label List Confirm?!") == 1 + assert response.text.count("!Label Details List Confirm?!") == 1 + + response = client.get(f"/admin/user/details/{user1.id}") + assert response.text.count("!Details Confirm?!") == 1 + assert response.text.count("!List Confirm?!") == 0 + assert response.text.count("!Details List Confirm?!") == 1 + assert response.text.count("!Label Details Confirm?!") == 1 + assert response.text.count("!Label List Confirm?!") == 0 + assert response.text.count("!Label Details List Confirm?!") == 1