Skip to content

Commit

Permalink
cache-time
Browse files Browse the repository at this point in the history
  • Loading branch information
HansKallekleiv committed Jan 30, 2025
1 parent d050598 commit 5c70469
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 2 deletions.
3 changes: 3 additions & 0 deletions backend_py/primary/primary/auth/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from primary import config
from primary.services.utils.authenticated_user import AuthenticatedUser
from primary.middleware.add_browser_cache import no_cache


class AuthHelper:
Expand All @@ -24,6 +25,7 @@ def __init__(self) -> None:
methods=["GET"],
)

@no_cache
async def _login_route(self, request: Request, redirect_url_after_login: Optional[str] = None) -> RedirectResponse:
# print("######################### _login_route()")

Expand Down Expand Up @@ -55,6 +57,7 @@ async def _login_route(self, request: Request, redirect_url_after_login: Optiona

return RedirectResponse(flow_dict["auth_uri"])

@no_cache
async def _authorized_callback_route(self, request: Request) -> Response:
# print("######################### _authorized_callback_route()")

Expand Down
4 changes: 2 additions & 2 deletions backend_py/primary/primary/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"ssdl": [SSDL_RESOURCE_SCOPE],
}

print(f"{RESOURCE_SCOPES_DICT=}")

DEFAULT_CACHE_MAX_AGE = 3600 # 1 hour
DEFAULT_STALE_WHILE_REVALIDATE = 3600 * 24 # 24 hour
REDIS_USER_SESSION_URL = "redis://redis-user-session:6379"
REDIS_CACHE_URL = "redis://redis-cache:6379"
4 changes: 4 additions & 0 deletions backend_py/primary/primary/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from primary.auth.auth_helper import AuthHelper
from primary.auth.enforce_logged_in_middleware import EnforceLoggedInMiddleware
from primary.middleware.add_process_time_to_server_timing_middleware import AddProcessTimeToServerTimingMiddleware

from primary.middleware.add_browser_cache import AddBrowserCacheMiddleware
from primary.routers.dev.router import router as dev_router
from primary.routers.explore.router import router as explore_router
from primary.routers.general import router as general_router
Expand Down Expand Up @@ -104,6 +106,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:
# Also redirects to /login endpoint for some select paths
unprotected_paths = ["/logged_in_user", "/alive", "/openapi.json"]
paths_redirected_to_login = ["/", "/alive_protected"]

