Skip to content

Commit

Permalink
feat: support bulk actions
Browse files Browse the repository at this point in the history
  • Loading branch information
SKairinos committed Jan 26, 2024
1 parent db3cfe8 commit f64ba03
Show file tree
Hide file tree
Showing 8 changed files with 570 additions and 92 deletions.
35 changes: 35 additions & 0 deletions .vscode/extensions/autoDocstring/docstring.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{{! Based off of: https://github.com/NilsJPWerner/autoDocstring/blob/master/src/docstring/templates/google-notypes.mustache }}
{{summaryPlaceholder}}

{{extendedSummaryPlaceholder}}
{{#parametersExist}}

Args:
{{#args}}
{{var}}: {{descriptionPlaceholder}}
{{/args}}
{{#kwargs}}
{{var}}: {{descriptionPlaceholder}}
{{/kwargs}}
{{/parametersExist}}
{{#exceptionsExist}}

Raises:
{{#exceptions}}
{{type}}: {{descriptionPlaceholder}}
{{/exceptions}}
{{/exceptionsExist}}
{{#returnsExist}}

Returns:
{{#returns}}
{{descriptionPlaceholder}}
{{/returns}}
{{/returnsExist}}
{{#yieldsExist}}

Yields:
{{#yields}}
{{descriptionPlaceholder}}
{{/yields}}
{{/yieldsExist}}
2 changes: 1 addition & 1 deletion codeforlife/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
Created on 20/01/2024 at 11:19:12(+00:00).
"""

from .base import *
from .model import *
25 changes: 0 additions & 25 deletions codeforlife/serializers/base.py

This file was deleted.

124 changes: 124 additions & 0 deletions codeforlife/serializers/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
© Ocado Group
Created on 20/01/2024 at 11:19:24(+00:00).
Base model serializers.
"""

import typing as t

from django.db.models import Model
from rest_framework.serializers import ListSerializer as _ListSerializer
from rest_framework.serializers import ModelSerializer as _ModelSerializer
from rest_framework.serializers import ValidationError as _ValidationError

AnyModel = t.TypeVar("AnyModel", bound=Model)


class ModelSerializer(_ModelSerializer[AnyModel], t.Generic[AnyModel]):
"""Base model serializer for all model serializers."""

# pylint: disable-next=useless-parent-delegation
def update(self, instance, validated_data: t.Dict[str, t.Any]):
return super().update(instance, validated_data)

# pylint: disable-next=useless-parent-delegation
def create(self, validated_data: t.Dict[str, t.Any]):
return super().create(validated_data)


class ModelListSerializer(
t.Generic[AnyModel],
_ListSerializer[t.List[AnyModel]],
):
"""Base model list serializer for all model list serializers.
Inherit this class if you wish to custom handle bulk create and/or update.
class UserListSerializer(ModelListSerializer[User]):
def create(self, validated_data):
...
def update(self, instance, validated_data):
...
class UserSerializer(ModelSerializer[User]):
class Meta:
model = User
list_serializer_class = UserListSerializer
"""

batch_size: t.Optional[int] = None

@classmethod
def get_model_class(cls) -> t.Type[AnyModel]:
"""Get the model view set's class.
Returns:
The model view set's class.
"""

# pylint: disable-next=no-member
return t.get_args(cls.__orig_bases__[0])[ # type: ignore[attr-defined]
0
]

def create(self, validated_data: t.List[t.Dict[str, t.Any]]):
"""Bulk create many instances of a model.
https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-create
Args:
validated_data: The data used to create the models.
Returns:
The models.
"""

model_class = self.get_model_class()
return model_class.objects.bulk_create( # type: ignore[attr-defined]
objs=[model_class(**data) for data in validated_data],
batch_size=self.batch_size,
)

def update(self, instance, validated_data: t.List[t.Dict[str, t.Any]]):
"""Bulk update many instances of a model.
https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update
Args:
instance: The models to update.
validated_data: The field-value pairs to update for each model.
Returns:
The models.
"""

# Models and data must have equal length and be ordered the same!
for model, data in zip(instance, validated_data):
for field, value in data.items():
setattr(model, field, value)

model_class = self.get_model_class()
model_class.objects.bulk_update( # type: ignore[attr-defined]
objs=instance,
fields={field for data in validated_data for field in data.keys()},
batch_size=self.batch_size,
)

return instance

def validate(self, attrs: t.List[t.Dict[str, t.Any]]):
# If performing a bulk create.
if self.instance is None:
if len(attrs) == 0:
raise _ValidationError("Nothing to create.")

# Else, performing a bulk update.
else:
if len(attrs) == 0:
raise _ValidationError("Nothing to update.")
if len(attrs) != len(self.instance):
raise _ValidationError("Some models do not exist.")

return attrs
Loading

0 comments on commit f64ba03

Please sign in to comment.