From 4d2e7c407037714bc2041aa8a9fb4a8c6b1d2b76 Mon Sep 17 00:00:00 2001 From: Stefan Kairinos Date: Fri, 26 Jan 2024 17:05:55 +0000 Subject: [PATCH] feat: support bulk actions (#63) --- .../autoDocstring/docstring.mustache | 35 +++ codeforlife/serializers/__init__.py | 2 +- codeforlife/serializers/base.py | 25 -- codeforlife/serializers/model.py | 124 ++++++++++ codeforlife/tests/model_view_set.py | 222 ++++++++++++++++-- codeforlife/views/__init__.py | 2 +- codeforlife/views/base.py | 39 --- codeforlife/views/model.py | 213 +++++++++++++++++ 8 files changed, 570 insertions(+), 92 deletions(-) create mode 100644 .vscode/extensions/autoDocstring/docstring.mustache delete mode 100644 codeforlife/serializers/base.py create mode 100644 codeforlife/serializers/model.py delete mode 100644 codeforlife/views/base.py create mode 100644 codeforlife/views/model.py diff --git a/.vscode/extensions/autoDocstring/docstring.mustache b/.vscode/extensions/autoDocstring/docstring.mustache new file mode 100644 index 00000000..64a5b91a --- /dev/null +++ b/.vscode/extensions/autoDocstring/docstring.mustache @@ -0,0 +1,35 @@ +{{! Based off of: https://github.com/NilsJPWerner/autoDocstring/blob/master/src/docstring/templates/google-notypes.mustache }} +{{summaryPlaceholder}} + +{{extendedSummaryPlaceholder}} +{{#parametersExist}} + +Args: +{{#args}} + {{var}}: {{descriptionPlaceholder}} +{{/args}} +{{#kwargs}} + {{var}}: {{descriptionPlaceholder}} +{{/kwargs}} +{{/parametersExist}} +{{#exceptionsExist}} + +Raises: +{{#exceptions}} + {{type}}: {{descriptionPlaceholder}} +{{/exceptions}} +{{/exceptionsExist}} +{{#returnsExist}} + +Returns: +{{#returns}} + {{descriptionPlaceholder}} +{{/returns}} +{{/returnsExist}} +{{#yieldsExist}} + +Yields: +{{#yields}} + {{descriptionPlaceholder}} +{{/yields}} +{{/yieldsExist}} \ No newline at end of file diff --git a/codeforlife/serializers/__init__.py b/codeforlife/serializers/__init__.py index 4a685c70..2c63875b 100644 --- a/codeforlife/serializers/__init__.py +++ b/codeforlife/serializers/__init__.py @@ -3,4 +3,4 @@ Created on 20/01/2024 at 11:19:12(+00:00). """ -from .base import * +from .model import * diff --git a/codeforlife/serializers/base.py b/codeforlife/serializers/base.py deleted file mode 100644 index 951692ee..00000000 --- a/codeforlife/serializers/base.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -© Ocado Group -Created on 20/01/2024 at 11:19:24(+00:00). - -Base model serializers. -""" - -import typing as t - -from django.db.models import Model -from rest_framework.serializers import ModelSerializer as _ModelSerializer - -AnyModel = t.TypeVar("AnyModel", bound=Model) - - -class ModelSerializer(_ModelSerializer[AnyModel], t.Generic[AnyModel]): - """Base model serializer for all model serializers.""" - - # pylint: disable-next=useless-parent-delegation - def update(self, instance, validated_data: t.Dict[str, t.Any]): - return super().update(instance, validated_data) - - # pylint: disable-next=useless-parent-delegation - def create(self, validated_data: t.Dict[str, t.Any]): - return super().create(validated_data) diff --git a/codeforlife/serializers/model.py b/codeforlife/serializers/model.py new file mode 100644 index 00000000..dfc64b51 --- /dev/null +++ b/codeforlife/serializers/model.py @@ -0,0 +1,124 @@ +""" +© Ocado Group +Created on 20/01/2024 at 11:19:24(+00:00). + +Base model serializers. +""" + +import typing as t + +from django.db.models import Model +from rest_framework.serializers import ListSerializer as _ListSerializer +from rest_framework.serializers import ModelSerializer as _ModelSerializer +from rest_framework.serializers import ValidationError as _ValidationError + +AnyModel = t.TypeVar("AnyModel", bound=Model) + + +class ModelSerializer(_ModelSerializer[AnyModel], t.Generic[AnyModel]): + """Base model serializer for all model serializers.""" + + # pylint: disable-next=useless-parent-delegation + def update(self, instance, validated_data: t.Dict[str, t.Any]): + return super().update(instance, validated_data) + + # pylint: disable-next=useless-parent-delegation + def create(self, validated_data: t.Dict[str, t.Any]): + return super().create(validated_data) + + +class ModelListSerializer( + t.Generic[AnyModel], + _ListSerializer[t.List[AnyModel]], +): + """Base model list serializer for all model list serializers. + + Inherit this class if you wish to custom handle bulk create and/or update. + + class UserListSerializer(ModelListSerializer[User]): + def create(self, validated_data): + ... + + def update(self, instance, validated_data): + ... + + class UserSerializer(ModelSerializer[User]): + class Meta: + model = User + list_serializer_class = UserListSerializer + """ + + batch_size: t.Optional[int] = None + + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. + + Returns: + The model view set's class. + """ + + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + def create(self, validated_data: t.List[t.Dict[str, t.Any]]): + """Bulk create many instances of a model. + + https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create + + Args: + validated_data: The data used to create the models. + + Returns: + The models. + """ + + model_class = self.get_model_class() + return model_class.objects.bulk_create( # type: ignore[attr-defined] + objs=[model_class(**data) for data in validated_data], + batch_size=self.batch_size, + ) + + def update(self, instance, validated_data: t.List[t.Dict[str, t.Any]]): + """Bulk update many instances of a model. + + https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update + + Args: + instance: The models to update. + validated_data: The field-value pairs to update for each model. + + Returns: + The models. + """ + + # Models and data must have equal length and be ordered the same! + for model, data in zip(instance, validated_data): + for field, value in data.items(): + setattr(model, field, value) + + model_class = self.get_model_class() + model_class.objects.bulk_update( # type: ignore[attr-defined] + objs=instance, + fields={field for data in validated_data for field in data.keys()}, + batch_size=self.batch_size, + ) + + return instance + + def validate(self, attrs: t.List[t.Dict[str, t.Any]]): + # If performing a bulk create. + if self.instance is None: + if len(attrs) == 0: + raise _ValidationError("Nothing to create.") + + # Else, performing a bulk update. + else: + if len(attrs) == 0: + raise _ValidationError("Nothing to update.") + if len(attrs) != len(self.instance): + raise _ValidationError("Some models do not exist.") + + return attrs diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index 5c99fc9b..1d9e0fd1 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -5,6 +5,7 @@ Base test case for all model view sets. """ +import json import typing as t from datetime import datetime from unittest.mock import patch @@ -55,10 +56,51 @@ def _model_view_set_class(self): # pylint: disable-next=no-member return self._test_case.model_view_set_class + @property + def _lookup_field(self): + """Resolves the field to lookup the model.""" + + lookup_field = self._model_view_set_class.lookup_field + return ( + self._model_class._meta.pk.attname + if lookup_field == "pk" + else lookup_field + ) + Data = t.Dict[str, t.Any] StatusCodeAssertion = t.Optional[t.Union[int, t.Callable[[int], bool]]] ListFilters = t.Optional[t.Dict[str, str]] + def _assert_response(self, response: Response, make_assertions: t.Callable): + if self.status_code_is_ok(response.status_code): + make_assertions() + + def _assert_response_json( + self, + response: Response, + make_assertions: t.Callable[[Data], None], + ): + self._assert_response( + response, + make_assertions=lambda: make_assertions( + response.json(), # type: ignore[attr-defined] + ), + ) + + def _assert_response_json_bulk( + self, + response: Response, + make_assertions: t.Callable[[t.List[Data]], None], + data: t.List[Data], + ): + def _make_assertions(): + response_json = response.json() # type: ignore[attr-defined] + assert isinstance(response_json, list) + assert len(response_json) == len(data) + make_assertions(response_json) + + self._assert_response(response, _make_assertions) + @staticmethod def status_code_is_ok(status_code: int): """Check if the status code is greater than or equal to 200 and less @@ -224,17 +266,52 @@ def create( response: Response = self.post( self.reverse("list"), - data=data, + data=json.dumps(data), + content_type="application/json", status_code_assertion=status_code_assertion, **kwargs, ) - if self.status_code_is_ok(response.status_code): - # pylint: disable-next=no-member - self._test_case.assertDictContainsSubset( - data, - response.json(), # type: ignore[attr-defined] - ) + self._assert_response_json( + response, + make_assertions=lambda actual_data: ( + # pylint: disable-next=no-member + self._test_case.assertDictContainsSubset(data, actual_data) + ), + ) + + return response + + def bulk_create( + self, + data: t.List[Data], + status_code_assertion: StatusCodeAssertion = status.HTTP_201_CREATED, + **kwargs, + ): + """Bulk create many instances of a model. + + Args: + data: The values for each field, for each model. + status_code_assertion: The expected status code. + + Returns: + The HTTP response. + """ + + response: Response = self.post( + self.reverse("bulk"), + data=json.dumps(data), + content_type="application/json", + status_code_assertion=status_code_assertion, + **kwargs, + ) + + def make_assertions(actual_data: t.List[self.Data]): + for model, actual_model in zip(data, actual_data): + # pylint: disable-next=no-member + self._test_case.assertDictContainsSubset(model, actual_model) + + self._assert_response_json_bulk(response, make_assertions, data) return response @@ -266,12 +343,14 @@ def retrieve( **kwargs, ) - if self.status_code_is_ok(response.status_code): - self.assert_data_equals_model( - response.json(), # type: ignore[attr-defined] + self._assert_response_json( + response, + make_assertions=lambda actual_data: self.assert_data_equals_model( + actual_data, model, model_serializer_class, - ) + ), + ) return response @@ -311,17 +390,16 @@ def list( **kwargs, ) - if self.status_code_is_ok(response.status_code): - for data, model in zip( - response.json()["data"], # type: ignore[attr-defined] - models, - ): + def _make_assertions(actual_data: self.Data): + for data, model in zip(actual_data["data"], models): self.assert_data_equals_model( data, model, model_serializer_class, ) + self._assert_response_json(response, _make_assertions) + return response def partial_update( @@ -350,20 +428,71 @@ def partial_update( response: Response = self.patch( self.reverse("detail", model), - data=data, + data=json.dumps(data), + content_type="application/json", status_code_assertion=status_code_assertion, **kwargs, ) - if self.status_code_is_ok(response.status_code): + def _make_assertions(actual_data: self.Data): model.refresh_from_db() self.assert_data_equals_model( - response.json(), # type: ignore[attr-defined] + actual_data, model, model_serializer_class, contains_subset=True, ) + self._assert_response_json(response, _make_assertions) + + return response + + def bulk_partial_update( + self, + models: t.List[AnyModel], + data: t.List[Data], + status_code_assertion: StatusCodeAssertion = status.HTTP_200_OK, + model_serializer_class: t.Optional[ + t.Type[ModelSerializer[AnyModel]] + ] = None, + **kwargs, + ): + # pylint: disable=line-too-long + """Bulk partially update many instances of a model. + + Args: + models: The models to partially update. + data: The values for each field, for each model. + status_code_assertion: The expected status code. + model_serializer_class: The serializer used to serialize the model's data. + + Returns: + The HTTP response. + """ + # pylint: enable=line-too-long + + response: Response = self.patch( + self.reverse("bulk"), + data=json.dumps(data), + content_type="application/json", + status_code_assertion=status_code_assertion, + **kwargs, + ) + + def make_assertions(actual_data: t.List[self.Data]): + models.sort(key=lambda model: getattr(model, self._lookup_field)) + + for data, model in zip(actual_data, models): + model.refresh_from_db() + self.assert_data_equals_model( + data, + model, + model_serializer_class, + contains_subset=True, + ) + + self._assert_response_json_bulk(response, make_assertions, data) + return response def destroy( @@ -390,12 +519,53 @@ def destroy( **kwargs, ) - if not anonymized and self.status_code_is_ok(response.status_code): - # pylint: disable-next=no-member - with self._test_case.assertRaises( - model.DoesNotExist # type: ignore[attr-defined] - ): - model.refresh_from_db() + if not anonymized: + + def _make_assertions(): + # pylint: disable-next=no-member + with self._test_case.assertRaises( + model.DoesNotExist # type: ignore[attr-defined] + ): + model.refresh_from_db() + + self._assert_response(response, _make_assertions) + + return response + + def bulk_destroy( + self, + lookup_values: t.List[t.Any], + status_code_assertion: StatusCodeAssertion = status.HTTP_204_NO_CONTENT, + anonymized: bool = False, + **kwargs, + ): + """Bulk destroy many instances of a model. + + Args: + lookup_values: The models to lookup and destroy. + status_code_assertion: The expected status code. + anonymized: Whether or not the data is anonymized. + + Returns: + The HTTP response. + """ + + response: Response = self.delete( + self.reverse("bulk"), + data=json.dumps(lookup_values), + content_type="application/json", + status_code_assertion=status_code_assertion, + **kwargs, + ) + + if not anonymized: + + def _make_assertions(): + assert not self._model_class.objects.filter( + **{f"{self._lookup_field}__in": lookup_values} + ).exists() + + self._assert_response(response, _make_assertions) return response @@ -483,7 +653,7 @@ def _pre_setup(self): self.client._test_case = self @classmethod - def get_model_class(cls): + def get_model_class(cls) -> t.Type[AnyModel]: """Get the model view set's class. Returns: diff --git a/codeforlife/views/__init__.py b/codeforlife/views/__init__.py index 5b2d49eb..558094e6 100644 --- a/codeforlife/views/__init__.py +++ b/codeforlife/views/__init__.py @@ -3,4 +3,4 @@ Created on 24/01/2024 at 13:07:38(+00:00). """ -from .base import ModelViewSet +from .model import ModelViewSet diff --git a/codeforlife/views/base.py b/codeforlife/views/base.py deleted file mode 100644 index cac85d4a..00000000 --- a/codeforlife/views/base.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -© Ocado Group -Created on 24/01/2024 at 13:08:23(+00:00). -""" - -import typing as t - -from django.db.models import Model -from rest_framework.viewsets import ModelViewSet as DrfModelViewSet - -from ..serializers import ModelSerializer - -AnyModel = t.TypeVar("AnyModel", bound=Model) - - -# pylint: disable-next=too-few-public-methods -class _ModelViewSet(t.Generic[AnyModel]): - pass - - -if t.TYPE_CHECKING: - # pylint: disable-next=too-few-public-methods - class ModelViewSet( - DrfModelViewSet[AnyModel], - _ModelViewSet[AnyModel], - t.Generic[AnyModel], - ): - """Base model view set for all model view sets.""" - - serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] - -else: - # pylint: disable-next=missing-class-docstring,too-many-ancestors - class ModelViewSet( - DrfModelViewSet, - _ModelViewSet[AnyModel], - t.Generic[AnyModel], - ): - pass diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py new file mode 100644 index 00000000..80560a0a --- /dev/null +++ b/codeforlife/views/model.py @@ -0,0 +1,213 @@ +""" +© Ocado Group +Created on 24/01/2024 at 13:08:23(+00:00). +""" + +import typing as t + +from django.db.models import Model +from django.db.models.query import QuerySet +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import ListSerializer +from rest_framework.viewsets import ModelViewSet as DrfModelViewSet + +from ..serializers import ModelListSerializer, ModelSerializer + +AnyModel = t.TypeVar("AnyModel", bound=Model) + +if t.TYPE_CHECKING: + # NOTE: This raises an error during runtime. + # pylint: disable-next=too-few-public-methods + class _ModelViewSet(DrfModelViewSet[AnyModel], t.Generic[AnyModel]): + pass + +else: + # pylint: disable-next=too-many-ancestors + class _ModelViewSet(DrfModelViewSet, t.Generic[AnyModel]): + pass + + +# pylint: disable-next=too-many-ancestors +class ModelViewSet(_ModelViewSet[AnyModel], t.Generic[AnyModel]): + """Base model view set for all model view sets.""" + + serializer_class: t.Optional[t.Type[ModelSerializer[AnyModel]]] + + @classmethod + def get_model_class(cls) -> t.Type[AnyModel]: + """Get the model view set's class. + + Returns: + The model view set's class. + """ + + # pylint: disable-next=no-member + return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined] + 0 + ] + + def get_serializer(self, *args, **kwargs): + serializer = super().get_serializer(*args, **kwargs) + + if self.action == "bulk": + list_serializer = t.cast(ListSerializer, serializer) + + meta = getattr(list_serializer.child, "Meta", None) + if meta is None: + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class Meta: + pass + + meta = Meta + setattr(list_serializer.child, "Meta", meta) + + if getattr(meta, "list_serializer_class", None) is None: + model_class = self.get_model_class() + + # pylint: disable-next=too-few-public-methods + class _ModelListSerializer( + ModelListSerializer[model_class] # type: ignore[valid-type] + ): + pass + + # Set list_serializer_class to default if not set. + setattr(meta, "list_serializer_class", _ModelListSerializer) + + # Get default list_serializer_class. + serializer = super().get_serializer(*args, **kwargs) + + return serializer + + def bulk_create(self, request: Request): + """Bulk create many instances of a model. + + This is an extension of the default create action: + https://www.django-rest-framework.org/api-guide/generic-views/#createmodelmixin + + Args: + request: A HTTP request containing a list of models to create. + + Returns: + A HTTP response containing a list of created models. + """ + + serializer = t.cast( + ModelListSerializer[AnyModel], + self.get_serializer(data=request.data, many=True), + ) + serializer.is_valid(raise_exception=True) + self.perform_bulk_create(serializer) + return Response( + serializer.data, + status=status.HTTP_201_CREATED, + headers=self.get_success_headers(serializer.data), + ) + + def perform_bulk_create(self, serializer: ModelListSerializer[AnyModel]): + """Bulk create many instances of a model. + + Args: + serializer: A model serializer for the specific model. + """ + + serializer.save() + + def bulk_partial_update(self, request: Request): + # pylint: disable=line-too-long + """Partially bulk update many instances of a model. + + This is an extension of the default partial-update action: + https://www.django-rest-framework.org/api-guide/generic-views/#updatemodelmixin + + Args: + request: A HTTP request containing a list of models to partially update. + + Returns: + A HTTP response containing a list of partially updated models. + """ + # pylint: enable=line-too-long + + model_class = self.get_model_class() + lookup_field = ( + model_class._meta.pk.attname # type: ignore[union-attr] + if self.lookup_field == "pk" + else self.lookup_field + ) + + data = t.cast(t.List[t.Dict[str, t.Any]], request.data) + data.sort(key=lambda model: model[lookup_field]) + + queryset = model_class.objects.filter( # type: ignore[attr-defined] + **{f"{lookup_field}__in": [model[lookup_field] for model in data]} + ).order_by(lookup_field) + + serializer = t.cast( + ModelListSerializer[AnyModel], + self.get_serializer( + list(queryset), + data=data, + many=True, + partial=True, + ), + ) + serializer.is_valid(raise_exception=True) + self.perform_bulk_update(serializer) + return Response(serializer.data) + + def perform_bulk_update(self, serializer: ModelListSerializer[AnyModel]): + """Partially bulk update many instances of a model. + + Args: + serializer: A model serializer for the specific model. + """ + + serializer.save() + + def bulk_destroy(self, request: Request): + """Bulk destroy many instances of a model. + + This is an extension of the default destroy action: + https://www.django-rest-framework.org/api-guide/generic-views/#destroymodelmixin + + Args: + request: A HTTP request containing a list of models to destroy. + + Returns: + A HTTP response containing a list of destroyed models. + """ + + model_class = self.get_model_class() + queryset = model_class.objects.filter( # type: ignore[attr-defined] + **{f"{self.lookup_field}__in": request.data} + ) + self.perform_bulk_destroy(queryset) + return Response(status=status.HTTP_204_NO_CONTENT) + + def perform_bulk_destroy(self, queryset: QuerySet[AnyModel]): + """Bulk destroy many instances of a model. + + Args: + queryset: A queryset of the models to delete. + """ + + queryset.delete() + + @action(detail=False, methods=["post", "patch", "delete"]) + def bulk(self, request: Request): + """Entry point for all bulk actions. + + Args: + request: A HTTP request. + + Returns: + A HTTP response. + """ + + return { + "POST": self.bulk_create, + "PATCH": self.bulk_partial_update, + "DELETE": self.bulk_destroy, + }[t.cast(str, request.method)](request)