Skip to content

Commit

Permalink
Change bulk annotations update on entity endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Nov 21, 2023
1 parent bfa374a commit c195ce6
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 57 deletions.
4 changes: 4 additions & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Fixed
-----
- Store random postfix to redis for use at cleaup time

Changed
-------
- Bulk annotations on entity endpoint now accept field path instead of id


===================
38.0.0 - 2023-11-13
Expand Down
19 changes: 13 additions & 6 deletions resolwe/flow/models/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@

from .base import AuditModel

VALIDATOR_LENGTH = 128
NAME_LENGTH = 128
LABEL_LENGTH = 128
DESCRIPTION_LENGTH = 256


class HandleMissingAnnotations(Enum):
"""How to handle missing annotations."""
Expand Down Expand Up @@ -217,10 +222,10 @@ class AnnotationGroup(models.Model):
"""Group of annotation fields."""

#: the name of the annotation group
name = models.CharField(max_length=128)
name = models.CharField(max_length=NAME_LENGTH)

#: the label of the annotation group
label = models.CharField(max_length=128)
label = models.CharField(max_length=LABEL_LENGTH)

#: the sorting order among annotation groups
sort_order = models.PositiveSmallIntegerField()
Expand All @@ -239,13 +244,13 @@ class AnnotationField(models.Model):
"""Annotation field."""

#: the name of the annotation fields
name = models.CharField(max_length=128)
name = models.CharField(max_length=NAME_LENGTH)

#: user visible field name
label = models.CharField(max_length=128)
label = models.CharField(max_length=LABEL_LENGTH)

#: user visible field description
description = models.CharField(max_length=256)
description = models.CharField(max_length=DESCRIPTION_LENGTH)

#: the type of the annotation field
type = models.CharField(max_length=16)
Expand All @@ -259,7 +264,9 @@ class AnnotationField(models.Model):
sort_order = models.PositiveSmallIntegerField()

#: optional regular expression for validation
validator_regex = models.CharField(max_length=128, null=True, blank=True)
validator_regex = models.CharField(
max_length=VALIDATOR_LENGTH, null=True, blank=True
)

#: optional map of valid values to labels
vocabulary = models.JSONField(null=True, blank=True)
Expand Down
9 changes: 6 additions & 3 deletions resolwe/flow/serializers/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rest_framework import serializers
from rest_framework.fields import empty

from resolwe.flow.models.annotations import NAME_LENGTH as ANNOTATION_NAME_LENGTH
from resolwe.flow.models.annotations import (
AnnotationField,
AnnotationGroup,
Expand Down Expand Up @@ -73,11 +74,13 @@ class AnnotationFieldDictSerializer(serializers.Serializer):
confirm_action = serializers.BooleanField(default=False)


class AnnotationsSerializer(serializers.Serializer):
class AnnotationsByPathSerializer(serializers.Serializer):
"""Serializer that reads annotation field and its value."""

field = PrimaryKeyDictRelatedField(queryset=AnnotationField.objects.all())
value = serializers.JSONField()
# The field path contains the annotation group name and the annotation field name
# separated by spaces.
field_path = serializers.CharField(max_length=2 * ANNOTATION_NAME_LENGTH + 1)
value = serializers.JSONField(allow_null=True)


class AnnotationPresetSerializer(ResolweBaseSerializer):
Expand Down
48 changes: 33 additions & 15 deletions resolwe/flow/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def test_empty_vocabulary(self):
entity=self.entity1, field=self.annotation_field1, value="non_existing"
)

