diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 4d71ec85..1216c7b2 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -36,10 +36,34 @@ class ModelViewSetClient( responses. """ - basename: str - model_class: t.Type[AnyModel] - model_serializer_class: t.Type[AnyModelSerializer] - model_view_set_class: t.Type[AnyModelViewSet] + _test_case: "ModelViewSetTestCase[AnyModelViewSet, AnyModelSerializer, AnyModel]" + + @property + def basename(self): + """Shortcut to get basename.""" + + return self._test_case.basename + + @property + def model_class(self): + """Shortcut to get model class.""" + + # pylint: disable-next=no-member + return self._test_case.get_model_class() + + @property + def model_serializer_class(self): + """Shortcut to get model serializer class.""" + + # pylint: disable-next=no-member + return self._test_case.get_model_serializer_class() + + @property + def model_view_set_class(self): + """Shortcut to get model view set class.""" + + # pylint: disable-next=no-member + return self._test_case.get_model_view_set_class() StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] ListFilters = t.Optional[t.Dict[str, str]] @@ -62,18 +86,22 @@ def assert_data_equals_model( self, data: t.Dict[str, t.Any], model: AnyModel, + contains_subset: bool = False, ): + # pylint: disable=line-too-long """Check if the data equals the current state of the model instance. Args: data: The data to check. model: The model instance. model_serializer_class: The serializer used to serialize the model's data. + contains_subset: A flag designating whether the data is a subset of the serialized model. Returns: A flag designating if the data equals the current state of the model instance. """ + # pylint: enable=line-too-long def parse_data(data): if isinstance(data, list): @@ -84,9 +112,34 @@ def parse_data(data): return data.strftime("%Y-%m-%dT%H:%M:%S.%fZ") return data - assert data == parse_data( - self.model_serializer_class(model).data - ), "Data does not equal serialized model." + actual_data = parse_data(self.model_serializer_class(model).data) + + if contains_subset: + # pylint: disable-next=no-member + self._test_case.assertDictContainsSubset( + data, + actual_data, + "Data is not a subset of serialized model.", + ) + else: + # pylint: disable-next=no-member + self._test_case.assertDictEqual( + data, + actual_data, + "Data does not equal serialized model.", + ) + + def _get_reverse_detail(self, model: AnyModel, **kwargs): + return reverse( + **kwargs, + viewname=kwargs.get("viewname", f"{self.basename}-detail"), + kwargs={ + **kwargs.get("kwargs", {}), + self.model_view_set_class.lookup_field: getattr( + model, self.model_view_set_class.lookup_field + ), + }, + ) # pylint: disable-next=too-many-arguments def generic( @@ -137,7 +190,7 @@ def retrieve( status_code_assertion: StatusCodeAssertion = None, **kwargs, ): - """Retrieve a model from the view set. + """Retrieve a model. Args: model: The model to retrieve. @@ -148,14 +201,7 @@ def retrieve( """ response: Response = self.get( - reverse( - f"{self.basename}-detail", - kwargs={ - self.model_view_set_class.lookup_field: getattr( - model, self.model_view_set_class.lookup_field - ) - }, - ), + self._get_reverse_detail(model), status_code_assertion=status_code_assertion, **kwargs, ) @@ -175,7 +221,7 @@ def list( filters: ListFilters = None, **kwargs, ): - """Retrieve a list of models from the view set. + """Retrieve a list of models. Args: models: The model list to retrieve. @@ -204,6 +250,66 @@ def list( return response + def partial_update( + self, + model: AnyModel, + data: t.Dict[str, t.Any], + status_code_assertion: StatusCodeAssertion = None, + **kwargs, + ): + """Partially update a model. + + Args: + model: The model to partially update. + status_code_assertion: The expected status code. + + Returns: + The HTTP response. + """ + + response: Response = self.patch( + self._get_reverse_detail(model), + data=data, + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if self.status_code_is_ok(response.status_code): + model.refresh_from_db() + self.assert_data_equals_model( + response.json(), # type: ignore[attr-defined] + model, + contains_subset=True, + ) + + return response + + def destroy( + self, + model: AnyModel, + status_code_assertion: StatusCodeAssertion = None, + **kwargs, + ): + """Destroy a model. + + Args: + model: The model to destroy. + status_code_assertion: The expected status code. + + Returns: + The HTTP response. + """ + + response: Response = self.delete( + self._get_reverse_detail(model), + status_code_assertion=status_code_assertion, + **kwargs, + ) + + # TODO: add standard post-destroy assertions. + + return response + def login(self, **credentials): assert super().login( **credentials @@ -282,10 +388,8 @@ class ModelViewSetTestCase( def _pre_setup(self): super()._pre_setup() - self.client.basename = self.basename - self.client.model_view_set_class = self.get_model_view_set_class() - self.client.model_serializer_class = self.get_model_serializer_class() - self.client.model_class = self.get_model_class() + # pylint: disable-next=protected-access + self.client._test_case = self @classmethod def _get_generic_args(