-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_functions.py
156 lines (136 loc) · 6.09 KB
/
train_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# python3
#
# PROGRAMMER: Félix Ramón López Martínez
# DATE CREATED: 18/11/2020
# REVISED DATE:
# PURPOSE: This is the repository of all the functions called fron train.py.
#
##
# Imports python modules
import argparse
#import sys
from torchvision import models
from torch import nn
import torch
def get_input_args():
"""
Retrieves and parses the command line arguments provided by the user when
they run the program from a terminal window. If the user fails to provide
some or all of the arguments, then the default values are used for the
missing arguments.
This function returns these arguments as an ArgumentParser object.
Returns:
parse_args() -data structure that stores the command line arguments object
"""
# Create Parse
parser = argparse.ArgumentParser(description='Retrieving inputs from user')
# Create command line arguments
parser.add_argument('data_directory', type = str, default = './',
help = 'path to the data directory (default: ./)')
parser.add_argument('--save_dir', type = str, default = './',
help = 'path to the folder to save checkpoint file (default: ./)')
parser.add_argument('--arch', type = str, default = 'VGG16',
help = 'CNN Model Architecture: vgg16, alexnet or densenet161 (default: VGG16)')
parser.add_argument('--learning_rate', type = float, default = 0.002,
help = 'Learning rate (default: 0.002)')
parser.add_argument('--epochs', type = int, default = 1,
help = 'Epochs (default: 1)')
parser.add_argument('--dropout', type = float, default = 0.1,
help = 'Dropout (default: 0.1)')
return parser.parse_args()
def load_pretrained_model(model_arch):
''' This function load the CNN pretrained model accordint to the choosen
architecture chosen by the user with --arch argument when lauching
the code.
In case the user fails to select a valid architecture, the function loads
the VGG-16 model.
It returns the model itself.
'''
if model_arch == 'vgg16':
model = models.vgg16(pretrained=True)
elif model_arch == 'alexnet':
model = models.alexnet(pretrained=True)
elif model_arch == 'densenet161':
model = models.densenet161(pretrained=True)
else:
model = models.vgg16(pretrained=True)
print('Invalid model name input in --arch. Loaded VGG16 model instead')
model_name = 'vgg16'
print('Loaded {} pretrained model'.format(model_arch))
return model
def load_classifier(model_arch, dpout):
''' This function creates and returns a classifier matching with the required
parameters of the CNN architecture choosen by the user with the --arch
argument when lauching the code.
'''
if model_arch == 'vgg16':
classifier = nn.Sequential(nn.Linear(25088, 4096),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(4096, 512),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(512, 102),
nn.LogSoftmax(dim=1))
elif model_arch == 'alexnet':
classifier = nn.Sequential(nn.Linear(9216, 1024),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(512, 102),
nn.LogSoftmax(dim=1))
elif model_arch == 'densenet161':
classifier = nn.Sequential(nn.Linear(2208, 512),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(256, 102),
nn.LogSoftmax(dim=1))
else:
classifier = nn.Sequential(nn.Linear(25088, 4096),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(4096, 512),
nn.ReLU(),
nn.Dropout(p = dpout),
nn.Linear(512, 102),
nn.LogSoftmax(dim=1))
model_arch = "VGG16"
print('Loaded new classifier for {} pretrained model'.format(model_arch))
return classifier
def save_checkpoint(model_arch, dropout, model_class_to_idx, model_state_dict):
''' This function save the checkpoint and stores the paramenter matching
with the CNN architecture choosen by the user with the --arch argument when
lauching the code.
'''
if model_arch == 'vgg16':
checkpoint = {'input_size': 25088,
'layer1_size': 4096,
'layer2_size': 512,
'output_size': 102}
elif model_arch == 'alexnet':
checkpoint = {'input_size': 9216,
'layer1_size': 1024,
'layer2_size': 512,
'output_size': 102}
elif model_arch == 'densenet161':
checkpoint = {'input_size': 2208,
'layer1_size': 512,
'layer2_size': 256,
'output_size': 102}
else:
checkpoint = {'input_size': 25088,
'layer1_size': 4096,
'layer2_size': 512,
'output_size': 102}
checkpoint['pretrained_model'] = model_arch
checkpoint['dropout'] = dropout
checkpoint['class_to_idx'] = model_class_to_idx
checkpoint['state_dict'] = model_state_dict
torch.save(checkpoint, '{}_model_checkpoint.pth'.format(model_arch))
print('Checkpoint file saved as: {}_model_checkpoint.pth'.format(model_arch))
return