From 4675ee289a551205abbb0cf8d94e1a1eb202b1f2 Mon Sep 17 00:00:00 2001 From: SKairinos Date: Thu, 7 Nov 2024 17:27:05 +0000 Subject: [PATCH] get arg helper --- codeforlife/forms.py | 6 ++---- codeforlife/models/base_session_store.py | 12 ++++-------- codeforlife/serializers/model_list.py | 7 ++----- codeforlife/tests/api.py | 6 ++---- codeforlife/tests/api_client.py | 7 ++----- codeforlife/tests/api_request_factory.py | 6 ++---- codeforlife/tests/model.py | 6 ++---- codeforlife/tests/model_serializer.py | 12 +++--------- codeforlife/tests/model_view_set.py | 16 ++++------------ codeforlife/types.py | 16 ++++++++++++++++ codeforlife/views/api.py | 6 ++---- codeforlife/views/model.py | 7 ++----- 12 files changed, 43 insertions(+), 64 deletions(-) diff --git a/codeforlife/forms.py b/codeforlife/forms.py index 12f3101b..b93046a7 100644 --- a/codeforlife/forms.py +++ b/codeforlife/forms.py @@ -11,6 +11,7 @@ from django.core.handlers.wsgi import WSGIRequest from .models import AbstractBaseUser +from .types import get_arg AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser) @@ -23,10 +24,7 @@ class BaseLoginForm(forms.Form, t.Generic[AnyAbstractBaseUser]): @classmethod def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: """Get the user class.""" - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def __init__(self, request: WSGIRequest, *args, **kwargs): self.request = request diff --git a/codeforlife/models/base_session_store.py b/codeforlife/models/base_session_store.py index 332a5de2..b6e5e648 100644 --- a/codeforlife/models/base_session_store.py +++ b/codeforlife/models/base_session_store.py @@ -9,6 +9,8 @@ from django.contrib.sessions.backends.db import SessionStore from django.utils import timezone +from ..types import get_arg + if t.TYPE_CHECKING: from .abstract_base_session import AbstractBaseSession from .abstract_base_user import AbstractBaseUser @@ -35,18 +37,12 @@ class BaseSessionStore( @classmethod def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]: - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @classmethod def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]: """Get the user class.""" - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 1) def associate_session_to_user( self, session: AnyAbstractBaseSession, user_id: int diff --git a/codeforlife/serializers/model_list.py b/codeforlife/serializers/model_list.py index 99be97ac..ecff4273 100644 --- a/codeforlife/serializers/model_list.py +++ b/codeforlife/serializers/model_list.py @@ -12,7 +12,7 @@ from rest_framework.serializers import ValidationError as _ValidationError from ..request import BaseRequest, Request -from ..types import DataDict, OrderedDataDict +from ..types import DataDict, OrderedDataDict, get_arg from .base import BaseSerializer # pylint: disable=duplicate-code @@ -75,10 +75,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def __init__(self, *args, **kwargs): instance = args[0] if args else kwargs.pop("instance", None) diff --git a/codeforlife/tests/api.py b/codeforlife/tests/api.py index 02420746..ac54e987 100644 --- a/codeforlife/tests/api.py +++ b/codeforlife/tests/api.py @@ -5,6 +5,7 @@ import typing as t +from ..types import get_arg from .api_client import APIClient, BaseAPIClient from .test import TestCase @@ -47,10 +48,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _get_client_class(self): # pylint: disable-next=too-few-public-methods diff --git a/codeforlife/tests/api_client.py b/codeforlife/tests/api_client.py index 4835c7c9..3dfa4450 100644 --- a/codeforlife/tests/api_client.py +++ b/codeforlife/tests/api_client.py @@ -12,7 +12,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient as _APIClient -from ..types import DataDict, JsonDict +from ..types import DataDict, JsonDict, get_arg from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory # pylint: disable=duplicate-code @@ -329,10 +329,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) # -------------------------------------------------------------------------- # Login Helpers diff --git a/codeforlife/tests/api_request_factory.py b/codeforlife/tests/api_request_factory.py index 0758e1a5..fbce7e56 100644 --- a/codeforlife/tests/api_request_factory.py +++ b/codeforlife/tests/api_request_factory.py @@ -16,6 +16,7 @@ from rest_framework.test import APIRequestFactory as _APIRequestFactory from ..request import BaseRequest, Request +from ..types import get_arg # pylint: disable=duplicate-code if t.TYPE_CHECKING: @@ -247,10 +248,7 @@ def get_user_class(cls) -> t.Type[AnyUser]: Returns: The user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _init_request(self, wsgi_request): return Request[AnyUser]( diff --git a/codeforlife/tests/model.py b/codeforlife/tests/model.py index e04eaa00..a61edb1c 100644 --- a/codeforlife/tests/model.py +++ b/codeforlife/tests/model.py @@ -10,6 +10,7 @@ from django.db.models import Model from django.db.utils import IntegrityError +from ..types import get_arg from .test import TestCase AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -25,10 +26,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def assert_raises_integrity_error(self, *args, **kwargs): """Assert the code block raises an integrity error. diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 84f52052..1f41dec1 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -18,7 +18,7 @@ BaseModelSerializer, ModelSerializer, ) -from ..types import DataDict +from ..types import DataDict, get_arg from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .test import TestCase @@ -397,10 +397,7 @@ def get_request_user_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -409,10 +406,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 1) @classmethod def _initialize_request_factory(cls, **kwargs): diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index d52b344e..2304a926 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -16,7 +16,7 @@ from ..models import AbstractBaseUser from ..permissions import Permission from ..serializers import BaseSerializer -from ..types import DataDict, JsonDict, KwArgs +from ..types import DataDict, JsonDict, KwArgs, get_arg from ..views import BaseModelViewSet, ModelViewSet from .api import APITestCase, BaseAPITestCase from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient @@ -65,9 +65,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: The model view set's class. """ # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 2 - ] + return get_arg(cls, 2) @classmethod def setUpClass(cls): @@ -272,10 +270,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @classmethod def get_model_class(cls) -> t.Type[AnyModel]: @@ -284,10 +279,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] + return get_arg(cls, 1) def _get_client_class(self): # TODO: unpack type args in index after moving to python 3.11 diff --git a/codeforlife/types.py b/codeforlife/types.py index 2d151263..a2c85b36 100644 --- a/codeforlife/types.py +++ b/codeforlife/types.py @@ -7,6 +7,8 @@ import typing as t +T = t.TypeVar("T") + Args = t.Tuple[t.Any, ...] KwArgs = t.Dict[str, t.Any] @@ -16,3 +18,17 @@ DataDict = t.Dict[str, t.Any] OrderedDataDict = t.OrderedDict[str, t.Any] + + +def get_arg(cls: t.Type[t.Any], index: int, orig_base: int = 0): + """Get a type arg from a class. + + Args: + cls: The class to get the type arg from. + index: The index of the type arg to get. + orig_base: The base class to get the type arg from. + + Returns: + The type arg from the class. + """ + return t.get_args(cls.__orig_bases__[orig_base])[index] diff --git a/codeforlife/views/api.py b/codeforlife/views/api.py index ff143590..1b95ead1 100644 --- a/codeforlife/views/api.py +++ b/codeforlife/views/api.py @@ -9,6 +9,7 @@ from rest_framework.views import APIView as _APIView from ..request import BaseRequest, Request +from ..types import get_arg # pylint: disable=duplicate-code if t.TYPE_CHECKING: @@ -55,10 +56,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]: Returns: The request's user class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) def _initialize_request(self, request, **kwargs): kwargs["user_class"] = self.get_request_user_class() diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 3f6dd195..1ac326e5 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -15,7 +15,7 @@ from ..permissions import Permission from ..request import BaseRequest, Request -from ..types import KwArgs +from ..types import KwArgs, get_arg from .api import APIView, BaseAPIView from .decorators import action @@ -71,10 +71,7 @@ def get_model_class(cls) -> t.Type[AnyModel]: Returns: The model view set's class. """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] + return get_arg(cls, 0) @cached_property def lookup_field_name(self):