From c5237c1bc539fcd01e712c23e78e2fe1660e11a3 Mon Sep 17 00:00:00 2001 From: Chris Bunney <4267911+crbunney@users.noreply.github.com> Date: Wed, 10 Jul 2024 12:26:32 +0100 Subject: [PATCH] refactor(middleware): Refactor internals of CSPMiddleware so that it's easier to extend existing logic without copy/pasting it into subclass --- csp/contrib/rate_limiting.py | 41 +++++++++--------------------- csp/middleware.py | 48 +++++++++++++++++++++++++++--------- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/csp/contrib/rate_limiting.py b/csp/contrib/rate_limiting.py index 9633a87..1b45739 100644 --- a/csp/contrib/rate_limiting.py +++ b/csp/contrib/rate_limiting.py @@ -5,8 +5,7 @@ from django.conf import settings -from csp.middleware import CSPMiddleware -from csp.utils import build_policy +from csp.middleware import CSPMiddleware, PolicyParts if TYPE_CHECKING: from django.http import HttpRequest, HttpResponseBase @@ -16,38 +15,20 @@ class RateLimitedCSPMiddleware(CSPMiddleware): """A CSP middleware that rate-limits the number of violation reports sent to report-uri by excluding it from some requests.""" - def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str: - config = getattr(response, "_csp_config", None) - update = getattr(response, "_csp_update", None) - replace = getattr(response, "_csp_replace", {}) - nonce = getattr(request, "_csp_nonce", None) - - policy = getattr(settings, "CONTENT_SECURITY_POLICY", None) - - if policy is None: - return "" - - report_percentage = policy.get("REPORT_PERCENTAGE", 100) - include_report_uri = random.randint(0, 100) < report_percentage - if not include_report_uri: - replace["report-uri"] = None - - return build_policy(config=config, update=update, replace=replace, nonce=nonce) - - def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str: - config = getattr(response, "_csp_config_ro", None) - update = getattr(response, "_csp_update_ro", None) - replace = getattr(response, "_csp_replace_ro", {}) - nonce = getattr(request, "_csp_nonce", None) - - policy = getattr(settings, "CONTENT_SECURITY_POLICY_REPORT_ONLY", None) + def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts: + policy_parts = super().get_policy_parts(request, response, report_only) + csp_setting_name = "CONTENT_SECURITY_POLICY_REPORT_ONLY" if report_only else "CONTENT_SECURITY_POLICY" + policy = getattr(settings, csp_setting_name, None) if policy is None: - return "" + return policy_parts report_percentage = policy.get("REPORT_PERCENTAGE", 100) include_report_uri = random.randint(0, 100) < report_percentage if not include_report_uri: - replace["report-uri"] = None + if policy_parts.replace is None: + policy_parts.replace = {"report-uri": None} + else: + policy_parts.replace["report-uri"] = None - return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True) + return policy_parts diff --git a/csp/middleware.py b/csp/middleware.py index 754c80d..bda0072 100644 --- a/csp/middleware.py +++ b/csp/middleware.py @@ -3,6 +3,8 @@ import base64 import http.client as http_client import os +import warnings +from dataclasses import asdict, dataclass from functools import partial from typing import TYPE_CHECKING @@ -11,12 +13,21 @@ from django.utils.functional import SimpleLazyObject from csp.constants import HEADER, HEADER_REPORT_ONLY -from csp.utils import build_policy +from csp.utils import _DIRECTIVES, build_policy if TYPE_CHECKING: from django.http import HttpRequest, HttpResponseBase +@dataclass +class PolicyParts: + # A dataclass is used rather than a namedtuple so that the attributes are mutable + config: _DIRECTIVES = None + update: _DIRECTIVES = None + replace: _DIRECTIVES = None + nonce: str | None = None + + class CSPMiddleware(MiddlewareMixin): """ Implements the Content-Security-Policy response header, which @@ -25,6 +36,7 @@ class CSPMiddleware(MiddlewareMixin): See http://www.w3.org/TR/CSP/ + Can be customised by subclassing and extending the get_policy_parts method. """ def _make_nonce(self, request: HttpRequest) -> str: @@ -49,7 +61,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) -> if response.status_code in exempted_debug_codes and settings.DEBUG: return response - csp = self.build_policy(request, response) + policy_parts = self.get_policy_parts(request=request, response=response) + csp = build_policy(**asdict(policy_parts)) if csp: # Only set header if not already set and not an excluded prefix and not exempted. is_not_exempt = getattr(response, "_csp_exempt", False) is False @@ -60,7 +73,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) -> if no_header and is_not_exempt and is_not_excluded: response[HEADER] = csp - csp_ro = self.build_policy_ro(request, response) + policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True) + csp_ro = build_policy(**asdict(policy_parts_ro), report_only=True) if csp_ro: # Only set header if not already set and not an excluded prefix and not exempted. is_not_exempt = getattr(response, "_csp_exempt_ro", False) is False @@ -74,15 +88,25 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) -> return response def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str: - config = getattr(response, "_csp_config", None) - update = getattr(response, "_csp_update", None) - replace = getattr(response, "_csp_replace", None) - nonce = getattr(request, "_csp_nonce", None) - return build_policy(config=config, update=update, replace=replace, nonce=nonce) + warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning) + policy_parts = self.get_policy_parts(request=request, response=response, report_only=False) + return build_policy(**asdict(policy_parts)) def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str: - config = getattr(response, "_csp_config_ro", None) - update = getattr(response, "_csp_update_ro", None) - replace = getattr(response, "_csp_replace_ro", None) + warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning) + policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True) + return build_policy(**asdict(policy_parts_ro), report_only=True) + + def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts: + if report_only: + config = getattr(response, "_csp_config_ro", None) + update = getattr(response, "_csp_update_ro", None) + replace = getattr(response, "_csp_replace_ro", None) + else: + config = getattr(response, "_csp_config", None) + update = getattr(response, "_csp_update", None) + replace = getattr(response, "_csp_replace", None) + nonce = getattr(request, "_csp_nonce", None) - return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True) + + return PolicyParts(config, update, replace, nonce)