-
Notifications
You must be signed in to change notification settings - Fork 0
/
back_end.py
150 lines (130 loc) · 4.33 KB
/
back_end.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from utils import similarity, get_all_transcripts, get_comb_label, build_phosc_model
from segmentation import multi_line_ext, words_extract
from tensorflow_addons.layers import SpatialPyramidPooling2D
import numpy as np
import uuid
import cv2
import tensorflow as tf
import pandas as pd
'''This module is the main back-end, which handles classification and
word/character extraction.
'''
global model
#if load_phosc_model:
model = build_phosc_model()
model.load_weights('model/phosc-model.h5')
#weights = model.get_weights()
#df = pd.DataFrame(weights)
#df.to_pickle('model/new_phosc_weights.pkl')
#else:
# model = load_model('model/phoc-model.h5', custom_objects={'SpatialPyramidPooling2D': SpatialPyramidPooling2D})
def classify(img, transcripts):
''' Classify a single word
To make a prediction
Args:
img: The image to be classified
Returns:
out: The prediciton
'''
img = img_to_array(img)
img = tf.image.resize(img, [70, 90])
img = np.expand_dims(img, axis=0)
y_pred=model.predict(img)
y_pred=np.squeeze(np.concatenate((y_pred[0],y_pred[1]),axis=1))
out = ''
mx = 0
for k in transcripts:
temp = similarity(y_pred, get_comb_label(k))
if temp > mx:
mx = temp
out = k
return out
def get_result(path):
''' Classify all features from within an image
This method classifies an image. It is split into three parts:
1. detect image lines
2. segment line to words
2. classify the detected words using the phoc
3. format the output
Args:
path: the location of the image to be classified
Returns:
result: the classified string
'''
output = ''
img = cv2.imread(path)
if np.mean(img) == 255:
output = ''
transcripts = get_all_transcripts()
endCharacters = ['ة', 'ئ', 'ى', 'ه', 'م', 'ن', 'ل', 'ك', 'ق', 'ف', 'غ', 'ع', 'ض', 'ص', 'ش', 'س', 'خ', 'ح', 'ج', 'ث', 'ت', 'ب']
lines = multi_line_ext(img)
for line in lines:
arr = words_extract(line)
if line != []:
for a in arr:
res = classify(a, transcripts)
# format string
lastChar = res[-1]
if len(res) == 1 and lastChar not in endCharacters:
output += ' ' + res
else:
if len(res) > 4 or lastChar in endCharacters:
output += res + ' '
else:
if (lastChar == 'ا' or lastChar == 'أ') and len(res) > 2:
output += res + ' '
else:
output += res
output += '<br>'
if output == '':
output = 'Classification Error'
return output
def format_result(arr):
'''Formats the prediciton result to a readable string
Args:
arr:a list of predicted chars sorted into lines
Returns:
string: the formatted string
'''
string = ''
for a in arr:
string += a
string += '\n'
string = string.strip()
return string
def save_thumbnail(user, img):
''' Reduce the size of an image and generate a thumbnail.
The image is saved to the users file. It is used in the hub page to
display previous predictions.
Args:
user: the current user
img: the image to be resized
Returns:
the location of the image
'''
max_height = 200
if img.shape[0] < img.shape[1]:
img = np.rot90(img)
hpercent = max_height / float(img.shape[0])
wsize = int(float(img.shape[1]) * float(hpercent))
img = cv2.resize(img, (wsize, max_height))
newname = uuid.uuid4()
cv2.imwrite('static/users/{}/{}.png'.format(user, newname), img)
return 'static/users/{}/{}.png'.format(user, newname)
# img = cv2.imread('image.png')
# transcripts = get_all_transcripts()
# lines = multi_line_ext(img)
# output = ''
# for line in lines:
# output = ''
# arr = words_extract(line)
# if line != []:
# for a in arr:
# res = classify(a, transcripts)
# output += ' ' + res
# img = cv2.imread('index.jpg')
# transcripts = get_all_transcripts()
# res = classify(img, transcripts)
# print(res)