Skip to content

Commit

Permalink
add rate limit (#308)
Browse files Browse the repository at this point in the history
* add rate limit

* i sort

* reset limit to 3600
  • Loading branch information
peterdudfield authored Nov 1, 2023
1 parent 829cfb4 commit 6a86bfe
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 3 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ fastapi-auth0==0.3.0
httpx
structlog
sentry-sdk
slowapi
8 changes: 7 additions & 1 deletion src/gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
LocationWithGSPYields,
OneDatetimeManyForecastValues,
)
from utils import format_datetime
from utils import N_CALLS_PER_HOUR, format_datetime, limiter

GSP_TOTAL = 317

Expand All @@ -44,6 +44,7 @@
dependencies=[Depends(get_auth_implicit_scheme())],
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_all_available_forecasts(
request: Request,
historic: Optional[bool] = True,
Expand Down Expand Up @@ -111,6 +112,7 @@ def get_all_available_forecasts(
responses={status.HTTP_204_NO_CONTENT: {"model": None}},
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_forecasts_for_a_specific_gsp_old_route(
request: Request,
gsp_id: int,
Expand All @@ -135,6 +137,7 @@ def get_forecasts_for_a_specific_gsp_old_route(
responses={status.HTTP_204_NO_CONTENT: {"model": None}},
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_forecasts_for_a_specific_gsp(
request: Request,
gsp_id: int,
Expand Down Expand Up @@ -203,6 +206,7 @@ def get_forecasts_for_a_specific_gsp(
dependencies=[Depends(get_auth_implicit_scheme())],
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_truths_for_all_gsps(
request: Request,
regime: Optional[str] = None,
Expand Down Expand Up @@ -257,6 +261,7 @@ def get_truths_for_all_gsps(
responses={status.HTTP_204_NO_CONTENT: {"model": None}},
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_truths_for_a_specific_gsp_old_route(
request: Request,
gsp_id: int,
Expand All @@ -282,6 +287,7 @@ def get_truths_for_a_specific_gsp_old_route(
responses={status.HTTP_204_NO_CONTENT: {"model": None}},
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_truths_for_a_specific_gsp(
request: Request,
gsp_id: int,
Expand Down
7 changes: 6 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

from gsp import router as gsp_router
from national import router as national_router
from redoc_theme import get_redoc_html_with_theme
from status import router as status_router
from system import router as system_router
from utils import traces_sampler
from utils import limiter, traces_sampler

# flake8: noqa E501

Expand Down Expand Up @@ -183,6 +185,9 @@
allow_headers=["*"],
)

app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)


@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
Expand Down
4 changes: 3 additions & 1 deletion src/national.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_truth_values_for_a_specific_gsp_from_database,
)
from pydantic_models import NationalForecast, NationalForecastValue, NationalYield
from utils import filter_forecast_values, format_datetime, format_plevels
from utils import N_CALLS_PER_HOUR, filter_forecast_values, format_datetime, format_plevels, limiter

logger = structlog.stdlib.get_logger()

Expand All @@ -33,6 +33,7 @@
dependencies=[Depends(get_auth_implicit_scheme())],
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_national_forecast(
request: Request,
session: Session = Depends(get_session),
Expand Down Expand Up @@ -156,6 +157,7 @@ def get_national_forecast(
dependencies=[Depends(get_auth_implicit_scheme())],
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_national_pvlive(
request: Request,
regime: Optional[str] = None,
Expand Down
3 changes: 3 additions & 0 deletions src/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from cache import cache_response
from database import get_latest_status_from_database, get_session, save_api_call_to_db
from utils import N_CALLS_PER_HOUR, limiter

logger = structlog.stdlib.get_logger()

Expand All @@ -20,6 +21,7 @@

@router.get("/status", response_model=Status)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_status(request: Request, session: Session = Depends(get_session)) -> Status:
"""### Get status for the database and forecasts
Expand All @@ -32,6 +34,7 @@ def get_status(request: Request, session: Session = Depends(get_session)) -> Sta


@router.get("/check_last_forecast_run", include_in_schema=False)
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def check_last_forecast(request: Request, session: Session = Depends(get_session)) -> datetime:
"""Check to that a forecast has run with in the last 2 hours"""

Expand Down
3 changes: 3 additions & 0 deletions src/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from auth_utils import get_auth_implicit_scheme, get_user
from cache import cache_response
from database import get_gsp_system, get_session
from utils import N_CALLS_PER_HOUR, limiter

# flake8: noqa: E501
logger = structlog.stdlib.get_logger()
Expand Down Expand Up @@ -43,6 +44,7 @@ def get_gsp_boundaries_from_eso_wgs84() -> gpd.GeoDataFrame:
dependencies=[Depends(get_auth_implicit_scheme())],
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_gsp_boundaries(
request: Request,
session: Session = Depends(get_session),
Expand Down Expand Up @@ -75,6 +77,7 @@ def get_gsp_boundaries(
dependencies=[Depends(get_auth_implicit_scheme())],
)
@cache_response
@limiter.limit(f"{N_CALLS_PER_HOUR}/hour")
def get_system_details(
request: Request,
session: Session = Depends(get_session),
Expand Down
4 changes: 4 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import structlog
from nowcasting_datamodel.models import Forecast
from pytz import timezone
from slowapi import Limiter
from slowapi.util import get_remote_address

from pydantic_models import NationalForecastValue

logger = structlog.stdlib.get_logger()

europe_london_tz = timezone("Europe/London")
utc = timezone("UTC")
limiter = Limiter(key_func=get_remote_address)
N_CALLS_PER_HOUR = 3600


def floor_30_minutes_dt(dt):
Expand Down

0 comments on commit 6a86bfe

Please sign in to comment.