From 919bd1f0f2127a9419dddb684c131c89d563f496 Mon Sep 17 00:00:00 2001 From: RikiSot Date: Wed, 3 Aug 2022 11:50:42 +0200 Subject: [PATCH] Add documentation for OpenAPI Schema --- api.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/api.py b/api.py index 6f39cdc..2ddbd77 100644 --- a/api.py +++ b/api.py @@ -10,21 +10,61 @@ import json import uvicorn +from pydantic import BaseModel from fastapi import FastAPI from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from fastapi import HTTPException import os -ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '.')) + +# 2. API Documentation +tags_metadata = [ + { + "name": "Home", + "description": "Home page" + }, + { + 'name': 'predict', + 'description': 'Predicts the class of different types of weather. There are 11 output classes: dew, fog/smog, ' + 'frost, glaze, hail, lightning , rain, rainbow, rime, sandstorm and snow. ' + } +] + +predict_responses = { + 200: { + 'description': 'Prediction successful', + 'content': { + 'application/json': { + 'examples': { + 'Dew image': { + 'summary': 'Dew image', + 'value': { + 'prediction': 'dew', + 'probability': 0.9, + } + }, + } + } + } + } +} +class PredictResponse(BaseModel): + """Response model for the predict endpoint""" + prediction: str + probability: float + + +ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '.')) + # Load labels from json file labels_path = os.path.join(ROOT_DIR, 'data', 'labels.json') labels = json.load(open(labels_path)) IMG_SIZE = (120, 120) # Same size as the model's input during training # 2. Create the app object -app = FastAPI() +app = FastAPI(openapi_tags=tags_metadata) # 3. Load models model_path = os.path.join(ROOT_DIR, 'models', 'ResNet50.h5') @@ -32,12 +72,12 @@ # 4. API Endpoints and methods -@app.get("/") +@app.get("/", tags=["Home"]) async def home(): return {"message": "Technical test"} -@app.post("/predict") +@app.post("/predict", tags=["predict"], response_model=PredictResponse, responses=predict_responses) async def predict_image(image_link: str = ''): """ Predict the label from a given image (url) @@ -57,11 +97,8 @@ async def predict_image(image_link: str = ''): prediction = np.argmax(score, axis=1) label = labels[str(prediction[0])] - response = { - 'label': label, - 'confidence': model_score, - } - return JSONResponse(content=response) + response = PredictResponse(prediction=label, probability=model_score) + return response def start_server():