From 1cd21b2bd16a6164f2cb9733abeaff83272a2282 Mon Sep 17 00:00:00 2001 From: Nikita Melkozerov Date: Thu, 5 Sep 2024 09:45:04 +0000 Subject: [PATCH] limit by IP too --- fixbackend/auth/rate_limiter.py | 8 ++++---- fixbackend/auth/router.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/fixbackend/auth/rate_limiter.py b/fixbackend/auth/rate_limiter.py index 64ef2f66..7fd03771 100644 --- a/fixbackend/auth/rate_limiter.py +++ b/fixbackend/auth/rate_limiter.py @@ -26,14 +26,14 @@ def _new_tokens(self, tokens: int, ttl: int) -> float: time_passed = self.window.total_seconds() - ttl return min(self.limit, tokens + time_passed * self.refill_rate) - async def check(self, username: str) -> bool: + async def check(self, key: str) -> bool: [tokens, ttl] = await self.redis.eval( """ local ttl = redis.call('TTL', KEYS[1]) local tokens = redis.call('GET', KEYS[1]) return {tokens, ttl} """, 1, - f"rate_limit:{username}", + f"rate_limit:{key}", ) # type: ignore if tokens is None: @@ -43,7 +43,7 @@ async def check(self, username: str) -> bool: new_tokens = self._new_tokens(tokens, ttl) return new_tokens >= 1 - async def consume(self, username: str) -> bool: + async def consume(self, key: str) -> bool: allowed: int = await self.redis.eval( dedent( """ @@ -80,7 +80,7 @@ async def consume(self, username: str) -> bool: """ ), 4, - f"rate_limit:{username}", + f"rate_limit:{key}", self.limit, int(self.window.total_seconds()), utc().timestamp(), diff --git a/fixbackend/auth/router.py b/fixbackend/auth/router.py index 7078cbd0..dee18760 100644 --- a/fixbackend/auth/router.py +++ b/fixbackend/auth/router.py @@ -187,7 +187,11 @@ async def login( user_manager: UserManager = Depends(get_user_manager), strategy: FixJWTStrategy = Depends(auth_backend.get_strategy), ) -> Response: - allowed = await login_rate_limiter.check(credentials.username) + rate_limiter_key = credentials.username + if request.client: + rate_limiter_key = f"{rate_limiter_key}:{request.client.host}" + + allowed = await login_rate_limiter.check(rate_limiter_key) if not allowed: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, @@ -196,7 +200,7 @@ async def login( user = await user_manager.authenticate(credentials) if user is None: - await login_rate_limiter.consume(credentials.username) + await login_rate_limiter.consume(rate_limiter_key) if user is None or not user.is_active: metric = FailedLoginAttempts.labels(user_id=None)