Skip to content

Commit

Permalink
Simplify permission checks on AnnotationValue endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorjerse committed Nov 6, 2023
1 parent 9be0b51 commit 9f58e12
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 101 deletions.
3 changes: 3 additions & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ Fixed
-----
- Set ``value`` to ``AnnotationValue`` object on duplication when it is created
- Send ``post_duplicate`` signal only on successful duplication
Changed
-------
- Simplify permission checks on ``AnontationValue`` endpoint


===================
Expand Down
22 changes: 1 addition & 21 deletions resolwe/flow/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import types
from copy import deepcopy
from functools import partial, partialmethod
from functools import partial
from typing import Callable, Union

from django_filters import rest_framework as filters
Expand Down Expand Up @@ -673,26 +673,6 @@ class AnnotationValueFilter(BaseResolweFilter, metaclass=AnnotationValueFieldMet

label = filters.CharFilter(method="filter_by_label")

def get_form_class(self):
"""Require at least one of the entity filters to be set."""

def clean(self, original_clean):
"""Override the clean method."""
cleaned_data = original_clean(self)
if not any(
cleaned_data[field]
for field in cleaned_data
if field.startswith("entity")
):
raise ValidationError("At least one of the entity filters must be set.")

form = super().get_form_class()
# Allow patch/delete without the entity filter.
if self.request.method not in ["PATCH", "DELETE"]:
form.clean = partialmethod(clean, original_clean=form.clean)

return form

def filter_by_label(self, queryset: QuerySet, name: str, value: str):
"""Filter by label."""
return queryset.filter(_value__label__icontains=value)
Expand Down
31 changes: 29 additions & 2 deletions resolwe/flow/models/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from django.db import models

from resolwe.flow.models.base import BaseModel
from resolwe.permissions.models import PermissionObject
from resolwe.permissions.models import Permission, PermissionObject

if TYPE_CHECKING:
from resolwe.flow.models import Entity, Collection
from resolwe.flow.models import Collection, Entity

from .base import AuditModel

Expand Down Expand Up @@ -367,6 +367,24 @@ def __str__(self):
return self.name


class AnnotationValueQuerySet(models.QuerySet):
"""Custom queryset for AnnotationValue."""

def filter_for_user(self, user) -> "AnnotationValueQuerySet":
"""Filter annotation values for user.
Return the annotation values on entities user has view permission for.
"""
# Avoid circular import.
from resolwe.flow.models import Entity

return self.filter(
entity__in=Entity.objects.filter(
pk__in=self.values("entity")
).filter_for_user(user)
)


class AnnotationValue(AuditModel):
"""The value of the annotation."""

Expand All @@ -380,6 +398,8 @@ class Meta:
]
ordering = ["field__group__sort_order", "field__sort_order"]

objects = AnnotationValueQuerySet.as_manager()

#: the entity this field belongs to
entity: "Entity" = models.ForeignKey(
"Entity", related_name="annotations", on_delete=models.CASCADE
Expand Down Expand Up @@ -465,6 +485,13 @@ def from_path(entity_id: int, path: str) -> Optional["AnnotationValue"]:
entity_id=entity_id, field_id=field_id
).first()

def has_permission(self, permission: Permission, user) -> bool:
"""Return if user permission on this object.
The permission is checked on the entity.
"""
return self.entity.has_permission(permission, user)

def __str__(self) -> str:
"""Return user-friendly string representation."""
return f"{self.label}"
91 changes: 53 additions & 38 deletions resolwe/flow/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,35 @@ def test_delete_annotation_value(self):
path = reverse(
"resolwe-api:annotationvalue-detail", args=[self.annotation_value1.pk]
)
# Unauthenticated request.
# Unauthenticated request, no view permission.
response = client.delete(path, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})

