Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refinement and refactor: model.py (ocr/model.py) #176

Open
aronfelipe opened this issue Oct 28, 2024 · 0 comments
Open

Refinement and refactor: model.py (ocr/model.py) #176

aronfelipe opened this issue Oct 28, 2024 · 0 comments

Comments

@aronfelipe
Copy link

aronfelipe commented Oct 28, 2024

代码概述

这段改进后的代码在以下几个方面优于您提供的代码:

  1. 导入模块:对导入的库进行了整理,使代码结构更清晰。
  2. CTC损失函数:定义了在训练过程中使用的CTC损失函数。
  3. 模型架构:构建了CNN-RNN模型,包含卷积层和GRU层,增强了模型的表达能力。
  4. 模型加载:在指定模型文件存在时加载模型权重,便于快速复用。
  5. 预测函数:处理输入图像,进行预测,并清理输出结果,确保输出的准确性。
  6. 解码函数:将预测的输出转换为人类可读的文本,提升了结果的可理解性。
  7. 清理输出:去除最终结果中的任何前导标点符号,提高输出的整洁性。

改进亮点

  • 结构清晰:通过整理模块导入,使代码更易读。
  • 函数文档:为每个函数添加了注释,帮助理解其作用,提升可维护性。
  • 输出处理模块化:将输出清理逻辑抽取到独立函数中,增强代码的模块化。
  • 一致的变量命名:使用更具描述性的变量名,使代码语义更加明确。
  • 减少冗余:优化了重复的代码逻辑,使整体代码更加简洁高效。

这些改进使得代码更加易读、可维护,同时增强了功能的清晰度和可用性。

import os
import sys
import numpy as np
from keras.layers import (Input, Conv2D, MaxPooling2D, ZeroPadding2D,
                         BatchNormalization, Permute, TimeDistributed, 
                         Flatten, Bidirectional, GRU, Dense, Lambda)
from keras.models import Model
from keras.optimizers import SGD
import keras.backend as K
import keys_ocr

# Define CTC loss function
def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    y_pred = y_pred[:, 2:, :]  # Remove first two frames for CTC
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

# Model architecture
def get_model(height, nclass):
    rnnunit = 256
    input_tensor = Input(shape=(height, None, 1), name='the_input')

    # CNN layers
    x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')(input_tensor)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same')(x)
    x = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same')(x)
    x = ZeroPadding2D(padding=(0, 1))(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 1))(x)
    x = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization(axis=1)(x)
    x = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization(axis=1)(x)
    x = ZeroPadding2D(padding=(0, 1))(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 1))(x)
    x = Conv2D(512, kernel_size=(2, 2), activation='relu', padding='valid')(x)

    # Reshape for RNN
    x = Permute((2, 1, 3))(x)
    x = TimeDistributed(Flatten())(x)

    # RNN layers
    x = Bidirectional(GRU(rnnunit, return_sequences=True))(x)
    x = Dense(rnnunit, activation='linear')(x)
    x = Bidirectional(GRU(rnnunit, return_sequences=True))(x)
    y_pred = Dense(nclass, activation='softmax')(x)

    # Create model for training
    basemodel = Model(inputs=input_tensor, outputs=y_pred)

    # Define inputs for CTC loss
    labels = Input(name='the_labels', shape=[None, ], dtype='float32')
    input_length = Input(name='input_length', shape=[1], dtype='int64')
    label_length = Input(name='label_length', shape=[1], dtype='int64')
    loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
    model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])
    
    # Compile model
    sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)

    return model, basemodel

# Load model
characters = keys_ocr.alphabet[:]
modelPath = os.path.join(os.getcwd(), "ocr/ocr0.2.h5")
height = 32
nclass = len(characters) + 1

if os.path.exists(modelPath):
    model, basemodel = get_model(height, nclass)
    basemodel.load_weights(modelPath)

def predict(im):
    """
    Input an image and return the recognized result from the keras model.
    """
    im = im.convert('L')  # Convert image to grayscale
    scale = im.size[1] / 32.0
    w = int(im.size[0] / scale)
    im = im.resize((w, 32))
    
    img = np.array(im).astype(np.float32) / 255.0
    X = img.reshape((32, w, 1))
    X = np.array([X])
    
    # Predict
    y_pred = basemodel.predict(X)
    y_pred = y_pred[:, 2:, :]  # Remove first two frames
    out = decode(y_pred)

    # Clean output
    out = clean_output(out)
    return out

def decode(pred):
    charactersS = characters + ' '  # Add space character
    t = pred.argmax(axis=2)[0]
    char_list = []
    n = len(characters)
    
    for i in range(len(t)):
        if t[i] != n and (i == 0 or t[i] != t[i - 1]):  # Avoid duplicates
            char_list.append(charactersS[t[i]])
    
    return ''.join(char_list)

def clean_output(out):
    while out and out[0] == '。':
        out = out[1:]  # Remove leading punctuation
    return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant