Skip to content

Commit

Permalink
init commit for paddlestructure
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jun 5, 2021
1 parent a5f7511 commit bc0d766
Show file tree
Hide file tree
Showing 17 changed files with 385 additions and 110 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
include LICENSE.txt
include README.md

recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
Expand Down
3 changes: 3 additions & 0 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
sys.path.append(os.path.join(__dir__, ''))

import cv2
import logging
import numpy as np
from pathlib import Path

Expand Down Expand Up @@ -150,6 +151,8 @@ def __init__(self, **kwargs):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [
Expand Down
3 changes: 1 addition & 2 deletions ppocr/utils/dict/table_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ D
Π
H
</
>
</strike>
L
Φ
Χ
Expand Down
9 changes: 9 additions & 0 deletions ppstructure/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
include LICENSE.txt
include README.md

recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include table *.py
recursive-include ppstructure *.py
Empty file removed ppstructure/layout/README.md
Empty file.
Empty file removed ppstructure/layout/README_ch.md
Empty file.
161 changes: 161 additions & 0 deletions ppstructure/paddlestructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys

__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))


import cv2
import numpy as np
from pathlib import Path

from ppocr.utils.logging import get_logger
from predict_system import OCRSystem, save_res
from utility import init_args

logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar

__all__ = ['PaddleStructure']

VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/")

model_urls = {
'det': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'rec': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'structure': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
}


def parse_args(mMain=True):
import argparse
parser = init_args()
parser.add_help = mMain

for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
else:
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
return argparse.Namespace(**inference_args_dict)


class PaddleStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
params.use_angle_cls = False
# init model dir
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det')
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec')
if params.structure_model_dir is None:
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure')
# download model
maybe_download(params.det_model_dir, model_urls['det'])
maybe_download(params.det_model_dir, model_urls['rec'])
maybe_download(params.det_model_dir, model_urls['structure'])

if params.rec_char_dict_path is None:
params.rec_char_type = 'EN'
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.structure_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.structure_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else:
params.structure_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')

print(params)
super().__init__(params)

def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

res = super().__call__(img)
return res


def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
save_folder = args.output
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return

structure_engine = PaddleStructure(**(args.__dict__))
for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = structure_engine(img_path)
save_res(result, args.output, os.path.basename(img_path).split('.')[0])
for item in result:
logger.info(item['res'])
save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))



if __name__ == '__main__':
table_engine = PaddleStructure(det_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_det_infer',
rec_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_rec_infer',
structure_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_structure_infer',
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
show_log=True)
img = cv2.imread('/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/table_1.png')
result = table_engine(img)
for line in result:
print(line)
102 changes: 49 additions & 53 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,97 +18,93 @@

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import numpy as np
import time
import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.layout.predict_layout import LayoutDetector

import layoutparser as lp

from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args

logger = get_logger()


def parse_args():
parser = utility.init_args()

# params for output
parser.add_argument("--table_output", type=str, default='output/table')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_max_text_length", type=int, default=100)
parser.add_argument("--table_max_elem_length", type=int, default=800)
parser.add_argument("--table_max_cell_num", type=int, default=500)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt")

# params for layout detector
parser.add_argument("--layout_model_dir", type=str)
return parser.parse_args()


class OCRSystem():
class OCRSystem(object):
def __init__(self, args):
self.text_system = TextSystem(args)
self.table_system = TableSystem(args)
self.table_layout = LayoutDetector(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score

def __call__(self, img):
ori_im = img.copy()
layout_res = self.table_layout(copy.deepcopy(img))
layout_res = self.table_layout.detect(img[..., ::-1])
res_list = []
for region in layout_res:
x1, y1, x2, y2 = region['bbox']
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region['label'] == 'table':
res = self.text_system(roi_img)
if region.type == 'Table':
res = self.table_system(roi_img)
elif region.type == 'Figure':
continue
else:
res = self.text_system(roi_img)
region['res'] = res
return layout_res
filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
res = (filter_boxes, filter_rec_res)
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list


def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
pass
else:
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))


def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
save_folder = args.table_output
save_folder = args.output
os.makedirs(save_folder, exist_ok=True)

text_sys = OCRSystem(args)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
img_name = os.path.basename(image_file).split('.')[0]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')

if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = text_sys(img)

excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['label'] == 'table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
else:
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
logger.info(res)
res = structure_sys(img)
save_res(res, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))

Expand Down
Loading

0 comments on commit bc0d766

Please sign in to comment.