diff --git a/cerulean_cloud/cloud_run_tipg/handler.py b/cerulean_cloud/cloud_run_tipg/handler.py index 1073c30a..4012aead 100644 --- a/cerulean_cloud/cloud_run_tipg/handler.py +++ b/cerulean_cloud/cloud_run_tipg/handler.py @@ -10,13 +10,15 @@ """ import logging +import os from typing import Any, List, Optional import asyncpg import jinja2 import pydantic -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from mangum import Mangum +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.templating import Jinja2Templates @@ -33,6 +35,80 @@ db_settings = DatabaseSettings() +def extract_table_from_request(request: Request) -> Optional[str]: + """ + Extract the collection ID (table name) from the URL path of an incoming HTTP request. + + Args: + request (Request): The incoming FastAPI request object. + + Returns: + Optional[str]: The collection ID if present in the URL path; otherwise, None. + + Example: + Given a request object with URL 'http://localhost:8000/collections/my_table/items', + this function will return 'my_table'. + """ + path_parts = request.url.split("/") + + # Check if the request is related to collections + if "collections" in path_parts: + idx = path_parts.index("collections") + + # The 'collectionId' should be the segment immediately following 'collections' + if len(path_parts) > idx + 1: + return path_parts[idx + 1] + + # Return None if 'collectionId' is not found + return None + + +def get_env_list(env_var: str, default: List[str] = None) -> List[str]: + """Get a list from an environment variable. Assumes values are comma-separated.""" + raw_value = os.environ.get(env_var) + if raw_value is None: + return default if default is not None else [] + return raw_value.split(",") + + +class AccessControlMiddleware(BaseHTTPMiddleware): + """ + Middleware to handle access control based on the collection ID and an API key. + + This middleware calls `extract_table_from_request` to determine the collection ID + from the request. It then checks if this collection is in the list of excluded collections. + If so, it verifies the API key in the request headers. If the API key is invalid, + it raises an HTTP 403 exception. + """ + + async def dispatch(self, request: Request, call_next): + """ + The dispatch method to handle the request and execute the middleware logic. + + Args: + request (Request): The incoming FastAPI request object. + call_next: The next middleware or endpoint in the processing pipeline. + + Raises: + HTTPException: If the collection is restricted and an invalid API key is provided. + + Returns: + Response: The outgoing FastAPI response object. + """ + table = extract_table_from_request(request) + excluded_collections = get_env_list("TIPG_DB_EXCLUDE_TABLES") + get_env_list( + "TIPG_DB_EXCLUDE_FUNCTIONS" + ) + if table in excluded_collections: + api_key = request.headers.get("X-API-Key") + if not api_key == "XXX_SECRET_API_KEY": + raise HTTPException( + status_code=403, detail="Access to table restricted" + ) + response = await call_next(request) + return response + + class PostgresSettings(pydantic.BaseSettings): """Postgres-specific API settings. @@ -102,6 +178,9 @@ class Config: allow_headers=["*"], ) +# Custom API key checking for restricted access +app.add_middleware(AccessControlMiddleware) + app.add_middleware(CacheControlMiddleware, cachecontrol=settings.cachecontrol) app.add_middleware(CompressionMiddleware) add_exception_handlers(app, DEFAULT_STATUS_CODES)