Skip to content

Commit

Permalink
add partial_update and destroy
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 20, 2024
1 parent 96cd7f0 commit af5603a
Showing 1 changed file with 125 additions and 21 deletions.
146 changes: 125 additions & 21 deletions codeforlife/tests/model_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,34 @@ class ModelViewSetClient(
responses.
"""

basename: str
model_class: t.Type[AnyModel]
model_serializer_class: t.Type[AnyModelSerializer]
model_view_set_class: t.Type[AnyModelViewSet]
_test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]"

@property
def basename(self):
"""Shortcut to get basename."""

return self._test_case.basename

@property
def model_class(self):
"""Shortcut to get model class."""

# pylint: disable-next=no-member
return self._test_case.get_model_class()

@property
def model_serializer_class(self):
"""Shortcut to get model serializer class."""

# pylint: disable-next=no-member
return self._test_case.get_model_serializer_class()

@property
def model_view_set_class(self):
"""Shortcut to get model view set class."""

# pylint: disable-next=no-member
return self._test_case.get_model_view_set_class()

StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]]
ListFilters = t.Optional[t.Dict[str, str]]
Expand All @@ -62,18 +86,22 @@ def assert_data_equals_model(
self,
data: t.Dict[str, t.Any],
model: AnyModel,
contains_subset: bool = False,
):
# pylint: disable=line-too-long
"""Check if the data equals the current state of the model instance.
Args:
data: The data to check.
model: The model instance.
model_serializer_class: The serializer used to serialize the model's data.
contains_subset: A flag designating whether the data is a subset of the serialized model.
Returns:
A flag designating if the data equals the current state of the model
instance.
"""
# pylint: enable=line-too-long

def parse_data(data):
if isinstance(data, list):
Expand All @@ -84,9 +112,34 @@ def parse_data(data):
return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return data

assert data == parse_data(
self.model_serializer_class(model).data
), "Data does not equal serialized model."
actual_data = parse_data(self.model_serializer_class(model).data)

if contains_subset:
# pylint: disable-next=no-member
self._test_case.assertDictContainsSubset(
data,
actual_data,
"Data is not a subset of serialized model.",
)
else:
# pylint: disable-next=no-member
self._test_case.assertDictEqual(
data,
actual_data,
"Data does not equal serialized model.",
)

def _get_reverse_detail(self, model: AnyModel, **kwargs):
return reverse(
**kwargs,
viewname=kwargs.get("viewname", f"{self.basename}-detail"),
kwargs={
**kwargs.get("kwargs", {}),
self.model_view_set_class.lookup_field: getattr(
model, self.model_view_set_class.lookup_field
),
},
)

# pylint: disable-next=too-many-arguments
def generic(
Expand Down Expand Up @@ -137,7 +190,7 @@ def retrieve(
status_code_assertion: StatusCodeAssertion = None,
**kwargs,
):
"""Retrieve a model from the view set.
"""Retrieve a model.
Args:
model: The model to retrieve.
Expand All @@ -148,14 +201,7 @@ def retrieve(
"""

response: Response = self.get(
reverse(
f"{self.basename}-detail",
kwargs={
self.model_view_set_class.lookup_field: getattr(
model, self.model_view_set_class.lookup_field
)
},
),
self._get_reverse_detail(model),
status_code_assertion=status_code_assertion,
**kwargs,
)
Expand All @@ -175,7 +221,7 @@ def list(
filters: ListFilters = None,
**kwargs,
):
"""Retrieve a list of models from the view set.
"""Retrieve a list of models.
Args:
models: The model list to retrieve.
Expand Down Expand Up @@ -204,6 +250,66 @@ def list(

return response

def partial_update(
self,
model: AnyModel,
data: t.Dict[str, t.Any],
status_code_assertion: StatusCodeAssertion = None,
**kwargs,
):
"""Partially update a model.
Args:
model: The model to partially update.
status_code_assertion: The expected status code.
Returns:
The HTTP response.
"""

response: Response = self.patch(
self._get_reverse_detail(model),
data=data,
status_code_assertion=status_code_assertion,
**kwargs,
)

if self.status_code_is_ok(response.status_code):
model.refresh_from_db()
self.assert_data_equals_model(
response.json(), # type: ignore[attr-defined]
model,
contains_subset=True,
)

return response

def destroy(
self,
model: AnyModel,
status_code_assertion: StatusCodeAssertion = None,
**kwargs,
):
"""Destroy a model.
Args:
model: The model to destroy.
status_code_assertion: The expected status code.
Returns:
The HTTP response.
"""

response: Response = self.delete(
self._get_reverse_detail(model),
status_code_assertion=status_code_assertion,
**kwargs,
)

# TODO: add standard post-destroy assertions.

return response

def login(self, **credentials):
assert super().login(
**credentials
Expand Down Expand Up @@ -282,10 +388,8 @@ class ModelViewSetTestCase(

def _pre_setup(self):
super()._pre_setup()
self.client.basename = self.basename
self.client.model_view_set_class = self.get_model_view_set_class()
self.client.model_serializer_class = self.get_model_serializer_class()
self.client.model_class = self.get_model_class()
# pylint: disable-next=protected-access
self.client._test_case = self

@classmethod
def _get_generic_args(
Expand Down

0 comments on commit af5603a

Please sign in to comment.