Skip to content

Commit

Permalink
abstract model serializer test case
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Nov 6, 2024
1 parent b805632 commit 27df3ee
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 54 deletions.
6 changes: 5 additions & 1 deletion codeforlife/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory
from .cron import CronTestCase
from .model import ModelTestCase
from .model_serializer import (
from .model_list_serializer import (
BaseModelListSerializerTestCase,
ModelListSerializerTestCase,
)
from .model_serializer import (
BaseModelSerializerTestCase,
ModelSerializerTestCase,
)
from .model_view_set import ModelViewSetClient, ModelViewSetTestCase
Expand Down
88 changes: 88 additions & 0 deletions codeforlife/tests/model_list_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
© Ocado Group
Created on 06/11/2024 at 12:45:33(+00:00).
Base test case for all model list serializers.
"""

import typing as t

from django.db.models import Model

from ..serializers import (
BaseModelListSerializer,
BaseModelSerializer,
ModelListSerializer,
ModelSerializer,
)
from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory
from .model_serializer import (
BaseModelSerializerTestCase,
ModelSerializerTestCase,
)

# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
from ..user.models import User

RequestUser = t.TypeVar("RequestUser", bound=User)
else:
RequestUser = t.TypeVar("RequestUser")

AnyModel = t.TypeVar("AnyModel", bound=Model)
AnyBaseModelSerializer = t.TypeVar(
"AnyBaseModelSerializer", bound=BaseModelSerializer
)
AnyBaseModelListSerializer = t.TypeVar(
"AnyBaseModelListSerializer", bound=BaseModelListSerializer
)
AnyBaseAPIRequestFactory = t.TypeVar(
"AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory
)
# pylint: enable=duplicate-code


class BaseModelListSerializerTestCase(
BaseModelSerializerTestCase[
AnyBaseModelSerializer,
AnyBaseAPIRequestFactory,
AnyModel,
],
t.Generic[
AnyBaseModelListSerializer,
AnyBaseModelSerializer,
AnyBaseAPIRequestFactory,
AnyModel,
],
):
"""Base for all model serializer test cases."""

model_list_serializer_class: t.Type[AnyBaseModelListSerializer]

REQUIRED_ATTRS = {
"model_list_serializer_class",
"model_serializer_class",
"request_factory_class",
}

def _init_model_serializer(self, *args, parent=None, **kwargs):
kwargs.setdefault("child", self.model_serializer_class())
serializer = self.model_list_serializer_class(*args, **kwargs)
if parent:
serializer.parent = parent

return serializer


# pylint: disable-next=too-many-ancestors
class ModelListSerializerTestCase(
BaseModelListSerializerTestCase[
ModelListSerializer[RequestUser, AnyModel],
ModelSerializer[RequestUser, AnyModel],
APIRequestFactory[RequestUser],
AnyModel,
],
ModelSerializerTestCase[RequestUser, AnyModel],
t.Generic[RequestUser, AnyModel],
):
"""Base for all model serializer test cases."""
117 changes: 64 additions & 53 deletions codeforlife/tests/model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,62 @@
from django.forms.models import model_to_dict
from rest_framework.serializers import BaseSerializer, ValidationError

from ..serializers import ModelListSerializer, ModelSerializer
from ..serializers import (
BaseModelSerializer,
ModelListSerializer,
ModelSerializer,
)
from ..types import DataDict
from .api_request_factory import APIRequestFactory
from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory
from .test import TestCase

# pylint: disable-next=duplicate-code
# pylint: disable=duplicate-code
if t.TYPE_CHECKING:
from ..user.models import User

RequestUser = t.TypeVar("RequestUser", bound=User)
else:
RequestUser = t.TypeVar("RequestUser")


AnyModel = t.TypeVar("AnyModel", bound=Model)
AnyBaseModelSerializer = t.TypeVar(
"AnyBaseModelSerializer", bound=BaseModelSerializer
)
AnyBaseAPIRequestFactory = t.TypeVar(
"AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory
)
# pylint: enable=duplicate-code


class BaseModelSerializerTestCase(
TestCase,
t.Generic[AnyBaseModelSerializer, AnyBaseAPIRequestFactory, AnyModel],
):
"""Base for all model serializer test cases."""

model_serializer_class: t.Type[AnyBaseModelSerializer]

class ModelSerializerTestCase(TestCase, t.Generic[RequestUser, AnyModel]):
"""Base for all model serializer test cases."""
request_factory: AnyBaseAPIRequestFactory
request_factory_class: t.Type[AnyBaseAPIRequestFactory]

model_serializer_class: t.Type[ModelSerializer[RequestUser, AnyModel]]
REQUIRED_ATTRS: t.Set[str] = {
"model_serializer_class",
"request_factory_class",
}

request_factory: APIRequestFactory[RequestUser]
@classmethod
def _initialize_request_factory(cls, **kwargs):
return cls.request_factory_class(**kwargs)

@classmethod
def setUpClass(cls):
attr_name = "model_serializer_class"
assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.'
for attr in cls.REQUIRED_ATTRS:
assert hasattr(cls, attr), f'Attribute "{attr}" must be set.'

cls.request_factory = APIRequestFactory(cls.get_request_user_class())
cls.request_factory = cls._initialize_request_factory()

return super().setUpClass()

@classmethod
def get_request_user_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]

# --------------------------------------------------------------------------
# Private helpers.
# --------------------------------------------------------------------------
Expand Down Expand Up @@ -379,31 +378,43 @@ def assert_new_data_is_subset_of_data(new_data: DataDict, data):
self._assert_data_is_subset_of_model(data, instance)


class ModelListSerializerTestCase(
ModelSerializerTestCase[RequestUser, AnyModel],
class ModelSerializerTestCase(
BaseModelSerializerTestCase[
ModelSerializer[RequestUser, AnyModel],
APIRequestFactory[RequestUser],
AnyModel,
],
t.Generic[RequestUser, AnyModel],
):
"""Base for all model serializer test cases."""

model_list_serializer_class: t.Type[
ModelListSerializer[RequestUser, AnyModel]
]
request_factory_class = APIRequestFactory

@classmethod
def setUpClass(cls):
attr_name = "model_list_serializer_class"
assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.'
def get_request_user_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
return super().setUpClass()
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

# --------------------------------------------------------------------------
# Private helpers.
# --------------------------------------------------------------------------
@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
def _init_model_serializer(self, *args, parent=None, **kwargs):
kwargs.setdefault("child", self.model_serializer_class())
serializer = self.model_list_serializer_class(*args, **kwargs)
if parent:
serializer.parent = parent
Returns:
The model view set's class.
"""
# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
1
]

return serializer
@classmethod
def _initialize_request_factory(cls, **kwargs):
kwargs["user_class"] = cls.get_request_user_class()
return super()._initialize_request_factory(**kwargs)

0 comments on commit 27df3ee

Please sign in to comment.