From 826e143b32054bc155f5dd040552f8bfdbf34466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gregor=20Jer=C5=A1e?= Date: Tue, 14 Nov 2023 07:31:26 +0100 Subject: [PATCH] Update or create --- resolwe/flow/models/annotations.py | 19 +- resolwe/flow/serializers/annotations.py | 8 + resolwe/flow/tests/test_annotations.py | 226 +++++++++++++++++++++++- resolwe/flow/views/annotations.py | 73 ++++++-- resolwe/flow/views/entity.py | 89 +++++++--- 5 files changed, 359 insertions(+), 56 deletions(-) diff --git a/resolwe/flow/models/annotations.py b/resolwe/flow/models/annotations.py index 1048f35a6..c22b6b89e 100644 --- a/resolwe/flow/models/annotations.py +++ b/resolwe/flow/models/annotations.py @@ -27,6 +27,11 @@ from .base import AuditModel +VALIDATOR_LENGTH = 128 +NAME_LENGTH = 128 +LABEL_LENGTH = 128 +DESCRIPTION_LENGTH = 256 + class HandleMissingAnnotations(Enum): """How to handle missing annotations.""" @@ -217,10 +222,10 @@ class AnnotationGroup(models.Model): """Group of annotation fields.""" #: the name of the annotation group - name = models.CharField(max_length=128) + name = models.CharField(max_length=NAME_LENGTH) #: the label of the annotation group - label = models.CharField(max_length=128) + label = models.CharField(max_length=LABEL_LENGTH) #: the sorting order among annotation groups sort_order = models.PositiveSmallIntegerField() @@ -239,13 +244,13 @@ class AnnotationField(models.Model): """Annotation field.""" #: the name of the annotation fields - name = models.CharField(max_length=128) + name = models.CharField(max_length=NAME_LENGTH) #: user visible field name - label = models.CharField(max_length=128) + label = models.CharField(max_length=LABEL_LENGTH) #: user visible field description - description = models.CharField(max_length=256) + description = models.CharField(max_length=DESCRIPTION_LENGTH) #: the type of the annotation field type = models.CharField(max_length=16) @@ -259,7 +264,9 @@ class AnnotationField(models.Model): sort_order = models.PositiveSmallIntegerField() #: optional regular expression for validation - validator_regex = models.CharField(max_length=128, null=True, blank=True) + validator_regex = models.CharField( + max_length=VALIDATOR_LENGTH, null=True, blank=True + ) #: optional map of valid values to labels vocabulary = models.JSONField(null=True, blank=True) diff --git a/resolwe/flow/serializers/annotations.py b/resolwe/flow/serializers/annotations.py index 8e951b4a3..5e883f373 100644 --- a/resolwe/flow/serializers/annotations.py +++ b/resolwe/flow/serializers/annotations.py @@ -4,6 +4,7 @@ from rest_framework import serializers from rest_framework.fields import empty +from resolwe.flow.models.annotations import NAME_LENGTH as ANNOTATION_NAME_LENGTH from resolwe.flow.models.annotations import ( AnnotationField, AnnotationGroup, @@ -80,6 +81,13 @@ class AnnotationsSerializer(serializers.Serializer): value = serializers.JSONField() +class AnnotationsByPathSerializer(serializers.Serializer): + """Serializer that reads annotation field and its value.""" + + field_path = serializers.CharField(max_length=2 * ANNOTATION_NAME_LENGTH + 1) + value = serializers.JSONField() + + class AnnotationPresetSerializer(ResolweBaseSerializer): """Serializer for AnnotationPreset objects.""" diff --git a/resolwe/flow/tests/test_annotations.py b/resolwe/flow/tests/test_annotations.py index fdc388703..920020032 100644 --- a/resolwe/flow/tests/test_annotations.py +++ b/resolwe/flow/tests/test_annotations.py @@ -414,10 +414,13 @@ def test_create_annotation_value(self): self.client.force_authenticate(self.contributor) response = self.client.post(path, values, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) - expected = AnnotationValue.objects.annotate(value=F("_value__value")).values( - "entity", "field", "value" + attributes = ("entity", "field", "value") + created = AnnotationValue.objects.annotate(value=F("_value__value")).values( + *attributes ) - self.assertCountEqual(values, expected) + received = [{key: entry[key] for key in attributes} for entry in response.data] + self.assertCountEqual(received, values) + self.assertCountEqual(created, values) # Authenticated request, no permission. AnnotationValue.objects.all().delete() @@ -451,7 +454,7 @@ def test_update_annotation_value(self): # Authenticated request, entity should not be changed values = { - "id": self.annotation_value1.pk, + "field": self.annotation_field1.pk, "value": "string", "entity": self.entity2.pk, } @@ -467,6 +470,95 @@ def test_update_annotation_value(self): self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.data, {"detail": "Not found."}) + # Single / bulk update with put. + self.entity1.collection.set_permission(Permission.EDIT, self.contributor) + path = reverse("resolwe-api:annotationvalue-list") + client = APIClient() + values = { + "entity": self.entity1.pk, + "field": self.annotation_field1.pk, + "value": -1, + } + + # Unauthenticated request. + response = client.put(path, values, format="json") + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + response.data, + {"detail": "You do not have permission to perform this action."}, + ) + + # Authenticated request with validation error. + client.force_authenticate(self.contributor) + response = client.put(path, values, format="json") + self.annotation_value1.refresh_from_db() + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual( + response.data["error"], + [ + "The value '-1' is not of the expected type 'str'.", + f"The value '-1' is not valid for the field {self.annotation_field1}.", + ], + ) + self.assertEqual(self.annotation_value1.value, "string") + self.assertEqual(AnnotationValue.objects.count(), 2) + + # Authenticated request with validation error. + values["value"] = "another" + response = client.put(path, values, format="json") + self.annotation_value1.refresh_from_db() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + response.data, + { + "label": "Another one", + "id": self.annotation_value1.pk, + "entity": self.entity1.pk, + "field": self.annotation_field1.pk, + "value": "another", + }, + ) + self.assertEqual(self.annotation_value1.value, "another") + self.assertEqual(self.annotation_value1.label, "Another one") + self.assertEqual(AnnotationValue.objects.count(), 2) + + # Multi. + values = [ + {"field": self.annotation_field2.pk, "value": 1, "entity": self.entity1.pk}, + { + "field": self.annotation_field1.pk, + "value": "string", + "entity": self.entity1.pk, + }, + ] + + response = client.put(path, values, format="json") + self.annotation_value1.refresh_from_db() + self.assertEqual(response.status_code, status.HTTP_200_OK) + created_value = AnnotationValue.objects.get( + entity=self.entity1, field=self.annotation_field2 + ) + expected = [ + { + "label": "label string", + "id": self.annotation_value1.pk, + "entity": self.entity1.pk, + "field": self.annotation_field1.pk, + "value": "string", + }, + { + "label": created_value.label, + "id": created_value.pk, + "entity": created_value.entity.pk, + "field": created_value.field.pk, + "value": created_value.value, + }, + ] + self.assertCountEqual(response.data, expected) + self.assertEqual(self.annotation_value1.value, "string") + self.assertEqual(self.annotation_value1.label, "label string") + self.assertEqual(AnnotationValue.objects.count(), 3) + def test_delete_annotation_value(self): """Test deleting annotation value objects.""" client = APIClient() @@ -1256,3 +1348,129 @@ def has_value(entity, field_id, value): self.assertEqual(self.annotation_value1.value, "bbb") has_value(self.entity1, self.annotation_field1.pk, "bbb") has_value(self.entity1, self.annotation_field2.pk, -1) + + def test_set_values_by_path(self): + def has_value(entity, field_id, value): + self.assertEqual( + value, entity.annotations.filter(field_id=field_id).get().value + ) + + # Remove vocabulary to simplify testing. + self.annotation_field1.vocabulary = None + self.annotation_field1.save() + viewset = EntityViewSet.as_view(actions={"post": "set_annotations_by_path"}) + request = factory.post("/", {}, format="json") + + # Unauthenticated request, no permissions. + response: Response = viewset(request, pk=self.entity1.pk) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Request without required parameter. + request = factory.post("/", [{}], format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual( + response.data[0], + { + "field_path": ["This field is required."], + "value": ["This field is required."], + }, + ) + + annotations = [ + {"field_path": str(self.annotation_field1), "value": "new value"}, + {"field_path": str(self.annotation_field2), "value": -1}, + ] + + # Valid request without regex validation. + request = factory.post("/", annotations, format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self.entity1.annotations.count(), 2) + has_value(self.entity1, self.annotation_field1.pk, "new value") + has_value(self.entity1, self.annotation_field2.pk, -1) + + # Wrong type. + annotations = [{"field_path": str(self.annotation_field1), "value": 10}] + request = factory.post("/", annotations, format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(response.data), 1) + self.assertEqual( + response.data, + {"error": ["The value '10' is not of the expected type 'str'."]}, + ) + + has_value(self.entity1, self.annotation_field1.pk, "new value") + has_value(self.entity1, self.annotation_field2.pk, -1) + + # Wrong regex. + self.annotation_field1.validator_regex = "b+" + self.annotation_field1.save() + annotations = [{"field_path": str(self.annotation_field1), "value": "aaa"}] + request = factory.post("/", annotations, format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(len(response.data), 1) + self.assertEqual( + response.data, + { + "error": [ + f"The value 'aaa' for the field '{self.annotation_field1.pk}' does not match the regex 'b+'." + ], + }, + ) + has_value(self.entity1, self.annotation_field1.pk, "new value") + has_value(self.entity1, self.annotation_field2.pk, -1) + + # Wrong regex and type. + annotations = [{"field_path": str(self.annotation_field1), "value": 10}] + request = factory.post("/", annotations, format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertCountEqual( + response.data["error"], + [ + f"The value '10' for the field '{self.annotation_field1.pk}' does not match the regex 'b+'.", + "The value '10' is not of the expected type 'str'.", + ], + ) + has_value(self.entity1, self.annotation_field1.pk, "new value") + has_value(self.entity1, self.annotation_field2.pk, -1) + + # Multiple fields validation error. + annotations = [ + {"field_path": str(self.annotation_field1), "value": 10}, + {"field_path": str(self.annotation_field2), "value": "string"}, + ] + request = factory.post("/", annotations, format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertCountEqual( + response.data["error"], + [ + "The value '10' is not of the expected type 'str'.", + f"The value '10' for the field '{self.annotation_field1.pk}' does not match the regex 'b+'.", + "The value 'string' is not of the expected type 'int'.", + ], + ) + has_value(self.entity1, self.annotation_field1.pk, "new value") + has_value(self.entity1, self.annotation_field2.pk, -1) + + # Regular request with regex validation. + annotations = [{"field_path": str(self.annotation_field1), "value": "bbb"}] + request = factory.post("/", annotations, format="json") + force_authenticate(request, self.contributor) + response = viewset(request, pk=self.entity1.pk) + self.annotation_value1.refresh_from_db() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(self.annotation_value1.value, "bbb") + has_value(self.entity1, self.annotation_field1.pk, "bbb") + has_value(self.entity1, self.annotation_field2.pk, -1) diff --git a/resolwe/flow/views/annotations.py b/resolwe/flow/views/annotations.py index 4772dcd24..57255c661 100644 --- a/resolwe/flow/views/annotations.py +++ b/resolwe/flow/views/annotations.py @@ -1,15 +1,11 @@ """Annotations viewset.""" -from rest_framework import ( - exceptions, - mixins, - permissions, - request, - response, - status, - viewsets, -) +from typing import Iterable + +from django.core.exceptions import ValidationError + +from rest_framework import exceptions, mixins, permissions, response, status, viewsets from resolwe.flow.filters import ( AnnotationFieldFilter, @@ -83,13 +79,36 @@ class AnnotationValueViewSet( queryset = AnnotationValue.objects.all() permission_classes = (get_permissions_class(),) - def get_serializer(self, *args, **kwargs): - """ - Return the serializer instance that should be used for validating and - deserializing input, and for serializing output. + def _validate_values(self, annotation_values: Iterable[AnnotationValue]): + """Validate annotation values. + + :raises ValidationError: if validation does not pass. """ - kwargs["many"] = isinstance(self.request.data, list) - return super().get_serializer(*args, **kwargs) + validation_errors: list[ValidationError] = [] + for value in annotation_values: + try: + value.validate() + except ValidationError as e: + validation_errors.append(e) + if validation_errors: + raise ValidationError(validation_errors) + + def _bulk_create(self, annotation_values: Iterable[AnnotationValue]): + """Bulk create or update annotation values.""" + also_update = self.request.method == "PUT" + inserted_values: list[AnnotationValue] = [] + if also_update: + for annotation_value in annotation_values: + inserted_values.append( + AnnotationValue.objects.update_or_create( + field=annotation_value.field, + entity=annotation_value.entity, + defaults={"_value": annotation_value._value}, + )[0] + ) + else: + inserted_values = AnnotationValue.objects.bulk_create(annotation_values) + return inserted_values def create(self, request, *args, **kwargs): """Create annotation value(s). @@ -97,7 +116,8 @@ def create(self, request, *args, **kwargs): Create one or more annotation values. The request data can be either dict or list (in case of bulk create). """ - serializer = self.get_serializer(data=request.data) + many = isinstance(self.request.data, list) + serializer = self.get_serializer(data=request.data, many=many) serializer.is_valid(raise_exception=True) raw_data = serializer.validated_data @@ -108,8 +128,27 @@ def create(self, request, *args, **kwargs): entity.has_permission(Permission.EDIT, request.user) for entity in entities ): raise exceptions.PermissionDenied() - self.perform_create(serializer) + + # Create annotation values. + annotation_values = [AnnotationValue(**value) for value in validated_data] + self._validate_values(annotation_values) + inserted_values = self._bulk_create(annotation_values) + + to_serialize = inserted_values if many else inserted_values[0] + serializer = self.get_serializer(to_serialize, many=many) headers = self.get_success_headers(serializer.data) return response.Response( serializer.data, status=status.HTTP_201_CREATED, headers=headers ) + + def put(self, request, *args, **kwargs): + """Update annotation value(s). + + Create one or more annotation values. The request data can be either dict or + list (in case of bulk create). + """ + # Call create, which calls _bulkcreate which will also allow update, since our + # method is PUT and not POST: + response = self.create(request, *args, **kwargs) + response.status_code = status.HTTP_200_OK + return response diff --git a/resolwe/flow/views/entity.py b/resolwe/flow/views/entity.py index f5ffd02b2..32d7251c5 100644 --- a/resolwe/flow/views/entity.py +++ b/resolwe/flow/views/entity.py @@ -1,22 +1,26 @@ """Entity viewset.""" import re +from typing import Optional from drf_spectacular.utils import extend_schema from django.core.exceptions import ValidationError as DjangoValidationError -from django.db import transaction from django.db.models import F, Func, OuterRef, Prefetch, Subquery from django.db.models.functions import Coalesce from rest_framework import exceptions, serializers, status from rest_framework.decorators import action +from rest_framework.request import Request from rest_framework.response import Response from resolwe.flow.filters import EntityFilter from resolwe.flow.models import AnnotationValue, Data, DescriptorSchema, Entity from resolwe.flow.models.annotations import AnnotationField, HandleMissingAnnotations from resolwe.flow.serializers import EntitySerializer -from resolwe.flow.serializers.annotations import AnnotationsSerializer +from resolwe.flow.serializers.annotations import ( + AnnotationsByPathSerializer, + AnnotationsSerializer, +) from resolwe.observers.mixins import ObservableMixin from resolwe.process.descriptor import ValidationError @@ -199,44 +203,71 @@ def patched_get_queryset(): request=AnnotationsSerializer(many=True), responses={status.HTTP_200_OK: None} ) @action(detail=True, methods=["post"]) - def set_annotations(self, request, pk=None): - """Add the given list of AnnotaitonFields to the given collection.""" + def set_annotations(self, request: Request, pk: Optional[int] = None): + """Add the given list of AnnotationFields to the given entity.""" # No need to check for permissions, since post requires edit by default. entity = self.get_object() # Read and validate the request data. serializer = AnnotationsSerializer(data=request.data, many=True) serializer.is_valid(raise_exception=True) - annotations = [ - (entry["field"], entry["value"]) for entry in serializer.validated_data - ] - annotation_fields = [e[0] for e in annotations] - # The following dict is a mapping from annotation field id to the annotation - # value id. - existing_annotations = dict( - entity.annotations.filter(field__in=annotation_fields).values_list( - "field_id", "id" - ) - ) - validation_errors = [] - to_create = [] - to_update = [] - for field, value in annotations: - annotation_id = existing_annotations.get(field.id) - append_to = to_create if annotation_id is None else to_update - annotation = AnnotationValue( - entity_id=entity.id, field_id=field.id, value=value, id=annotation_id - ) + # Create annotation values. + annotation_values: list[AnnotationValue] = [] + validation_errors: list[DjangoValidationError] = [] + for value in serializer.validated_data: + value = AnnotationValue(**value, entity=entity) + annotation_values.append(value) try: - annotation.validate() + value.validate() except DjangoValidationError as e: validation_errors += e - append_to.append(annotation) + if validation_errors: + raise DjangoValidationError(validation_errors) + AnnotationValue.objects.bulk_create( + annotation_values, + update_conflicts=True, + update_fields=["_value"], + unique_fields=["entity", "field"], + ) + return Response() + + @extend_schema( + request=AnnotationsByPathSerializer(many=True), + responses={status.HTTP_200_OK: None}, + ) + @action(detail=True, methods=["post"]) + def set_annotations_by_path(self, request: Request, pk: Optional[int] = None): + """Add the given list of AnnotationFields to the given entity.""" + # No need to check for permissions, since post requires edit by default. + entity = self.get_object() + # Read and validate the request data. + serializer = AnnotationsByPathSerializer(data=request.data, many=True) + serializer.is_valid(raise_exception=True) + field_paths = {value["field_path"] for value in serializer.validated_data} + field_map = { + field_path: AnnotationField.field_from_path(field_path) + for field_path in field_paths + } + + # Create annotation values. + annotation_values: list[AnnotationValue] = [] + validation_errors: list[DjangoValidationError] = [] + for value in serializer.validated_data: + value["field"] = field_map[value.pop("field_path")] + value = AnnotationValue(**value, entity=entity) + annotation_values.append(value) + try: + value.validate() + except DjangoValidationError as e: + validation_errors += e if validation_errors: raise DjangoValidationError(validation_errors) - with transaction.atomic(): - AnnotationValue.objects.bulk_create(to_create) - AnnotationValue.objects.bulk_update(to_update, ["_value"]) + AnnotationValue.objects.bulk_create( + annotation_values, + update_conflicts=True, + update_fields=["_value"], + unique_fields=["entity", "field"], + ) return Response()