-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
157 lines (129 loc) · 6.02 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import argparse
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
import numpy as np
# function to load the file list from a text file (generated by split_data)
def load_file_list(file_path):
file_list = []
with open(file_path, 'r') as f:
for line in f:
stripped_line = line.strip()
file_list.append(stripped_line)
return file_list
# a function that finds the highest label in the dataset to be used to determine the number of classes for the OCR model.
def find_max_label(file_list):
max_label = 0
for file in file_list:
parts = file.split(os.sep)
if parts[-3] in ['200', '300', '400']:
label = int(parts[-4])
else:
label = int(parts[-3])
if label > max_label:
max_label = label
return max_label
# dataset class for loading images and corresponding labels with resolution and style
class OCRDataset(Dataset):
def __init__(self, file_list, transform=None):
self.file_list = file_list
self.transform = transform
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
line = self.file_list[idx].strip()
img_path, resolution, style = line.split(',')
img_path = img_path.strip()
image = Image.open(img_path).convert('L')
if self.transform:
image = self.transform(image)
label = self.get_label_from_path(img_path)
return image, label, resolution, style
def get_label_from_path(self, path):
parts = path.split(os.sep)
if parts[-3] in ['200', '300', '400']:
character_folder = parts[-4]
else:
character_folder = parts[-3]
label = int(character_folder)
return label
# model architecture for OCR with Conv2D and Linear layers
class OCRModel(nn.Module):
def __init__(self, num_classes, img_dims):
super(OCRModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * (img_dims[0]//2) * (img_dims[1]//2), 128)
self.fc2 = nn.Linear(128, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# function to scale images
def scale_imgs(imgs):
size = imgs.shape
scaled_imgs = StandardScaler().fit_transform(
imgs.reshape(size[0], size[1]*size[2])).reshape(size)
return scaled_imgs
# training function
def train_model(train_loader, val_loader, num_epochs=10, learning_rate=0.001, num_classes=100, img_dims=(64, 64), model_path='ocr_model.pth'):
model = OCRModel(num_classes=num_classes, img_dims=img_dims)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels, resolution, style in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")
torch.save(model.state_dict(), model_path)
print(f"Model saved {model_path}")
# main function to handle arguments and to run the training script
def main():
parser = argparse.ArgumentParser(description=" To train an OCR model with resolution and style handling.")
parser.add_argument('--train_file', type=str, required=True, help="Path to the training file list.")
parser.add_argument('--val_file', type=str, required=True, help="Path to the validation file list.")
parser.add_argument('--output_dir', type=str, default='./output', help="Directory where the model is saved.")
parser.add_argument('--epochs', type=int, default=10, help="Number of epochs to train the model.")
parser.add_argument('--batch_size', type=int, default=32, help="Batch size for training.")
parser.add_argument('--learning_rate', type=float, default=0.001, help="Learning rate for optimizer.")
parser.add_argument('--img_height', type=int, default=64, help="Height of input images.")
parser.add_argument('--img_width', type=int, default=64, help="Width of input images.")
args = parser.parse_args()
train_files = load_file_list(args.train_file)
val_files = load_file_list(args.val_file)
max_train_label = find_max_label(train_files)
max_val_label = find_max_label(val_files)
num_classes = max(max_train_label, max_val_label) + 1
img_dims = (args.img_height, args.img_width)
transform = transforms.Compose([
transforms.Resize(img_dims),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = OCRDataset(train_files, transform=transform)
val_dataset = OCRDataset(val_files, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
model_path = os.path.join(args.output_dir, 'ocr_model.pth')
train_model(train_loader, val_loader, num_epochs=args.epochs, learning_rate=args.learning_rate, num_classes=num_classes, img_dims=img_dims, model_path=model_path)
if __name__ == "__main__":
main()