-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpredictor.py
479 lines (398 loc) · 16.5 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import cv2
import math
import numpy as np
import pyclipper
from shapely.geometry import Polygon
from segm.predictor import SegmPredictor
from ocr.predictor import OcrPredictor
from ocrpipeline.utils import img_crop
from ocrpipeline.config import Config
from ocrpipeline.linefinder import (
add_polygon_center, add_page_idx_for_lines, add_line_idx_for_lines,
add_line_idx_for_words, add_column_idx_for_words, add_word_indexes
)
def is_valid_polygon(polygon):
"""Check if a polygon is valid. Return True if valid and False otherwise.
Args:
polygon (shapely.geometry.Polygon): The polygon.
"""
if (
polygon.length < 1
or polygon.area <= 0
):
return False
return True
def change_contour_size(polygon, scale_ratio=1):
"""
Args:
polygon (np.array): Array of polygon coordinates
np.array([[x, y], ...])
scale_ratio (Float): The scale scale_ratio to change the contour.
"""
poly = Polygon(polygon)
pco = pyclipper.PyclipperOffset()
pco.AddPath(polygon, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
if not is_valid_polygon(poly):
return None
distance = int(poly.area * (1 - scale_ratio ** 2) / poly.length)
scaled_polygons = pco.Execute(-distance)
return scaled_polygons
def get_upscaled_bbox(bbox, upscale_x=1, upscale_y=1):
"""Increase size of the bbox."""
height = bbox[3] - bbox[1]
width = bbox[2] - bbox[0]
y_change = (height * upscale_y) - height
x_change = (width * upscale_x) - width
x_min = max(0, bbox[0] - int(x_change/2))
y_min = max(0, bbox[1] - int(y_change/2))
x_max = bbox[2] + int(x_change/2)
y_max = bbox[3] + int(y_change/2)
return x_min, y_min, x_max, y_max
def contour2bbox(contour):
"""Get bbox from contour."""
x, y, w, h = cv2.boundingRect(contour)
return (x, y, x + w, y + h)
def get_contours_from_mask(mask, min_area=5):
contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
contour_list = []
for contour in contours:
if cv2.contourArea(contour) >= min_area:
contour_list.append(contour)
return contour_list
def get_angle_between_vectors(x1, y1, x2=1, y2=0):
"""Define angle between two vectors. Outpur angle always positive."""
vector_1 = [x1, y1]
vector_2 = [x2, y2]
unit_vector_1 = vector_1 / np.linalg.norm(vector_1)
unit_vector_2 = vector_2 / np.linalg.norm(vector_2)
dot_product = np.dot(unit_vector_1, unit_vector_2)
radian = np.arccos(dot_product)
return math.degrees(radian)
def get_angle_by_fitline(contour):
"""Get angle of contour using cv2.fitLine."""
vx, vy, x, y = cv2.fitLine(contour, cv2.DIST_L2, 0, 0.01, 0.01)
angle = get_angle_between_vectors(vx[0], vy[0])
# get_line_angle return angle between vectors always positive
# so multiply by minus one if the line is negative
if vy > 0:
angle *= -1
return angle
def get_angle_by_minarearect(contour):
"""Get angle of contour using cv2.minAreaRect."""
rect = cv2.minAreaRect(contour)
angle = rect[2]
# revert angles as cv2 coordinate axis starts from up right corner
angle *= -1
# take the opposite angle if the rectangle is too rotated
if angle < -45:
angle += 90
return angle
def rotate_image_and_contours(image, contours, angle):
"""Rotate the image and contours by the angle."""
rotated_image, M = rotate_image(image, angle)
rotated_contours = []
for contour in contours:
contour = cv2.transform(contour, M)
rotated_contours.append(contour)
return rotated_image, rotated_contours
def get_image_angle(contours, by_fitline=True):
"""Define the angle of the image using the contours of the words."""
angles = []
for contour in contours:
if by_fitline:
angle = get_angle_by_fitline(contour)
else:
angle = get_angle_by_minarearect(contour)
angles.append(angle)
return np.median(np.array(angles))
def rotate_image(mat, angle):
"""
https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
Rotates an image (angle in degrees) and expands image to avoid cropping.
"""
height, width = mat.shape[:2]
image_center = (width/2, height/2)
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
abs_cos = abs(rotation_mat[0, 0])
abs_sin = abs(rotation_mat[0, 1])
bound_w = int(height * abs_sin + width * abs_cos)
bound_h = int(height * abs_cos + width * abs_sin)
rotation_mat[0, 2] += bound_w/2 - image_center[0]
rotation_mat[1, 2] += bound_h/2 - image_center[1]
rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
return rotated_mat, rotation_mat
class SegmPrediction:
def __init__(
self, pipeline_config, model_path, num_threads,
config_path, device, runtime
):
self.segm_predictor = SegmPredictor(
model_path=model_path,
config_path=config_path,
num_threads=num_threads,
device=device,
runtime=runtime
)
def __call__(self, image, pred_img):
pred_img = self.segm_predictor([image])[0]
return image, pred_img
class PrepareJSON:
"""Prepare final json to saving on disk. Remove unused subdicts.
"""
def __init__(self, pipeline_config):
self.elements_to_remove = ['crop', 'polygon_center']
def __call__(self, image, pred_img):
for prediction in pred_img['predictions']:
for element in self.elements_to_remove:
if element in prediction:
del prediction[element]
return image, pred_img
class OCRPrediction:
def __init__(
self, pipeline_config, model_path, config_path, num_threads,
lm_path, classes_to_ocr, device, batch_size, runtime
):
self.classes_to_ocr = classes_to_ocr
self.ocr_predictor = OcrPredictor(
model_path=model_path,
config_path=config_path,
num_threads=num_threads,
lm_path=lm_path,
device=device,
batch_size=batch_size,
runtime=runtime
)
def __call__(self, image, pred_img):
crops = []
indexes = []
for idx, prediction in enumerate(pred_img['predictions']):
if prediction['class_name'] in self.classes_to_ocr:
crops.append(prediction['crop'])
indexes.append(idx)
text_preds = self.ocr_predictor(crops)
for idx, text_pred in zip(indexes, text_preds):
pred_img['predictions'][idx]['text'] = text_pred
return image, pred_img
class LineFinder:
"""Heuristic methods to define indexes of rows, columns and pages for
polygons on the image.
Args:
line_classes (list of strs): List of line class names.
text_classes (list of strs): List of text class names.
pages_clust_dist (float): Relative (to image width) distance between two
clusters of lines' polygons to consider that image has two pages.
"""
def __init__(
self, pipeline_config, line_classes, text_classes,
pages_clust_dist=0.25
):
self.line_classes = line_classes
self.text_classes = text_classes
self.pages_clust_dist = pages_clust_dist
def __call__(self, image, pred_img):
_, img_w = image.shape[:2]
add_polygon_center(pred_img)
add_page_idx_for_lines(
pred_img, self.line_classes, img_w, self.pages_clust_dist)
add_line_idx_for_lines(pred_img, self.line_classes)
add_line_idx_for_words(pred_img, self.line_classes, self.text_classes)
add_column_idx_for_words(pred_img, self.text_classes)
add_word_indexes(pred_img, self.text_classes)
return image, pred_img
class RestoreImageAngle:
"""Define the angle of the image and rotates the image and contours to
this angle.
Args:
pipeline_config (ocrpipeline.config.Config): The pipeline config.json.
restoring_class_names (list of str): List of class names using find
angle of the image.
min_angle_to_rotate (int): The safe range of angles within which image
rotation does not occur (-min_angle_to_rotate; min_angle_to_rotate)
"""
def __init__(
self, pipeline_config, restoring_class_names, min_angle_to_rotate=0.5
):
self.restoring_class_names = restoring_class_names
self.min_angle_to_rotate = min_angle_to_rotate
def __call__(self, image, pred_img):
contours = []
restoring_contours = []
for prediction in pred_img['predictions']:
contour = prediction['polygon']
contour = np.array([contour])
contours.append(contour)
if prediction['class_name'] in self.restoring_class_names:
restoring_contours.append(contour)
angle = get_image_angle(restoring_contours)
if abs(angle) > self.min_angle_to_rotate:
image, contours = rotate_image_and_contours(image, contours, -angle)
for prediction, contour in zip(pred_img['predictions'], contours):
contour = [[int(i[0]), int(i[1])] for i in contour[0]]
prediction['rotated_polygon'] = contour
return image, pred_img
class BboxFromContour:
def __call__(self, image, crop, bbox, contour):
bbox = contour2bbox(np.array([contour]))
return crop, bbox, contour
class CropByBbox:
def __call__(self, image, crop, bbox, contour):
crop = img_crop(image, bbox)
return crop, bbox, contour
class MakeMaskedCrop:
"""Make crop by mask: blacken area outside predicted polygon."""
def __call__(self, image, crop, bbox, contour):
pts = np.array(contour)
pts = pts - pts.min(axis=0)
mask = np.zeros(crop.shape[:2], np.uint8)
cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
crop = cv2.bitwise_and(crop, crop, mask=mask)
return crop, bbox, contour
class UpscaleContour:
"""Upscale contour as it could have been shrinked during segmentation
training.
"""
def __init__(self, upscale_contour):
self.upscale_contour = upscale_contour
def __call__(self, image, crop, bbox, contour):
upscaled_contour = \
change_contour_size(contour, self.upscale_contour)
if upscaled_contour is None:
return crop, bbox, contour
# take zero contour (when upscaling only one contour could be returned)
upscaled_contour = upscaled_contour[0]
# coords shouldn't be outside image after upscaling
upscaled_contour = [[max(0, i[0]), max(0, i[1])]
for i in upscaled_contour]
return crop, bbox, upscaled_contour
class UpscaleBbox:
def __init__(self, upscale_bbox):
self.upscale_bbox = upscale_bbox
def __call__(self, image, crop, bbox, contour):
bbox = get_upscaled_bbox(
bbox=bbox,
upscale_x=self.upscale_bbox[0],
upscale_y=self.upscale_bbox[1]
)
return crop, bbox, contour
class RotateVerticalCrops:
"""Rotate vertical text crops (e.g. text in the margins).
Args:
h2w_ratio (float): height to width ratio to consider text as vertival.
clockwise (bool): Rotate crop clockwise or counterclockwise.
"""
def __init__(self, h2w_ratio=5, clockwise=False):
self.h2w_ratio = h2w_ratio
self.flip_code = 0
if clockwise:
self.flip_code = 1
def __call__(self, image, crop, bbox, contour):
h, w = crop.shape[:2]
if h / w >= self.h2w_ratio:
rotated_crop = cv2.transpose(crop)
rotated_crop = cv2.flip(rotated_crop, flipCode=self.flip_code)
return rotated_crop, bbox, contour
return crop, bbox, contour
CONTOUR_PROCESS_DICT = {
"BboxFromContour": BboxFromContour,
"UpscaleBbox": UpscaleBbox,
"CropByBbox": CropByBbox,
"RotateVerticalCrops": RotateVerticalCrops,
"UpscaleContour": UpscaleContour,
"MakeMaskedCrop": MakeMaskedCrop
}
class ClassContourPosptrocess:
"""Class to handle postprocess functions for bboxs and contours."""
def __init__(self, pipeline_config):
self.class2process_funcs = {}
for class_name, params in pipeline_config.get_classes().items():
self.class2process_funcs[class_name] = []
for process_name, args in params['contour_posptrocess'].items():
self.class2process_funcs[class_name].append(
CONTOUR_PROCESS_DICT[process_name](**args)
)
def __call__(self, image, pred_img):
for prediction in pred_img['predictions']:
if prediction['class_name'] in self.class2process_funcs:
process_funcs = \
self.class2process_funcs[prediction['class_name']]
bbox = None
crop = None
contour = prediction['rotated_polygon']
for process_func in process_funcs:
crop, bbox, contour = \
process_func(image, crop, bbox, contour)
prediction['rotated_polygon'] = contour
prediction['rotated_bbox'] = bbox
prediction['crop'] = crop
return image, pred_img
class ImageToBGR:
def __init__(self, pipeline_config, input_format="BGR"):
self.input_format = input_format
def __call__(self, image, pred_img):
if self.input_format == "RGB":
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
return image, pred_img
MAIN_PROCESS_DICT = {
"ImageToBGR": ImageToBGR,
"SegmPrediction": SegmPrediction,
"ClassContourPosptrocess": ClassContourPosptrocess,
"RestoreImageAngle": RestoreImageAngle,
"OCRPrediction": OCRPrediction,
"PrepareJSON": PrepareJSON,
"LineFinder": LineFinder
}
class PipelinePredictor:
"""Main class to handle sub-classes which make preiction pipeline loop:
from segmentatino to ocr models. All pipeline sub-classes should be
listed in pipeline_config.json in main_process-dict.
Args:
pipeline_config_path (str): A path to the pipeline config.json.
"""
def __init__(self, pipeline_config_path):
self.config = Config(pipeline_config_path)
self.main_process_funcs = []
for process_name, args in self.config.get('main_process').items():
self.main_process_funcs.append(
MAIN_PROCESS_DICT[process_name](
pipeline_config=self.config,
**args)
)
def __call__(self, image):
"""
Args:
image (np.array): An input image in BGR format.
Returns:
rotated_image (np.array): The input image which was rotated to
restore rotation angle.
pred_data (dict): A result dict for the input image.
{
'image': {'height': Int, 'width': Int} params of the input image,
'predictions': [
{
'polygon': [ [x1,y1], [x2,y2], ..., [xN,yN] ] initial polygon
'bbox': [x_min, y_min, x_max, y_max] initial bounding box
'class_name': str, class name of the polygon.
'text': predicted text.
'rotated_bbox': [x_min, y_min, x_max, y_max] processed
bbox for a rotated image with the restored angle
'rotated_polygon': [ [x1,y1], [x2,y2], ..., [xN,yN] ]
processed polygon for a rotated image with the restored angle
'page_idx': int, The page index of the polygon.
'line_idx': int, The line index of the polygon within given page.
'column_idx': int, The column index of the polygon within
given line and page.
'word_idx': int, The positional index of the text polygons.
Using this index structured text can be extracted
from the prediction.
},
...
]
}
"""
pred_img = None
for process_func in self.main_process_funcs:
image, pred_img = process_func(image, pred_img)
return image, pred_img