Skip to content

Commit

Permalink
Merge pull request #35 from cesaregarza/feature/plus_minus_api
Browse files Browse the repository at this point in the history
Feature/plus minus api
  • Loading branch information
cesaregarza authored Oct 29, 2024
2 parents 149916e + a4861ff commit 05db4d5
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 80 deletions.
238 changes: 160 additions & 78 deletions src/fast_api_app/routes/infer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import os
import time
import traceback
import uuid
from contextlib import asynccontextmanager

import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from sqlalchemy.dialects.postgresql import insert as pg_insert

from fast_api_app.connections import (
async_session,
Expand All @@ -21,7 +23,7 @@
MAIN_ONLY_ABILITIES,
STANDARD_ABILITIES,
)
from shared_lib.models import ModelInferenceLog
from shared_lib.models import FeedbackLog, ModelInferenceLog

router = APIRouter()

Expand Down Expand Up @@ -54,6 +56,12 @@ class InferenceResponse(BaseModel):
metadata: MetaData


class FeedbackRequest(BaseModel):
request_id: str
user_agent: str
feedback: bool


# Create a persistent client
persistent_client = httpx.AsyncClient()

Expand All @@ -67,7 +75,7 @@ async def infer_instructions():
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>SplatGPT Inference API Instructions</title>
<title>SplatGPT API Documentation</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
Expand Down Expand Up @@ -128,41 +136,55 @@ async def infer_instructions():
font-family: monospace;
color: #6c757d;
}
.section {
border: 1px solid #e1e4e8;
border-radius: 6px;
margin: 2rem 0;
padding: 1rem;
}
.implementation-details {
background-color: #f8f9fa;
margin-top: 3rem;
padding: 1.5rem;
border-radius: 8px;
}
</style>
</head>
<body>
<h1>SplatGPT Inference API Instructions</h1>
<p>This endpoint provides detailed instructions on how to use the inference API for Splatoon 3 gear ability predictions.</p>
<h1>SplatGPT API Documentation</h1>
<div class="endpoint">
<h2>Endpoint Details</h2>
<div class="section">
<h2>Core API Endpoints</h2>
<h3>1. Inference Endpoint</h3>
<div class="endpoint">
<h4>Endpoint Details</h4>
<ul>
<li><strong>Method:</strong> POST</li>
<li><strong>Endpoint:</strong> <code>/api/infer</code></li>
<li><strong>Header:</strong> A custom User-Agent is required</li>
</ul>
</div>
<h4>Request Headers</h4>
<p>A custom User-Agent header is required for all requests to this endpoint. Requests without a custom User-Agent will be rejected.</p>
<h4>Request Body</h4>
<h5>abilities</h5>
<p>A dictionary of ability names and their corresponding Ability Point (AP) values. Each ability is represented by an integer AP value, where:</p>
<ul>
<li><strong>Method:</strong> POST</li>
<li><strong>Endpoint:</strong> <code>/api/infer</code></li>
<li><strong>Header:</strong> A custom User-Agent is required</li>
<li>A main slot ability has a weight of <code>10 AP</code></li>
<li>A sub slot ability has a weight of <code>3 AP</code></li>
<li>Main-Slot-Only abilities should always be represented as <code>10 AP</code></li>
</ul>
</div>
<h2>Request Headers</h2>
<p>A custom User-Agent header is required for all requests to this endpoint. Requests without a custom User-Agent will be rejected.</p>
<h5>weapon_id</h5>
<p>An integer representing the unique identifier for a specific weapon in Splatoon 3. The internal ID, where 50 is the ID for 52 gal.</p>
<h2>Request Body</h2>
<h3>abilities</h3>
<p>A dictionary of ability names and their corresponding Ability Point (AP) values. Each ability is represented by an integer AP value, where:</p>
<ul>
<li>A main slot ability has a weight of <code>10 AP</code></li>
<li>A sub slot ability has a weight of <code>3 AP</code></li>
<li>Main-Slot-Only abilities should always be represented as <code>10 AP</code></li>
</ul>
<p>The total AP for an ability is the sum of its main and sub slot values. For example, one main (10 AP) and three subs (3 AP each) of Swim Speed Up would be represented as 19 AP.</p>
<h3>weapon_id</h3>
<p>An integer representing the unique identifier for a specific weapon in Splatoon 3. The internal ID, where 50 is the ID for 52 gal.</p>
<h2>Example Request</h2>
<pre>{
<h4>Example Inference Request</h4>
<pre>{
"abilities": {
"swim_speed_up": 19,
"ninja_squid": 10,
Expand All @@ -175,68 +197,97 @@ async def infer_instructions():
"weapon_id": 50
}</pre>
<h2>Response</h2>
<p>The response contains two main parts:</p>
<ol>
<li><strong>predictions:</strong> A list of tuples, each containing:
<ul>
<li>An ability token (string)</li>
<li>The predicted value for that token (float)</li>
</ul>
</li>
<li><strong>metadata:</strong> Additional information about the request and response, including:
<h4>Inference Response</h4>
<p>The response contains two main parts:</p>
<ol>
<li><strong>predictions:</strong> A list of tuples, each containing:
<ul>
<li>An ability token (string)</li>
<li>The predicted value for that token (float)</li>
</ul>
</li>
<li><strong>metadata:</strong> Additional information about the request and response, including:
<ul>
<li>request_id: A unique identifier for the request</li>
<li>api_version: The version of the API used</li>
<li>splatgpt_version: The version of the model used for prediction</li>
<li>cache_status: Whether the result was retrieved from cache ("hit") or newly computed ("miss")</li>
<li>processing_time_ms: The time taken to process the request, in milliseconds</li>
</ul>
</li>
</ol>
<div class="note">
<p>The inference endpoint is rate-limited to 10 requests per minute to ensure fair usage and system stability.</p>
</div>
<h3>2. Feedback Endpoint</h3>
<div class="endpoint">
<h4>Endpoint Details</h4>
<ul>
<li>request_id: A unique identifier for the request</li>
<li>api_version: The version of the API used</li>
<li>splatgpt_version: The version of the model used for prediction</li>
<li>cache_status: Whether the result was retrieved from cache ("hit") or newly computed ("miss")</li>
<li>processing_time_ms: The time taken to process the request, in milliseconds</li>
<li><strong>Method:</strong> POST</li>
<li><strong>Endpoint:</strong> <code>/api/feedback</code></li>
</ul>
</li>
</ol>
<p>Ability tokens are formatted as follows:</p>
<ul>
<li>For main-slot-only abilities: the ability name (e.g., <code>ninja_squid</code>)</li>
<li>For standard abilities: the ability name followed by a number representing the AP breakpoint (e.g., <code>swim_speed_up_3</code>, <code>swim_speed_up_6</code>, etc.)</li>
</ul>
<p>The number in the token represents the minimum AP value for that prediction. For instance, <code>swim_speed_up_3</code> represents the effect of Swim Speed Up with at least 3 AP.</p>
<div class="note">
<h2>Note</h2>
<p>This endpoint is rate-limited to 10 requests per minute to ensure fair usage and system stability.</p>
</div>
<p>The feedback endpoint allows users to provide feedback on inference predictions.</p>
<h4>Feedback Request Body</h4>
<pre>{
"request_id": "string", // The request_id from the inference response
"user_agent": "string", // The User-Agent used in the request
"feedback": boolean // true for positive feedback, false for negative
}</pre>
<h4>Feedback Response</h4>
<p>Upon successful submission, the endpoint returns a status message indicating whether the feedback was inserted or updated:</p>
<pre>{
"status": "Feedback updated successfully" // or "New feedback inserted successfully"
}</pre>
</div>
<h2>Ability Lists</h2>
<h3>Main-Only Abilities</h3>
<ul>
"""
<div class="implementation-details">
<h2>Implementation Details</h2>
<h3>Token Format</h3>
<p>Ability tokens in the response follow these formatting rules:</p>
<ul>
<li>For main-slot-only abilities: the ability name (e.g., <code>ninja_squid</code>)</li>
<li>For standard abilities: the ability name followed by a number representing the AP breakpoint (e.g., <code>swim_speed_up_3</code>, <code>swim_speed_up_6</code>, etc.)</li>
</ul>
<p>The number in the token represents the minimum AP value for that prediction. For instance, <code>swim_speed_up_3</code> represents the effect of Swim Speed Up with at least 3 AP.</p>
<h3>Special Tokens</h3>
<p>These special tokens may appear in the output with near-zero probability:</p>
<ul>
<li><span class="special-token">&lt;NULL&gt;</span>: Placeholder token to build from no input, safe to ignore</li>
<li><span class="special-token">&lt;PAD&gt;</span>: Padding token used in training, safe to ignore</li>
</ul>
<h3>Available Abilities</h3>
<h4>Main-Only Abilities</h4>
<ul>
"""
+ "".join([f"<li>{ability}</li>" for ability in MAIN_ONLY_ABILITIES])
+ """
</ul>
</ul>
<h3>Standard Abilities</h3>
<ul>
"""
<h4>Standard Abilities</h4>
<ul>
"""
+ "".join([f"<li>{ability}</li>" for ability in STANDARD_ABILITIES])
+ """
</ul>
<h3>Special Tokens</h3>
<p>There are special tokens that will be returned that should be close to zero probability in the output, here are the meanings:</p>
<ul>
<li><span class="special-token">&lt;NULL&gt;</span>: Placeholder token to build from no input, safe to ignore</li>
<li><span class="special-token">&lt;PAD&gt;</span>: Padding token used in training, safe to ignore</li>
</ul>
<h2>AP Breakpoints</h2>
<ul>
"""
</ul>
<h3>AP Breakpoints</h3>
<ul>
"""
+ "".join(
[f"<li>{breakpoint}</li>" for breakpoint in BUCKET_THRESHOLDS]
)
+ """
</ul>
</ul>
</div>
</body>
</html>
"""
Expand Down Expand Up @@ -405,3 +456,34 @@ async def infer(inference_request: InferenceRequest, request: Request):
"processing_time_ms": processing_time,
},
)


