diff --git a/codeforlife/request.py b/codeforlife/request.py index 6c8e342..2e9d18a 100644 --- a/codeforlife/request.py +++ b/codeforlife/request.py @@ -18,7 +18,7 @@ from .user.models import User from .user.models.session import SessionStore - AnyUser = t.TypeVar("AnyUser", bound=User) +AnyUser = t.TypeVar("AnyUser") # pylint: disable-next=missing-class-docstring @@ -34,11 +34,11 @@ class HttpRequest(_HttpRequest): # pylint: disable-next=missing-class-docstring,abstract-method -class Request(_Request, t.Generic["AnyUser"]): +class Request(_Request, t.Generic[AnyUser]): session: "SessionStore" data: t.Any - def __init__(self, user_class: t.Type["AnyUser"], *args, **kwargs): + def __init__(self, user_class: t.Type[AnyUser], *args, **kwargs): super().__init__(*args, **kwargs) self.user_class = user_class @@ -48,11 +48,18 @@ def query_params(self) -> t.Dict[str, str]: # type: ignore[override] @property def user(self): - return t.cast(t.Union["AnyUser", AnonymousUser], super().user) + return t.cast(t.Union[AnyUser, AnonymousUser], super().user) @user.setter def user(self, value): - if isinstance(value, User) and not isinstance(value, self.user_class): + # pylint: disable-next=import-outside-toplevel + from .user.models import User + + if ( + isinstance(value, User) + and issubclass(self.user_class, User) + and not isinstance(value, self.user_class) + ): value = value.as_type(self.user_class) self._user = value @@ -66,7 +73,7 @@ def anon_user(self): @property def auth_user(self): """The authenticated user that made the request.""" - return t.cast("AnyUser", self.user) + return t.cast(AnyUser, self.user) @property def teacher_user(self):