def test_set_values(self):
def test_set_values_by_path(self):
def has_value(entity, field_id, value):
self.assertEqual(
value, entity.annotations.filter(field_id=field_id).get().value
Expand All @@ -1136,14 +1136,14 @@ def has_value(entity, field_id, value):
self.assertDictEqual(
response.data[0],
{
"field": ["This field is required."],
"field_path": ["This field is required."],
"value": ["This field is required."],
},
)

annotations = [
{"field": {"id": self.annotation_field1.pk}, "value": "new value"},
{"field": {"id": self.annotation_field2.pk}, "value": -1},
{"field_path": str(self.annotation_field1), "value": "new value"},
{"field_path": str(self.annotation_field2), "value": -1},
]

# Valid request without regex validation.
Expand All @@ -1156,8 +1156,26 @@ def has_value(entity, field_id, value):
has_value(self.entity1, self.annotation_field1.pk, "new value")
has_value(self.entity1, self.annotation_field2.pk, -1)

annotations = [
{"field_path": str(self.annotation_field1), "value": None},
{"field_path": str(self.annotation_field2), "value": 2},
]

# Valid request without regex validation, delete annotation.
request = factory.post("/", annotations, format="json")
force_authenticate(request, self.contributor)
response = viewset(request, pk=self.entity1.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self.entity1.annotations.count(), 1)
has_value(self.entity1, self.annotation_field2.pk, 2)

# Re-create deleted annotation value.
self.annotation_value1 = AnnotationValue.objects.create(
entity=self.entity1, field=self.annotation_field1, value="new value"
)

# Wrong type.
annotations = [{"field": {"id": self.annotation_field1.pk}, "value": 10}]
annotations = [{"field_path": str(self.annotation_field1), "value": 10}]
request = factory.post("/", annotations, format="json")
force_authenticate(request, self.contributor)
response = viewset(request, pk=self.entity1.pk)
Expand All @@ -1169,12 +1187,12 @@ def has_value(entity, field_id, value):
)

has_value(self.entity1, self.annotation_field1.pk, "new value")
has_value(self.entity1, self.annotation_field2.pk, -1)
has_value(self.entity1, self.annotation_field2.pk, 2)

# Wrong regex.
self.annotation_field1.validator_regex = "b+"
self.annotation_field1.save()
annotations = [{"field": {"id": self.annotation_field1.pk}, "value": "aaa"}]
annotations = [{"field_path": str(self.annotation_field1), "value": "aaa"}]
request = factory.post("/", annotations, format="json")
force_authenticate(request, self.contributor)
response = viewset(request, pk=self.entity1.pk)
Expand All @@ -1189,10 +1207,10 @@ def has_value(entity, field_id, value):
},
)
has_value(self.entity1, self.annotation_field1.pk, "new value")
has_value(self.entity1, self.annotation_field2.pk, -1)
has_value(self.entity1, self.annotation_field2.pk, 2)

# Wrong regex and type.
annotations = [{"field": {"id": self.annotation_field1.pk}, "value": 10}]
annotations = [{"field_path": str(self.annotation_field1), "value": 10}]
request = factory.post("/", annotations, format="json")
force_authenticate(request, self.contributor)
response = viewset(request, pk=self.entity1.pk)
Expand All @@ -1205,12 +1223,12 @@ def has_value(entity, field_id, value):
],
)
has_value(self.entity1, self.annotation_field1.pk, "new value")
has_value(self.entity1, self.annotation_field2.pk, -1)
has_value(self.entity1, self.annotation_field2.pk, 2)

# Multiple fields validation error.
annotations = [
{"field": {"id": self.annotation_field1.pk}, "value": 10},
{"field": {"id": self.annotation_field2.pk}, "value": "string"},
{"field_path": str(self.annotation_field1), "value": 10},
{"field_path": str(self.annotation_field2), "value": "string"},
]
request = factory.post("/", annotations, format="json")
force_authenticate(request, self.contributor)
Expand All @@ -1225,15 +1243,15 @@ def has_value(entity, field_id, value):
],
)
has_value(self.entity1, self.annotation_field1.pk, "new value")
has_value(self.entity1, self.annotation_field2.pk, -1)
has_value(self.entity1, self.annotation_field2.pk, 2)

