From 27df3eeebbc3e486d49276906c2efa4c7702fc3b Mon Sep 17 00:00:00 2001 From: SKairinos Date: Wed, 6 Nov 2024 13:12:52 +0000 Subject: [PATCH] abstract model serializer test case --- codeforlife/tests/__init__.py | 6 +- codeforlife/tests/model_list_serializer.py | 88 ++++++++++++++++ codeforlife/tests/model_serializer.py | 117 +++++++++++---------- 3 files changed, 157 insertions(+), 54 deletions(-) create mode 100644 codeforlife/tests/model_list_serializer.py diff --git a/codeforlife/tests/__init__.py b/codeforlife/tests/__init__.py index 9943912..a2135dd 100644 --- a/codeforlife/tests/__init__.py +++ b/codeforlife/tests/__init__.py @@ -9,8 +9,12 @@ from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .cron import CronTestCase from .model import ModelTestCase -from .model_serializer import ( +from .model_list_serializer import ( + BaseModelListSerializerTestCase, ModelListSerializerTestCase, +) +from .model_serializer import ( + BaseModelSerializerTestCase, ModelSerializerTestCase, ) from .model_view_set import ModelViewSetClient, ModelViewSetTestCase diff --git a/codeforlife/tests/model_list_serializer.py b/codeforlife/tests/model_list_serializer.py new file mode 100644 index 0000000..f25da03 --- /dev/null +++ b/codeforlife/tests/model_list_serializer.py @@ -0,0 +1,88 @@ +""" +© Ocado Group +Created on 06/11/2024 at 12:45:33(+00:00). + +Base test case for all model list serializers. +""" + +import typing as t + +from django.db.models import Model + +from ..serializers import ( + BaseModelListSerializer, + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, +) +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory +from .model_serializer import ( + BaseModelSerializerTestCase, + ModelSerializerTestCase, +) + +# pylint: disable=duplicate-code +if t.TYPE_CHECKING: + from ..user.models import User + + RequestUser = t.TypeVar("RequestUser", bound=User) +else: + RequestUser = t.TypeVar("RequestUser") + +AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer +) +AnyBaseModelListSerializer = t.TypeVar( + "AnyBaseModelListSerializer", bound=BaseModelListSerializer +) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseModelListSerializerTestCase( + BaseModelSerializerTestCase[ + AnyBaseModelSerializer, + AnyBaseAPIRequestFactory, + AnyModel, + ], + t.Generic[ + AnyBaseModelListSerializer, + AnyBaseModelSerializer, + AnyBaseAPIRequestFactory, + AnyModel, + ], +): + """Base for all model serializer test cases.""" + + model_list_serializer_class: t.Type[AnyBaseModelListSerializer] + + REQUIRED_ATTRS = { + "model_list_serializer_class", + "model_serializer_class", + "request_factory_class", + } + + def _init_model_serializer(self, *args, parent=None, **kwargs): + kwargs.setdefault("child", self.model_serializer_class()) + serializer = self.model_list_serializer_class(*args, **kwargs) + if parent: + serializer.parent = parent + + return serializer + + +# pylint: disable-next=too-many-ancestors +class ModelListSerializerTestCase( + BaseModelListSerializerTestCase[ + ModelListSerializer[RequestUser, AnyModel], + ModelSerializer[RequestUser, AnyModel], + APIRequestFactory[RequestUser], + AnyModel, + ], + ModelSerializerTestCase[RequestUser, AnyModel], + t.Generic[RequestUser, AnyModel], +): + """Base for all model serializer test cases.""" diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index dcda7bb..8f01541 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -13,12 +13,16 @@ from django.forms.models import model_to_dict from rest_framework.serializers import BaseSerializer, ValidationError -from ..serializers import ModelListSerializer, ModelSerializer +from ..serializers import ( + BaseModelSerializer, + ModelListSerializer, + ModelSerializer, +) from ..types import DataDict -from .api_request_factory import APIRequestFactory +from .api_request_factory import APIRequestFactory, BaseAPIRequestFactory from .test import TestCase -# pylint: disable-next=duplicate-code +# pylint: disable=duplicate-code if t.TYPE_CHECKING: from ..user.models import User @@ -26,50 +30,45 @@ else: RequestUser = t.TypeVar("RequestUser") - AnyModel = t.TypeVar("AnyModel", bound=Model) +AnyBaseModelSerializer = t.TypeVar( + "AnyBaseModelSerializer", bound=BaseModelSerializer +) +AnyBaseAPIRequestFactory = t.TypeVar( + "AnyBaseAPIRequestFactory", bound=BaseAPIRequestFactory +) +# pylint: enable=duplicate-code + + +class BaseModelSerializerTestCase( + TestCase, + t.Generic[AnyBaseModelSerializer, AnyBaseAPIRequestFactory, AnyModel], +): + """Base for all model serializer test cases.""" + model_serializer_class: t.Type[AnyBaseModelSerializer] -class ModelSerializerTestCase(TestCase, t.Generic[RequestUser, AnyModel]): - """Base for all model serializer test cases.""" + request_factory: AnyBaseAPIRequestFactory + request_factory_class: t.Type[AnyBaseAPIRequestFactory] - model_serializer_class: t.Type[ModelSerializer[RequestUser, AnyModel]] + REQUIRED_ATTRS: t.Set[str] = { + "model_serializer_class", + "request_factory_class", + } - request_factory: APIRequestFactory[RequestUser] + @classmethod + def _initialize_request_factory(cls, **kwargs): + return cls.request_factory_class(**kwargs) @classmethod def setUpClass(cls): - attr_name = "model_serializer_class" - assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + for attr in cls.REQUIRED_ATTRS: + assert hasattr(cls, attr), f'Attribute "{attr}" must be set.' - cls.request_factory = APIRequestFactory(cls.get_request_user_class()) + cls.request_factory = cls._initialize_request_factory() return super().setUpClass() - @classmethod - def get_request_user_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 0 - ] - - @classmethod - def get_model_class(cls) -> t.Type[AnyModel]: - """Get the model view set's class. - - Returns: - The model view set's class. - """ - # pylint: disable-next=no-member - return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] - 1 - ] - # -------------------------------------------------------------------------- # Private helpers. # -------------------------------------------------------------------------- @@ -379,31 +378,43 @@ def assert_new_data_is_subset_of_data(new_data: DataDict, data): self._assert_data_is_subset_of_model(data, instance) -class ModelListSerializerTestCase( - ModelSerializerTestCase[RequestUser, AnyModel], +class ModelSerializerTestCase( + BaseModelSerializerTestCase[ + ModelSerializer[RequestUser, AnyModel], + APIRequestFactory[RequestUser], + AnyModel, + ], t.Generic[RequestUser, AnyModel], ): """Base for all model serializer test cases.""" - model_list_serializer_class: t.Type[ - ModelListSerializer[RequestUser, AnyModel] - ] + request_factory_class = APIRequestFactory @classmethod - def setUpClass(cls): - attr_name = "model_list_serializer_class" - assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + def get_request_user_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. - return super().setUpClass() + Returns: + The model view set's class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] - # -------------------------------------------------------------------------- - # Private helpers. - # -------------------------------------------------------------------------- + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. - def _init_model_serializer(self, *args, parent=None, **kwargs): - kwargs.setdefault("child", self.model_serializer_class()) - serializer = self.model_list_serializer_class(*args, **kwargs) - if parent: - serializer.parent = parent + Returns: + The model view set's class. + """ + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 1 + ] - return serializer + @classmethod + def _initialize_request_factory(cls, **kwargs): + kwargs["user_class"] = cls.get_request_user_class() + return super()._initialize_request_factory(**kwargs)