diff --git a/easyocr/detection.py b/easyocr/detection.py index b1a4b71351..56df3d6b9c 100644 --- a/easyocr/detection.py +++ b/easyocr/detection.py @@ -22,42 +22,55 @@ def copyStateDict(state_dict): return new_state_dict def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): + if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays + image_arrs = image + else: # image is single numpy array + image_arrs = [image] + + img_resized_list = [] # resize - img_resized, target_ratio, size_heatmap = resize_aspect_ratio(image, canvas_size,\ - interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio) + for img in image_arrs: + img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, + interpolation=cv2.INTER_LINEAR, + mag_ratio=mag_ratio) + img_resized_list.append(img_resized) ratio_h = ratio_w = 1 / target_ratio - # preprocessing - x = normalizeMeanVariance(img_resized) - x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] - x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] + x = np.array([normalizeMeanVariance(n_img) for n_img in img_resized_list]) + x = Variable(torch.from_numpy(x).permute(0, 3, 1, 2)) # [b,h,w,c] to [b,c,h,w] x = x.to(device) # forward pass with torch.no_grad(): y, feature = net(x) - # make score and link map - score_text = y[0,:,:,0].cpu().data.numpy() - score_link = y[0,:,:,1].cpu().data.numpy() + boxes_list, polys_list = [], [] + for out in y: + # make score and link map + score_text = out[:, :, 0].cpu().data.numpy() + score_link = out[:, :, 1].cpu().data.numpy() - # Post-processing - boxes, polys, mapper = getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) + # Post-processing + boxes, polys, mapper = getDetBoxes( + score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) - # coordinate adjustment - boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) - polys = adjustResultCoordinates(polys, ratio_w, ratio_h) - if estimate_num_chars: - boxes = list(boxes) - polys = list(polys) - for k in range(len(polys)): + # coordinate adjustment + boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) + polys = adjustResultCoordinates(polys, ratio_w, ratio_h) if estimate_num_chars: - boxes[k] = (boxes[k], mapper[k]) - if polys[k] is None: polys[k] = boxes[k] - - return boxes, polys - -def get_detector(trained_model, device='cpu', quantize=True): + boxes = list(boxes) + polys = list(polys) + for k in range(len(polys)): + if estimate_num_chars: + boxes[k] = (boxes[k], mapper[k]) + if polys[k] is None: + polys[k] = boxes[k] + boxes_list.append(boxes) + polys_list.append(polys) + + return boxes_list, polys_list + +def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): net = CRAFT() if device == 'cpu': @@ -70,7 +83,7 @@ def get_detector(trained_model, device='cpu', quantize=True): else: net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) net = torch.nn.DataParallel(net).to(device) - cudnn.benchmark = False + cudnn.benchmark = cudnn_benchmark net.eval() return net @@ -78,13 +91,19 @@ def get_detector(trained_model, device='cpu', quantize=True): def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None): result = [] estimate_num_chars = optimal_num_chars is not None - bboxes, polys = test_net(canvas_size, mag_ratio, detector, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars) - + bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, + image, text_threshold, + link_threshold, low_text, poly, + device, estimate_num_chars) if estimate_num_chars: - polys = [p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] - - for i, box in enumerate(polys): - poly = np.array(box).astype(np.int32).reshape((-1)) - result.append(poly) + polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] + for polys in polys_list] + + for polys in polys_list: + single_img_result = [] + for i, box in enumerate(polys): + poly = np.array(box).astype(np.int32).reshape((-1)) + single_img_result.append(poly) + result.append(single_img_result) return result diff --git a/easyocr/easyocr.py b/easyocr/easyocr.py index 911bc30029..6657b3110e 100644 --- a/easyocr/easyocr.py +++ b/easyocr/easyocr.py @@ -4,7 +4,8 @@ from .recognition import get_recognizer, get_text from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\ download_and_unzip, printProgressBar, diff, reformat_input,\ - make_rotated_img_list, set_result_with_confidence + make_rotated_img_list, set_result_with_confidence,\ + reformat_input_batched from .config import * from bidi.algorithm import get_display import numpy as np @@ -31,7 +32,7 @@ class Reader(object): def __init__(self, lang_list, gpu=True, model_storage_directory=None, user_network_directory=None, recog_network = 'standard', download_enabled=True, detector=True, recognizer=True, - verbose=True, quantize=True): + verbose=True, quantize=True, cudnn_benchmark=False): """Create an EasyOCR Reader. Parameters: @@ -75,7 +76,7 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, else: self.device = gpu self.recognition_models = recognition_models - + # check and download detection model detector_model = 'craft' corrupt_msg = 'MD5 hash mismatch, possible file corruption' @@ -215,7 +216,7 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, dict_list[lang] = os.path.join(BASE_PATH, 'dict', lang + ".txt") if detector: - self.detector = get_detector(detector_path, self.device, quantize) + self.detector = get_detector(detector_path, self.device, quantize, cudnn_benchmark=cudnn_benchmark) if recognizer: if recog_network == 'generation1': network_params = { @@ -271,19 +272,25 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\ if reformat: img, img_cv_grey = reformat_input(img) - text_box = get_textbox(self.detector, img, canvas_size, mag_ratio,\ - text_threshold, link_threshold, low_text,\ - False, self.device, optimal_num_chars) - horizontal_list, free_list = group_text_box(text_box, slope_ths,\ - ycenter_ths, height_ths,\ - width_ths, add_margin, \ - (optimal_num_chars is None)) - - if min_size: - horizontal_list = [i for i in horizontal_list if max(i[1]-i[0],i[3]-i[2]) > min_size] - free_list = [i for i in free_list if max(diff([c[0] for c in i]), diff([c[1] for c in i]))>min_size] - - return horizontal_list, free_list + text_box_list = get_textbox(self.detector, img, canvas_size, mag_ratio, + text_threshold, link_threshold, low_text, + False, self.device, optimal_num_chars) + + horizontal_list_agg, free_list_agg = [], [] + for text_box in text_box_list: + horizontal_list, free_list = group_text_box(text_box, slope_ths, + ycenter_ths, height_ths, + width_ths, add_margin, + (optimal_num_chars is None)) + if min_size: + horizontal_list = [i for i in horizontal_list if max( + i[1] - i[0], i[3] - i[2]) > min_size] + free_list = [i for i in free_list if max( + diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size] + horizontal_list_agg.append(horizontal_list) + free_list_agg.append(free_list) + + return horizontal_list_agg, free_list_agg def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\ decoder = 'greedy', beamWidth= 5, batch_size = 1,\ @@ -381,7 +388,8 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ slope_ths, ycenter_ths,\ height_ths,width_ths,\ add_margin, False) - + # get the 1st result from hor & free list as self.detect returns a list of depth 3 + horizontal_list, free_list = horizontal_list[0], free_list[0] result = self.recognize(img_cv_grey, horizontal_list, free_list,\ decoder, beamWidth, batch_size,\ workers, allowlist, blocklist, detail, rotation_info,\ @@ -389,3 +397,40 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ filter_ths, y_ths, x_ths, False, output_format) return result + + def readtext_batched(self, image, n_width=None, n_height=None,\ + decoder = 'greedy', beamWidth= 5, batch_size = 1,\ + workers = 0, allowlist = None, blocklist = None, detail = 1,\ + rotation_info = None, paragraph = False, min_size = 20,\ + contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ + text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\ + canvas_size = 2560, mag_ratio = 1.,\ + slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\ + width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, output_format='standard'): + ''' + Parameters: + image: file path or numpy-array or a byte stream object + When sending a list of images, they all must of the same size, + the following parameters will automatically resize if they are not None + n_width: int, new width + n_height: int, new height + ''' + img, img_cv_grey = reformat_input_batched(image, n_width, n_height) + + horizontal_list_agg, free_list_agg = self.detect(img, min_size, text_threshold,\ + low_text, link_threshold,\ + canvas_size, mag_ratio,\ + slope_ths, ycenter_ths,\ + height_ths, width_ths,\ + add_margin, False) + result_agg = [] + # put img_cv_grey in a list if its a single img + img_cv_grey = [img_cv_grey] if len(img_cv_grey.shape) == 2 else img_cv_grey + for grey_img, horizontal_list, free_list in zip(img_cv_grey, horizontal_list_agg, free_list_agg): + result_agg.append(self.recognize(grey_img, horizontal_list, free_list,\ + decoder, beamWidth, batch_size,\ + workers, allowlist, blocklist, detail, rotation_info,\ + paragraph, contrast_ths, adjust_contrast,\ + filter_ths, y_ths, x_ths, False, output_format)) + + return result_agg diff --git a/easyocr/utils.py b/easyocr/utils.py index 0254ad6149..4972a57fb1 100644 --- a/easyocr/utils.py +++ b/easyocr/utils.py @@ -726,6 +726,35 @@ def reformat_input(image): return img, img_cv_grey +def reformat_input_batched(image, n_width=None, n_height=None): + """ + reformats an image or list of images or a 4D numpy image array & + returns a list of corresponding img, img_cv_grey nd.arrays + image: + [file path, numpy-array, byte stream object, + list of file paths, list of numpy-array, 4D numpy array, + list of byte stream objects] + """ + if ((isinstance(image, np.ndarray) and len(image.shape) == 4) or isinstance(image, list)): + # process image batches if image is list of image np arr, paths, bytes + img, img_cv_grey = [], [] + for single_img in image: + clr, gry = reformat_input(single_img) + if n_width is not None and n_height is not None: + clr = cv2.resize(clr, (n_width, n_height)) + gry = cv2.resize(gry, (n_width, n_height)) + img.append(clr) + img_cv_grey.append(gry) + img, img_cv_grey = np.array(img), np.array(img_cv_grey) + # ragged tensors created when all input imgs are not of the same size + if len(img.shape) == 1 and len(img_cv_grey.shape) == 1: + raise ValueError("The input image array contains images of different sizes. " + + "Please resize all images to same shape or pass n_width, n_height to auto-resize") + else: + img, img_cv_grey = reformat_input(image) + return img, img_cv_grey + + def make_rotated_img_list(rotationInfo, img_list): result_img_list = img_list[:]