Skip to content

Commit

Permalink
Allow private networks
Browse files Browse the repository at this point in the history
  • Loading branch information
zodecky authored Oct 21, 2024
1 parent 1bfb819 commit d89c066
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

from typing import Any, Iterable, Optional, Union

from .request import Request
from .response import Response


class CORSMiddleware(object):
"""CORS Middleware.
This middleware provides a simple out-of-the box CORS policy, including handling
of preflighted requests from the browser.
See also:
* https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
* https://www.w3.org/TR/cors/#resource-processing-model
Keyword Arguments:
allow_origins (Union[str, Iterable[str]]): List of origins to allow (case
sensitive). The string ``'*'`` acts as a wildcard, matching every origin.
(default ``'*'``).
expose_headers (Optional[Union[str, Iterable[str]]]): List of additional
response headers to expose via the ``Access-Control-Expose-Headers``
header. These headers are in addition to the CORS-safelisted ones:
``Cache-Control``, ``Content-Language``, ``Content-Length``,
``Content-Type``, ``Expires``, ``Last-Modified``, ``Pragma``.
(default ``None``).
See also:
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
allow_credentials (Optional[Union[str, Iterable[str]]]): List of origins
(case sensitive) for which to allow credentials via the
``Access-Control-Allow-Credentials`` header.
The string ``'*'`` acts as a wildcard, matching every allowed origin,
while ``None`` disallows all origins. This parameter takes effect only
if the origin is allowed by the ``allow_origins`` argument.
(Default ``None``).
"""

def __init__(
self,
allow_origins: Union[str, Iterable[str]] = '*',
expose_headers: Optional[Union[str, Iterable[str]]] = None,
allow_credentials: Optional[Union[str, Iterable[str]]] = None,
allow_private_network: bool = False,
):

if allow_origins == '*':
self.allow_origins = allow_origins
else:
if isinstance(allow_origins, str):
allow_origins = [allow_origins]
self.allow_origins = frozenset(allow_origins)
if '*' in self.allow_origins:
raise ValueError(
'The wildcard string "*" may only be passed to allow_origins as a '
'string literal, not inside an iterable.'
)

if expose_headers is not None and not isinstance(expose_headers, str):
expose_headers = ', '.join(expose_headers)
self.expose_headers = expose_headers

if allow_credentials is None:
allow_credentials = frozenset()
elif allow_credentials != '*':
if isinstance(allow_credentials, str):
allow_credentials = [allow_credentials]
allow_credentials = frozenset(allow_credentials)
if '*' in allow_credentials:
raise ValueError(
'The wildcard string "*" may only be passed to allow_credentials '
'as a string literal, not inside an iterable.'
)
self.allow_credentials = allow_credentials

self.allow_private_network = allow_private_network

def process_response(
self, req: Request, resp: Response, resource: object, req_succeeded: bool
) -> None:
"""Implement the CORS policy for all routes.
This middleware provides a simple out-of-the box CORS policy,
including handling of preflighted requests from the browser.
See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
See also: https://www.w3.org/TR/cors/#resource-processing-model
"""

origin = req.get_header('Origin')
if origin is None:
return

if self.allow_origins != '*' and origin not in self.allow_origins:
return

if resp.get_header('Access-Control-Allow-Origin') is None:
set_origin = '*' if self.allow_origins == '*' else origin
if self.allow_credentials == '*' or origin in self.allow_credentials:
set_origin = origin
resp.set_header('Access-Control-Allow-Credentials', 'true')
resp.set_header('Access-Control-Allow-Origin', set_origin)

if self.expose_headers:
resp.set_header('Access-Control-Expose-Headers', self.expose_headers)

if (
req_succeeded
and req.method == 'OPTIONS'
and req.get_header('Access-Control-Request-Method')
):
# NOTE(kgriffs): This is a CORS preflight request. Patch the
# response accordingly.

allow = resp.get_header('Allow')
resp.delete_header('Allow')

allow_headers = req.get_header(
'Access-Control-Request-Headers', default='*'
)

resp.set_header('Access-Control-Allow-Methods', allow)
resp.set_header('Access-Control-Allow-Headers', allow_headers)
resp.set_header('Access-Control-Max-Age', '86400') # 24 hours

if self.allow_private_network and req.get_header('Access-Control-Request-Private-Network') == 'true':
resp.set_header('Access-Control-Allow-Private-Network', 'true')


async def process_response_async(self, *args: Any) -> None:
self.process_response(*args)

0 comments on commit d89c066

Please sign in to comment.