diff --git a/codeforlife/tests/model_serializer.py b/codeforlife/tests/model_serializer.py index 1af5f5ef..084866bc 100644 --- a/codeforlife/tests/model_serializer.py +++ b/codeforlife/tests/model_serializer.py @@ -10,8 +10,11 @@ from django.db.models import Model from django.test import TestCase from rest_framework.serializers import ValidationError +from rest_framework.test import APIRequestFactory from ..serializers import ModelSerializer +from ..types import KwArgs +from ..user.models import User AnyModel = t.TypeVar("AnyModel", bound=Model) @@ -21,6 +24,15 @@ class ModelSerializerTestCase(TestCase, t.Generic[AnyModel]): model_serializer_class: t.Type[ModelSerializer[AnyModel]] + request_factory = APIRequestFactory() + + @classmethod + def setUpClass(cls): + attr_name = "model_serializer_class" + assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + + return super().setUpClass() + @classmethod def get_model_class(cls) -> t.Type[AnyModel]: """Get the model view set's class. @@ -61,3 +73,97 @@ def __exit__(self, *args, **kwargs): return value return ContextWrapper(context) + + # pylint: disable-next=too-many-arguments + def _assert_validate( + self, + value, + error_code: str, + user: t.Optional[User], + request_kwargs: t.Optional[KwArgs], + get_validate: t.Callable[ + [ModelSerializer[AnyModel]], t.Callable[[t.Any], t.Any] + ], + **kwargs, + ): + kwargs = kwargs or {} + kwargs.setdefault("context", {}) + context: t.Dict[str, t.Any] = kwargs["context"] + + if "request" not in context: + request_kwargs = request_kwargs or {} + request_kwargs.setdefault("method", "POST") + request_kwargs.setdefault("path", "/") + request_kwargs.setdefault("data", "") + request_kwargs.setdefault("content_type", "application/json") + + request = self.request_factory.generic(**request_kwargs) + if user is not None: + request.user = user + + context["request"] = request + + serializer = self.model_serializer_class(**kwargs) + + with self.assert_raises_validation_error(error_code): + get_validate(serializer)(value) + + def assert_validate( + self, + attrs: t.Dict[str, t.Any], + error_code: str, + user: t.Optional[User] = None, + request_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + """Asserts that calling validate() raises the expected error code. + + Args: + attrs: The attributes to pass to validate(). + error_code: The expected error code to be raised. + user: The requesting user. + request_kwargs: The kwargs used to initialize the request. + """ + + self._assert_validate( + attrs, + error_code, + user, + request_kwargs, + get_validate=lambda serializer: serializer.validate, + **kwargs, + ) + + # pylint: disable-next=too-many-arguments + def assert_validate_field( + self, + name: str, + value, + error_code: str, + user: t.Optional[User] = None, + request_kwargs: t.Optional[KwArgs] = None, + **kwargs, + ): + """Asserts that calling validate_field() raises the expected error code. + + Args: + name: The name of the field. + value: The value to pass to validate_field(). + error_code: The expected error code to be raised. + user: The requesting user. + request_kwargs: The kwargs used to initialize the request. + """ + + def get_validate(serializer: ModelSerializer[AnyModel]): + validate_field = getattr(serializer, f"validate_{name}") + assert callable(validate_field) + return validate_field + + self._assert_validate( + value, + error_code, + user, + request_kwargs, + get_validate, + **kwargs, + ) diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index a53800f6..eee9dffa 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -682,6 +682,13 @@ def get_model_class(cls) -> t.Type[AnyModel]: 0 ] + @classmethod + def setUpClass(cls): + attr_name = "model_view_set_class" + assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.' + + return super().setUpClass() + def get_other_user( self, user: User, diff --git a/codeforlife/types.py b/codeforlife/types.py new file mode 100644 index 00000000..a51526d8 --- /dev/null +++ b/codeforlife/types.py @@ -0,0 +1,15 @@ +""" +© Ocado Group +Created on 15/01/2024 at 15:32:54(+00:00). + +Reusable type hints. +""" + +import typing as t + +Args = t.Tuple[t.Any, ...] +KwArgs = t.Dict[str, t.Any] + +JsonList = t.List["JsonValue"] +JsonDict = t.Dict[str, "JsonValue"] +JsonValue = t.Union[int, str, bool, JsonList, JsonDict]