From 16c6b953342d5f12940d39f0c0dfdd2a3df54c29 Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Fri, 15 Jul 2022 07:19:44 +0000 Subject: [PATCH] Format code with autopep8 This commit fixes the style issues introduced in 5244817 according to the output from autopep8. Details: https://deepsource.io/gh/gulldan/EasyOCR/transform/a4b4d3ba-d295-4792-9e43-d5307febeaef/ --- easyocr/easyocr.py | 322 ++++++++++++++----------- trainer/craft/metrics/eval_det_iou.py | 14 +- trainer/craft/utils/inference_boxes.py | 25 +- trainer/modules/transformation.py | 68 ++++-- trainer/test.py | 38 +-- 5 files changed, 268 insertions(+), 199 deletions(-) diff --git a/easyocr/easyocr.py b/easyocr/easyocr.py index 7127c28f4..6987176ae 100644 --- a/easyocr/easyocr.py +++ b/easyocr/easyocr.py @@ -3,9 +3,9 @@ from .detection import get_detector, get_textbox from .recognition import get_recognizer, get_text from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\ - download_and_unzip, printProgressBar, diff, reformat_input,\ - make_rotated_img_list, set_result_with_confidence,\ - reformat_input_batched + download_and_unzip, printProgressBar, diff, reformat_input,\ + make_rotated_img_list, set_result_with_confidence,\ + reformat_input_batched from .config import * from bidi.algorithm import get_display import numpy as np @@ -27,10 +27,11 @@ LOGGER = getLogger(__name__) + class Reader(object): def __init__(self, lang_list, gpu=True, model_storage_directory=None, - user_network_directory=None, recog_network = 'standard', + user_network_directory=None, recog_network='standard', download_enabled=True, detector=True, recognizer=True, verbose=True, quantize=True, cudnn_benchmark=False): """Create an EasyOCR Reader @@ -66,11 +67,13 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, if gpu is False: self.device = 'cpu' if verbose: - LOGGER.warning('Using CPU. Note: This module is much faster with a GPU.') + LOGGER.warning( + 'Using CPU. Note: This module is much faster with a GPU.') elif not torch.cuda.is_available(): self.device = 'cpu' if verbose: - LOGGER.warning('CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU.') + LOGGER.warning( + 'CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU.') elif gpu is True: self.device = 'cuda' else: @@ -80,25 +83,32 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, # check and download detection model detector_model = 'craft' corrupt_msg = 'MD5 hash mismatch, possible file corruption' - detector_path = os.path.join(self.model_storage_directory, detection_models[detector_model]['filename']) + detector_path = os.path.join( + self.model_storage_directory, detection_models[detector_model]['filename']) if detector: if os.path.isfile(detector_path) == False: if not self.download_enabled: - raise FileNotFoundError("Missing %s and downloads disabled" % detector_path) + raise FileNotFoundError( + "Missing %s and downloads disabled" % detector_path) LOGGER.warning('Downloading detection model, please wait. ' 'This may take several minutes depending upon your network connection.') - download_and_unzip(detection_models[detector_model]['url'], detection_models[detector_model]['filename'], self.model_storage_directory, verbose) - assert calculate_md5(detector_path) == detection_models[detector_model]['md5sum'], corrupt_msg + download_and_unzip(detection_models[detector_model]['url'], + detection_models[detector_model]['filename'], self.model_storage_directory, verbose) + assert calculate_md5( + detector_path) == detection_models[detector_model]['md5sum'], corrupt_msg LOGGER.info('Download complete') elif calculate_md5(detector_path) != detection_models[detector_model]['md5sum']: if not self.download_enabled: - raise FileNotFoundError("MD5 mismatch for %s and downloads disabled" % detector_path) + raise FileNotFoundError( + "MD5 mismatch for %s and downloads disabled" % detector_path) LOGGER.warning(corrupt_msg) os.remove(detector_path) LOGGER.warning('Re-downloading the detection model, please wait. ' 'This may take several minutes depending upon your network connection.') - download_and_unzip(detection_models[detector_model]['url'], detection_models[detector_model]['filename'], self.model_storage_directory, verbose) - assert calculate_md5(detector_path) == detection_models[detector_model]['md5sum'], corrupt_msg + download_and_unzip(detection_models[detector_model]['url'], + detection_models[detector_model]['filename'], self.model_storage_directory, verbose) + assert calculate_md5( + detector_path) == detection_models[detector_model]['md5sum'], corrupt_msg # recognition model separator_list = {} @@ -112,57 +122,69 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, model = recognition_models['gen2'][recog_network] recog_network = 'generation2' self.model_lang = model['model_script'] - else: # auto-detect + else: # auto-detect unknown_lang = set(lang_list) - set(all_lang_list) if unknown_lang != set(): raise ValueError(unknown_lang, 'is not supported') # choose recognition model if lang_list == ['en']: - self.setModelLanguage('english', lang_list, ['en'], '["en"]') + self.setModelLanguage( + 'english', lang_list, ['en'], '["en"]') model = recognition_models['gen2']['english_g2'] recog_network = 'generation2' elif 'th' in lang_list: - self.setModelLanguage('thai', lang_list, ['th','en'], '["th","en"]') + self.setModelLanguage('thai', lang_list, [ + 'th', 'en'], '["th","en"]') model = recognition_models['gen1']['thai_g1'] recog_network = 'generation1' elif 'ch_tra' in lang_list: - self.setModelLanguage('chinese_tra', lang_list, ['ch_tra','en'], '["ch_tra","en"]') + self.setModelLanguage('chinese_tra', lang_list, [ + 'ch_tra', 'en'], '["ch_tra","en"]') model = recognition_models['gen1']['zh_tra_g1'] recog_network = 'generation1' elif 'ch_sim' in lang_list: - self.setModelLanguage('chinese_sim', lang_list, ['ch_sim','en'], '["ch_sim","en"]') + self.setModelLanguage('chinese_sim', lang_list, [ + 'ch_sim', 'en'], '["ch_sim","en"]') model = recognition_models['gen2']['zh_sim_g2'] recog_network = 'generation2' elif 'ja' in lang_list: - self.setModelLanguage('japanese', lang_list, ['ja','en'], '["ja","en"]') + self.setModelLanguage('japanese', lang_list, [ + 'ja', 'en'], '["ja","en"]') model = recognition_models['gen2']['japanese_g2'] recog_network = 'generation2' elif 'ko' in lang_list: - self.setModelLanguage('korean', lang_list, ['ko','en'], '["ko","en"]') + self.setModelLanguage('korean', lang_list, [ + 'ko', 'en'], '["ko","en"]') model = recognition_models['gen2']['korean_g2'] recog_network = 'generation2' elif 'ta' in lang_list: - self.setModelLanguage('tamil', lang_list, ['ta','en'], '["ta","en"]') + self.setModelLanguage('tamil', lang_list, [ + 'ta', 'en'], '["ta","en"]') model = recognition_models['gen1']['tamil_g1'] recog_network = 'generation1' elif 'te' in lang_list: - self.setModelLanguage('telugu', lang_list, ['te','en'], '["te","en"]') + self.setModelLanguage('telugu', lang_list, [ + 'te', 'en'], '["te","en"]') model = recognition_models['gen2']['telugu_g2'] recog_network = 'generation2' elif 'kn' in lang_list: - self.setModelLanguage('kannada', lang_list, ['kn','en'], '["kn","en"]') + self.setModelLanguage('kannada', lang_list, [ + 'kn', 'en'], '["kn","en"]') model = recognition_models['gen2']['kannada_g2'] recog_network = 'generation2' elif set(lang_list) & set(bengali_lang_list): - self.setModelLanguage('bengali', lang_list, bengali_lang_list+['en'], '["bn","as","en"]') + self.setModelLanguage( + 'bengali', lang_list, bengali_lang_list+['en'], '["bn","as","en"]') model = recognition_models['gen1']['bengali_g1'] recog_network = 'generation1' elif set(lang_list) & set(arabic_lang_list): - self.setModelLanguage('arabic', lang_list, arabic_lang_list+['en'], '["ar","fa","ur","ug","en"]') + self.setModelLanguage( + 'arabic', lang_list, arabic_lang_list+['en'], '["ar","fa","ur","ug","en"]') model = recognition_models['gen1']['arabic_g1'] recog_network = 'generation1' elif set(lang_list) & set(devanagari_lang_list): - self.setModelLanguage('devanagari', lang_list, devanagari_lang_list+['en'], '["hi","mr","ne","en"]') + self.setModelLanguage( + 'devanagari', lang_list, devanagari_lang_list+['en'], '["hi","mr","ne","en"]') model = recognition_models['gen1']['devanagari_g1'] recog_network = 'generation1' elif set(lang_list) & set(cyrillic_lang_list): @@ -176,42 +198,50 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, recog_network = 'generation2' self.character = model['characters'] - model_path = os.path.join(self.model_storage_directory, model['filename']) + model_path = os.path.join( + self.model_storage_directory, model['filename']) # check recognition model file if recognizer: if os.path.isfile(model_path) == False: if not self.download_enabled: - raise FileNotFoundError("Missing %s and downloads disabled" % model_path) + raise FileNotFoundError( + "Missing %s and downloads disabled" % model_path) LOGGER.warning('Downloading recognition model, please wait. ' 'This may take several minutes depending upon your network connection.') - download_and_unzip(model['url'], model['filename'], self.model_storage_directory, verbose) - assert calculate_md5(model_path) == model['md5sum'], corrupt_msg + download_and_unzip( + model['url'], model['filename'], self.model_storage_directory, verbose) + assert calculate_md5( + model_path) == model['md5sum'], corrupt_msg LOGGER.info('Download complete.') elif calculate_md5(model_path) != model['md5sum']: if not self.download_enabled: - raise FileNotFoundError("MD5 mismatch for %s and downloads disabled" % model_path) + raise FileNotFoundError( + "MD5 mismatch for %s and downloads disabled" % model_path) LOGGER.warning(corrupt_msg) os.remove(model_path) LOGGER.warning('Re-downloading the recognition model, please wait. ' 'This may take several minutes depending upon your network connection.') - download_and_unzip(model['url'], model['filename'], self.model_storage_directory, verbose) - assert calculate_md5(model_path) == model['md5sum'], corrupt_msg + download_and_unzip( + model['url'], model['filename'], self.model_storage_directory, verbose) + assert calculate_md5( + model_path) == model['md5sum'], corrupt_msg LOGGER.info('Download complete') self.setLanguageList(lang_list, model) - else: # user-defined model - with open(os.path.join(self.user_network_directory, recog_network+ '.yaml'), encoding='utf8') as file: + else: # user-defined model + with open(os.path.join(self.user_network_directory, recog_network + '.yaml'), encoding='utf8') as file: recog_config = yaml.load(file, Loader=yaml.FullLoader) - - global imgH # if custom model, save this variable. (from *.yaml) + + global imgH # if custom model, save this variable. (from *.yaml) if recog_config['imgH']: imgH = recog_config['imgH'] - + available_lang = recog_config['lang_list'] - self.setModelLanguage(recog_network, lang_list, available_lang, str(available_lang)) + self.setModelLanguage(recog_network, lang_list, + available_lang, str(available_lang)) #char_file = os.path.join(self.user_network_directory, recog_network+ '.txt') self.character = recog_config['character_list'] - model_file = recog_network+ '.pth' + model_file = recog_network + '.pth' model_path = os.path.join(self.model_storage_directory, model_file) self.setLanguageList(lang_list, recog_config) @@ -220,32 +250,34 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, dict_list[lang] = os.path.join(BASE_PATH, 'dict', lang + ".txt") if detector: - self.detector = get_detector(detector_path, self.device, quantize, cudnn_benchmark=cudnn_benchmark) + self.detector = get_detector( + detector_path, self.device, quantize, cudnn_benchmark=cudnn_benchmark) if recognizer: if recog_network == 'generation1': network_params = { 'input_channel': 1, 'output_channel': 512, 'hidden_size': 512 - } + } elif recog_network == 'generation2': network_params = { 'input_channel': 1, 'output_channel': 256, 'hidden_size': 256 - } + } else: network_params = recog_config['network_params'] - self.recognizer, self.converter = get_recognizer(recog_network, network_params,\ - self.character, separator_list,\ - dict_list, model_path, device = self.device, quantize=quantize) + self.recognizer, self.converter = get_recognizer(recog_network, network_params, + self.character, separator_list, + dict_list, model_path, device=self.device, quantize=quantize) def setModelLanguage(self, language, lang_list, list_lang, list_lang_string): self.model_lang = language if set(lang_list) - set(list_lang) != set(): if language == 'ch_tra' or language == 'ch_sim': language = 'chinese' - raise ValueError(language.capitalize() + ' is only compatible with English, try lang_list=' + list_lang_string) + raise ValueError(language.capitalize( + ) + ' is only compatible with English, try lang_list=' + list_lang_string) def getChar(self, fileName): char_file = os.path.join(BASE_PATH, 'character', fileName) @@ -257,9 +289,10 @@ def getChar(self, fileName): def setLanguageList(self, lang_list, model): self.lang_char = [] for lang in lang_list: - char_file = os.path.join(BASE_PATH, 'character', lang + "_char.txt") - with open(char_file, "r", encoding = "utf-8-sig") as input_file: - char_list = input_file.read().splitlines() + char_file = os.path.join( + BASE_PATH, 'character', lang + "_char.txt") + with open(char_file, "r", encoding="utf-8-sig") as input_file: + char_list = input_file.read().splitlines() self.lang_char += char_list if model.get('symbols'): symbol = model['symbols'] @@ -270,10 +303,10 @@ def setLanguageList(self, lang_list, model): self.lang_char = set(self.lang_char).union(set(symbol)) self.lang_char = ''.join(self.lang_char) - def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\ - link_threshold = 0.4,canvas_size = 2560, mag_ratio = 1.,\ - slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\ - width_ths = 0.5, add_margin = 0.1, reformat=True, optimal_num_chars=None): + def detect(self, img, min_size=20, text_threshold=0.7, low_text=0.4, + link_threshold=0.4, canvas_size=2560, mag_ratio=1., + slope_ths=0.1, ycenter_ths=0.5, height_ths=0.5, + width_ths=0.5, add_margin=0.1, reformat=True, optimal_num_chars=None): if reformat: img, img_cv_grey = reformat_input(img) @@ -298,12 +331,12 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\ return horizontal_list_agg, free_list_agg - def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\ - decoder = 'greedy', beamWidth= 5, batch_size = 1,\ - workers = 0, allowlist = None, blocklist = None, detail = 1,\ - rotation_info = None,paragraph = False,\ - contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ - y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard'): + def recognize(self, img_cv_grey, horizontal_list=None, free_list=None, + decoder='greedy', beamWidth=5, batch_size=1, + workers=0, allowlist=None, blocklist=None, detail=1, + rotation_info=None, paragraph=False, + contrast_ths=0.1, adjust_contrast=0.5, filter_ths=0.003, + y_ths=0.5, x_ths=1.0, reformat=True, output_format='standard'): if reformat: img, img_cv_grey = reformat_input(img_cv_grey) @@ -315,9 +348,10 @@ def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\ else: ignore_char = ''.join(set(self.character)-set(self.lang_char)) - if self.model_lang in ['chinese_tra','chinese_sim']: decoder = 'greedy' + if self.model_lang in ['chinese_tra', 'chinese_sim']: + decoder = 'greedy' - if (horizontal_list==None) and (free_list==None): + if (horizontal_list == None) and (free_list == None): y_max, x_max = img_cv_grey.shape horizontal_list = [[0, x_max, 0, y_max]] free_list = [] @@ -328,33 +362,36 @@ def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\ for bbox in horizontal_list: h_list = [bbox] f_list = [] - image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH) - result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\ - ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\ - workers, self.device) + image_list, max_width = get_image_list( + h_list, f_list, img_cv_grey, model_height=imgH) + result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list, + ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths, + workers, self.device) result += result0 for bbox in free_list: h_list = [] f_list = [bbox] - image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH) - result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\ - ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\ - workers, self.device) + image_list, max_width = get_image_list( + h_list, f_list, img_cv_grey, model_height=imgH) + result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list, + ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths, + workers, self.device) result += result0 # default mode will try to process multiple boxes at the same time else: - image_list, max_width = get_image_list(horizontal_list, free_list, img_cv_grey, model_height = imgH) + image_list, max_width = get_image_list( + horizontal_list, free_list, img_cv_grey, model_height=imgH) image_len = len(image_list) if rotation_info and image_list: image_list = make_rotated_img_list(rotation_info, image_list) max_width = max(max_width, imgH) - result = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\ - ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\ - workers, self.device) + result = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list, + ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths, + workers, self.device) if rotation_info and (horizontal_list+free_list): - # Reshape result to be a list of lists, each row being for + # Reshape result to be a list of lists, each row being for # one of the rotations (first row being no rotation) result = set_result_with_confidence( [result[image_len*i:image_len*(i+1)] for i in range(len(rotation_info) + 1)]) @@ -368,110 +405,112 @@ def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\ direction_mode = 'ltr' if paragraph: - result = get_paragraph(result, x_ths=x_ths, y_ths=y_ths, mode = direction_mode) + result = get_paragraph(result, x_ths=x_ths, + y_ths=y_ths, mode=direction_mode) if detail == 0: return [item[1] for item in result] elif output_format == 'dict': - return [ {'boxes':item[0],'text':item[1],'confident':item[2]} for item in result] + return [{'boxes': item[0], 'text':item[1], 'confident':item[2]} for item in result] else: return result - def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ - workers = 0, allowlist = None, blocklist = None, detail = 1,\ - rotation_info = None, paragraph = False, min_size = 20,\ - contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ - text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\ - canvas_size = 2560, mag_ratio = 1.,\ - slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\ - width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, output_format='standard'): + def readtext(self, image, decoder='greedy', beamWidth=5, batch_size=1, + workers=0, allowlist=None, blocklist=None, detail=1, + rotation_info=None, paragraph=False, min_size=20, + contrast_ths=0.1, adjust_contrast=0.5, filter_ths=0.003, + text_threshold=0.7, low_text=0.4, link_threshold=0.4, + canvas_size=2560, mag_ratio=1., + slope_ths=0.1, ycenter_ths=0.5, height_ths=0.5, + width_ths=0.5, y_ths=0.5, x_ths=1.0, add_margin=0.1, output_format='standard'): ''' Parameters: image: file path or numpy-array or a byte stream object ''' img, img_cv_grey = reformat_input(image) - horizontal_list, free_list = self.detect(img, min_size, text_threshold,\ - low_text, link_threshold,\ - canvas_size, mag_ratio,\ - slope_ths, ycenter_ths,\ - height_ths,width_ths,\ + horizontal_list, free_list = self.detect(img, min_size, text_threshold, + low_text, link_threshold, + canvas_size, mag_ratio, + slope_ths, ycenter_ths, + height_ths, width_ths, add_margin, False) # get the 1st result from hor & free list as self.detect returns a list of depth 3 horizontal_list, free_list = horizontal_list[0], free_list[0] - result = self.recognize(img_cv_grey, horizontal_list, free_list,\ - decoder, beamWidth, batch_size,\ - workers, allowlist, blocklist, detail, rotation_info,\ - paragraph, contrast_ths, adjust_contrast,\ + result = self.recognize(img_cv_grey, horizontal_list, free_list, + decoder, beamWidth, batch_size, + workers, allowlist, blocklist, detail, rotation_info, + paragraph, contrast_ths, adjust_contrast, filter_ths, y_ths, x_ths, False, output_format) return result - - def readtextlang(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ - workers = 0, allowlist = None, blocklist = None, detail = 1,\ - rotation_info = None, paragraph = False, min_size = 20,\ - contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ - text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\ - canvas_size = 2560, mag_ratio = 1.,\ - slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\ - width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, output_format='standard'): + + def readtextlang(self, image, decoder='greedy', beamWidth=5, batch_size=1, + workers=0, allowlist=None, blocklist=None, detail=1, + rotation_info=None, paragraph=False, min_size=20, + contrast_ths=0.1, adjust_contrast=0.5, filter_ths=0.003, + text_threshold=0.7, low_text=0.4, link_threshold=0.4, + canvas_size=2560, mag_ratio=1., + slope_ths=0.1, ycenter_ths=0.5, height_ths=0.5, + width_ths=0.5, y_ths=0.5, x_ths=1.0, add_margin=0.1, output_format='standard'): ''' Parameters: image: file path or numpy-array or a byte stream object ''' img, img_cv_grey = reformat_input(image) - horizontal_list, free_list = self.detect(img, min_size, text_threshold,\ - low_text, link_threshold,\ - canvas_size, mag_ratio,\ - slope_ths, ycenter_ths,\ - height_ths,width_ths,\ + horizontal_list, free_list = self.detect(img, min_size, text_threshold, + low_text, link_threshold, + canvas_size, mag_ratio, + slope_ths, ycenter_ths, + height_ths, width_ths, add_margin, False) # get the 1st result from hor & free list as self.detect returns a list of depth 3 horizontal_list, free_list = horizontal_list[0], free_list[0] - result = self.recognize(img_cv_grey, horizontal_list, free_list,\ - decoder, beamWidth, batch_size,\ - workers, allowlist, blocklist, detail, rotation_info,\ - paragraph, contrast_ths, adjust_contrast,\ + result = self.recognize(img_cv_grey, horizontal_list, free_list, + decoder, beamWidth, batch_size, + workers, allowlist, blocklist, detail, rotation_info, + paragraph, contrast_ths, adjust_contrast, filter_ths, y_ths, x_ths, False, output_format) - + char = [] directory = 'characters/' for i in range(len(result)): char.append(result[i][1]) - - def search(arr,x): + + def search(arr, x): g = False for i in range(len(arr)): - if arr[i]==x: + if arr[i] == x: g = True return 1 if g == False: return -1 + def tupleadd(i): a = result[i] b = a + (filename[0:2],) return b - + for filename in os.listdir(directory): if filename.endswith(".txt"): - with open ('characters/'+ filename,'rt',encoding="utf8") as myfile: - chartrs = str(myfile.read().splitlines()).replace('\n','') + with open('characters/' + filename, 'rt', encoding="utf8") as myfile: + chartrs = str(myfile.read().splitlines()).replace('\n', '') for i in range(len(char)): - res = search(chartrs,char[i]) + res = search(chartrs, char[i]) if res != -1: - if filename[0:2]=="en" or filename[0:2]=="ch": + if filename[0:2] == "en" or filename[0:2] == "ch": print(tupleadd(i)) - def readtext_batched(self, image, n_width=None, n_height=None,\ - decoder = 'greedy', beamWidth= 5, batch_size = 1,\ - workers = 0, allowlist = None, blocklist = None, detail = 1,\ - rotation_info = None, paragraph = False, min_size = 20,\ - contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ - text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\ - canvas_size = 2560, mag_ratio = 1.,\ - slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\ - width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, output_format='standard'): + def readtext_batched(self, image, n_width=None, n_height=None, + decoder='greedy', beamWidth=5, batch_size=1, + workers=0, allowlist=None, blocklist=None, detail=1, + rotation_info=None, paragraph=False, min_size=20, + contrast_ths=0.1, adjust_contrast=0.5, filter_ths=0.003, + text_threshold=0.7, low_text=0.4, link_threshold=0.4, + canvas_size=2560, mag_ratio=1., + slope_ths=0.1, ycenter_ths=0.5, height_ths=0.5, + width_ths=0.5, y_ths=0.5, x_ths=1.0, add_margin=0.1, output_format='standard'): ''' Parameters: image: file path or numpy-array or a byte stream object @@ -482,20 +521,21 @@ def readtext_batched(self, image, n_width=None, n_height=None,\ ''' img, img_cv_grey = reformat_input_batched(image, n_width, n_height) - horizontal_list_agg, free_list_agg = self.detect(img, min_size, text_threshold,\ - low_text, link_threshold,\ - canvas_size, mag_ratio,\ - slope_ths, ycenter_ths,\ - height_ths, width_ths,\ + horizontal_list_agg, free_list_agg = self.detect(img, min_size, text_threshold, + low_text, link_threshold, + canvas_size, mag_ratio, + slope_ths, ycenter_ths, + height_ths, width_ths, add_margin, False) result_agg = [] # put img_cv_grey in a list if its a single img - img_cv_grey = [img_cv_grey] if len(img_cv_grey.shape) == 2 else img_cv_grey + img_cv_grey = [img_cv_grey] if len( + img_cv_grey.shape) == 2 else img_cv_grey for grey_img, horizontal_list, free_list in zip(img_cv_grey, horizontal_list_agg, free_list_agg): - result_agg.append(self.recognize(grey_img, horizontal_list, free_list,\ - decoder, beamWidth, batch_size,\ - workers, allowlist, blocklist, detail, rotation_info,\ - paragraph, contrast_ths, adjust_contrast,\ - filter_ths, y_ths, x_ths, False, output_format)) + result_agg.append(self.recognize(grey_img, horizontal_list, free_list, + decoder, beamWidth, batch_size, + workers, allowlist, blocklist, detail, rotation_info, + paragraph, contrast_ths, adjust_contrast, + filter_ths, y_ths, x_ths, False, output_format)) return result_agg diff --git a/trainer/craft/metrics/eval_det_iou.py b/trainer/craft/metrics/eval_det_iou.py index a518865cb..76d119e4f 100644 --- a/trainer/craft/metrics/eval_det_iou.py +++ b/trainer/craft/metrics/eval_det_iou.py @@ -97,7 +97,7 @@ def compute_ap(confList, matchList, numGtCare): if not Polygon(points).is_valid or not Polygon(points).is_simple: continue except: - import ipdb; + import ipdb ipdb.set_trace() #import ipdb;ipdb.set_trace() @@ -158,7 +158,8 @@ def compute_ap(confList, matchList, numGtCare): pairs.append({'gt': gtNum, 'det': detNum}) detMatchedNums.append(detNum) evaluationLog += "Match GT #" + \ - str(gtNum) + " with Det #" + str(detNum) + "\n" + str(gtNum) + " with Det #" + \ + str(detNum) + "\n" numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) numDetCare = (len(detPols) - len(detDontCarePolsNum)) @@ -167,10 +168,11 @@ def compute_ap(confList, matchList, numGtCare): precision = float(0) if numDetCare > 0 else float(1) else: recall = float(detMatched) / numGtCare - precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + precision = 0 if numDetCare == 0 else float( + detMatched) / numDetCare hmean = 0 if (precision + recall) == 0 else 2.0 * \ - precision * recall / (precision + recall) + precision * recall / (precision + recall) matchedSum += detMatched numGlobalCareGt += numGtCare @@ -208,8 +210,8 @@ def combine_results(self, results): methodPrecision = 0 if numGlobalCareDet == 0 else float( matchedSum) / numGlobalCareDet methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ - methodRecall * methodPrecision / ( - methodRecall + methodPrecision) + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) # print(methodRecall, methodPrecision, methodHmean) # sys.exit(-1) methodMetrics = { diff --git a/trainer/craft/utils/inference_boxes.py b/trainer/craft/utils/inference_boxes.py index 14fa13672..92eee9906 100644 --- a/trainer/craft/utils/inference_boxes.py +++ b/trainer/craft/utils/inference_boxes.py @@ -27,6 +27,7 @@ def rotatePoint(xc, yc, xp, yp, theta): # pRes = (xc + pResx, yc + pResy) return int(xc + pResx), int(yc + pResy) + def addRotatedShape(cx, cy, w, h, angle): p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle) p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle) @@ -37,6 +38,7 @@ def addRotatedShape(cx, cy, w, h, angle): return points + def xml_parsing(xml): tree = elemTree.parse(xml) @@ -46,7 +48,8 @@ def xml_parsing(xml): for element in iter_element: annotation = {} # Initialize the dict to store labels - annotation['name'] = element.find("name").text # Save the name tag value + annotation['name'] = element.find( + "name").text # Save the name tag value box_coords = element.iter(tag="robndbox") @@ -75,9 +78,6 @@ def xml_parsing(xml): [xmin, ymax]] annotations.append(annotation) - - - bounds = [] for i in range(len(annotations)): box_info_dict = {"points": None, "text": None, "ignore": None} @@ -92,14 +92,12 @@ def xml_parsing(xml): bounds.append(box_info_dict) - - return bounds #-------------------------------------------------------------------------------------------------------------------# -def load_prescription_gt(dataFolder): +def load_prescription_gt(dataFolder): total_img_path = [] total_imgs_bboxes = [] @@ -112,7 +110,6 @@ def load_prescription_gt(dataFolder): gt_path = os.path.join(root, file) total_imgs_bboxes.append(gt_path) - total_imgs_parsing_bboxes = [] for img_path, bbox in zip(sorted(total_img_path), sorted(total_imgs_bboxes)): # check file @@ -122,14 +119,12 @@ def load_prescription_gt(dataFolder): result_label = xml_parsing(bbox) total_imgs_parsing_bboxes.append(result_label) - return total_imgs_parsing_bboxes, sorted(total_img_path) # NOTE def load_prescription_cleval_gt(dataFolder): - total_img_path = [] total_gt_path = [] for (root, directories, files) in os.walk(dataFolder): @@ -141,7 +136,6 @@ def load_prescription_cleval_gt(dataFolder): gt_path = os.path.join(root, file) total_gt_path.append(gt_path) - total_imgs_parsing_bboxes = [] for img_path, gt_path in zip(sorted(total_img_path), sorted(total_gt_path)): # check file @@ -216,7 +210,8 @@ def load_icdar2015_gt(dataFolder, isTraing=False): total_imgs_bboxes = [] total_img_path = [] for gt_path in gt_folder_path: - gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path) + gt_path = os.path.join(os.path.join( + dataFolder, gt_folderName), gt_path) img_path = ( gt_path.replace(gt_folderName, img_folderName) .replace(".txt", ".jpg") @@ -234,7 +229,8 @@ def load_icdar2015_gt(dataFolder, isTraing=False): word = ",".join(word) box_points = np.array(box_points, np.int32).reshape(4, 2) cv2.polylines( - image, [np.array(box_points).astype(np.int)], True, (0, 0, 255), 1 + image, [np.array(box_points).astype( + np.int)], True, (0, 0, 255), 1 ) box_info_dict["points"] = box_points box_info_dict["text"] = word @@ -264,7 +260,8 @@ def load_icdar2013_gt(dataFolder, isTraing=False): total_imgs_bboxes = [] total_img_path = [] for gt_path in gt_folder_path: - gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path) + gt_path = os.path.join(os.path.join( + dataFolder, gt_folderName), gt_path) img_path = ( gt_path.replace(gt_folderName, img_folderName) .replace(".txt", ".jpg") diff --git a/trainer/modules/transformation.py b/trainer/modules/transformation.py index 17c0edb0a..9a0d3ae49 100644 --- a/trainer/modules/transformation.py +++ b/trainer/modules/transformation.py @@ -23,14 +23,18 @@ def __init__(self, F, I_size, I_r_size, I_channel_num=1): self.I_size = I_size self.I_r_size = I_r_size # = (I_r_height, I_r_width) self.I_channel_num = I_channel_num - self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.LocalizationNetwork = LocalizationNetwork( + self.F, self.I_channel_num) self.GridGenerator = GridGenerator(self.F, self.I_r_size) def forward(self, batch_I): batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 - build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 - build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) - batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) + build_P_prime_reshape = build_P_prime.reshape( + [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + batch_I_r = F.grid_sample( + batch_I, build_P_prime_reshape, padding_mode='border') return batch_I_r @@ -46,15 +50,19 @@ def __init__(self, F, I_channel_num): nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 - nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d( + 128), nn.ReLU(True), nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 - nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d( + 256), nn.ReLU(True), nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 - nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d( + 512), nn.ReLU(True), nn.AdaptiveAvgPool2d(1) # batch_size x 512 ) - self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) + self.localization_fc1 = nn.Sequential( + nn.Linear(512, 256), nn.ReLU(True)) self.localization_fc2 = nn.Linear(256, self.F * 2) # Init fc2 in LocalizationNetwork @@ -66,7 +74,8 @@ def __init__(self, F, I_channel_num): ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) - self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + self.localization_fc2.bias.data = torch.from_numpy( + initial_bias).float().view(-1) def forward(self, batch_I): """ @@ -75,7 +84,8 @@ def forward(self, batch_I): """ batch_size = batch_I.size(0) features = self.conv(batch_I).view(batch_size, -1) - batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + batch_C_prime = self.localization_fc2( + self.localization_fc1(features)).view(batch_size, self.F, 2) return batch_C_prime @@ -90,12 +100,14 @@ def __init__(self, F, I_r_size): self.F = F self.C = self._build_C(self.F) # F x 2 self.P = self._build_P(self.I_r_width, self.I_r_height) - ## for multi-gpu, you need register buffer - self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 - self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 - ## for fine-tuning with different image width, you may use below instead of self.register_buffer - #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 - #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 + # for multi-gpu, you need register buffer + self.register_buffer("inv_delta_C", torch.tensor( + self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor( + self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + # for fine-tuning with different image width, you may use below instead of self.register_buffer + # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 + # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 def _build_C(self, F): """ Return coordinates of fiducial points in I_r; C """ @@ -121,8 +133,10 @@ def _build_inv_delta_C(self, F, C): delta_C = np.concatenate( # F+3 x F+3 [ np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 - np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 - np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + np.concatenate( + [np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate( + [np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 ], axis=0 ) @@ -130,8 +144,10 @@ def _build_inv_delta_C(self, F, C): return inv_delta_C # F+3 x F+3 def _build_P(self, I_r_width, I_r_height): - I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width - I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + + 1.0) / I_r_height # self.I_r_height P = np.stack( # self.I_r_width x self.I_r_height x 2 np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2 @@ -140,11 +156,14 @@ def _build_P(self, I_r_width, I_r_height): def _build_P_hat(self, F, C, P): n = P.shape[0] # n (= self.I_r_width x self.I_r_height) - P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1) + ) # n x 2 -> n x 1 x 2 -> n x F x 2 C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 P_diff = P_tile - C_tile # n x F x 2 - rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F - rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + rbf_norm = np.linalg.norm( + P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log( + rbf_norm + self.eps)) # n x F P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) return P_hat # n x F+3 @@ -155,6 +174,7 @@ def build_P_prime(self, batch_C_prime): batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 - batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 return batch_P_prime # batch_size x n x 2 diff --git a/trainer/test.py b/trainer/test.py index 48eaa76c7..531c5bb24 100644 --- a/trainer/test.py +++ b/trainer/test.py @@ -14,6 +14,7 @@ from dataset import hierarchical_dataset, AlignCollate from model import Model + def validation(model, criterion, evaluation_loader, converter, opt, device): """ validation or evaluation """ n_correct = 0 @@ -27,11 +28,14 @@ def validation(model, criterion, evaluation_loader, converter, opt, device): length_of_data = length_of_data + batch_size image = image_tensors.to(device) # For max length prediction - length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) - text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) + length_for_pred = torch.IntTensor( + [opt.batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor( + batch_size, opt.batch_max_length + 1).fill_(0).to(device) + + text_for_loss, length_for_loss = converter.encode( + labels, batch_max_length=opt.batch_max_length) - text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) - start_time = time.time() if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) @@ -40,13 +44,15 @@ def validation(model, criterion, evaluation_loader, converter, opt, device): # Calculate evaluation loss for CTC decoder. preds_size = torch.IntTensor([preds.size(1)] * batch_size) # permute 'preds' to use CTCloss format - cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) + cost = criterion(preds.log_softmax(2).permute( + 1, 0, 2), text_for_loss, preds_size, length_for_loss) if opt.decode == 'greedy': # Select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_index = preds_index.view(-1) - preds_str = converter.decode_greedy(preds_index.data, preds_size.data) + preds_str = converter.decode_greedy( + preds_index.data, preds_size.data) elif opt.decode == 'beamsearch': preds_str = converter.decode_beamsearch(preds, beamWidth=2) @@ -56,7 +62,8 @@ def validation(model, criterion, evaluation_loader, converter, opt, device): preds = preds[:, :text_for_loss.shape[1] - 1, :] target = text_for_loss[:, 1:] # without [GO] Symbol - cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) + cost = criterion(preds.contiguous().view(-1, + preds.shape[-1]), target.contiguous().view(-1)) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) @@ -70,12 +77,13 @@ def validation(model, criterion, evaluation_loader, converter, opt, device): preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) confidence_score_list = [] - + for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred_EOS = pred.find('[s]') - pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) + # prune after "end of sentence" token ([s]) + pred = pred[:pred_EOS] pred_max_prob = pred_max_prob[:pred_EOS] if pred == gt: @@ -89,9 +97,9 @@ def validation(model, criterion, evaluation_loader, converter, opt, device): else: norm_ED += edit_distance(pred, gt) / len(gt) ''' - - # ICDAR2019 Normalized Edit Distance - if len(gt) == 0 or len(pred) ==0: + + # ICDAR2019 Normalized Edit Distance + if len(gt) == 0 or len(pred) == 0: norm_ED += 0 elif len(gt) > len(pred): norm_ED += 1 - edit_distance(pred, gt) / len(gt) @@ -102,11 +110,13 @@ def validation(model, criterion, evaluation_loader, converter, opt, device): try: confidence_score = pred_max_prob.cumprod(dim=0)[-1] except: - confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) + # for empty pred case, when prune after "end of sentence" token ([s]) + confidence_score = 0 confidence_score_list.append(confidence_score) # print(pred, gt, pred==gt, confidence_score) accuracy = n_correct / float(length_of_data) * 100 - norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance + # ICDAR2019 Normalized Edit Distance + norm_ED = norm_ED / float(length_of_data) return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data