diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index cd9ee69c520..488d536f25b 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -24,7 +24,7 @@ from asyncio.log import logger from datetime import datetime, timedelta, timezone from genericpath import isfile -from typing import Any, List +from typing import Any, List, Set from anyio import to_thread from fastapi import FastAPI, HTTPException, Request @@ -32,9 +32,11 @@ from fastapi.responses import ORJSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import ( + BaseHTTPMiddleware, + RequestResponseEndpoint, +) from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request from starlette.responses import FileResponse, JSONResponse, Response from starlette.types import ASGIApp @@ -147,24 +149,61 @@ def validation_exception_handler( class RequestBodyLimit(BaseHTTPMiddleware): + """Limits the size of the request body.""" + def __init__(self, app: ASGIApp, max_bytes: int) -> None: + """Limits the size of the request body. + + Args: + app: The FastAPI app. + max_bytes: The maximum size of the request body. + """ super().__init__(app) self.max_bytes = max_bytes - async def dispatch(self, request: Request, call_next): - if request.headers.get("content-length"): - content_length = int(request.headers.get("content-length")) - if content_length > self.max_bytes: + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + """Limits the size of the request body. + + Args: + request: The incoming request. + call_next: The next function to be called. + + Returns: + The response to the request. + """ + if content_length := request.headers.get("content-length"): + if int(content_length) > self.max_bytes: return Response(status_code=413) # Request Entity Too Large return await call_next(request) class RestrictFileUploadsMiddleware(BaseHTTPMiddleware): - def __init__(self, app: FastAPI, allowed_paths: set[str]): + """Restrict file uploads to certain paths.""" + + def __init__(self, app: FastAPI, allowed_paths: Set[str]): + """Restrict file uploads to certain paths. + + Args: + app: The FastAPI app. + allowed_paths: The allowed paths. + """ super().__init__(app) self.allowed_paths = allowed_paths - async def dispatch(self, request: Request, call_next): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + """Restrict file uploads to certain paths. + + Args: + request: The incoming request. + call_next: The next function to be called. + + Returns: + The response to the request. + """ if request.method == "POST": content_type = request.headers.get("content-type", "") if ( @@ -180,7 +219,7 @@ async def dispatch(self, request: Request, call_next): return await call_next(request) -ALLOWED_FOR_FILE_UPLOAD = set() +ALLOWED_FOR_FILE_UPLOAD: Set[str] = set() app.add_middleware( CORSMiddleware,