Skip to content

Commit

Permalink
Add bulk annotations update
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Nov 21, 2023
1 parent c195ce6 commit 99a4f85
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Fixed
Changed
-------
- Bulk annotations on entity endpoint now accept field path instead of id
- Suport bulk create/update/delete on AnnotationValues endpoint


===================
Expand Down
3 changes: 3 additions & 0 deletions resolwe/flow/models/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ def validate(self):
def save(self, *args, **kwargs):
"""Save the annotation value after validation has passed."""
annotation_value_validator.validate(self, raise_exception=True)
# Make sure the label is always set.
if "label" not in self._value:
self.recompute_label()
super().save(*args, **kwargs)

def recompute_label(self):
Expand Down
59 changes: 57 additions & 2 deletions resolwe/flow/serializers/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Resolwe annotations serializer."""
from itertools import groupby
from typing import Any

from django.db import models
from django.db.models import Q

from rest_framework import serializers
from rest_framework.fields import empty
Expand Down Expand Up @@ -94,13 +98,63 @@ class Meta:
fields = read_only_fields + ("name", "fields", "contributor")


class AnnotationValueListSerializer(serializers.ListSerializer):
"""Perform bulk update of annotation values to speed up requests."""

def create(self, validated_data: Any) -> Any:
"""Perform efficient bulk create."""
return AnnotationValue.objects.bulk_create(
AnnotationValue(**data)
for data in validated_data
if data["_value"]["value"] is not None
)

def update(self, instance, validated_data: Any):
"""Perform efficient bulk create/update/delete."""
# Read existing annotations in a single query.
query = Q()
for data in validated_data:
query |= Q(field=data["field"], entity=data["entity"])
existing_annotations = AnnotationValue.objects.filter(query)

# Create a mapping between (field, entity) and the existing annotations.
annotation_map = {
(annotation.field, annotation.entity): annotation
for annotation in existing_annotations
}

self.instance = []
to_update = []
to_create = []
to_delete = []
for data in validated_data:
if value := annotation_map.get((data["field"], data["entity"])):
if data["_value"]["value"] is None:
to_delete.append(value.pk)
else:
to_update.append((value, data))
else:
to_create.append(data)

# Bulk create new annotations.
self.instance += self.create(to_create)
# Bulk delete annotations.
AnnotationValue.objects.filter(pk__in=to_delete).delete()
# Update annotations.
for value, data in to_update:
self.instance.append(self.child.update(value, data))
return self.instance


class AnnotationValueSerializer(ResolweBaseSerializer):
"""Serializer for AnnotationValue objects."""

def __init__(self, instance=None, data=empty, **kwargs):
"""Rewrite value -> _value."""
if data is not empty and "value" in data:
data["_value"] = {"value": data.pop("value", None)}
if data is not empty:
for entry in data if isinstance(data, list) else [data]:
if "value" in entry:
entry["_value"] = {"value": entry.pop("value")}
super().__init__(instance, data, **kwargs)

class Meta:
Expand All @@ -111,3 +165,4 @@ class Meta:
update_protected_fields = ("id", "entity", "field")
fields = read_only_fields + update_protected_fields + ("value", "_value")
extra_kwargs = {"_value": {"write_only": True}}
list_serializer_class = AnnotationValueListSerializer
157 changes: 150 additions & 7 deletions resolwe/flow/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Sequence

from django.core.exceptions import ValidationError
from django.db.models import F
from django.urls import reverse

from rest_framework import status
Expand Down Expand Up @@ -385,8 +386,7 @@ def test_create_annotation_value(self):

# Unauthenticated request.
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertAlmostEqual(values_count, AnnotationValue.objects.count())

# Authenticated request.
Expand All @@ -399,13 +399,35 @@ def test_create_annotation_value(self):
self.assertEqual(created_value.value, -1)
self.assertAlmostEqual(values_count + 1, AnnotationValue.objects.count())

# Bulk create, no permission on entity 2.
AnnotationValue.objects.all().delete()
values = [
{"entity": self.entity1.pk, "field": field.pk, "value": -1},
{"entity": self.entity2.pk, "field": field.pk, "value": -2},
]
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertAlmostEqual(0, AnnotationValue.objects.count())

# Bulk create, edit permission on both entities.
self.collection2.set_permission(Permission.EDIT, self.contributor)
self.client.force_authenticate(self.contributor)
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
attributes = ("entity", "field", "value")
created = AnnotationValue.objects.annotate(value=F("_value__value")).values(
*attributes
)
received = [{key: entry[key] for key in attributes} for entry in response.data]
self.assertCountEqual(received, values)
self.assertCountEqual(created, values)

# Authenticated request, no permission.
AnnotationValue.objects.all().delete()
self.entity1.collection.set_permission(Permission.NONE, self.contributor)
self.client.force_authenticate(self.contributor)
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})
self.assertAlmostEqual(values_count + 1, AnnotationValue.objects.count())
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertAlmostEqual(0, AnnotationValue.objects.count())

def test_update_annotation_value(self):
"""Test updating new annotation value objects."""
Expand All @@ -432,7 +454,7 @@ def test_update_annotation_value(self):

# Authenticated request, entity should not be changed
values = {
"id": self.annotation_value1.pk,
"field": self.annotation_field1.pk,
"value": "string",
"entity": self.entity2.pk,
}
Expand All @@ -448,6 +470,127 @@ def test_update_annotation_value(self):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})

# Single / bulk update with put.
self.entity1.collection.set_permission(Permission.EDIT, self.contributor)
path = reverse("resolwe-api:annotationvalue-list")
client = APIClient()
values = [
{
"entity": self.entity1.pk,
"field": self.annotation_field1.pk,
"value": -1,
}
]

# Unauthenticated request.
response = client.put(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(
response.data,
{"detail": "You do not have permission to perform this action."},
)

# Authenticated request with validation error.
client.force_authenticate(self.contributor)
response = client.put(path, values, format="json")
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["error"],
[
"The value '-1' is not of the expected type 'str'.",
f"The value '-1' is not valid for the field {self.annotation_field1}.",
],
)
self.assertEqual(self.annotation_value1.value, "string")
self.assertEqual(AnnotationValue.objects.count(), 2)

# Authenticated request with validation error.
values[0]["value"] = "another"
response = client.put(path, values, format="json")
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data,
[
{
"label": "Another one",
"id": self.annotation_value1.pk,
"entity": self.entity1.pk,
"field": self.annotation_field1.pk,
"value": "another",
}
],
)
self.assertEqual(self.annotation_value1.value, "another")
self.assertEqual(self.annotation_value1.label, "Another one")
self.assertEqual(AnnotationValue.objects.count(), 2)

# Multi.
values = [
{"field": self.annotation_field2.pk, "value": 1, "entity": self.entity1.pk},
{
"field": self.annotation_field1.pk,
"value": "string",
"entity": self.entity1.pk,
},
]

response = client.put(path, values, format="json")
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_200_OK)
created_value = AnnotationValue.objects.get(
entity=self.entity1, field=self.annotation_field2
)
expected = [
{
"label": "label string",
"id": self.annotation_value1.pk,
"entity": self.entity1.pk,
"field": self.annotation_field1.pk,
"value": "string",
},
{
"label": created_value.label,
"id": created_value.pk,
"entity": created_value.entity.pk,
"field": created_value.field.pk,
"value": created_value.value,
},
]
self.assertCountEqual(response.data, expected)
self.assertEqual(self.annotation_value1.value, "string")
self.assertEqual(self.annotation_value1.label, "label string")
self.assertEqual(AnnotationValue.objects.count(), 3)

# Multi + delete.
values = [
{"field": self.annotation_field2.pk, "value": 2, "entity": self.entity1.pk},
{
"field": self.annotation_field1.pk,
"value": None,
"entity": self.entity1.pk,
},
]
response = client.put(path, values, format="json")
created_value.refresh_from_db()
expected = [
{
"label": created_value.label,
"id": created_value.pk,
"entity": created_value.entity.pk,
"field": created_value.field.pk,
"value": created_value.label,
},
]
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertCountEqual(response.data, expected)
with self.assertRaises(AnnotationValue.DoesNotExist):
self.annotation_value1.refresh_from_db()
created_value.refresh_from_db()
self.assertEqual(created_value.value, 2)
self.assertEqual(created_value.label, 2)

def test_delete_annotation_value(self):
"""Test deleting annotation value objects."""
client = APIClient()
Expand Down
76 changes: 57 additions & 19 deletions resolwe/flow/views/annotations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
"""Annotations viewset."""


from rest_framework import exceptions, mixins, permissions, request, viewsets
from typing import Any

from rest_framework import (
exceptions,
generics,
mixins,
permissions,
response,
status,
viewsets,
)
from rest_framework.serializers import BaseSerializer

from resolwe.flow.filters import (
AnnotationFieldFilter,
Expand Down Expand Up @@ -63,9 +74,10 @@ class AnnotationFieldViewSet(
class AnnotationValueViewSet(
mixins.RetrieveModelMixin,
ResolweUpdateModelMixin,
ResolweCreateModelMixin,
mixins.CreateModelMixin,
mixins.ListModelMixin,
mixins.DestroyModelMixin,
generics.UpdateAPIView,
viewsets.GenericViewSet,
):
"""Annotation value viewset."""
Expand All @@ -75,25 +87,51 @@ class AnnotationValueViewSet(
queryset = AnnotationValue.objects.all()
permission_classes = (get_permissions_class(),)

def _get_entity(self, request: request.Request) -> AnnotationValue:
"""Get annotation value from request.
def get_serializer(self, *args: Any, **kwargs: Any) -> BaseSerializer:
"""Get serializer instance depending on the request type."""
kwargs_many = kwargs.get("many", False)
kwargs["many"] = isinstance(self.request.data, list) or kwargs_many
return super().get_serializer(*args, **kwargs)

