Skip to content

Commit

Permalink
fix: add assertions (#67)
Browse files Browse the repository at this point in the history
* fix: add assertions

* fix: always dict
  • Loading branch information
SKairinos authored Jan 31, 2024
1 parent b5479b3 commit ae3abdf
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
105 changes: 105 additions & 0 deletions codeforlife/tests/model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from django.db.models import Model
from django.test import TestCase
from rest_framework.serializers import ValidationError
from rest_framework.test import APIRequestFactory

from ..serializers import ModelSerializer
from ..types import KwArgs
from ..user.models import User

AnyModel = t.TypeVar("AnyModel", bound=Model)

Expand All @@ -21,6 +24,15 @@ class ModelSerializerTestCase(TestCase, t.Generic[AnyModel]):

model_serializer_class: t.Type[ModelSerializer[AnyModel]]

request_factory = APIRequestFactory()

@classmethod
def setUpClass(cls):
attr_name = "model_serializer_class"
assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.'

return super().setUpClass()

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Expand Down Expand Up @@ -61,3 +73,96 @@ def __exit__(self, *args, **kwargs):
return value

return ContextWrapper(context)

# pylint: disable-next=too-many-arguments
def _assert_validate(
self,
value,
error_code: str,
user: t.Optional[User],
request_kwargs: t.Optional[KwArgs],
get_validate: t.Callable[
[ModelSerializer[AnyModel]], t.Callable[[t.Any], t.Any]
],
**kwargs,
):
kwargs.setdefault("context", {})
context: t.Dict[str, t.Any] = kwargs["context"]

if "request" not in context:
request_kwargs = request_kwargs or {}
request_kwargs.setdefault("method", "POST")
request_kwargs.setdefault("path", "/")
request_kwargs.setdefault("data", "")
request_kwargs.setdefault("content_type", "application/json")

request = self.request_factory.generic(**request_kwargs)
if user is not None:
request.user = user

context["request"] = request

serializer = self.model_serializer_class(**kwargs)

with self.assert_raises_validation_error(error_code):
get_validate(serializer)(value)

def assert_validate(
self,
attrs: t.Dict[str, t.Any],
error_code: str,
user: t.Optional[User] = None,
request_kwargs: t.Optional[KwArgs] = None,
**kwargs,
):
"""Asserts that calling validate() raises the expected error code.
Args:
attrs: The attributes to pass to validate().
error_code: The expected error code to be raised.
user: The requesting user.
request_kwargs: The kwargs used to initialize the request.
"""

self._assert_validate(
attrs,
error_code,
user,
request_kwargs,
get_validate=lambda serializer: serializer.validate,
**kwargs,
)

# pylint: disable-next=too-many-arguments
def assert_validate_field(
self,
name: str,
value,
error_code: str,
user: t.Optional[User] = None,
request_kwargs: t.Optional[KwArgs] = None,
**kwargs,
):
"""Asserts that calling validate_field() raises the expected error code.
Args:
name: The name of the field.
value: The value to pass to validate_field().
error_code: The expected error code to be raised.
user: The requesting user.
request_kwargs: The kwargs used to initialize the request.
"""

def get_validate(serializer: ModelSerializer[AnyModel]):
validate_field = getattr(serializer, f"validate_{name}")
assert callable(validate_field)
return validate_field

self._assert_validate(
value,
error_code,
user,
request_kwargs,
get_validate,
**kwargs,
)
7 changes: 7 additions & 0 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,13 @@ def get_model_class(cls) -> t.Type[AnyModel]:
0
]

@classmethod
def setUpClass(cls):
attr_name = "model_view_set_class"
assert hasattr(cls, attr_name), f'Attribute "{attr_name}" must be set.'

return super().setUpClass()

def get_other_user(
self,
user: User,
Expand Down
15 changes: 15 additions & 0 deletions codeforlife/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
© Ocado Group
Created on 15/01/2024 at 15:32:54(+00:00).
Reusable type hints.
"""

import typing as t

Args = t.Tuple[t.Any, ...]
KwArgs = t.Dict[str, t.Any]

JsonList = t.List["JsonValue"]
JsonDict = t.Dict[str, "JsonValue"]
JsonValue = t.Union[int, str, bool, JsonList, JsonDict]

0 comments on commit ae3abdf

Please sign in to comment.