Skip to content

Commit

Permalink
Uplift type hints (#780)
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee authored Jun 13, 2024
1 parent c2dd278 commit 870f628
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 172 deletions.
6 changes: 3 additions & 3 deletions docs/cookbook/optimize_relationship_loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ class ParentAdmin(ModelView, model=Parent):
form_excluded_columns = [Parent.children]
```

### Using `edit_form_query` to customize the edit form data
### Using `form_edit_query` to customize the edit form data

If you would like to fully customize the query to populate the edit object form, you may override
the `edit_form_query` function with your own SQLAlchemy query. In the following example, overriding
the `form_edit_query` function with your own SQLAlchemy query. In the following example, overriding
the default query will allow you to filter relationships to show only related children of the parent.

```py
class ParentAdmin(ModelView, model=Parent):
def edit_form_query(self, request: Request) -> Select:
def form_edit_query(self, request: Request) -> Select:
parent_id = request.path_params["pk"]
return (
self._stmt_by_identifier(parent_id)
Expand Down
20 changes: 11 additions & 9 deletions sqladmin/_menu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, List, Optional, Union
from __future__ import annotations

from typing import TYPE_CHECKING

from starlette.datastructures import URL
from starlette.requests import Request
Expand All @@ -8,11 +10,11 @@


class ItemMenu:
def __init__(self, name: str, icon: Optional[str] = None) -> None:
def __init__(self, name: str, icon: str | None = None) -> None:
self.name = name
self.icon = icon
self.parent: Optional["ItemMenu"] = None
self.children: List["ItemMenu"] = []
self.parent: "ItemMenu" | None = None
self.children: list["ItemMenu"] = []

def add_child(self, item: "ItemMenu") -> None:
item.parent = self
Expand All @@ -27,7 +29,7 @@ def is_accessible(self, request: Request) -> bool:
def is_active(self, request: Request) -> bool:
return False

def url(self, request: Request) -> Union[str, URL]:
def url(self, request: Request) -> str | URL:
return "#"

@property
Expand All @@ -53,9 +55,9 @@ def type_(self) -> str:
class ViewMenu(ItemMenu):
def __init__(
self,
view: Union["BaseView", "ModelView"],
view: "BaseView" | "ModelView",
name: str,
icon: Optional[str] = None,
icon: str | None = None,
) -> None:
super().__init__(name=name, icon=icon)
self.view = view
Expand All @@ -69,7 +71,7 @@ def is_accessible(self, request: Request) -> bool:
def is_active(self, request: Request) -> bool:
return self.view.identity == request.path_params.get("identity")

def url(self, request: Request) -> Union[str, URL]:
def url(self, request: Request) -> str | URL:
if self.view.is_model:
return request.url_for("admin:list", identity=self.view.identity)
return request.url_for(f"admin:{self.view.identity}")
Expand All @@ -85,7 +87,7 @@ def type_(self) -> str:

class Menu:
def __init__(self) -> None:
self.items: List[ItemMenu] = []
self.items: list[ItemMenu] = []

def add(self, item: ItemMenu) -> None:
# Only works for one-level menu
Expand Down
14 changes: 8 additions & 6 deletions sqladmin/_queries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import anyio
from sqlalchemy import select
Expand All @@ -24,7 +26,7 @@ class Query:
def __init__(self, model_view: "ModelView") -> None:
self.model_view = model_view

def _get_to_many_stmt(self, relation: MODEL_PROPERTY, values: List[Any]) -> Select:
def _get_to_many_stmt(self, relation: MODEL_PROPERTY, values: list[Any]) -> Select:
target = relation.mapper.class_

target_pks = get_primary_keys(target)
Expand Down Expand Up @@ -131,7 +133,7 @@ async def _set_attributes_async(
setattr(obj, key, value)
return obj

def _update_sync(self, pk: Any, data: Dict[str, Any], request: Request) -> Any:
def _update_sync(self, pk: Any, data: dict[str, Any], request: Request) -> Any:
stmt = self.model_view._stmt_by_identifier(pk)

with self.model_view.session_maker(expire_on_commit=False) as session:
Expand All @@ -147,7 +149,7 @@ def _update_sync(self, pk: Any, data: Dict[str, Any], request: Request) -> Any:
return obj

async def _update_async(
self, pk: Any, data: Dict[str, Any], request: Request
self, pk: Any, data: dict[str, Any], request: Request
) -> Any:
stmt = self.model_view._stmt_by_identifier(pk)

Expand Down Expand Up @@ -187,7 +189,7 @@ async def _delete_async(self, pk: str, request: Request) -> None:
await session.commit()
await self.model_view.after_model_delete(obj, request)

def _insert_sync(self, data: Dict[str, Any], request: Request) -> Any:
def _insert_sync(self, data: dict[str, Any], request: Request) -> Any:
obj = self.model_view.model()

with self.model_view.session_maker(expire_on_commit=False) as session:
Expand All @@ -202,7 +204,7 @@ def _insert_sync(self, data: Dict[str, Any], request: Request) -> Any:
)
return obj

async def _insert_async(self, data: Dict[str, Any], request: Request) -> Any:
async def _insert_async(self, data: dict[str, Any], request: Request) -> Any:
obj = self.model_view.model()

async with self.model_view.session_maker(expire_on_commit=False) as session:
Expand Down
8 changes: 5 additions & 3 deletions sqladmin/ajax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from sqlalchemy import String, cast, inspect, or_, select

Expand Down Expand Up @@ -52,13 +54,13 @@ def _process_fields(self) -> list:

return remote_fields

def format(self, model: type) -> Dict[str, Any]:
def format(self, model: type) -> dict[str, Any]:
if not model:
return {}

return {"id": str(get_object_identifier(model)), "text": str(model)}

async def get_list(self, term: str, limit: int = DEFAULT_PAGE_SIZE) -> List[Any]:
async def get_list(self, term: str, limit: int = DEFAULT_PAGE_SIZE) -> list[Any]:
stmt = select(self.model)

# no type casting to string if a ColumnAssociationProxyInstance is given
Expand Down
67 changes: 32 additions & 35 deletions sqladmin/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
import io
import logging
Expand All @@ -7,12 +9,7 @@
Any,
Awaitable,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
no_type_check,
)
Expand Down Expand Up @@ -66,14 +63,14 @@ class BaseAdmin:
def __init__(
self,
app: Starlette,
engine: Optional[ENGINE_TYPE] = None,
session_maker: Optional[sessionmaker] = None,
engine: ENGINE_TYPE | None = None,
session_maker: sessionmaker | None = None,
base_url: str = "/admin",
title: str = "Admin",
logo_url: Optional[str] = None,
logo_url: str | None = None,
templates_dir: str = "templates",
middlewares: Optional[Sequence[Middleware]] = None,
authentication_backend: Optional[AuthenticationBackend] = None,
middlewares: Sequence[Middleware] | None = None,
authentication_backend: AuthenticationBackend | None = None,
) -> None:
self.app = app
self.engine = engine
Expand All @@ -100,7 +97,7 @@ def __init__(

self.admin = Starlette(middleware=middlewares)
self.templates = self.init_templating_engine()
self._views: List[Union[BaseView, ModelView]] = []
self._views: list[BaseView | ModelView] = []
self._menu = Menu()

def init_templating_engine(self) -> Jinja2Templates:
Expand All @@ -120,7 +117,7 @@ def init_templating_engine(self) -> Jinja2Templates:
return templates

@property
def views(self) -> List[Union[BaseView, ModelView]]:
def views(self) -> list[BaseView | ModelView]:
"""Get list of ModelView and BaseView instances lazily.
Returns:
Expand All @@ -136,7 +133,7 @@ def _find_model_view(self, identity: str) -> ModelView:

raise HTTPException(status_code=404)

def add_view(self, view: Union[Type[ModelView], Type[BaseView]]) -> None:
def add_view(self, view: type[ModelView] | type[BaseView]) -> None:
"""Add ModelView or BaseView classes to Admin.
This is a shortcut that will handle both `add_model_view` and `add_base_view`.
"""
Expand All @@ -149,10 +146,10 @@ def add_view(self, view: Union[Type[ModelView], Type[BaseView]]) -> None:

def _find_decorated_funcs(
self,
view: Type[Union[BaseView, ModelView]],
view_instance: Union[BaseView, ModelView],
view: type[BaseView | ModelView],
view_instance: BaseView | ModelView,
handle_fn: Callable[
[MethodType, Type[Union[BaseView, ModelView]], Union[BaseView, ModelView]],
[MethodType, type[BaseView | ModelView], BaseView | ModelView],
None,
],
) -> None:
Expand All @@ -164,8 +161,8 @@ def _find_decorated_funcs(
def _handle_action_decorated_func(
self,
func: MethodType,
view: Type[Union[BaseView, ModelView]],
view_instance: Union[BaseView, ModelView],
view: type[BaseView | ModelView],
view_instance: BaseView | ModelView,
) -> None:
if hasattr(func, "_action"):
view_instance = cast(ModelView, view_instance)
Expand Down Expand Up @@ -194,8 +191,8 @@ def _handle_action_decorated_func(
def _handle_expose_decorated_func(
self,
func: MethodType,
view: Type[Union[BaseView, ModelView]],
view_instance: Union[BaseView, ModelView],
view: type[BaseView | ModelView],
view_instance: BaseView | ModelView,
) -> None:
if hasattr(func, "_exposed"):
self.admin.add_route(
Expand All @@ -208,7 +205,7 @@ def _handle_expose_decorated_func(

view.identity = getattr(func, "_identity")

def add_model_view(self, view: Type[ModelView]) -> None:
def add_model_view(self, view: type[ModelView]) -> None:
"""Add ModelView to the Admin.
???+ usage
Expand Down Expand Up @@ -237,7 +234,7 @@ class UserAdmin(ModelView, model=User):
self._views.append(view_instance)
self._build_menu(view_instance)

def add_base_view(self, view: Type[BaseView]) -> None:
def add_base_view(self, view: type[BaseView]) -> None:
"""Add BaseView to the Admin.
???+ usage
Expand Down Expand Up @@ -265,7 +262,7 @@ async def test_page(self, request: Request):
self._views.append(view_instance)
self._build_menu(view_instance)

def _build_menu(self, view: Union[ModelView, BaseView]) -> None:
def _build_menu(self, view: ModelView | BaseView) -> None:
if view.category:
menu = CategoryMenu(name=view.category)
menu.add_child(ViewMenu(view=view, name=view.name, icon=view.icon))
Expand Down Expand Up @@ -338,15 +335,15 @@ class UserAdmin(ModelView, model=User):
def __init__(
self,
app: Starlette,
engine: Optional[ENGINE_TYPE] = None,
session_maker: Optional[Union[sessionmaker, "async_sessionmaker"]] = None,
engine: ENGINE_TYPE | None = None,
session_maker: sessionmaker | "async_sessionmaker" | None = None,
base_url: str = "/admin",
title: str = "Admin",
logo_url: Optional[str] = None,
middlewares: Optional[Sequence[Middleware]] = None,
logo_url: str | None = None,
middlewares: Sequence[Middleware] | None = None,
debug: bool = False,
templates_dir: str = "templates",
authentication_backend: Optional[AuthenticationBackend] = None,
authentication_backend: AuthenticationBackend | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -374,7 +371,7 @@ def __init__(

async def http_exception(
request: Request, exc: Exception
) -> Union[Response, Awaitable[Response]]:
) -> Response | Awaitable[Response]:
assert isinstance(exc, HTTPException)
context = {
"status_code": exc.status_code,
Expand Down Expand Up @@ -662,7 +659,7 @@ async def ajax_lookup(self, request: Request) -> Response:

def get_save_redirect_url(
self, request: Request, form: FormData, model_view: ModelView, obj: Any
) -> Union[str, URL]:
) -> str | URL:
"""
Get the redirect URL after a save action
which is triggered from create/edit page.
Expand All @@ -687,7 +684,7 @@ async def _handle_form_data(self, request: Request, obj: Any = None) -> FormData
"""

form = await request.form()
form_data: List[Tuple[str, Union[str, UploadFile]]] = []
form_data: list[tuple[str, str | UploadFile]] = []
for key, value in form.multi_items():
if not isinstance(value, UploadFile):
form_data.append((key, value))
Expand Down Expand Up @@ -728,8 +725,8 @@ def _denormalize_wtform_data(self, form_data: dict, obj: Any) -> dict:
def expose(
path: str,
*,
methods: List[str] = ["GET"],
identity: Optional[str] = None,
methods: list[str] = ["GET"],
identity: str | None = None,
include_in_schema: bool = True,
) -> Callable[..., Any]:
"""Expose View with information."""
Expand All @@ -748,8 +745,8 @@ def wrap(func):

def action(
name: str,
label: Optional[str] = None,
confirmation_message: Optional[str] = None,
label: str | None = None,
confirmation_message: str | None = None,
*,
include_in_schema: bool = True,
add_in_detail: bool = True,
Expand Down
6 changes: 4 additions & 2 deletions sqladmin/authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import functools
import inspect
from typing import Any, Callable, Union
from typing import Any, Callable

from starlette.middleware import Middleware
from starlette.requests import Request
Expand Down Expand Up @@ -33,7 +35,7 @@ async def logout(self, request: Request) -> bool:
"""
raise NotImplementedError()

async def authenticate(self, request: Request) -> Union[Response, bool]:
async def authenticate(self, request: Request) -> Response | bool:
"""Implement authenticate logic here.
This method will be called for each incoming request
to validate the authentication.
Expand Down
Loading

0 comments on commit 870f628

Please sign in to comment.