Skip to content

Commit

Permalink
fix: type as vars
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Nov 5, 2024
1 parent 77bd91c commit 4922af6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
4 changes: 3 additions & 1 deletion codeforlife/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from .user.models import User
from .user.models.session import SessionStore

AnyUser = t.TypeVar("AnyUser")
AnyUser = t.TypeVar("AnyUser", bound=User)
else:
AnyUser = t.TypeVar("AnyUser")


# pylint: disable-next=missing-class-docstring
Expand Down
32 changes: 12 additions & 20 deletions codeforlife/tests/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,16 @@
from .test import TestCase

if t.TYPE_CHECKING:
from ..user.models import (
AdminSchoolTeacherUser,
AuthFactor,
IndependentUser,
NonAdminSchoolTeacherUser,
NonSchoolTeacherUser,
SchoolTeacherUser,
StudentUser,
TeacherUser,
TypedUser,
User,
)
from ..user.models import TypedUser, User

RequestUser = t.TypeVar("RequestUser", bound=User)
LoginUser = t.TypeVar("LoginUser", bound=User)
else:
RequestUser = t.TypeVar("RequestUser")
LoginUser = t.TypeVar("LoginUser")


class APIClient(_APIClient, t.Generic["RequestUser"]):
class APIClient(_APIClient, t.Generic[RequestUser]):
"""Base API client to be inherited by all other API clients."""

_test_case: "APITestCase[RequestUser]"
Expand All @@ -58,7 +50,7 @@ def __init__(
)

@classmethod
def get_request_user_class(cls) -> t.Type["RequestUser"]:
def get_request_user_class(cls) -> t.Type[RequestUser]:
"""Get the request's user class.
Returns:
Expand Down Expand Up @@ -120,7 +112,7 @@ def _make_assertions():
# Login Helpers
# --------------------------------------------------------------------------

def _login_user_type(self, user_type: t.Type["LoginUser"], **credentials):
def _login_user_type(self, user_type: t.Type[LoginUser], **credentials):
# pylint: disable-next=import-outside-toplevel
from ..user.models import AuthFactor

Expand All @@ -140,7 +132,7 @@ def _login_user_type(self, user_type: t.Type["LoginUser"], **credentials):
with patch.object(timezone, "now", return_value=now):
assert super().login(
request=self.request_factory.post(
user=t.cast("RequestUser", user)
user=t.cast(RequestUser, user)
),
otp=otp,
), f'Failed to login with OTP "{otp}" at {now}.'
Expand Down Expand Up @@ -524,14 +516,14 @@ def options( # type: ignore[override]
# pylint: enable=too-many-arguments,redefined-builtin


class APITestCase(TestCase, t.Generic["RequestUser"]):
class APITestCase(TestCase, t.Generic[RequestUser]):
"""Base API test case to be inherited by all other API test cases."""

client: APIClient["RequestUser"]
client_class: t.Type[APIClient["RequestUser"]] = APIClient
client: APIClient[RequestUser]
client_class: t.Type[APIClient[RequestUser]] = APIClient

@classmethod
def get_request_user_class(cls) -> t.Type["RequestUser"]:
def get_request_user_class(cls) -> t.Type[RequestUser]:
"""Get the request's user class.
Returns:
Expand Down

0 comments on commit 4922af6

Please sign in to comment.