def _check_permissions(self, serializer):
"""Check if user has edit permission on entities."""
validated_data = (
serializer.validated_data
if isinstance(serializer.validated_data, list)
else [serializer.validated_data]
)
# Check permissions on entities.
if not all(
entity.has_permission(Permission.EDIT, self.request.user)
for entity in {value["entity"] for value in validated_data}
):
raise exceptions.PermissionDenied()

:raises ValidationError: if the annotation value is not valid.
:raises NotFound: if the user is not authenticated.
def perform_create(self, serializer: BaseSerializer) -> None:
"""Perform create annotation value(s).
The permission on entities must be checked.
"""
serializer = self.get_serializer(data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
return serializer.validated_data["entity"]
self._check_permissions(serializer)
return super().perform_create(serializer)

def create(self, request, *args, **kwargs):
"""Create annotation value.
def update(self, request, *args, pk=None, **kwargs):
"""Update annotation value(s).
Authenticated users with edit permissions on the entity can create annotations.
When posting multiple values, the request is treated as a bulk update. The bulk
update can create, update or delete values. Values are deleted when the value
is set no None.
"""
entity = self._get_entity(request)
if entity.has_permission(Permission.EDIT, request.user):
return super().create(request, *args, **kwargs)
elif entity.has_permission(Permission.VIEW, request.user):
raise exceptions.PermissionDenied()
else:
raise exceptions.NotFound()
# Regular update on a detail view.
if pk is not None:
return super().update(request, *args, pk=pk, **kwargs)

# Bulk update / create / delete.
serializer = self.get_serializer(data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
self._check_permissions(serializer)
serializer.update(None, serializer.validated_data)
headers = self.get_success_headers(serializer.data)
return response.Response(
serializer.data, status=status.HTTP_200_OK, headers=headers
)

0 comments on commit 99a4f85

Please sign in to comment.