forked from unica-isde/isde-projects-2023-E
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
61 lines (50 loc) · 1.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
from typing import Dict, List
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import redis
from rq import Connection, Queue
from rq.job import Job
from app.config import Configuration
from app.forms.classification_form import ClassificationForm
from app.ml.classification_utils import classify_image
from app.utils import list_images
app = FastAPI()
config = Configuration()
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")
@app.get("/info")
def info() -> Dict[str, List[str]]:
"""Returns a dictionary with the list of models and
the list of available image files."""
list_of_images = list_images()
list_of_models = Configuration.models
data = {"models": list_of_models, "images": list_of_images}
return data
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
"""The home page of the service."""
return templates.TemplateResponse("home.html", {"request": request})
@app.get("/classifications")
def create_classify(request: Request):
return templates.TemplateResponse(
"classification_select.html",
{"request": request, "images": list_images(), "models": Configuration.models},
)
@app.post("/classifications")
async def request_classification(request: Request):
form = ClassificationForm(request)
await form.load_data()
image_id = form.image_id
model_id = form.model_id
classification_scores = classify_image(model_id=model_id, img_id=image_id)
return templates.TemplateResponse(
"classification_output.html",
{
"request": request,
"image_id": image_id,
"classification_scores": json.dumps(classification_scores),
},
)