-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
123 lines (90 loc) · 4.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from io import BytesIO
import torch
import cv2
import numpy as np
from typing import List
from clock_model_color import ColorClockVAEHandler
from clock_model_mono import MonoClockVAEHandler
from face_embedding import FaceRecognizer
from color_picker import extract_dominant_color_with_priority
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
color_model_path = "clock-vae-color-140x-v1-500epoch.pth" # os.getenv("COLOR_MODEL_PATH")
mono_model_path = "clock-vae-mono-100x-v1-1000epoch.pth" # os.getenv("MONO_MODEL_PATH")
color_model = ColorClockVAEHandler(model_path=color_model_path, device=device)
mono_model = MonoClockVAEHandler(model_path="clock-vae-mono-100x-v1-1000epoch.pth", size=100, device=device)
face_recognizer = FaceRecognizer()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 모든 출처 허용
allow_credentials=True,
allow_methods=["*"], # 모든 HTTP 메서드 허용
allow_headers=["*"], # 모든 HTTP 헤더 허용
)
class TimeRequst(BaseModel):
hour: int
minute: int
class SimilarityRequest(BaseModel):
embedding1: List[float]
embedding2: List[float]
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/clock_captcha_color/")
async def clock_captcha_color(input_time: TimeRequst):
correct_image = color_model.generate_image(input_time.hour, input_time.minute)
# Convert image to PNG format in a BytesIO buffer
_, encoded_image = cv2.imencode(".png", correct_image)
buffer = BytesIO(encoded_image.tobytes())
buffer.seek(0)
# Return a StreamingResponse with the PNG image
return StreamingResponse(buffer, media_type="image/png")
@app.post("/clock_captcha_mono/")
async def clock_captcha_mono(input_time: TimeRequst):
correct_image = mono_model.generate_image(input_time.hour, input_time.minute)
_, encoded_image = cv2.imencode(".png", correct_image)
buffer = BytesIO(encoded_image.tobytes())
buffer.seek(0)
return StreamingResponse(buffer, media_type="image/png")
@app.post("/extract_embedding/")
async def extract_embedding(file: UploadFile = File(...)):
contents = await file.read()
np_img = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
embedding = face_recognizer.get_face_embedding_from_image(image)
if embedding is None:
raise HTTPException(status_code=400, detail="이미지에서 얼굴을 찾을 수 없습니다.")
return {"embedding": embedding.tolist()}
@app.post("/check_similarity/")
async def check_similarity(request: SimilarityRequest):
embedding1 = np.array(request.embedding1)
embedding2 = np.array(request.embedding2)
result = face_recognizer.calculate_similarity(embedding1, embedding2)
return result
@app.post("/check_two_face/")
async def check_two_face(file1: UploadFile = File(...), file2: UploadFile = File(...)):
contents1 = await file1.read()
contents2 = await file2.read()
np_img1 = np.frombuffer(contents1, np.uint8)
np_img2 = np.frombuffer(contents2, np.uint8)
image1 = cv2.imdecode(np_img1, cv2.IMREAD_COLOR)
image2 = cv2.imdecode(np_img2, cv2.IMREAD_COLOR)
embedding1 = face_recognizer.get_face_embedding_from_image(image1)
embedding2 = face_recognizer.get_face_embedding_from_image(image2)
if embedding1 is None or embedding2 is None:
raise HTTPException(status_code=400, detail="이미지에서 얼굴을 찾을 수 없습니다.")
result = face_recognizer.calculate_similarity(embedding1, embedding2)
return result
@app.post("/extract_dominant_color/")
async def extract_dominant_color(file: UploadFile = File(...)):
try:
contents = await file.read()
image = BytesIO(contents)
dominant_color = extract_dominant_color_with_priority(image)
return {"dominant_color": dominant_color}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")