Skip to content

Commit

Permalink
torch script used for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
voun7 committed Nov 7, 2024
1 parent 0870420 commit 2dcfa63
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 57 deletions.
11 changes: 3 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@

setup(
name="subtitle_ocr",
version="1.3.2",
packages=[
'sub_ocr', 'sub_ocr.modeling', 'sub_ocr.modeling.heads', 'sub_ocr.modeling.necks', 'sub_ocr.modeling.backbones',
'sub_ocr.modeling.transforms', 'sub_ocr.modeling.architectures', 'sub_ocr.postprocess', "sub_ocr.alphabets"
],
version="1.4",
packages=['sub_ocr', 'sub_ocr.postprocess', "sub_ocr.alphabets"],
include_package_data=True,
package_data={"sub_ocr.alphabets": ["*.txt"]},
install_requires=[
"torch@https://download.pytorch.org/whl/cu124/"
"torch-2.5.1%2Bcu124-cp312-cp312-win_amd64.whl ;platform_system=='Windows'",
"torchvision@https://download.pytorch.org/whl/cu124/"
"torchvision-0.20.1%2Bcu124-cp312-cp312-win_amd64.whl ;platform_system=='Windows'",
"torchvision;platform_system!='Windows'", "opencv-python", "shapely", "pyclipper", "requests"
"torch;platform_system!='Windows'", "opencv-python", "shapely", "pyclipper", "requests"
],
url="https://github.com/voun7/Subtitle_OCR",
license="",
Expand Down
4 changes: 2 additions & 2 deletions sub_ocr/modeling/necks/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(self, in_channels, **kwargs):
self.out_channels = in_channels

def forward(self, x):
B, C, H, W = x.shape
assert H == 1
# B, C, H, W = x.shape
# assert H == 1 # raises torch TracerWarning
x = x.squeeze(dim=2)
x = x.permute(0, 2, 1) # (NTC)(batch, width, channels)
return x
Expand Down
37 changes: 4 additions & 33 deletions sub_ocr/subtitle_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import requests
import torch

from sub_ocr.modeling import build_model
from sub_ocr.postprocess import build_post_process
from sub_ocr.utils import read_image, normalize_img, pascal_voc_bb

Expand All @@ -32,31 +31,18 @@ class SubtitleOCR:
"det": {
"en": {
"en_det_ppocr_v3": {
"Architecture": {'model_type': 'det', 'algorithm': 'DB', 'Transform': None,
'Backbone': {'name': 'MobileNetV3', 'scale': 0.5, 'model_name': 'large',
'disable_se': True},
'Neck': {'name': 'RSEFPN', 'out_channels': 96, 'shortcut': True},
'Head': {'name': 'DBHead', 'k': 50}},
"params": {"height": 960, "width": 960, "m32": True, "sort_merge": True},
"PostProcess": {'name': 'DBPostProcess', 'thresh': 0.3, 'box_thresh': 0.6, 'max_candidates': 1000,
'unclip_ratio': 2.5}
},
},
"ch": {
"ch_PP-OCRv4_det_student": {
"Architecture": {'model_type': 'det', 'algorithm': 'DB', 'Transform': None,
'Backbone': {'name': 'PPLCNetV3', 'scale': 0.75, 'det': True},
'Neck': {'name': 'RSEFPN', 'out_channels': 96, 'shortcut': True},
'Head': {'name': 'DBHead', 'k': 50}},
"params": {"height": 640, "width": 640, "m32": True, "sort_merge": True},
"PostProcess": {'name': 'DBPostProcess', 'thresh': 0.3, 'box_thresh': 0.6, 'max_candidates': 1000,
'unclip_ratio': 2.5}
},
"ch_ptocr_v4_det_infer.pth": {
"Architecture": {'model_type': 'det', 'algorithm': 'DB', 'Transform': None,
'Backbone': {'name': 'PPLCNetV3', 'scale': 0.75, 'det': True},
'Neck': {'name': 'RSEFPN', 'out_channels': 96, 'shortcut': True},
'Head': {'name': 'DBHead', 'k': 50}},
"params": {"height": 640, "width": 960, "m32": True, "sort_merge": False},
"PostProcess": {'name': 'DBPostProcess', 'thresh': 0.3, 'box_thresh': 0.6, 'max_candidates': 1000,
'unclip_ratio': 2.5}
Expand All @@ -66,24 +52,12 @@ class SubtitleOCR:
"rec": {
"en": {
"en_PP-OCRv4_rec": {
"Architecture": {'model_type': 'rec', 'algorithm': 'SVTR_LCNet', 'Transform': None,
'Backbone': {'name': 'PPLCNetV3', 'scale': 0.95},
'Neck': {'name': 'SequenceEncoder', 'encoder_type': 'svtr', 'dims': 120,
'depth': 2,
'hidden_dims': 120, 'kernel_size': [1, 3], 'use_guide': True},
'Head': {'name': 'CTCHead'}},
"params": {"height": 48, "width": 320},
"PostProcess": {'name': 'CTCLabelDecode'}
},
},
"ch": {
"ch_PP-OCRv4_rec": {
"Architecture": {'model_type': 'rec', 'algorithm': 'SVTR_LCNet', 'Transform': None,
'Backbone': {'name': 'PPLCNetV3', 'scale': 0.95},
'Neck': {'name': 'SequenceEncoder', 'encoder_type': 'svtr', 'dims': 120,
'depth': 2,
'hidden_dims': 120, 'kernel_size': [1, 3], 'use_guide': True},
'Head': {'name': 'CTCHead'}},
"params": {"height": 48, "width": 320},
"PostProcess": {'name': 'CTCLabelDecode'}
},
Expand Down Expand Up @@ -123,18 +97,15 @@ def init_model(self, model_type: str) -> tuple:
Setup model and post processor.
"""
config_name = self.default_configs[f"{model_type}_{self.lang}"]
config = self.configs[model_type][self.lang][config_name]
config = self.configs[model_type][self.lang][config_name] | {"lang": self.lang}
if ".pt" in config_name:
model_file = self.models_dir / config_name
else:
model_file = next(self.models_dir.glob(f"{config_name} *.pt")) # best loss will be used
config.update({"lang": self.lang})
model, post_processor = build_model(config), build_post_process(config)

traced_model, post_processor = torch.jit.load(model_file, map_location=self.device), build_post_process(config)
traced_model.eval()
logger.debug(f"Device: {self.device}, Model Config: {config},\nModel File: {model_file}")
model.load_state_dict(torch.load(model_file, self.device, weights_only=True))
model.to(self.device).eval()
return model, post_processor, config["params"]
return traced_model, post_processor, config["params"]

def det_image_resize(self, image: np.ndarray) -> np.ndarray:
scale = min(self.det_params["height"] / image.shape[0], self.det_params["width"] / image.shape[1])
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def train_model(model_dir: str, config_name: str, config: dict) -> None:
trainer = ModelTrainer(model, train_params)
trainer.set_loaders(train_ds, val_ds, params["batch_size"], params["val_batch_size"], params["num_workers"])
trainer.load_checkpoint("")
trainer.train()
trainer.train_model()
trainer.save_model()


def main() -> None:
Expand Down
14 changes: 8 additions & 6 deletions utilities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def update_writer(self) -> None:
self.writer.add_scalars("Validation Metric", {k: self.val_metrics[k][i] for k in self.val_metrics}, i + 1)
logger.info("Writer Updated with Checkpoint data.")

def train(self, seed: int = None) -> None:
def train_model(self, seed: int = None) -> None:
assert self.train_loader and self.val_loader, "Train or Val data loader has not been set!"
start_time, self.writer = perf_counter(), SummaryWriter()
best_model_wts = deepcopy(self.model.state_dict()) # Initial copy of model weights is saved
Expand Down Expand Up @@ -229,7 +229,6 @@ def train(self, seed: int = None) -> None:
self.model.load_state_dict(best_model_wts)

self.writer.close() # Closes the writer
self.save_model(val_loss["loss"])
logger.info(f"Model Training Completed. Duration: {self.dur_calc(start_time)}")
logger.debug(f"Trainer Values:\n{self.losses=}\n{self.val_losses=}\n{self.metrics=}\n{self.val_metrics=}\n"
f"{self.learning_rates=}")
Expand Down Expand Up @@ -302,13 +301,16 @@ def load_checkpoint(self, checkpoint_file: str, new_learning_rate: float = None,
logger.debug(f"Checkpoint Values:\n{self.losses=}\n{self.val_losses=}\n{self.metrics=}\n{self.val_metrics=}\n"
f"{self.learning_rates=}")

def save_model(self, last_val_loss: float = None) -> None:
def save_model(self) -> None:
"""
Save the model state and checkpoint from the last epoch.
:param last_val_loss: Value of validation loss in the last epoch.
"""
save_path = self.model_dir / f"{self.model_filename} ({last_val_loss or self.best_val_loss}).pt"
torch.save(self.model.state_dict(), save_path)
save_path = self.model_dir / f"{self.model_filename} ({self.best_val_loss}).pt"
self.model.eval()
with torch.inference_mode():
dummy_input = torch.rand(1, 3, 48, 320, device=self.device)
traced_model = torch.jit.trace(self.model, dummy_input)
torch.jit.save(traced_model, save_path)
logger.info(f"Model Saved! Path: {save_path}")

def create_model_checkpoint(self, model_file: str) -> None:
Expand Down
12 changes: 5 additions & 7 deletions utilities/visualize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from collections import Counter

import cv2 as cv
Expand Down Expand Up @@ -122,12 +121,11 @@ def visualize_model(model, ds_data: dict) -> None:
"""
Visualize the model with a tensorboard graph.
"""
with warnings.catch_warnings(action="ignore", category=torch.jit.TracerWarning):
input_image = torch.from_numpy(ds_data["image"]).unsqueeze(0)
writer = SummaryWriter(comment="_model_graph")
writer.add_graph(model, input_image)
writer.close()
print("\nModel Graph Created! Run 'tensorboard --logdir=runs' to view graph.")
input_image = torch.from_numpy(ds_data["image"]).unsqueeze(0)
writer = SummaryWriter(comment="_model_graph")
writer.add_graph(model, input_image)
writer.close()
print("\nModel Graph Created! Run 'tensorboard --logdir=runs' to view graph.")


def visualize_feature_maps(model, ds_data: dict, debug: bool = False) -> None:
Expand Down

0 comments on commit 2dcfa63

Please sign in to comment.