Skip to content

Commit

Permalink
Add bulk annotations update
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Nov 21, 2023
1 parent c195ce6 commit 6bcb50e
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Fixed
Changed
-------
- Bulk annotations on entity endpoint now accept field path instead of id
- Suport bulk create/update/delete on AnnotationValues endpoint


===================
Expand Down
3 changes: 3 additions & 0 deletions resolwe/flow/models/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ def validate(self):
def save(self, *args, **kwargs):
"""Save the annotation value after validation has passed."""
annotation_value_validator.validate(self, raise_exception=True)
# Make sure the label is always set.
if "label" not in self._value:
self.recompute_label()
super().save(*args, **kwargs)

def recompute_label(self):
Expand Down
66 changes: 64 additions & 2 deletions resolwe/flow/serializers/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Resolwe annotations serializer."""
from typing import Any

from django.db import models
from django.db.models import Q

from rest_framework import serializers
from rest_framework.fields import empty
Expand Down Expand Up @@ -94,13 +97,71 @@ class Meta:
fields = read_only_fields + ("name", "fields", "contributor")


class AnnotationValueListSerializer(serializers.ListSerializer):
"""Perform bulk update of annotation values to speed up requests."""

def create(self, validated_data: Any) -> Any:
"""Perform efficient bulk create."""
return AnnotationValue.objects.bulk_create(
AnnotationValue(**data)
for data in validated_data
if data["_value"]["value"] is not None
)

def update(self, instance, validated_data: Any):
"""Perform efficient bulk create/update/delete."""
# Read existing annotations in a single query.
query = Q()
for data in validated_data:
query |= Q(field=data["field"], entity=data["entity"])
existing_annotations = AnnotationValue.objects.filter(query)

# Create a mapping between (field, entity) and the existing annotations.
annotation_map = {
(annotation.field, annotation.entity): annotation
for annotation in existing_annotations
}

self.instance = []
to_update = []
to_create = []
to_delete = []
for data in validated_data:
if value := annotation_map.get((data["field"], data["entity"])):
if data["_value"]["value"] is None:
to_delete.append(value.pk)
else:
to_update.append((value, data))
else:
to_create.append(data)

# Bulk create new annotations.
self.instance += self.create(to_create)
# Update annotations.
for value, data in to_update:
self.instance.append(self.child.update(value, data))
# Bulk delete annotations.
AnnotationValue.objects.filter(pk__in=to_delete).delete()
return self.instance

def validate(self, attrs: Any) -> Any:
"""Validate list of annotation values."""
if len(set((attr["field"], attr["entity"]) for attr in attrs)) != len(attrs):
raise serializers.ValidationError(
"Duplicate annotation values for the same entity and field."
)
return super().validate(attrs)


class AnnotationValueSerializer(ResolweBaseSerializer):
"""Serializer for AnnotationValue objects."""

def __init__(self, instance=None, data=empty, **kwargs):
"""Rewrite value -> _value."""
if data is not empty and "value" in data:
data["_value"] = {"value": data.pop("value", None)}
if data is not empty:
for entry in data if isinstance(data, list) else [data]:
if "value" in entry:
entry["_value"] = {"value": entry.pop("value")}
super().__init__(instance, data, **kwargs)

class Meta:
Expand All @@ -111,3 +172,4 @@ class Meta:
update_protected_fields = ("id", "entity", "field")
fields = read_only_fields + update_protected_fields + ("value", "_value")
extra_kwargs = {"_value": {"write_only": True}}
list_serializer_class = AnnotationValueListSerializer
197 changes: 190 additions & 7 deletions resolwe/flow/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Sequence

from django.core.exceptions import ValidationError
from django.db.models import F
from django.urls import reverse

from rest_framework import status
Expand Down Expand Up @@ -385,8 +386,7 @@ def test_create_annotation_value(self):

# Unauthenticated request.
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertAlmostEqual(values_count, AnnotationValue.objects.count())

# Authenticated request.
Expand All @@ -399,13 +399,58 @@ def test_create_annotation_value(self):
self.assertEqual(created_value.value, -1)
self.assertAlmostEqual(values_count + 1, AnnotationValue.objects.count())

# Bulk create, no permission on entity 2.
AnnotationValue.objects.all().delete()
values = [
{"entity": self.entity1.pk, "field": field.pk, "value": -1},
{"entity": self.entity2.pk, "field": field.pk, "value": -2},
]
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertAlmostEqual(0, AnnotationValue.objects.count())

# Bulk create, edit permission on both entities.
self.collection2.set_permission(Permission.EDIT, self.contributor)
self.client.force_authenticate(self.contributor)
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
attributes = ("entity", "field", "value")
created = AnnotationValue.objects.annotate(value=F("_value__value")).values(
*attributes
)
received = [{key: entry[key] for key in attributes} for entry in response.data]
self.assertCountEqual(received, values)
self.assertCountEqual(created, values)

# Bulk create, request for same entity and field.
values = [
{"entity": self.entity1.pk, "field": field.pk, "value": -10},
{"entity": self.entity1.pk, "field": field.pk, "value": -20},
]
response = self.client.post(path, values, format="json")
self.assertContains(
response,
"Duplicate annotation values for the same entity and field.",
status_code=status.HTTP_400_BAD_REQUEST,
)
self.assertEqual(
AnnotationValue.objects.get(entity=self.entity1, field=field).value, -1
)
self.assertEqual(
AnnotationValue.objects.get(entity=self.entity2, field=field).value, -2
)
self.assertEqual(AnnotationValue.objects.count(), 2)

# Authenticated request, no permission.
values = [
{"entity": self.entity1.pk, "field": field.pk, "value": -10},
{"entity": self.entity2.pk, "field": field.pk, "value": -20},
]
AnnotationValue.objects.all().delete()
self.entity1.collection.set_permission(Permission.NONE, self.contributor)
self.client.force_authenticate(self.contributor)
response = self.client.post(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})
self.assertAlmostEqual(values_count + 1, AnnotationValue.objects.count())
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertAlmostEqual(0, AnnotationValue.objects.count())

def test_update_annotation_value(self):
"""Test updating new annotation value objects."""
Expand All @@ -432,7 +477,7 @@ def test_update_annotation_value(self):

# Authenticated request, entity should not be changed
values = {
"id": self.annotation_value1.pk,
"field": self.annotation_field1.pk,
"value": "string",
"entity": self.entity2.pk,
}
Expand All @@ -448,6 +493,143 @@ def test_update_annotation_value(self):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})

# Single / bulk update with put.
self.entity1.collection.set_permission(Permission.EDIT, self.contributor)
path = reverse("resolwe-api:annotationvalue-list")
client = APIClient()
values = [
{
"entity": self.entity1.pk,
"field": self.annotation_field1.pk,
"value": -1,
}
]

# Unauthenticated request.
response = client.put(path, values, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(
response.data,
{"detail": "You do not have permission to perform this action."},
)

# Authenticated request with validation error.
client.force_authenticate(self.contributor)
response = client.put(path, values, format="json")
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["error"],
[
"The value '-1' is not of the expected type 'str'.",
f"The value '-1' is not valid for the field {self.annotation_field1}.",
],
)
self.assertEqual(self.annotation_value1.value, "string")
self.assertEqual(AnnotationValue.objects.count(), 2)

# Authenticated request with validation error.
values[0]["value"] = "another"
response = client.put(path, values, format="json")
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data,
[
{
"label": "Another one",
"id": self.annotation_value1.pk,
"entity": self.entity1.pk,
"field": self.annotation_field1.pk,
"value": "another",
}
],
)
self.assertEqual(self.annotation_value1.value, "another")
self.assertEqual(self.annotation_value1.label, "Another one")
self.assertEqual(AnnotationValue.objects.count(), 2)

# Multi with validation error.
values = [
{"field": self.annotation_field2.pk, "value": 1, "entity": self.entity1.pk},
{
"field": self.annotation_field2.pk,
"value": "string",
"entity": self.entity1.pk,
},
]
response = client.put(path, values, format="json")
self.assertContains(
response,
"Duplicate annotation values for the same entity and field.",
status_code=status.HTTP_400_BAD_REQUEST,
)

# Multi.
values = [
{"field": self.annotation_field2.pk, "value": 1, "entity": self.entity1.pk},
{
"field": self.annotation_field1.pk,
"value": "string",
"entity": self.entity1.pk,
},
]

response = client.put(path, values, format="json")
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_200_OK)
created_value = AnnotationValue.objects.get(
entity=self.entity1, field=self.annotation_field2
)
expected = [
{
"label": "label string",
"id": self.annotation_value1.pk,
"entity": self.entity1.pk,
"field": self.annotation_field1.pk,
"value": "string",
},
{
"label": created_value.label,
"id": created_value.pk,
"entity": created_value.entity.pk,
"field": created_value.field.pk,
"value": created_value.value,
},
]
self.assertCountEqual(response.data, expected)
self.assertEqual(self.annotation_value1.value, "string")
self.assertEqual(self.annotation_value1.label, "label string")
self.assertEqual(AnnotationValue.objects.count(), 3)

# Multi + delete.
values = [
{"field": self.annotation_field2.pk, "value": 2, "entity": self.entity1.pk},
{
"field": self.annotation_field1.pk,
"value": None,
"entity": self.entity1.pk,
},
]
response = client.put(path, values, format="json")
created_value.refresh_from_db()
expected = [
{
"label": created_value.label,
"id": created_value.pk,
"entity": created_value.entity.pk,
"field": created_value.field.pk,
"value": created_value.label,
},
]
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertCountEqual(response.data, expected)
with self.assertRaises(AnnotationValue.DoesNotExist):
self.annotation_value1.refresh_from_db()
created_value.refresh_from_db()
self.assertEqual(created_value.value, 2)
self.assertEqual(created_value.label, 2)

def test_delete_annotation_value(self):
"""Test deleting annotation value objects."""
client = APIClient()
Expand Down Expand Up @@ -1255,3 +1437,4 @@ def has_value(entity, field_id, value):
self.assertEqual(self.annotation_value1.value, "bbb")
has_value(self.entity1, self.annotation_field1.pk, "bbb")
has_value(self.entity1, self.annotation_field2.pk, 2)
has_value(self.entity1, self.annotation_field2.pk, 2)
Loading

0 comments on commit 6bcb50e

Please sign in to comment.