1
- import os
2
- import torch
3
1
from catalyst .contrib .nn import RAdam , Lookahead
4
2
from torch .nn import functional as F
5
- from torch .utils .data import DataLoader
6
3
from catalyst import dl
7
- from catalyst .contrib .data .transforms import ToTensor
8
- from catalyst .contrib .datasets import MNIST
9
4
from catalyst .utils import metrics , set_global_seed , prepare_cudnn
10
-
11
-
12
- from catalyst .utils import (
13
- create_dataset , create_dataframe , get_dataset_labeling , map_dataframe
14
- )
15
- # train_dir='/home/deepkot/Downloads/imagenette2-160/train/n01440764'
16
- # dataset = create_dataset(dirs=f"{train_dir}/*", extension="*.jpg")
17
- # df = create_dataframe(dataset, columns=["class", "filepath"])
18
- #
19
- # tag_to_label = get_dataset_labeling(df, "class")
20
- # class_names = [
21
- # name for name, id_ in sorted(tag_to_label.items(), key=lambda x: x[1])
22
- # ]
23
- #
24
- # df_with_labels = map_dataframe(
25
- # df,
26
- # tag_column="class",
27
- # class_column="label",
28
- # tag2class=tag_to_label,
29
- # verbose=False
30
- # )
31
- # df_with_labels.head()
32
5
from config import train_dir , val_dir , batch_size , get_train_transforms , get_val_transforms
33
6
from dataset_bn import get_loaders
34
7
from model_batchnorm import resnet18
35
8
36
9
37
- def train_research_bn (model ):
38
- SEED = 69
39
- set_global_seed (SEED )
40
- prepare_cudnn (deterministic = True )
41
-
10
+ def train_research_bn (model , log ):
42
11
optimizer = Lookahead (RAdam (model .parameters (), lr = 0.02 ))
43
12
44
-
45
- train_data_transforms , val_data_transforms = get_train_transforms (224 ), get_val_transforms (224 )
13
+ train_data_transforms , val_data_transforms = get_train_transforms (112 ), get_val_transforms (112 )
46
14
47
15
loaders = get_loaders (
48
16
train_dir = train_dir ,
@@ -52,17 +20,14 @@ def train_research_bn(model):
52
20
batch_size = batch_size ,
53
21
)
54
22
55
-
56
-
57
23
class CustomRunner (dl .Runner ):
58
24
59
25
def predict_batch (self , batch ):
60
26
# model inference step
61
27
return self .model (batch [0 ].to (self .device ).view (batch [0 ].size (0 ), - 1 ))
62
28
63
29
def _handle_batch (self , batch ):
64
- # model train/valid step
65
- x , y = batch ['image' ],batch ['label' ],
30
+ x , y = batch ['image' ], batch ['label' ],
66
31
y_hat = self .model (x )
67
32
68
33
loss = F .cross_entropy (y_hat , y )
@@ -77,18 +42,31 @@ def _handle_batch(self, batch):
77
42
self .state .optimizer .zero_grad ()
78
43
79
44
runner = CustomRunner ()
80
- # model training
81
45
runner .train (
82
46
model = model ,
83
47
optimizer = optimizer ,
84
48
loaders = loaders ,
85
- logdir = "./logs" ,
49
+ logdir = log ,
86
50
num_epochs = 15 ,
87
51
verbose = True ,
88
52
load_best_on_end = True ,
89
53
)
90
- return runner
54
+ return runner
91
55
92
56
93
57
if __name__ == '__main__' :
94
- train_research_bn (resnet18 (train_bn = False ))
58
+ SEED = 69
59
+ set_global_seed (SEED )
60
+ prepare_cudnn (deterministic = True )
61
+
62
+ model = resnet18 (train_bn = True )
63
+ train_research_bn (model , "resnet" )
64
+
65
+ model = resnet18 (train_bn = False )
66
+ train_research_bn (model , 'resnet_no_bn' )
67
+
68
+ model = resnet18 (train_bn = True )
69
+ for name , p in model .named_parameters ():
70
+ if 'bn' not in name :
71
+ p .requires_grad = False
72
+ train_research_bn (model , 'resnet_bn_only' )
0 commit comments