Skip to content

Commit

Permalink
First caching implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ab-smith committed Sep 21, 2024
1 parent fee4878 commit af65dd6
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
BUILD,
VERSION,
)

from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from django.views.decorators.vary import vary_on_cookie, vary_on_headers
from django.core.cache import cache

from django.contrib.auth.models import Permission
from django.core.files.storage import default_storage
from django.db import models
Expand Down Expand Up @@ -57,6 +63,10 @@

User = get_user_model()

SHORT_CACHE_TTL = 2 # mn
MED_CACHE_TTL = 5 # mn
LONG_CACHE_TTL = 60 # mn


class BaseModelViewSet(viewsets.ModelViewSet):
filter_backends = [
Expand Down Expand Up @@ -133,6 +143,10 @@ def partial_update(self, request: Request, *args, **kwargs) -> Response:
self._process_request_data(request)
return super().partial_update(request, *args, **kwargs)

def destroy(self, request: Request, *args, **kwargs) -> Response:
self._process_request_data(request)
return super().destroy(request, *args, **kwargs)

class Meta:
abstract = True

Expand Down Expand Up @@ -248,6 +262,12 @@ class ThreatViewSet(BaseModelViewSet):
filterset_fields = ["folder", "risk_scenarios"]
search_fields = ["name", "provider", "description"]

def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

def retrieve(self, request, *args, **kwargs):
return super().retrieve(request, *args, **kwargs)

@action(detail=False, name="Get threats count")
def threats_count(self, request):
return Response({"results": threats_count_per_name(request.user)})
Expand Down Expand Up @@ -276,10 +296,12 @@ class ReferenceControlViewSet(BaseModelViewSet):
filterset_fields = ["folder", "category", "csf_function"]
search_fields = ["name", "description", "provider"]

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get category choices")
def category(self, request):
return Response(dict(ReferenceControl.CATEGORY))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get function choices")
def csf_function(self, request):
return Response(dict(ReferenceControl.CSF_FUNCTION))
Expand Down Expand Up @@ -340,6 +362,7 @@ def per_status(self, request):
data = assessment_per_status(request.user, RiskAssessment)
return Response({"results": data})

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get status choices")
def status(self, request):
return Response(dict(RiskAssessment.Status.choices))
Expand Down Expand Up @@ -624,18 +647,22 @@ class AppliedControlViewSet(BaseModelViewSet):
]
search_fields = ["name", "description", "risk_scenarios", "requirement_assessments"]

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get status choices")
def status(self, request):
return Response(dict(AppliedControl.Status.choices))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get category choices")
def category(self, request):
return Response(dict(AppliedControl.CATEGORY))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get csf_function choices")
def csf_function(self, request):
return Response(dict(AppliedControl.CSF_FUNCTION))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get effort choices")
def effort(self, request):
return Response(dict(AppliedControl.EFFORT))
Expand Down Expand Up @@ -761,6 +788,7 @@ class PolicyViewSet(AppliedControlViewSet):
]
search_fields = ["name", "description", "risk_scenarios", "requirement_assessments"]

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get csf_function choices")
def csf_function(self, request):
return Response(dict(AppliedControl.CSF_FUNCTION))
Expand All @@ -782,14 +810,17 @@ class RiskScenarioViewSet(BaseModelViewSet):
"applied_controls",
]

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get treatment choices")
def treatment(self, request):
return Response(dict(RiskScenario.TREATMENT_OPTIONS))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get qualification choices")
def qualifications(self, request):
return Response(dict(RiskScenario.QUALIFICATIONS))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=True, name="Get probability choices")
def probability(self, request, pk):
undefined = {-1: "--"}
Expand All @@ -802,6 +833,7 @@ def probability(self, request, pk):
choices = undefined | _choices
return Response(choices)

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=True, name="Get impact choices")
def impact(self, request, pk):
undefined = dict([(-1, "--")])
Expand All @@ -814,6 +846,7 @@ def impact(self, request, pk):
choices = undefined | _choices
return Response(choices)

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=True, name="Get strength of knowledge choices")
def strength_of_knowledge(self, request, pk):
undefined = {-1: RiskScenario.DEFAULT_SOK_OPTIONS[-1]}
Expand Down Expand Up @@ -1104,7 +1137,16 @@ def perform_create(self, serializer):
is_recursive=True,
)
ra4.perimeter_folders.add(folder)
# Clear the cache after a new folder is created - purposely clearing everything