# Unauthenticated request, view permission.
self.annotation_value1.entity.collection.set_permission(
Permission.VIEW, get_anonymous_user()
)
response = client.delete(path, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(
response.data,
{"detail": "You do not have permission to perform this action."},
{"detail": "Authentication credentials were not provided."},
)

# Unauthenticated request, edit permission.
self.annotation_value1.entity.collection.set_permission(
Permission.EDIT, get_anonymous_user()
)
response = client.delete(path, format="json")
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
with self.assertRaises(AnnotationValue.DoesNotExist):
self.annotation_value1.refresh_from_db()
self.annotation_value1: AnnotationValue = AnnotationValue.objects.create(
entity=self.entity1, field=self.annotation_field1, value="string"
)
path = reverse(
"resolwe-api:annotationvalue-detail", args=[self.annotation_value1.pk]
)

# Authenticated request.
Expand All @@ -464,12 +487,12 @@ def test_delete_annotation_value(self):
with self.assertRaises(AnnotationValue.DoesNotExist):
self.annotation_value1.refresh_from_db()

# Authenticated request, no permission.
# Authenticated request, view permission.
path = reverse(
"resolwe-api:annotationvalue-detail", args=[self.annotation_value2.pk]
)
self.annotation_value2.entity.collection.set_permission(
Permission.NONE, self.contributor
Permission.VIEW, self.contributor
)
response = client.delete(path, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
Expand All @@ -478,6 +501,14 @@ def test_delete_annotation_value(self):
{"detail": "You do not have permission to perform this action."},
)

# Authenticated request, no permission.
self.annotation_value2.entity.collection.set_permission(
Permission.NONE, self.contributor
)
response = client.delete(path, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {"detail": "Not found."})

def test_annotate_path(self):
"""Test annotate entity queryset."""
entities = Entity.objects.all().annotate_path("group1.field1")
Expand All @@ -487,27 +518,31 @@ def test_annotate_path(self):
self.assertIsNone(second.group1_field1)

def test_filter_value_by_group_name(self):
# Unauthenticated request without entity filter.
# Unauthenticated request, no permissions.
request = factory.get(
"/", {"field__group__name": self.annotation_group1.name}, format="json"
)
response: Response = self.annotationvalue_viewset(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["__all__"][0],
"At least one of the entity filters must be set.",
response = self.annotationvalue_viewset(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 0)

self.annotation_value1.entity.collection.set_permission(
Permission.VIEW, get_anonymous_user()
)
response = self.annotationvalue_viewset(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)

# Authenticated request without entity filter.
force_authenticate(request, self.contributor)
response = self.annotationvalue_viewset(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["__all__"][0],
"At least one of the entity filters must be set.",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)

# Proper unauthenticated request.
# Requests with entity filter.
self.annotation_value1.entity.collection.set_permission(
Permission.NONE, get_anonymous_user()
)
request = factory.get(
"/",
{
Expand Down Expand Up @@ -544,7 +579,7 @@ def test_filter_field_by_entity(self):

def test_list_filter_preset(self):
request = factory.get("/", {}, format="json")
response: Response = self.preset_viewset(request)
response = self.preset_viewset(request)

# Unauthenticated request, no permissions.
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down Expand Up @@ -606,7 +641,7 @@ def test_annotation_field(self):

# No authentication is necessary to access the annotation field endpoint.
request = factory.get("/", {}, format="json")
response: Response = self.annotationfield_viewset(request)
response = self.annotationfield_viewset(request)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0]["name"], "field2")
self.assertEqual(response.data[0]["label"], "Annotation field 2")
Expand Down Expand Up @@ -796,7 +831,6 @@ def test_annotation_field(self):
type="INTEGER",
)
request = factory.get("/", {}, format="json")
force_authenticate(request, self.contributor)
response = self.annotationfield_viewset(request)
self.assertEqual(len(response.data), 3)
self.assertEqual(response.data[0]["name"], "field1")
Expand All @@ -809,7 +843,6 @@ def test_annotation_field(self):
# Change the field sort order within the group.
field.sort_order = self.annotation_field1.sort_order - 1
field.save()
force_authenticate(request, self.contributor)
response = self.annotationfield_viewset(request)
self.assertEqual(len(response.data), 3)
self.assertEqual(response.data[1]["name"], "field1")
Expand All @@ -823,7 +856,6 @@ def test_annotation_field(self):
self.annotation_group1.sort_order = self.annotation_group2.sort_order + 1
self.annotation_group1.save()
field.save()
force_authenticate(request, self.contributor)
response = self.annotationfield_viewset(request)
self.assertEqual(len(response.data), 3)
self.assertEqual(response.data[2]["name"], "field1")
Expand Down Expand Up @@ -913,23 +945,6 @@ def test_required_fields(self):
def test_list_filter_values(self):
request = factory.get("/", {}, format="json")

# Unauthenticated request without entity filter.
response: Response = self.annotationvalue_viewset(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["__all__"][0],
"At least one of the entity filters must be set.",
)

# Authenticated request without entity filter.
force_authenticate(request, self.contributor)
response = self.annotationvalue_viewset(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["__all__"][0],
"At least one of the entity filters must be set.",
)

# Unauthenticated request without permissions.
request = factory.get("/", {"entity": self.entity1.pk}, format="json")
response = self.annotationvalue_viewset(request)
Expand Down
48 changes: 9 additions & 39 deletions resolwe/flow/views/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
AnnotationPresetFilter,
AnnotationValueFilter,
)
from resolwe.flow.models import AnnotationPreset, Entity
from resolwe.flow.models import AnnotationPreset
from resolwe.flow.models.annotations import AnnotationField, AnnotationValue
from resolwe.flow.serializers.annotations import (
AnnotationFieldSerializer,
Expand Down Expand Up @@ -61,43 +61,26 @@ class AnnotationFieldViewSet(


class AnnotationValueViewSet(
ResolweCreateModelMixin,
mixins.RetrieveModelMixin,
ResolweUpdateModelMixin,
ResolweCreateModelMixin,
mixins.ListModelMixin,
mixins.DestroyModelMixin,
viewsets.GenericViewSet,
):
"""Annotation value viewset."""

# Users can only see the annotation values on the entities they have permission to
# access. The actual permissions are checked in the AnnotationValueFilter. The
# filter assures at least one of the entity filters is applied and the permissions
# filters are applied inside filter_permissions method defined in
# AnnotationValueMetaclass.
# This behaviour is tested in test_list_filter_values.
permission_classes = (permissions.AllowAny,)
serializer_class = AnnotationValueSerializer
filterset_class = AnnotationValueFilter
queryset = AnnotationValue.objects.all()

def _has_permissions_on_entity(self, entity: Entity) -> bool:
"""Has the authenticated user EDIT permission on the associated entity."""
return (
Entity.objects.filter(pk=entity.pk)
.filter_for_user(self.request.user, Permission.EDIT)
.exists()
)
permission_classes = (get_permissions_class(),)

def _get_entity(self, request: request.Request) -> AnnotationValue:
"""Get annotation value from request.
:raises ValidationError: if the annotation value is not valid.
:raises NotFound: if the user is not authenticated.
"""
if not request.user.is_authenticated:
raise exceptions.NotFound

serializer = self.get_serializer(data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
return serializer.validated_data["entity"]
Expand All @@ -107,23 +90,10 @@ def create(self, request, *args, **kwargs):
Authenticated users with edit permissions on the entity can create annotations.
"""
if self._has_permissions_on_entity(self._get_entity(request)):
entity = self._get_entity(request)
if entity.has_permission(Permission.EDIT, request.user):
return super().create(request, *args, **kwargs)
raise exceptions.NotFound()

def update(self, request, *args, **kwargs):
"""Update annotation values.
Authenticated users with edit permission on the entity can update annotations.
"""
entity = AnnotationValue.objects.get(pk=kwargs["pk"]).entity
if self._has_permissions_on_entity(entity):
return super().update(request, *args, **kwargs)
raise exceptions.NotFound()

def destroy(self, request, *args, **kwargs):
"""Destroy the annotation value."""
entity = AnnotationValue.objects.get(pk=kwargs["pk"]).entity
if self._has_permissions_on_entity(entity):
return super().destroy(request, *args, **kwargs)
raise exceptions.PermissionDenied()
elif entity.has_permission(Permission.VIEW, request.user):
raise exceptions.PermissionDenied()
else:
raise exceptions.NotFound()
2 changes: 1 addition & 1 deletion resolwe/permissions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_user(user: User) -> User:

def model_has_permissions(obj: models.Model) -> bool:
"""Check whether model has object level permissions."""
additional_labels = ["flow.Storage"]
additional_labels = ["flow.Storage", "flow.AnnotationValue"]
return hasattr(obj, "permission_group") or obj._meta.label in additional_labels


Expand Down

0 comments on commit 9f58e12

Please sign in to comment.