-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
刘亮均
authored and
刘亮均
committed
Apr 22, 2020
1 parent
5ba89bc
commit 9f7862a
Showing
118 changed files
with
31,614 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
图像文字方向检测 | ||
@author: xiaofeng | ||
""" |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# _Author_: xiaofeng | ||
# Date: 2018-04-22 18:13:46 | ||
# Last Modified by: xiaofeng | ||
# Last Modified time: 2018-04-22 18:13:46 | ||
''' | ||
根据给定的图形,分析文字的朝向 | ||
''' | ||
# from keras.models import load_model | ||
import numpy as np | ||
from PIL import Image | ||
from keras.applications.vgg16 import preprocess_input, VGG16 | ||
from keras.layers import Dense | ||
from keras.models import Model | ||
# 编译模型,以较小的学习参数进行训练 | ||
from keras.optimizers import SGD | ||
|
||
|
||
def load(): | ||
vgg = VGG16(weights=None, input_shape=(224, 224, 3)) | ||
# 修改输出层 3个输出 | ||
x = vgg.layers[-2].output | ||
predictions_class = Dense( | ||
4, activation='softmax', name='predictions_class')(x) | ||
prediction = [predictions_class] | ||
model = Model(inputs=vgg.input, outputs=prediction) | ||
sgd = SGD(lr=0.00001, momentum=0.9) | ||
model.compile( | ||
optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy']) | ||
model.load_weights( | ||
'angle/modelAngle.h5') | ||
return model | ||
|
||
|
||
# 加载模型 | ||
model = None | ||
|
||
|
||
def predict(path=None, img=None): | ||
global model | ||
if model is None: | ||
model = load() | ||
""" | ||
图片文字方向预测 | ||
""" | ||
ROTATE = [0, 90, 180, 270] | ||
if path is not None: | ||
im = Image.open(path).convert('RGB') | ||
elif img is not None: | ||
im = Image.fromarray(img).convert('RGB') | ||
w, h = im.size | ||
# 对图像进行剪裁 | ||
# 左上角(int(0.1 * w), int(0.1 * h)) | ||
# 右下角(w - int(0.1 * w), h - int(0.1 * h)) | ||
xmin, ymin, xmax, ymax = int(0.1 * w), int( | ||
0.1 * h), w - int(0.1 * w), h - int(0.1 * h) | ||
im = im.crop((xmin, ymin, xmax, ymax)) # 剪切图片边缘,清除边缘噪声 | ||
# 对图片进行剪裁之后进行resize成(224,224) | ||
im = im.resize((224, 224)) | ||
# 将图像转化成数组形式 | ||
img = np.array(im) | ||
img = preprocess_input(img.astype(np.float32)) | ||
pred = model.predict(np.array([img])) | ||
index = np.argmax(pred, axis=1)[0] | ||
return ROTATE[index] |
Empty file.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+2.37 KB
ctpn_detect_v1/ctpn/ctpn/__pycache__/text_proposal_connector.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+3.19 KB
ctpn_detect_v1/ctpn/ctpn/__pycache__/text_proposal_graph_builder.cpython-36.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import numpy as np | ||
|
||
|
||
class Config: | ||
MEAN = np.float32([102.9801, 115.9465, 122.7717]) | ||
# MEAN=np.float32([100.0, 100.0, 100.0]) | ||
TEST_GPU_ID = 0 | ||
SCALE = 900 | ||
MAX_SCALE = 1500 | ||
TEXT_PROPOSALS_WIDTH = 0 | ||
MIN_RATIO = 0.01 | ||
LINE_MIN_SCORE = 0.6 | ||
TEXT_LINE_NMS_THRESH = 0.3 | ||
MAX_HORIZONTAL_GAP = 30 | ||
TEXT_PROPOSALS_MIN_SCORE = 0.7 | ||
TEXT_PROPOSALS_NMS_THRESH = 0.3 | ||
MIN_NUM_PROPOSALS = 0 | ||
MIN_V_OVERLAPS = 0.6 | ||
MIN_SIZE_SIM = 0.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import glob | ||
import os | ||
import shutil | ||
import sys | ||
|
||
import cv2 | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
sys.path.append(parentdir) | ||
|
||
from lib.networks.factory import get_network | ||
from lib.fast_rcnn.config import cfg | ||
from lib.fast_rcnn.test import test_ctpn | ||
from lib.fast_rcnn.nms_wrapper import nms | ||
from lib.utils.timer import Timer | ||
from text_proposal_connector import TextProposalConnector | ||
|
||
CLASSES = ('__background__', 'text') | ||
|
||
|
||
def connect_proposal(text_proposals, scores, im_size): | ||
cp = TextProposalConnector() | ||
line = cp.get_text_lines(text_proposals, scores, im_size) | ||
return line | ||
|
||
|
||
def save_results(image_name, im, line, thresh): | ||
inds = np.where(line[:, -1] >= thresh)[0] | ||
if len(inds) == 0: | ||
return | ||
|
||
for i in inds: | ||
bbox = line[i, :4] | ||
score = line[i, -1] | ||
cv2.rectangle( | ||
im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), | ||
color=(0, 0, 255), | ||
thickness=1) | ||
image_name = image_name.split('/')[-1] | ||
cv2.imwrite(os.path.join("../data/results", image_name), im) | ||
|
||
|
||
def check_img(im): | ||
im_size = im.shape | ||
if max(im_size[0:2]) < 600: | ||
img = np.zeros((600, 600, 3), dtype=np.uint8) | ||
start_row = int((600 - im_size[0]) / 2) | ||
start_col = int((600 - im_size[1]) / 2) | ||
end_row = start_row + im_size[0] | ||
end_col = start_col + im_size[1] | ||
img[start_row:end_row, start_col:end_col, :] = im | ||
return img | ||
else: | ||
return im | ||
|
||
|
||
def ctpn(sess, net, image_name): | ||
img = cv2.imread(image_name) | ||
im = check_img(img) | ||
timer = Timer() | ||
timer.tic() | ||
scores, boxes = test_ctpn(sess, net, im) | ||
timer.toc() | ||
# print('Detection took {:.3f}s for ' | ||
# '{:d} object proposals').format(timer.total_time, boxes.shape[0]) | ||
|
||
# Visualize detections for each class | ||
CONF_THRESH = 0.9 | ||
NMS_THRESH = 0.3 | ||
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32) | ||
keep = nms(dets, NMS_THRESH) | ||
dets = dets[keep, :] | ||
|
||
keep = np.where(dets[:, 4] >= 0.7)[0] | ||
dets = dets[keep, :] | ||
line = connect_proposal(dets[:, 0:4], dets[:, 4], im.shape) | ||
save_results(image_name, im, line, thresh=0.9) | ||
|
||
|
||
if __name__ == '__main__': | ||
if os.path.exists("../data/results/"): | ||
shutil.rmtree("../data/results/") | ||
os.makedirs("../data/results/") | ||
|
||
cfg.TEST.HAS_RPN = True # Use RPN for proposals | ||
# init session | ||
config = tf.ConfigProto(allow_soft_placement=True) | ||
sess = tf.Session(config=config) | ||
# load network | ||
net = get_network("VGGnet_test") | ||
# load model | ||
print('Loading network {:s}... '.format("VGGnet_test")), | ||
saver = tf.train.Saver() | ||
# saver.restore(sess, | ||
# os.path.join(os.getcwd(), "checkpoints/model_final.ckpt")) | ||
saver.restore(sess, | ||
os.path.join(os.getcwd(), | ||
"/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/checkpoints/VGGnet_fast_rcnn_iter_50000.ckpt")) | ||
print(' done.') | ||
|
||
# Warmup on a dummy image | ||
im = 128 * np.ones((300, 300, 3), dtype=np.uint8) | ||
for i in range(2): | ||
_, _ = test_ctpn(sess, net, im) | ||
|
||
im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \ | ||
glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) | ||
|
||
for im_name in im_names: | ||
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') | ||
print('Demo for {:s}'.format(im_name)) | ||
ctpn(sess, net, im_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# coding:utf-8 | ||
import sys | ||
|
||
import numpy as np | ||
|
||
from .cfg import Config as cfg | ||
from .other import normalize | ||
|
||
sys.path.append('..') | ||
from ..lib.fast_rcnn.nms_wrapper import nms | ||
# from lib.fast_rcnn.test import test_ctpn | ||
|
||
from .text_proposal_connector import TextProposalConnector | ||
|
||
|
||
class TextDetector: | ||
""" | ||
Detect text from an image | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
pass | ||
""" | ||
self.text_proposal_connector = TextProposalConnector() | ||
|
||
def detect(self, text_proposals, scores, size): | ||
""" | ||
Detecting texts from an image | ||
:return: the bounding boxes of the detected texts | ||
""" | ||
# text_proposals, scores=self.text_proposal_detector.detect(im, cfg.MEAN) | ||
keep_inds = np.where(scores > cfg.TEXT_PROPOSALS_MIN_SCORE)[0] | ||
text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] | ||
|
||
sorted_indices = np.argsort(scores.ravel())[::-1] | ||
text_proposals, scores = text_proposals[sorted_indices], scores[sorted_indices] | ||
|
||
# nms for text proposals | ||
keep_inds = nms(np.hstack((text_proposals, scores)), cfg.TEXT_PROPOSALS_NMS_THRESH) | ||
text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] | ||
|
||
scores = normalize(scores) | ||
|
||
text_lines = self.text_proposal_connector.get_text_lines(text_proposals, scores, size) | ||
|
||
keep_inds = self.filter_boxes(text_lines) | ||
text_lines = text_lines[keep_inds] | ||
|
||
if text_lines.shape[0] != 0: | ||
keep_inds = nms(text_lines, cfg.TEXT_LINE_NMS_THRESH) | ||
text_lines = text_lines[keep_inds] | ||
|
||
return text_lines | ||
|
||
def filter_boxes(self, boxes): | ||
heights = boxes[:, 3] - boxes[:, 1] + 1 | ||
widths = boxes[:, 2] - boxes[:, 0] + 1 | ||
scores = boxes[:, -1] | ||
return np.where((widths / heights > cfg.MIN_RATIO) & (scores > cfg.LINE_MIN_SCORE) & | ||
(widths > (cfg.TEXT_PROPOSALS_WIDTH * cfg.MIN_NUM_PROPOSALS)))[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import sys | ||
import os | ||
|
||
import tensorflow as tf | ||
|
||
from .cfg import Config | ||
from .other import resize_im | ||
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) | ||
|
||
|
||
sys.path.append(os.getcwd()) | ||
from lib.fast_rcnn.config import cfg | ||
from lib.networks.factory import get_network | ||
from lib.fast_rcnn.test import test_ctpn | ||
|
||
# from ..lib.networks.factory import get_network | ||
# from ..lib.fast_rcnn.config import cfg | ||
# from..lib.fast_rcnn.test import test_ctpn | ||
''' | ||
load network | ||
输入的名称为'Net_model' | ||
'VGGnet_test'--test | ||
'VGGnet_train'-train | ||
''' | ||
|
||
|
||
def load_tf_model(): | ||
cfg.TEST.HAS_RPN = True # Use RPN for proposals | ||
# init session | ||
config = tf.ConfigProto(allow_soft_placement=True) | ||
net = get_network("VGGnet_test") | ||
# load model | ||
saver = tf.train.Saver() | ||
# sess = tf.Session(config=config) | ||
sess = tf.Session() | ||
ckpt_path = './ctpn/ctpn/retrain/ckpt' | ||
ckpt = tf.train.get_checkpoint_state(ckpt_path) | ||
reader = tf.train.NewCheckpointReader(ckpt.model_checkpoint_path) | ||
var_to_shape_map = reader.get_variable_to_shape_map() | ||
for key in var_to_shape_map: | ||
print("Tensor_name is : ", key) | ||
# print(reader.get_tensor(key)) | ||
saver.restore(sess, ckpt.model_checkpoint_path) | ||
print("load vggnet done") | ||
return sess, saver, net | ||
|
||
|
||
|
||
# init model | ||
sess, saver, net = load_tf_model() | ||
|
||
|
||
# 进行文本识别 | ||
def ctpn(img): | ||
""" | ||
text box detect | ||
""" | ||
scale, max_scale = Config.SCALE, Config.MAX_SCALE | ||
# 对图像进行resize,输出的图像长宽 | ||
img, f = resize_im(img, scale=scale, max_scale=max_scale) | ||
scores, boxes = test_ctpn(sess, net, img) | ||
return scores, boxes, img |
Oops, something went wrong.