def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

def retrieve(self, request, *args, **kwargs):
return super().retrieve(request, *args, **kwargs)

@method_decorator(cache_page(60 * MED_CACHE_TTL))
@method_decorator(vary_on_cookie)
@action(detail=False, methods=["get"])
def org_tree(self, request):
"""
Expand Down Expand Up @@ -1150,6 +1192,8 @@ def org_tree(self, request):
return Response(tree)


@cache_page(60 * SHORT_CACHE_TTL)
@vary_on_cookie
@api_view(["GET"])
@permission_classes([permissions.IsAuthenticated])
def get_counters_view(request):
Expand All @@ -1159,6 +1203,8 @@ def get_counters_view(request):
return Response({"results": get_counters(request.user)})


@cache_page(60 * SHORT_CACHE_TTL)
@vary_on_cookie
@api_view(["GET"])
@permission_classes([permissions.IsAuthenticated])
def get_metrics_view(request):
Expand All @@ -1171,6 +1217,8 @@ def get_metrics_view(request):
# TODO: Add all the proper docstrings for the following list of functions


@cache_page(60 * SHORT_CACHE_TTL)
@vary_on_cookie
@api_view(["GET"])
@permission_classes([permissions.IsAuthenticated])
def get_agg_data(request):
Expand Down Expand Up @@ -1254,6 +1302,8 @@ class FrameworkViewSet(BaseModelViewSet):
search_fields = ["name", "description"]
ordering_fields = ["name", "description"]

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@method_decorator(vary_on_cookie)
@action(detail=False, methods=["get"])
def names(self, request):
uuid_list = request.query_params.getlist("id[]", [])
Expand All @@ -1266,6 +1316,12 @@ def names(self, request):
}
)

def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

def retrieve(self, request, *args, **kwargs):
return super().retrieve(request, *args, **kwargs)

@action(detail=True, methods=["get"])
def tree(self, request, pk):
_framework = Framework.objects.get(id=pk)
Expand Down Expand Up @@ -1322,6 +1378,9 @@ class RequirementNodeViewSet(BaseModelViewSet):
filterset_fields = ["framework", "urn"]
search_fields = ["name", "description"]

def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class RequirementViewSet(BaseModelViewSet):
"""
Expand All @@ -1332,6 +1391,9 @@ class RequirementViewSet(BaseModelViewSet):
filterset_fields = ["framework", "urn"]
search_fields = ["name"]

def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class EvidenceViewSet(BaseModelViewSet):
"""
Expand Down Expand Up @@ -1415,10 +1477,13 @@ class ComplianceAssessmentViewSet(BaseModelViewSet):
search_fields = ["name", "description"]
ordering_fields = ["name", "description"]

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get status choices")
def status(self, request):
return Response(dict(ComplianceAssessment.Status.choices))

@method_decorator(cache_page(60 * MED_CACHE_TTL))
@method_decorator(vary_on_cookie)
@action(detail=True, name="Get implementation group choices")
def selected_implementation_groups(self, request, pk):
compliance_assessment = self.get_object()
Expand Down Expand Up @@ -1648,6 +1713,8 @@ def quality_check(self, request):
]
return Response({"results": res})

@method_decorator(cache_page(60 * SHORT_CACHE_TTL))
@method_decorator(vary_on_cookie)
@action(detail=True, methods=["get"])
def global_score(self, request, pk):
"""Returns the global score of the compliance assessment"""
Expand Down Expand Up @@ -1745,6 +1812,8 @@ def export(self, request, pk):
else:
return Response({"error": "Permission denied"})

@method_decorator(cache_page(60 * SHORT_CACHE_TTL))
@method_decorator(vary_on_cookie)
@action(detail=True, methods=["get"])
def donut_data(self, request, pk):
compliance_assessment = ComplianceAssessment.objects.get(id=pk)
Expand Down Expand Up @@ -1825,10 +1894,12 @@ def to_review(self, request):

return Response({"results": measures})

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get status choices")
def status(self, request):
return Response(dict(RequirementAssessment.Status.choices))

@method_decorator(cache_page(60 * LONG_CACHE_TTL))
@action(detail=False, name="Get result choices")
def result(self, request):
return Response(dict(RequirementAssessment.Result.choices))
Expand Down

0 comments on commit af65dd6

Please sign in to comment.