-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9b04a4d
Showing
29 changed files
with
877 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
label_idx,label_name | ||
0,0 | ||
1,1 | ||
2,2 | ||
3,3 | ||
4,4 | ||
5,5 | ||
6,6 | ||
7,7 | ||
8,8 | ||
9,9 | ||
10,A | ||
11,B | ||
12,C | ||
13,D | ||
14,E | ||
15,F | ||
16,G | ||
17,H | ||
18,J | ||
19,K | ||
20,L | ||
21,M | ||
22,N | ||
23,P | ||
24,Q | ||
25,R | ||
26,S | ||
27,T | ||
28,U | ||
29,V | ||
30,W | ||
31,X | ||
32,Y | ||
33,Z | ||
34,�� | ||
35,�� | ||
36,�� | ||
37,�� | ||
38,�� | ||
39,�� | ||
40,�� | ||
41,�� | ||
42,�� | ||
43,�� | ||
44,�� | ||
45,�� | ||
46,�� | ||
47,³ | ||
48,�� | ||
49,�� | ||
50,�� | ||
51,�� | ||
52,�� | ||
53,�� | ||
54,�� | ||
55,�� | ||
56,�� | ||
57,�� | ||
58,�� | ||
59,ԥ | ||
60,�� | ||
61,�� | ||
62,�� | ||
63,�� | ||
64,�� |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
|
||
import warnings | ||
import torch as t | ||
|
||
class DefaultConfig(object): | ||
env = 'default' # visdom 环境 | ||
vis_port =8097 # visdom 端口 | ||
model = 'SqueezeNet' # 使用的模型,名字必须与models/__init__.py中的名字一致 | ||
classifier_num = 2 # 分类器最终的分类数量 | ||
gray = False # 读取图片是否为灰度图 | ||
|
||
train_data_root = './imgs/images/cnn_plate_train/' # 训练集存放路径 | ||
test_data_root = './data/test/' # 测试集存放路径 | ||
load_model_path = None # 加载预训练的模型的路径,为None代表不加载 | ||
|
||
batch_size = 16 # batch size | ||
use_gpu = True # user GPU or not | ||
num_workers = 0 # how many workers for loading data | ||
print_freq = 20 # print info every N batch | ||
|
||
debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb | ||
result_file = 'result.csv' | ||
id_file = './findplate/plate.csv' | ||
|
||
max_epoch = 100 | ||
lr = 0.001 # initial learning rate | ||
lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay | ||
weight_decay = 0e-5 # 损失函数 | ||
|
||
|
||
def _parse(self, kwargs): | ||
""" | ||
根据字典kwargs 更新 config参数 | ||
""" | ||
for k, v in kwargs.items(): | ||
if not hasattr(self, k): | ||
warnings.warn("Warning: opt has not attribut %s" % k) | ||
setattr(self, k, v) | ||
|
||
self.device =t.device('cuda') if self.use_gpu else t.device('cpu') | ||
|
||
|
||
print('user config:') | ||
for k, v in self.__class__.__dict__.items(): | ||
if not k.startswith('_'): | ||
print(k, getattr(self, k)) | ||
|
||
opt = DefaultConfig() |
Empty file.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
|
||
import os | ||
from PIL import Image | ||
from torch.utils import data | ||
import numpy as np | ||
from torchvision import transforms as T | ||
from torchvision.datasets import ImageFolder | ||
import random | ||
from findplate.config import opt | ||
|
||
|
||
class MyDataset(data.Dataset): | ||
|
||
def __init__(self, root, transforms=None, train=True, test=False): | ||
""" | ||
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据 | ||
""" | ||
self.test = test | ||
|
||
if self.test: | ||
imgs = [os.path.join(root, img) for img in os.listdir(root)] | ||
else: | ||
dataset = ImageFolder(root) | ||
self.data_classes = dataset.classes | ||
imgs = [dataset.imgs[i][0] for i in range(len(dataset.imgs))] | ||
labels = [dataset.imgs[i][1] for i in range(len(dataset.imgs))] | ||
imgs_num = len(imgs) | ||
|
||
if self.test: | ||
self.imgs = imgs | ||
|
||
# 按7:3的比例划分训练集和验证集 | ||
elif train: | ||
self.imgs = [] | ||
self.labels = [] | ||
for i in range(imgs_num): | ||
if random.random()<0.7: | ||
self.imgs.append(imgs[i]) | ||
self.labels.append(labels[i]) | ||
else: | ||
self.imgs = [] | ||
self.labels = [] | ||
for i in range(imgs_num): | ||
if random.random()>0.7: | ||
self.imgs.append(imgs[i]) | ||
self.labels.append(labels[i]) | ||
if transforms is None: | ||
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
if self.test or not train: | ||
self.transforms = T.Compose([ | ||
T.Resize(224), | ||
T.CenterCrop(224), | ||
T.ToTensor(), | ||
normalize | ||
]) | ||
else: | ||
self.transforms = T.Compose([ | ||
T.Resize(256), | ||
T.RandomResizedCrop(224), | ||
T.RandomHorizontalFlip(), | ||
T.ToTensor(), | ||
normalize | ||
]) | ||
|
||
def id_to_class(self, index): | ||
return self.data_classes(index) | ||
|
||
def __getitem__(self, index): | ||
""" | ||
一次返回一张图片的数据 | ||
""" | ||
img_path = self.imgs[index] | ||
if self.test: | ||
# label = self.imgs[index].split('.')[-2].split('/')[-1] | ||
label = img_path.split('/')[-1] | ||
else: | ||
label = self.labels[index] | ||
data = Image.open(img_path) | ||
if opt.gray == True: | ||
dataRGB = data.convert('RGB') | ||
dataRGB = self.transforms(dataRGB) | ||
return dataRGB, label | ||
|
||
data = self.transforms(data) | ||
return data, label | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .squeezenet import SqueezeNet | ||
from .squeezenet_gray import SqueezeNetGray |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
|
||
import torch as t | ||
import time | ||
|
||
|
||
class BasicModule(t.nn.Module): | ||
""" | ||
封装了nn.Module,主要是提供了save和load两个方法 | ||
""" | ||
|
||
def __init__(self): | ||
super(BasicModule,self).__init__() | ||
self.model_name=str(type(self))# 默认名字 | ||
|
||
def load(self, path): | ||
""" | ||
可加载指定路径的模型 | ||
""" | ||
self.load_state_dict(t.load(path, map_location='cpu')) | ||
|
||
def save(self, name=None): | ||
""" | ||
保存模型,默认使用“模型名字+时间”作为文件名 | ||
""" | ||
if name is None: | ||
prefix = './findplate/checkpoints/' + self.model_name + '_' | ||
name = time.strftime(prefix + '%m%d_%H%M%S.pth') | ||
t.save(self.state_dict(), name) | ||
return name | ||
|
||
def get_optimizer(self, lr, weight_decay): | ||
return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) | ||
|
||
|
||
class Flat(t.nn.Module): | ||
""" | ||
把输入reshape成(batch_size,dim_length) | ||
""" | ||
|
||
def __init__(self): | ||
super(Flat, self).__init__() | ||
#self.size = size | ||
|
||
def forward(self, x): | ||
return x.view(x.size(0), -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torchvision.models import squeezenet1_1 | ||
from findplate.models.basic_module import BasicModule | ||
from torch import nn | ||
from torch.optim import Adam | ||
from findplate.config import opt | ||
|
||
class SqueezeNet(BasicModule): | ||
def __init__(self, num_classes=2): | ||
super(SqueezeNet, self).__init__() | ||
self.model_name = 'squeezenet' | ||
self.model = squeezenet1_1(pretrained=False) | ||
# 修改 原始的num_class: 预训练模型是1000分类 | ||
self.model.num_classes = num_classes | ||
self.model.classifier = nn.Sequential( | ||
nn.Dropout(p=0.5), | ||
nn.Conv2d(512, num_classes, 1), | ||
nn.ReLU(inplace=True), | ||
nn.AvgPool2d(13, stride=1) | ||
) | ||
|
||
def forward(self,x): | ||
return self.model(x) | ||
|
||
def get_optimizer(self, lr, weight_decay): | ||
# 因为使用了预训练模型,我们只需要训练后面的分类 | ||
# 前面的特征提取部分可以保持不变 | ||
return Adam(self.model.classifier.parameters(), lr, weight_decay=weight_decay) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torchvision.models import squeezenet1_1 | ||
from findplate.models.basic_module import BasicModule | ||
from torch import nn | ||
from torch.optim import Adam | ||
from findplate.config import opt | ||
|
||
class SqueezeNetGray(BasicModule): | ||
def __init__(self, num_classes=65): | ||
super(SqueezeNetGray, self).__init__() | ||
self.model_name = 'squeezenet_gray' | ||
self.model = squeezenet1_1(pretrained=False) | ||
# 修改 原始的num_class: 预训练模型是1000分类 | ||
self.model.num_classes = num_classes | ||
self.model.classifier = nn.Sequential( | ||
nn.Dropout(p=0.5), | ||
nn.Conv2d(512, num_classes, 1), | ||
nn.ReLU(inplace=True), | ||
nn.AvgPool2d(13, stride=1) | ||
) | ||
|
||
def forward(self,x): | ||
return self.model(x) | ||
|
||
def get_optimizer(self, lr, weight_decay): | ||
# 因为使用了预训练模型,我们只需要训练后面的分类 | ||
# 前面的特征提取部分可以保持不变 | ||
return Adam(self.model.classifier.parameters(), lr, weight_decay=weight_decay) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
label_idx,label_name | ||
0,has | ||
1,no |
Oops, something went wrong.