Skip to content

Commit c8c8914

Browse files
mrkitomrkito
mrkito
authored and
mrkito
committed
loop
1 parent 6158ae8 commit c8c8914

File tree

2 files changed

+23
-47
lines changed

2 files changed

+23
-47
lines changed

config.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import albumentations as A
22
from albumentations.pytorch import ToTensorV2
33

4-
batch_size=64
4+
batch_size = 64
55

6-
train_dir='/home/deepkot/Downloads/imagenette2-160/train/'
7-
val_dir='/home/deepkot/Downloads/imagenette2-160/val/'
6+
train_dir = 'imagenette2-160/train/'
7+
val_dir = 'imagenette2-160/val/'
88

99

1010
def get_train_transforms(size):
1111
TRAIN_TRANSFORMS = A.Compose([
1212
A.Resize(size, size),
13-
1413
A.HorizontalFlip(p=0.5),
1514
A.Rotate(limit=30),
1615
A.ImageCompression(),
@@ -46,7 +45,6 @@ def get_train_transforms(size):
4645

4746
def get_val_transforms(size):
4847
VALID_TRANSFORMS = A.Compose([
49-
5048
A.Resize(size, size),
5149
A.Normalize(),
5250
ToTensorV2(),

train_loop.py

+20-42
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,16 @@
1-
import os
2-
import torch
31
from catalyst.contrib.nn import RAdam, Lookahead
42
from torch.nn import functional as F
5-
from torch.utils.data import DataLoader
63
from catalyst import dl
7-
from catalyst.contrib.data.transforms import ToTensor
8-
from catalyst.contrib.datasets import MNIST
94
from catalyst.utils import metrics, set_global_seed, prepare_cudnn
10-
11-
12-
from catalyst.utils import (
13-
create_dataset, create_dataframe, get_dataset_labeling, map_dataframe
14-
)
15-
# train_dir='/home/deepkot/Downloads/imagenette2-160/train/n01440764'
16-
# dataset = create_dataset(dirs=f"{train_dir}/*", extension="*.jpg")
17-
# df = create_dataframe(dataset, columns=["class", "filepath"])
18-
#
19-
# tag_to_label = get_dataset_labeling(df, "class")
20-
# class_names = [
21-
# name for name, id_ in sorted(tag_to_label.items(), key=lambda x: x[1])
22-
# ]
23-
#
24-
# df_with_labels = map_dataframe(
25-
# df,
26-
# tag_column="class",
27-
# class_column="label",
28-
# tag2class=tag_to_label,
29-
# verbose=False
30-
# )
31-
# df_with_labels.head()
325
from config import train_dir, val_dir, batch_size, get_train_transforms, get_val_transforms
336
from dataset_bn import get_loaders
347
from model_batchnorm import resnet18
358

369

37-
def train_research_bn(model):
38-
SEED = 69
39-
set_global_seed(SEED)
40-
prepare_cudnn(deterministic=True)
41-
10+
def train_research_bn(model, log):
4211
optimizer = Lookahead(RAdam(model.parameters(), lr=0.02))
4312

44-
45-
train_data_transforms, val_data_transforms = get_train_transforms(224), get_val_transforms(224)
13+
train_data_transforms, val_data_transforms = get_train_transforms(112), get_val_transforms(112)
4614

4715
loaders = get_loaders(
4816
train_dir=train_dir,
@@ -52,17 +20,14 @@ def train_research_bn(model):
5220
batch_size=batch_size,
5321
)
5422

55-
56-
5723
class CustomRunner(dl.Runner):
5824

5925
def predict_batch(self, batch):
6026
# model inference step
6127
return self.model(batch[0].to(self.device).view(batch[0].size(0), -1))
6228

6329
def _handle_batch(self, batch):
64-
# model train/valid step
65-
x, y = batch['image'],batch['label'],
30+
x, y = batch['image'], batch['label'],
6631
y_hat = self.model(x)
6732

6833
loss = F.cross_entropy(y_hat, y)
@@ -77,18 +42,31 @@ def _handle_batch(self, batch):
7742
self.state.optimizer.zero_grad()
7843

7944
runner = CustomRunner()
80-
# model training
8145
runner.train(
8246
model=model,
8347
optimizer=optimizer,
8448
loaders=loaders,
85-
logdir="./logs",
49+
logdir=log,
8650
num_epochs=15,
8751
verbose=True,
8852
load_best_on_end=True,
8953
)
90-
return runner
54+
return runner
9155

9256

9357
if __name__ == '__main__':
94-
train_research_bn(resnet18(train_bn=False))
58+
SEED = 69
59+
set_global_seed(SEED)
60+
prepare_cudnn(deterministic=True)
61+
62+
model = resnet18(train_bn=True)
63+
train_research_bn(model, "resnet")
64+
65+
model = resnet18(train_bn=False)
66+
train_research_bn(model, 'resnet_no_bn')
67+
68+
model = resnet18(train_bn=True)
69+
for name, p in model.named_parameters():
70+
if 'bn' not in name:
71+
p.requires_grad = False
72+
train_research_bn(model, 'resnet_bn_only')

0 commit comments

Comments
 (0)