-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
57 lines (46 loc) · 1.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import json
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from app.config import Configuration
from app.forms.classification_form import ClassificationForm
from app.ml.classification_utils import classify_image
from app.utils import list_images
app = FastAPI()
config = Configuration()
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")
@app.get("/info")
def info() -> dict[str, list[str]]:
"""Returns a dictionary with the list of models and
the list of available image files."""
list_of_images = list_images()
list_of_models = Configuration.models
data = {"models": list_of_models, "images": list_of_images}
return data
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
"""The home page of the service."""
return templates.TemplateResponse("home.html", {"request": request})
@app.get("/classifications")
def create_classify(request: Request):
return templates.TemplateResponse(
"classification_select.html",
{"request": request, "images": list_images(), "models": Configuration.models},
)
@app.post("/classifications")
async def request_classification(request: Request):
form = ClassificationForm(request)
await form.load_data()
image_id = form.image_id
model_id = form.model_id
classification_scores = classify_image(model_id=model_id, img_id=image_id)
return templates.TemplateResponse(
"classification_output.html",
{
"request": request,
"image_id": image_id,
"classification_scores": json.dumps(classification_scores),
},
)