Skip to content

Commit

Permalink
get arg helper
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Nov 7, 2024
1 parent 882de73 commit 4675ee2
Show file tree
Hide file tree
Showing 12 changed files with 43 additions and 64 deletions.
6 changes: 2 additions & 4 deletions codeforlife/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.core.handlers.wsgi import WSGIRequest

from .models import AbstractBaseUser
from .types import get_arg

AnyAbstractBaseUser = t.TypeVar("AnyAbstractBaseUser", bound=AbstractBaseUser)

Expand All @@ -23,10 +24,7 @@ class BaseLoginForm(forms.Form, t.Generic[AnyAbstractBaseUser]):
@classmethod
def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]:
"""Get the user class."""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

Check warning on line 27 in codeforlife/forms.py

View check run for this annotation

Codecov / codecov/patch

codeforlife/forms.py#L27

Added line #L27 was not covered by tests

def __init__(self, request: WSGIRequest, *args, **kwargs):
self.request = request
Expand Down
12 changes: 4 additions & 8 deletions codeforlife/models/base_session_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from django.contrib.sessions.backends.db import SessionStore
from django.utils import timezone

from ..types import get_arg

if t.TYPE_CHECKING:
from .abstract_base_session import AbstractBaseSession
from .abstract_base_user import AbstractBaseUser

Check warning on line 16 in codeforlife/models/base_session_store.py

View check run for this annotation

Codecov / codecov/patch

codeforlife/models/base_session_store.py#L15-L16

Added lines #L15 - L16 were not covered by tests
Expand All @@ -35,18 +37,12 @@ class BaseSessionStore(

@classmethod
def get_model_class(cls) -> t.Type[AnyAbstractBaseSession]:
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

@classmethod
def get_user_class(cls) -> t.Type[AnyAbstractBaseUser]:
"""Get the user class."""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]
return get_arg(cls, 1)

def associate_session_to_user(
self, session: AnyAbstractBaseSession, user_id: int
Expand Down
7 changes: 2 additions & 5 deletions codeforlife/serializers/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rest_framework.serializers import ValidationError as _ValidationError

from ..request import BaseRequest, Request
from ..types import DataDict, OrderedDataDict
from ..types import DataDict, OrderedDataDict, get_arg
from .base import BaseSerializer

# pylint: disable=duplicate-code
Expand Down Expand Up @@ -75,10 +75,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

Check warning on line 78 in codeforlife/serializers/model_list.py

View check run for this annotation

Codecov / codecov/patch

codeforlife/serializers/model_list.py#L78

Added line #L78 was not covered by tests

def __init__(self, *args, **kwargs):
instance = args[0] if args else kwargs.pop("instance", None)
Expand Down
6 changes: 2 additions & 4 deletions codeforlife/tests/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import typing as t

from ..types import get_arg
from .api_client import APIClient, BaseAPIClient
from .test import TestCase

Expand Down Expand Up @@ -47,10 +48,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]:
Returns:
The request's user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

def _get_client_class(self):
# pylint: disable-next=too-few-public-methods
Expand Down
7 changes: 2 additions & 5 deletions codeforlife/tests/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rest_framework.response import Response
from rest_framework.test import APIClient as _APIClient

from ..types import DataDict, JsonDict
from ..types import DataDict, JsonDict, get_arg
from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory

# pylint: disable=duplicate-code
Expand Down Expand Up @@ -329,10 +329,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]:
Returns:
The request's user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

# --------------------------------------------------------------------------
# Login Helpers
Expand Down
6 changes: 2 additions & 4 deletions codeforlife/tests/api_request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rest_framework.test import APIRequestFactory as _APIRequestFactory

from ..request import BaseRequest, Request
from ..types import get_arg

# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
Expand Down Expand Up @@ -247,10 +248,7 @@ def get_user_class(cls) -> t.Type[AnyUser]:
Returns:
The user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

def _init_request(self, wsgi_request):
return Request[AnyUser](
Expand Down
6 changes: 2 additions & 4 deletions codeforlife/tests/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from django.db.models import Model
from django.db.utils import IntegrityError

from ..types import get_arg
from .test import TestCase

AnyModel = t.TypeVar("AnyModel", bound=Model)
Expand All @@ -25,10 +26,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
Returns:
The model's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

def assert_raises_integrity_error(self, *args, **kwargs):
"""Assert the code block raises an integrity error.
Expand Down
12 changes: 3 additions & 9 deletions codeforlife/tests/model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
BaseModelSerializer,
ModelSerializer,
)
from ..types import DataDict
from ..types import DataDict, get_arg
from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory
from .test import TestCase

Expand Down Expand Up @@ -397,10 +397,7 @@ def get_request_user_class(cls) -> t.Type[AnyModel]:
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
Expand All @@ -409,10 +406,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]
return get_arg(cls, 1)

@classmethod
def _initialize_request_factory(cls, **kwargs):
Expand Down
16 changes: 4 additions & 12 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..models import AbstractBaseUser
from ..permissions import Permission
from ..serializers import BaseSerializer
from ..types import DataDict, JsonDict, KwArgs
from ..types import DataDict, JsonDict, KwArgs, get_arg
from ..views import BaseModelViewSet, ModelViewSet
from .api import APITestCase, BaseAPITestCase
from .model_view_set_client import BaseModelViewSetClient, ModelViewSetClient
Expand Down Expand Up @@ -65,9 +65,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
2
]
return get_arg(cls, 2)

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -272,10 +270,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]:
Returns:
The request's user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
Expand All @@ -284,10 +279,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]
return get_arg(cls, 1)

def _get_client_class(self):
# TODO: unpack type args in index after moving to python 3.11
Expand Down
16 changes: 16 additions & 0 deletions codeforlife/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import typing as t

T = t.TypeVar("T")

Args = t.Tuple[t.Any, ...]
KwArgs = t.Dict[str, t.Any]

Expand All @@ -16,3 +18,17 @@

DataDict = t.Dict[str, t.Any]
OrderedDataDict = t.OrderedDict[str, t.Any]


def get_arg(cls: t.Type[t.Any], index: int, orig_base: int = 0):
"""Get a type arg from a class.
Args:
cls: The class to get the type arg from.
index: The index of the type arg to get.
orig_base: The base class to get the type arg from.
Returns:
The type arg from the class.
"""
return t.get_args(cls.__orig_bases__[orig_base])[index]
6 changes: 2 additions & 4 deletions codeforlife/views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rest_framework.views import APIView as _APIView

from ..request import BaseRequest, Request
from ..types import get_arg

# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
Expand Down Expand Up @@ -55,10 +56,7 @@ def get_request_user_class(cls) -> t.Type[RequestUser]:
Returns:
The request's user class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

def _initialize_request(self, request, **kwargs):
kwargs["user_class"] = self.get_request_user_class()
Expand Down
7 changes: 2 additions & 5 deletions codeforlife/views/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..permissions import Permission
from ..request import BaseRequest, Request
from ..types import KwArgs
from ..types import KwArgs, get_arg
from .api import APIView, BaseAPIView
from .decorators import action

Expand Down Expand Up @@ -71,10 +71,7 @@ def get_model_class(cls) -> t.Type[AnyModel]:
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]
return get_arg(cls, 0)

Check warning on line 74 in codeforlife/views/model.py

View check run for this annotation

Codecov / codecov/patch

codeforlife/views/model.py#L74

Added line #L74 was not covered by tests

@cached_property
def lookup_field_name(self):
Expand Down

0 comments on commit 4675ee2

Please sign in to comment.