Skip to content

Commit

Permalink
Add documentation for OpenAPI Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
RikiSot committed Aug 3, 2022
1 parent 37619d9 commit 919bd1f
Showing 1 changed file with 46 additions and 9 deletions.
55 changes: 46 additions & 9 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,74 @@
import json

import uvicorn
from pydantic import BaseModel
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi import HTTPException
import os
ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '.'))

# 2. API Documentation
tags_metadata = [
{
"name": "Home",
"description": "Home page"
},
{
'name': 'predict',
'description': 'Predicts the class of different types of weather. There are 11 output classes: dew, fog/smog, '
'frost, glaze, hail, lightning , rain, rainbow, rime, sandstorm and snow. '
}
]

predict_responses = {
200: {
'description': 'Prediction successful',
'content': {
'application/json': {
'examples': {
'Dew image': {
'summary': 'Dew image',
'value': {
'prediction': 'dew',
'probability': 0.9,
}
},
}
}
}
}
}


class PredictResponse(BaseModel):
"""Response model for the predict endpoint"""
prediction: str
probability: float


ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '.'))

# Load labels from json file
labels_path = os.path.join(ROOT_DIR, 'data', 'labels.json')
labels = json.load(open(labels_path))
IMG_SIZE = (120, 120) # Same size as the model's input during training

# 2. Create the app object
app = FastAPI()
app = FastAPI(openapi_tags=tags_metadata)

# 3. Load models
model_path = os.path.join(ROOT_DIR, 'models', 'ResNet50.h5')
model = tf.keras.models.load_model(model_path)


# 4. API Endpoints and methods
@app.get("/")
@app.get("/", tags=["Home"])
async def home():
return {"message": "Technical test"}


@app.post("/predict")
@app.post("/predict", tags=["predict"], response_model=PredictResponse, responses=predict_responses)
async def predict_image(image_link: str = ''):
"""
Predict the label from a given image (url)
Expand All @@ -57,11 +97,8 @@ async def predict_image(image_link: str = ''):

prediction = np.argmax(score, axis=1)
label = labels[str(prediction[0])]
response = {
'label': label,
'confidence': model_score,
}
return JSONResponse(content=response)
response = PredictResponse(prediction=label, probability=model_score)
return response


def start_server():
Expand Down

0 comments on commit 919bd1f

Please sign in to comment.