Skip to content

Commit

Permalink
Merge pull request #34 from cesaregarza/bugfix/client-ip
Browse files Browse the repository at this point in the history
fixed client IP not tracking correctly
  • Loading branch information
cesaregarza authored Oct 26, 2024
2 parents d6333a3 + 74f25df commit cbc9ec9
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/fast_api_app/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from celery import Celery
from fastapi import WebSocket
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import scoped_session, sessionmaker

from fast_api_app.utils import get_client_ip
from shared_lib.constants import REDIS_HOST, REDIS_PORT
from shared_lib.db import create_uri

Expand Down Expand Up @@ -124,7 +124,7 @@ async def broadcast_player_data(self, message: str, player_id: str):
sqlite_cursor = sqlite_conn.cursor()

# Create slowapi limiter
limiter = Limiter(key_func=get_remote_address)
limiter = Limiter(key_func=get_client_ip)


# Model Queue for SplatGPT
Expand Down
3 changes: 2 additions & 1 deletion src/fast_api_app/routes/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
model_queue,
redis_conn,
)
from fast_api_app.utils import get_client_ip
from shared_lib.constants import (
BUCKET_THRESHOLDS,
MAIN_ONLY_ABILITIES,
Expand Down Expand Up @@ -268,7 +269,7 @@ async def log_inference_request(
# Prepare log entry
log_entry = {
"request_id": request_id,
"ip_address": request.client.host,
"ip_address": get_client_ip(request),
"user_agent": request.headers.get("user-agent"),
"http_method": request.method,
"endpoint": str(request.url.path),
Expand Down
27 changes: 27 additions & 0 deletions src/fast_api_app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi import Request


def get_client_ip(request: Request) -> str:
"""Get the real client IP address considering proxy headers."""

forwarded_for = next(
(
v
for k, v in request.headers.items()
if k.lower() == "x-forwarded-for"
),
None,
)
if forwarded_for:
ip = forwarded_for.split(",")[0].strip()
return ip

real_ip = next(
(v for k, v in request.headers.items() if k.lower() == "x-real-ip"),
None,
)
if real_ip:
return real_ip

direct_ip = request.client.host
return direct_ip

0 comments on commit cbc9ec9

Please sign in to comment.