diff --git a/codeforlife/permissions/allow_none.py b/codeforlife/permissions/allow_none.py index 6bc21734..ef2cbb83 100644 --- a/codeforlife/permissions/allow_none.py +++ b/codeforlife/permissions/allow_none.py @@ -14,5 +14,8 @@ class AllowNone(BasePermission): https://www.django-rest-framework.org/api-guide/permissions/#allowany """ + def __eq__(self, other): + return isinstance(other, self.__class__) + def has_permission(self, request, view): return False diff --git a/codeforlife/permissions/is_cron_request_from_google.py b/codeforlife/permissions/is_cron_request_from_google.py index b8f49f7c..c6cdf254 100644 --- a/codeforlife/permissions/is_cron_request_from_google.py +++ b/codeforlife/permissions/is_cron_request_from_google.py @@ -14,6 +14,9 @@ class IsCronRequestFromGoogle(BasePermission): https://cloud.google.com/appengine/docs/flexible/scheduling-jobs-with-cron-yaml#securing_urls_for_cron """ + def __eq__(self, other): + return isinstance(other, self.__class__) + def has_permission(self, request, view): return ( settings.DEBUG diff --git a/codeforlife/tests/model_view_set.py b/codeforlife/tests/model_view_set.py index eee9dffa..36db8841 100644 --- a/codeforlife/tests/model_view_set.py +++ b/codeforlife/tests/model_view_set.py @@ -17,6 +17,7 @@ from django.utils.http import urlencode from pyotp import TOTP from rest_framework import status +from rest_framework.permissions import BasePermission from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory, APITestCase @@ -689,6 +690,22 @@ def setUpClass(cls): return super().setUpClass() + def assert_get_permissions( + self, + permissions: t.List[BasePermission], + *args, + **kwargs, + ): + """Assert that we get the expected permissions. + + Args: + permissions: The expected permissions. + """ + + model_view_set = self.model_view_set_class(*args, **kwargs) + actual_permissions = model_view_set.get_permissions() + self.assertListEqual(permissions, actual_permissions) + def get_other_user( self, user: User, diff --git a/codeforlife/user/permissions/in_class.py b/codeforlife/user/permissions/in_class.py index c030e6d0..2c5ceac6 100644 --- a/codeforlife/user/permissions/in_class.py +++ b/codeforlife/user/permissions/in_class.py @@ -26,6 +26,12 @@ def __init__(self, class_id: t.Optional[str] = None): super().__init__() self.class_id = class_id + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.class_id == other.class_id + ) + def has_permission(self, request: Request, view: APIView): user = request.user if super().has_permission(request, view) and isinstance(user, User): diff --git a/codeforlife/user/permissions/in_school.py b/codeforlife/user/permissions/in_school.py index 1b866d21..2fde40c2 100644 --- a/codeforlife/user/permissions/in_school.py +++ b/codeforlife/user/permissions/in_school.py @@ -26,6 +26,12 @@ def __init__(self, school_id: t.Optional[int] = None): super().__init__() self.school_id = school_id + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.school_id == other.school_id + ) + def has_permission(self, request: Request, view: APIView): def in_school(school_id: int): return self.school_id is None or self.school_id == school_id diff --git a/codeforlife/user/permissions/is_independent.py b/codeforlife/user/permissions/is_independent.py index 5d0f5a8e..0ad94565 100644 --- a/codeforlife/user/permissions/is_independent.py +++ b/codeforlife/user/permissions/is_independent.py @@ -13,6 +13,9 @@ class IsIndependent(IsAuthenticated): """Request's user must be independent.""" + def __eq__(self, other): + return isinstance(other, self.__class__) + def has_permission(self, request: Request, view: APIView): user = request.user return ( diff --git a/codeforlife/user/permissions/is_student.py b/codeforlife/user/permissions/is_student.py index a4a43e9e..2f0a1f9b 100644 --- a/codeforlife/user/permissions/is_student.py +++ b/codeforlife/user/permissions/is_student.py @@ -26,6 +26,12 @@ def __init__(self, student_id: t.Optional[int] = None): super().__init__() self.student_id = student_id + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.student_id == other.student_id + ) + def has_permission(self, request: Request, view: APIView): user = request.user return ( diff --git a/codeforlife/user/permissions/is_teacher.py b/codeforlife/user/permissions/is_teacher.py index 255d9c6d..c7a8ec65 100644 --- a/codeforlife/user/permissions/is_teacher.py +++ b/codeforlife/user/permissions/is_teacher.py @@ -34,6 +34,13 @@ def __init__( self.teacher_id = teacher_id self.is_admin = is_admin + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.teacher_id == other.teacher_id + and self.is_admin == other.is_admin + ) + def has_permission(self, request: Request, view: APIView): user = request.user return ( diff --git a/codeforlife/views/model.py b/codeforlife/views/model.py index 80560a0a..5e568e9f 100644 --- a/codeforlife/views/model.py +++ b/codeforlife/views/model.py @@ -9,6 +9,7 @@ from django.db.models.query import QuerySet from rest_framework import status from rest_framework.decorators import action +from rest_framework.permissions import BasePermission from rest_framework.request import Request from rest_framework.response import Response from rest_framework.serializers import ListSerializer @@ -49,6 +50,9 @@ def get_model_class(cls) -> t.Type[AnyModel]: 0 ] + def get_permissions(self): + return t.cast(t.List[BasePermission], super().get_permissions()) + def get_serializer(self, *args, **kwargs): serializer = super().get_serializer(*args, **kwargs)