@router.post("/api/feedback")
async def feedback(feedback_request: FeedbackRequest):
try:
async with async_session() as session:
stmt = (
pg_insert(FeedbackLog)
.values(
request_id=feedback_request.request_id,
user_agent=feedback_request.user_agent,
feedback=feedback_request.feedback,
)
.on_conflict_do_update(
index_elements=["request_id"],
set_={"feedback": feedback_request.feedback},
)
)
result = await session.execute(stmt)
await session.commit()

status_message = (
"Feedback updated"
if result.rowcount > 0
else "New feedback inserted"
)
return {"status": f"{status_message} successfully"}

except Exception as e:
logger.error(f"Error logging feedback: {e}")
raise HTTPException(status_code=500, detail="Error logging feedback")
29 changes: 27 additions & 2 deletions src/shared_lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Column,
DateTime,
Float,
ForeignKeyConstraint,
Index,
Integer,
SmallInteger,
Expand Down Expand Up @@ -242,10 +243,15 @@ class WeaponLeaderboard(Base):

class ModelInferenceLog(Base):
__tablename__ = "model_inference_logs"
__table_args__ = {"schema": "splatgpt"}
__table_args__ = (
UniqueConstraint(
"request_id", name="uq_model_inference_logs_request_id"
),
{"schema": "splatgpt"},
)

id = Column(BigInteger, primary_key=True)
request_id = Column(UUID(as_uuid=True), default=uuid.uuid4)
request_id = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True)
timestamp = Column(
DateTime(timezone=True), server_default=text("CURRENT_TIMESTAMP")
)
Expand All @@ -261,3 +267,22 @@ class ModelInferenceLog(Base):
status_code = Column(SmallInteger)
output_data = Column(JSONB)
error_message = Column(Text)


class FeedbackLog(Base):
__tablename__ = "feedback_logs"
__table_args__ = (
ForeignKeyConstraint(
["request_id"],
["splatgpt.model_inference_logs.request_id"],
name="fk_request_id",
ondelete="CASCADE",
),
UniqueConstraint("request_id", name="uq_feedback_logs_request_id"),
{"schema": "splatgpt"},
)

id = Column(BigInteger, primary_key=True)
request_id = Column(UUID(as_uuid=True), nullable=False)
user_agent = Column(Text)
feedback = Column(Boolean)

0 comments on commit 05db4d5

Please sign in to comment.