-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
481 lines (356 loc) · 14.7 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import sys
from functools import wraps
from typing import Tuple, Any, Union, Callable
import numpy as np
from PIL import ImageFile
from PIL.Image import Image
from flask import Flask, jsonify, request, Response, make_response, Request
from flask_cors import CORS, cross_origin
from werkzeug.datastructures import FileStorage
from torch import Tensor
from models import get_classifier
from models.model_handlers import (
MODEL_REGISTER,
get_generator,
get_autoencoder,
get_text_function,
get_style_transfer_function,
get_text_translate_function,
get_image_captioning_function,
get_speech_to_text_function,
)
from utils import setup_logger, allowed_file, file2image, file2audiotensor
from utils.upload_utils import image2b64
ImageFile.LOAD_TRUNCATED_IMAGES = True
logger = setup_logger(__name__)
logger.info("=> Finished Importing")
# attach our logger to the system exceptions
sys.excepthook = lambda type, val, tb: logger.error(
"Unhandled exception:", exc_info=val
)
app: Flask = Flask(__name__)
cors: CORS = CORS(app=app)
app.config["CORS_HEADERS"] = "Content-Type"
if "PRODUCTION" not in os.environ:
app.config["DEBUG"] = True
@app.route("/")
@cross_origin()
def hello_thetensorclan() -> Tuple[Any, int]:
return (
jsonify({"message": "You've reached the TensorClan Heroku Backend EndPoint"}),
200,
)
@app.route("/classify/<model_handle>", methods=["POST"])
@cross_origin()
def classify_image_api(model_handle="resnet34-imagenet") -> Response:
"""
Args:
model_handle: the model handle string, should be in `models.model_handler.MODEL_REGISTER`
Returns:
(Response): if error then a json of {'error': 'message'} is sent
else return a json of sorted List[Dict[{'class_idx': idx, 'class_name': cn, 'confidence': 'c'}]]
"""
if model_handle not in MODEL_REGISTER:
return Response(
{"error": f"{model_handle} not found in registered models"}, status=404
)
if "file" not in request.files:
return Response({"error": "No file part"}, status=412)
file: FileStorage = request.files["file"]
if file.filename == "":
return Response({"error": "No file selected"}, status=417)
if allowed_file(file.filename):
image: Image = file2image(file)
classifier = get_classifier(model_handle)
output = classifier(image)
return Response(json.dumps(output), status=200)
else:
return Response({"error": f"{file.mimetype} not allowed"}, status=412)
@app.route("/generators/<model_handle>", methods=["POST"])
@cross_origin()
def generator_api(model_handle="red-car-gan-generator") -> Response:
"""
generator_api
This is the generator end point, that has the model handle as the parameter
and takes in the latent_z values in the POST requests, followed by passing this
vector to the model and generates an image, which is returned as a b64 image in
the Response
Args:
model_handle: the model handle string
Returns:
Response: the base 64 encoded generated image
"""
if model_handle not in MODEL_REGISTER:
return make_response(
jsonify({"error": f"{model_handle} not found in registered models"}), 404
)
if (
model_handle in MODEL_REGISTER
and MODEL_REGISTER[model_handle]["type"] != "gan-generator"
):
return make_response(
jsonify({"error": f"{model_handle} model is not a GAN"}), 412
)
if "latent_z_size" in MODEL_REGISTER[model_handle]:
# this is a latentz input type of gan model
if "latent_z" not in request.form:
return make_response(
jsonify({"error": "latent_z not found in the form"}), 412
)
latent_z = json.loads(f"[{request.form['latent_z']}]")
latent_z = np.array(latent_z, dtype=np.float32)
generator = get_generator(model_handle)
output = generator(latent_z)
# convert it to b64 bytes
b64_image = image2b64(output)
return make_response(jsonify(b64_image), 200)
if "input_shape" in MODEL_REGISTER[model_handle]:
# this is a image input type of gan model
if "file" not in request.files:
return make_response(jsonify({"error": "No file part"}), 412)
file: FileStorage = request.files["file"]
if file.filename == "":
return make_response(jsonify({"error": "No file selected"}), 417)
if allowed_file(file.filename):
image: Image = file2image(file)
generator = get_generator(model_handle)
output = generator(image)
# convert it to b64 bytes
b64_image = image2b64(output)
return make_response(jsonify(b64_image), 200)
return make_response(jsonify({"error": f"{model_handle} is not a valid GAN"}), 412)
@app.route("/autoencoders/<model_handle>", methods=["POST"])
@cross_origin()
def autoencoder_api(model_handle="red-car-autoencoder") -> Response:
"""
autoencoder_api
This end point is used to encode an image and then get the latentz vector as well as
the reconstructed image, this kind of technique can be used for image compression
and video compression, but right now only supports images and specific type of input
data.
The latentz vector is a unique representation of the input, and thus the latentz given
to a encoder and reconstruct the image exactly, thus reducing the data transmitted.
Args:
model_handle: the model handle string, must be in the MODEL_REGISTER
Returns:
Response: The response is a JSON containing the reconstructed image and the latent z
vector for the image
"""
if model_handle not in MODEL_REGISTER:
return make_response(
jsonify({"error": f"{model_handle} not found in registered models"}), 404
)
if (
model_handle in MODEL_REGISTER
and MODEL_REGISTER[model_handle]["type"] != "variational-autoencoder"
):
return make_response(
jsonify({"error": f"{model_handle} model is not an AutoEncoder"}), 412
)
if "file" not in request.files:
return make_response(jsonify({"error": "No file part"}), 412)
file: FileStorage = request.files["file"]
if file.filename == "":
return make_response(jsonify({"error": "No file selected"}), 417)
if allowed_file(file.filename):
image: Image = file2image(file)
autoencoder = get_autoencoder(model_handle)
output: Image
latent_z: np.ndarray
output, latent_z = autoencoder(image)
# convert it to b64 bytes
b64_image = image2b64(output)
return make_response(
jsonify(dict(recon_image=b64_image, latent_z=latent_z.tolist())), 200
)
else:
return make_response(jsonify({"error": f"{file.mimetype} not allowed"}), 412)
@app.route("/text/<model_handle>", methods=["POST"])
@cross_origin()
def text_api(model_handle="conv-sentimental-mclass") -> Response:
if model_handle not in MODEL_REGISTER:
return make_response(
jsonify({"error": f"{model_handle} not found in registered models"}), 404
)
if "input_text" not in request.form:
return make_response(
jsonify({"error": "input_text not found in the form"}), 412
)
input_text = request.form["input_text"]
text_func = get_text_function(model_handle)
output = text_func(input_text)
return make_response(jsonify(output), 200)
@app.route("/text/translate/<source_ln>/<target_ln>", methods=["POST"])
@cross_origin()
def translate_text(source_ln="de", target_ln="en") -> Response:
if "source_text" not in request.form:
return make_response(
jsonify({"error": "input_text not found in the form"}), 412
)
source_text = request.form["source_text"]
if source_ln == "de" and target_ln == "en":
translate_func = get_text_translate_function("annotated-encoder-decoder-de-en")
output = translate_func(source_text)
return make_response(jsonify(output), 200)
else:
return make_response(
jsonify({"error": f"{source_ln} -> {target_ln} not supported"}), 404
)
def form_file_check(file_key):
"""
Checks if the file key is present in request.files
Args:
file_key: the key used to retrieve file from the request.files dict
"""
def decorator(api_func):
@wraps(api_func)
def wrapper(*args, **kwargs):
if file_key not in request.files:
return make_response(jsonify({"error": "No file part"}), 412)
return api_func(*args, **kwargs)
return wrapper
return decorator
def model_handle_check(model_type):
"""
Checks for the model_type and model_handle on the api function,
model_type is a argument to this decorator, it steals model_handle and checks if it is
present in the MODEL_REGISTER
the api must have model_handle in it
Args:
model_type: the "type" of the model, as specified in the MODEL_REGISTER
Returns:
wrapped api function
"""
def decorator(api_func):
@wraps(api_func)
def wrapper(*args, model_handle, **kwargs):
if model_handle not in MODEL_REGISTER:
return make_response(
jsonify(
{"error": f"{model_handle} not found in registered models"}
),
404,
)
if (
model_handle in MODEL_REGISTER
and MODEL_REGISTER[model_handle]["type"] != model_type
):
return make_response(
jsonify({"error": f"{model_handle} model is not an {model_type}"}),
412,
)
return api_func(*args, model_handle=model_handle, **kwargs)
return wrapper
return decorator
def get_audio_from_request(
from_request: Request, file_key: str
) -> Union[Response, Tensor]:
file: FileStorage = from_request.files[file_key]
if file.filename == "":
return make_response(jsonify({"error": "No file selected"}), 417)
if allowed_file(file.filename):
audio: Tensor = file2audiotensor(file)
return audio
else:
return make_response(jsonify({"error": f"{file.mimetype} not allowed"}), 412)
def get_image_from_request(
from_request: Request, file_key: str
) -> Union[Response, Image]:
file: FileStorage = from_request.files[file_key]
if file.filename == "":
return make_response(jsonify({"error": "No file selected"}), 417)
if allowed_file(file.filename):
image: Image = file2image(file)
return image
else:
return make_response(jsonify({"error": f"{file.mimetype} not allowed"}), 412)
@app.route("/style-transfer/<model_handle>/<style_name>", methods=["POST"])
@cross_origin()
@model_handle_check(model_type="style-transfer")
def style_transfer_api(
model_handle="fast-style-transfer", style_name="candy"
) -> Response:
# check if its a valid style
if style_name not in MODEL_REGISTER[model_handle]["model_stack"]:
return make_response(
jsonify({"error": f"{style_name} not in model_stack of {model_handle}"}),
404,
)
# get the input image from the request
returned_val: Union[Response, Image] = get_image_from_request(
from_request=request, file_key="file"
)
# if a response is already created during process i.e. an error, then return that
if isinstance(returned_val, Response):
response: Response = returned_val
return response
image: Image = returned_val
# now process the image
style_transfer = get_style_transfer_function(model_handle, style_name)
output: Image = style_transfer(image)
# convert it to b64 bytes
b64_image = image2b64(output)
return make_response(jsonify(b64_image), 200)
@app.route("/image-captioning/<model_handle>", methods=["POST"])
@cross_origin()
@model_handle_check(model_type="image-caption")
@form_file_check(file_key="file")
def image_caption_api(model_handle="flickr8k-image-caption") -> Response:
# get the input image from the request
returned_val: Union[Response, Image] = get_image_from_request(
from_request=request, file_key="file"
)
# if a response is already created during process i.e. an error, then return that
if isinstance(returned_val, Response):
response: Response = returned_val
return response
image: Image = returned_val
# now process the image
image_caption: Callable[[Image], str] = get_image_captioning_function(model_handle)
output: str = image_caption(image)
return make_response(jsonify({"caption": output}), 200)
@app.route("/speech-to-text/<model_handle>", methods=["POST"])
@cross_origin()
@model_handle_check(model_type="speech-to-text")
@form_file_check(file_key="file")
def speech_to_text_api(model_handle="speech-recognition-residual-model") -> Response:
# get the input audio from the request
returned_val: Union[Response, Tensor] = get_audio_from_request(
from_request=request, file_key="file"
)
# if a response is already created during process i.e. an error, then return that
if isinstance(returned_val, Response):
response: Response = returned_val
return response
audio: Tensor = returned_val
speech_to_text: Callable[[Tensor], str] = get_speech_to_text_function(model_handle)
output: str = speech_to_text(audio)
return make_response(jsonify({"text": output}), 200)
@app.route("/human-pose", methods=["POST"])
@cross_origin()
def get_human_pose() -> Response:
"""
Handles the human pose POST request, takes the pose image and identifies the human pose keypoints,
stitches them together and returns a response with the image as b64 encoded, with the detected human pose
Returns:
(Response): b64 image string with the detected human pose
"""
from models import get_pose
if "file" not in request.files:
return Response({"error": "No file part"}, status=412)
file: FileStorage = request.files["file"]
if file.filename == "":
return Response({"error": "No file selected"}, status=417)
if allowed_file(file.filename):
image: Image = file2image(file)
pose_img = get_pose(image)
# convert it to b64 bytes
b64_pose = image2b64(pose_img)
return jsonify(b64_pose), 200
else:
return Response({"error": f"{file.mimetype} not allowed"}, status=412)