Skip to content

Commit

Permalink
Try adding in security middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
jonaraphael committed Oct 18, 2023
1 parent 03e4fad commit fb07d8d
Showing 1 changed file with 80 additions and 1 deletion.
81 changes: 80 additions & 1 deletion cerulean_cloud/cloud_run_tipg/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
"""
import logging
import os
from typing import Any, List, Optional

import asyncpg
import jinja2
import pydantic
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from mangum import Mangum
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.templating import Jinja2Templates
Expand All @@ -33,6 +35,80 @@
db_settings = DatabaseSettings()


def extract_table_from_request(request: Request) -> Optional[str]:
"""
Extract the collection ID (table name) from the URL path of an incoming HTTP request.
Args:
request (Request): The incoming FastAPI request object.
Returns:
Optional[str]: The collection ID if present in the URL path; otherwise, None.
Example:
Given a request object with URL 'http://localhost:8000/collections/my_table/items',
this function will return 'my_table'.
"""
path_parts = request.url.split("/")

# Check if the request is related to collections
if "collections" in path_parts:
idx = path_parts.index("collections")

# The 'collectionId' should be the segment immediately following 'collections'
if len(path_parts) > idx + 1:
return path_parts[idx + 1]

# Return None if 'collectionId' is not found
return None


def get_env_list(env_var: str, default: List[str] = None) -> List[str]:
"""Get a list from an environment variable. Assumes values are comma-separated."""
raw_value = os.environ.get(env_var)
if raw_value is None:
return default if default is not None else []
return raw_value.split(",")


class AccessControlMiddleware(BaseHTTPMiddleware):
"""
Middleware to handle access control based on the collection ID and an API key.
This middleware calls `extract_table_from_request` to determine the collection ID
from the request. It then checks if this collection is in the list of excluded collections.
If so, it verifies the API key in the request headers. If the API key is invalid,
it raises an HTTP 403 exception.
"""

async def dispatch(self, request: Request, call_next):
"""
The dispatch method to handle the request and execute the middleware logic.
Args:
request (Request): The incoming FastAPI request object.
call_next: The next middleware or endpoint in the processing pipeline.
Raises:
HTTPException: If the collection is restricted and an invalid API key is provided.
Returns:
Response: The outgoing FastAPI response object.
"""
table = extract_table_from_request(request)
excluded_collections = get_env_list("TIPG_DB_EXCLUDE_TABLES") + get_env_list(
"TIPG_DB_EXCLUDE_FUNCTIONS"
)
if table in excluded_collections:
api_key = request.headers.get("X-API-Key")
if not api_key == "XXX_SECRET_API_KEY":
raise HTTPException(
status_code=403, detail="Access to table restricted"
)
response = await call_next(request)
return response


class PostgresSettings(pydantic.BaseSettings):
"""Postgres-specific API settings.
Expand Down Expand Up @@ -102,6 +178,9 @@ class Config:
allow_headers=["*"],
)

# Custom API key checking for restricted access
app.add_middleware(AccessControlMiddleware)

app.add_middleware(CacheControlMiddleware, cachecontrol=settings.cachecontrol)
app.add_middleware(CompressionMiddleware)
add_exception_handlers(app, DEFAULT_STATUS_CODES)
Expand Down

0 comments on commit fb07d8d

Please sign in to comment.