# Regular request with regex validation.
annotations = [{"field": {"id": self.annotation_field1.pk}, "value": "bbb"}]
annotations = [{"field_path": str(self.annotation_field1), "value": "bbb"}]
request = factory.post("/", annotations, format="json")
force_authenticate(request, self.contributor)
response = viewset(request, pk=self.entity1.pk)
self.annotation_value1.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(self.annotation_value1.value, "bbb")
has_value(self.entity1, self.annotation_field1.pk, "bbb")
has_value(self.entity1, self.annotation_field2.pk, -1)
has_value(self.entity1, self.annotation_field2.pk, 2)
72 changes: 39 additions & 33 deletions resolwe/flow/views/entity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Entity viewset."""
import re
from typing import Optional

from drf_spectacular.utils import extend_schema

Expand All @@ -10,13 +11,14 @@

from rest_framework import exceptions, serializers, status
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response

from resolwe.flow.filters import EntityFilter
from resolwe.flow.models import AnnotationValue, Data, DescriptorSchema, Entity
from resolwe.flow.models.annotations import AnnotationField, HandleMissingAnnotations
from resolwe.flow.serializers import EntitySerializer
from resolwe.flow.serializers.annotations import AnnotationsSerializer
from resolwe.flow.serializers.annotations import AnnotationsByPathSerializer
from resolwe.observers.mixins import ObservableMixin
from resolwe.process.descriptor import ValidationError

Expand Down Expand Up @@ -196,47 +198,51 @@ def patched_get_queryset():
return resp

@extend_schema(
request=AnnotationsSerializer(many=True), responses={status.HTTP_200_OK: None}
request=AnnotationsByPathSerializer(many=True),
responses={status.HTTP_200_OK: None},
)
@action(detail=True, methods=["post"])
def set_annotations(self, request, pk=None):
"""Add the given list of AnnotaitonFields to the given collection."""
def set_annotations(self, request: Request, pk: Optional[int] = None):
"""Add the given list of AnnotationFields to the given entity."""
# No need to check for permissions, since post requires edit by default.
entity = self.get_object()
# Read and validate the request data.
serializer = AnnotationsSerializer(data=request.data, many=True)
serializer = AnnotationsByPathSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
annotations = [
(entry["field"], entry["value"]) for entry in serializer.validated_data
]
annotation_fields = [e[0] for e in annotations]
# The following dict is a mapping from annotation field id to the annotation
# value id.
existing_annotations = dict(
entity.annotations.filter(field__in=annotation_fields).values_list(
"field_id", "id"
)
)

validation_errors = []
to_create = []
to_update = []
for field, value in annotations:
annotation_id = existing_annotations.get(field.id)
append_to = to_create if annotation_id is None else to_update
annotation = AnnotationValue(
entity_id=entity.id, field_id=field.id, value=value, id=annotation_id
)
try:
annotation.validate()
except DjangoValidationError as e:
validation_errors += e
append_to.append(annotation)
field_paths = {value["field_path"] for value in serializer.validated_data}
field_map = {
field_path: AnnotationField.field_from_path(field_path)
for field_path in field_paths
}

# Create annotation values and prepare list of fields which annotations should
# be deleted.
annotation_values: list[AnnotationValue] = []
validation_errors: list[DjangoValidationError] = []
fields_to_delete: list[int] = []
for value in serializer.validated_data:
if value["value"] is None:
fields_to_delete.append(field_map[value["field_path"]].pk)
else:
value["field"] = field_map[value.pop("field_path")]
value = AnnotationValue(**value, entity=entity)
annotation_values.append(value)
try:
value.validate()
except DjangoValidationError as e:
validation_errors += e
if validation_errors:
raise DjangoValidationError(validation_errors)

# Delete and update annotations in a transaction.
with transaction.atomic():
AnnotationValue.objects.bulk_create(to_create)
AnnotationValue.objects.bulk_update(to_update, ["_value"])
AnnotationValue.objects.filter(
entity=entity, field_id__in=fields_to_delete
).delete()
AnnotationValue.objects.bulk_create(
annotation_values,
update_conflicts=True,
update_fields=["_value"],
unique_fields=["entity", "field"],
)
return Response()

0 comments on commit c195ce6

Please sign in to comment.