Skip to content

Commit

Permalink
Rate limiting: decrement on auth check (#620)
Browse files Browse the repository at this point in the history
  • Loading branch information
meln1k authored Sep 12, 2024
1 parent 960cb2f commit d0ed477
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 35 deletions.
21 changes: 0 additions & 21 deletions fixbackend/auth/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,6 @@ def __init__(
self.window = window
self.refill_rate = limit / window.total_seconds()

def _new_tokens(self, tokens: int, ttl: int) -> float:
time_passed = self.window.total_seconds() - ttl
return min(self.limit, tokens + time_passed * self.refill_rate)

async def check(self, key: str) -> bool:
[tokens, ttl] = await self.redis.eval(
""" local ttl = redis.call('TTL', KEYS[1])
local tokens = redis.call('GET', KEYS[1])
return {tokens, ttl}
""",
1,
f"rate_limit:{key}",
) # type: ignore

if tokens is None:
return True
tokens = int(tokens)
ttl = int(ttl)
new_tokens = self._new_tokens(tokens, ttl)
return new_tokens >= 1

async def consume(self, key: str) -> bool:
allowed: int = await self.redis.eval(
dedent(
Expand Down
7 changes: 2 additions & 5 deletions fixbackend/auth/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def auth_router(
) -> APIRouter:
router = APIRouter()

login_rate_limiter = LoginRateLimiter(redis, limit=4, window=timedelta(minutes=1))
login_rate_limiter = LoginRateLimiter(redis, limit=config.auth_rate_limit_per_minute, window=timedelta(minutes=1))

auth_backend = get_auth_backend(config)

Expand Down Expand Up @@ -191,17 +191,14 @@ async def login(
if request.client:
rate_limiter_key = f"{rate_limiter_key}:{request.client.host}"

allowed = await login_rate_limiter.check(rate_limiter_key)
allowed = await login_rate_limiter.consume(rate_limiter_key)
if not allowed:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many login attempts, please try again in 15 seconds",
)
user = await user_manager.authenticate(credentials)

if user is None:
await login_rate_limiter.consume(rate_limiter_key)

if user is None or not user.is_active:
metric = FailedLoginAttempts.labels(user_id=None)
try:
Expand Down
2 changes: 2 additions & 0 deletions fixbackend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class Config(BaseSettings):
azure_client_secret: str
account_failed_resource_count: int
degraded_accounts_ping_interval_hours: int
auth_rate_limit_per_minute: int

def frontend_cdn_origin(self) -> str:
return f"{self.cdn_endpoint}/{self.cdn_bucket}/{self.fixui_sha}"
Expand Down Expand Up @@ -209,6 +210,7 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> Namespace:
parser.add_argument("--azure-client-secret", default=os.environ.get("AZURE_APP_CLIENT_SECRET", ""))
parser.add_argument("--account-failed-resource-count", default=1)
parser.add_argument("--degraded-accounts-ping-interval-hours", default=24)
parser.add_argument("--auth-rate-limit-per-minute", default=4)
return parser.parse_known_args(argv if argv is not None else sys.argv[1:])[0]


Expand Down
9 changes: 0 additions & 9 deletions tests/fixbackend/auth/rate_limiter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,17 @@ def rate_limiter(redis: Redis) -> LoginRateLimiter:
return LoginRateLimiter(redis=redis, limit=5, window=timedelta(seconds=1))


@pytest.mark.asyncio
async def test_check(rate_limiter: LoginRateLimiter) -> None:
assert await rate_limiter.check("user") is True


@pytest.mark.asyncio
async def test_consume(rate_limiter: LoginRateLimiter) -> None:
assert await rate_limiter.consume("user") is True
assert await rate_limiter.check("user") is True


@pytest.mark.asyncio
async def test_consume_exceed_limit(rate_limiter: LoginRateLimiter) -> None:
# Ensure the bucket is empty initially
for _ in range(5):
assert await rate_limiter.check("user") is True
assert await rate_limiter.consume("user") is True
assert await rate_limiter.consume("user") is False
assert await rate_limiter.check("user") is False
# Wait for the window to expire
await asyncio.sleep(1)
assert await rate_limiter.check("user") is True
assert await rate_limiter.consume("user") is True
1 change: 1 addition & 0 deletions tests/fixbackend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def default_config() -> Config:
azure_tenant_id="",
account_failed_resource_count=1,
degraded_accounts_ping_interval_hours=24,
auth_rate_limit_per_minute=100,
)


Expand Down

0 comments on commit d0ed477

Please sign in to comment.