diff --git a/src/fast_api_app/routes/infer.py b/src/fast_api_app/routes/infer.py index d405924..48da42a 100644 --- a/src/fast_api_app/routes/infer.py +++ b/src/fast_api_app/routes/infer.py @@ -9,12 +9,18 @@ from fastapi.responses import HTMLResponse from pydantic import BaseModel -from fast_api_app.connections import limiter, model_queue, redis_conn, async_session +from fast_api_app.connections import ( + async_session, + limiter, + model_queue, + redis_conn, +) from shared_lib.constants import ( BUCKET_THRESHOLDS, MAIN_ONLY_ABILITIES, STANDARD_ABILITIES, ) +from shared_lib.models import ModelInferenceLog router = APIRouter() @@ -298,29 +304,21 @@ async def log_inference_request( ) else: async with async_session() as session: - await session.execute( - """ - INSERT INTO splatgpt.model_inference_logs ( - request_id, ip_address, user_agent, http_method, - endpoint, input_data, model_version, - processing_time_ms, status_code, error_message, - output_data - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 - ) - """, - request_id, - log_entry["ip_address"], - log_entry["user_agent"], - log_entry["http_method"], - log_entry["endpoint"], - log_entry["input_data"], - log_entry["model_version"], - log_entry["processing_time_ms"], - log_entry["status_code"], - log_entry["error_message"], - log_entry.get("output_data"), + new_log_entry = ModelInferenceLog( + request_id=request_id, + ip_address=log_entry["ip_address"], + user_agent=log_entry["user_agent"], + http_method=log_entry["http_method"], + endpoint=log_entry["endpoint"], + input_data=log_entry["input_data"], + model_version=log_entry["model_version"], + processing_time_ms=log_entry["processing_time_ms"], + status_code=log_entry["status_code"], + error_message=log_entry["error_message"], + output_data=log_entry.get("output_data"), ) + session.add(new_log_entry) + await session.commit() except Exception as db_error: logger.error(f"Failed to log inference request: {db_error}") diff --git a/src/shared_lib/models.py b/src/shared_lib/models.py index 9cc5615..7e8d131 100644 --- a/src/shared_lib/models.py +++ b/src/shared_lib/models.py @@ -1,14 +1,21 @@ +import uuid + from sqlalchemy import ( + UUID, + BigInteger, Boolean, Column, DateTime, Float, Index, Integer, + SmallInteger, String, + Text, UniqueConstraint, + text, ) -from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy.dialects.postgresql import ENUM, INET, JSONB from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() @@ -231,3 +238,25 @@ class WeaponLeaderboard(Base): Index("idx_weapon_leaderboard_weapon_id", "weapon_id"), {"schema": "xscraper"}, ) + + +class ModelInferenceLog(Base): + __tablename__ = "model_inference_logs" + __table_args__ = {"schema": "splatgpt"} + + id = Column(BigInteger, primary_key=True) + request_id = Column(UUID(as_uuid=True), default=uuid.uuid4) + timestamp = Column( + DateTime(timezone=True), server_default=text("CURRENT_TIMESTAMP") + ) + ip_address = Column(INET) + user_agent = Column(Text) + http_method = Column(String) + endpoint = Column(String) + client_id = Column(String) + input_data = Column(JSONB) + model_version = Column(String) + processing_time_ms = Column(Integer) + status_code = Column(SmallInteger) + output_data = Column(JSONB) + error_message = Column(Text)