app.add_middleware(
EnforceLoggedInMiddleware,
unprotected_paths=unprotected_paths,
Expand All @@ -117,6 +120,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:

# This middleware instance measures execution time of the endpoints, including the cost of other middleware
app.add_middleware(AddProcessTimeToServerTimingMiddleware, metric_name="total")
app.add_middleware(AddBrowserCacheMiddleware)


@app.get("/")
Expand Down
92 changes: 92 additions & 0 deletions backend_py/primary/primary/middleware/add_browser_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from functools import wraps
from contextvars import ContextVar
from typing import Dict, Any, Callable, Awaitable, Union, Never

from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp, Scope, Receive, Send, Message
from primary.config import DEFAULT_CACHE_MAX_AGE, DEFAULT_STALE_WHILE_REVALIDATE

# Initialize with a factory function to ensure a new dict for each context
def get_default_context() -> Dict[str, Any]:
return {"max_age": DEFAULT_CACHE_MAX_AGE, "stale_while_revalidate": DEFAULT_STALE_WHILE_REVALIDATE}


cache_context: ContextVar[Dict[str, Any]] = ContextVar("cache_context", default=get_default_context())


def add_custom_cache_time(max_age: int, stale_while_revalidate: int = 0) -> Callable:
"""
Decorator that sets a custom browser cache time for the endpoint response.
Args:
max_age (int): The maximum age in seconds for the cache
stale_while_revalidate (int): The stale-while-revalidate time in seconds
Example:
@add_custom_cache_time(300, 600) # 5 minutes max age, 10 minutes stale-while-revalidate
async def my_endpoint():
return {"data": "some_data"}
"""

def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
context = cache_context.get()
context["max_age"] = max_age
context["stale_while_revalidate"] = stale_while_revalidate

return await func(*args, **kwargs)

return wrapper

return decorator


def no_cache(func: Callable) -> Callable:
"""
Decorator that explicitly disables browser caching for the endpoint response.
Example:
@no_cache
async def my_endpoint():
return {"data": "some_data"}
"""

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
context = cache_context.get()
context["max_age"] = 0
context["stale_while_revalidate"] = 0

return await func(*args, **kwargs)

return wrapper


class AddBrowserCacheMiddleware:
"""
Adds cache-control to the response headers
"""

def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

# Set initial context and store token
cache_context.set(get_default_context())

async def send_with_cache_header(message: Message) -> None:
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
context = cache_context.get()
cache_control_str = (
f"max-age={context['max_age']}, stale-while-revalidate={context['stale_while_revalidate']}, private"
)
headers.append("cache-control", cache_control_str)

await send(message)

await self.app(scope, receive, send_with_cache_header)
5 changes: 5 additions & 0 deletions backend_py/primary/primary/routers/explore/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from primary.services.sumo_access.case_inspector import CaseInspector
from primary.services.sumo_access.sumo_inspector import SumoInspector
from primary.services.utils.authenticated_user import AuthenticatedUser
from primary.middleware.add_browser_cache import no_cache

from . import schemas

router = APIRouter()


@router.get("/fields")
@no_cache
async def get_fields(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
) -> List[schemas.FieldInfo]:
Expand All @@ -27,6 +29,7 @@ async def get_fields(


@router.get("/cases")
@no_cache
async def get_cases(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
field_identifier: str = Query(description="Field identifier"),
Expand All @@ -43,6 +46,7 @@ async def get_cases(


@router.get("/cases/{case_uuid}/ensembles")
@no_cache
async def get_ensembles(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
case_uuid: str = Path(description="Sumo case uuid"),
Expand All @@ -55,6 +59,7 @@ async def get_ensembles(


@router.get("/cases/{case_uuid}/ensembles/{ensemble_name}")
@no_cache
async def get_ensemble_details(
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
case_uuid: str = Path(description="Sumo case uuid"),
Expand Down
4 changes: 4 additions & 0 deletions backend_py/primary/primary/routers/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from primary.auth.auth_helper import AuthHelper
from primary.services.graph_access.graph_access import GraphApiAccess
from primary.middleware.add_browser_cache import no_cache

LOGGER = logging.getLogger(__name__)

Expand All @@ -25,18 +26,21 @@ class UserInfo(BaseModel):


@router.get("/alive")
@no_cache
def get_alive() -> str:
print("entering alive route")
return f"ALIVE: Backend is alive at this time: {datetime.datetime.now()}"


@router.get("/alive_protected")
@no_cache
def get_alive_protected() -> str:
print("entering alive_protected route")
return f"ALIVE_PROTECTED: Backend is alive at this time: {datetime.datetime.now()}"


@router.get("/logged_in_user", response_model=UserInfo)
@no_cache
async def get_logged_in_user(
request: Request,
includeGraphApiInfo: bool = Query( # pylint: disable=invalid-name
Expand Down
3 changes: 3 additions & 0 deletions backend_py/primary/primary/routers/well/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from primary.services.ssdl_access.well_access import WellAccess as SsdlWellAccess


from primary.middleware.add_browser_cache import add_custom_cache_time
from . import schemas
from . import converters

Expand Down Expand Up @@ -40,6 +42,7 @@ async def get_drilled_wellbore_headers(


@router.get("/well_trajectories/")
@add_custom_cache_time(3600 * 24 * 7, 3600 * 24 * 7 * 10) # 1 week cache, 10 week stale-while-revalidate
async def get_well_trajectories(
# fmt:off
authenticated_user: AuthenticatedUser = Depends(AuthHelper.get_authenticated_user),
Expand Down

0 comments on commit 5c70469

Please sign in to comment.