-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcfm_training.py
executable file
·151 lines (102 loc) · 5.78 KB
/
cfm_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import os
import torch
import wandb
import numpy as np
from itertools import chain
from tqdm import tqdm, trange
from models.features import MultimodalFeatures
from models.dataset import get_data_loader
from models.feature_transfer_nets import FeatureProjectionMLP, FeatureProjectionMLP_big
def set_seeds(sid=115):
np.random.seed(sid)
torch.manual_seed(sid)
if torch.cuda.is_available():
torch.cuda.manual_seed(sid)
torch.cuda.manual_seed_all(sid)
def train_CFM(args):
set_seeds()
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = f'{args.class_name}_{args.epochs_no}ep_{args.batch_size}bs'
wandb.init(
project = 'crossmodal-feature-mappings',
name = model_name
)
# Dataloader.
train_loader = get_data_loader("train", class_name = args.class_name, img_size = 224, dataset_path = args.dataset_path, batch_size = args.batch_size, shuffle = True)
# Feature extractors.
feature_extractor = MultimodalFeatures()
# Model instantiation.
CFM_2Dto3D = FeatureProjectionMLP(in_features = 768, out_features = 1152)
CFM_3Dto2D = FeatureProjectionMLP(in_features = 1152, out_features = 768)
optimizer = torch.optim.Adam(params = chain(CFM_2Dto3D.parameters(), CFM_3Dto2D.parameters()))
CFM_2Dto3D.to(device), CFM_3Dto2D.to(device)
metric = torch.nn.CosineSimilarity(dim = -1, eps = 1e-06)
for epoch in trange(args.epochs_no, desc = f'Training Feature Transfer Net.'):
epoch_cos_sim_3Dto2D, epoch_cos_sim_2Dto3D = [], []
# ------------ [Trainig Loop] ------------ #
# * Return (rgb_img, organized_pc, depth_map_3channel), globl_label
for (rgb, pc, _), _ in tqdm(train_loader, desc = f'Extracting feature from class: {args.class_name}.'):
rgb, pc = rgb.to(device), pc.to(device)
# Make CFMs trainable.
CFM_2Dto3D.train(), CFM_3Dto2D.train()
if args.batch_size == 1:
rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb, pc)
else:
rgb_patches = []
xyz_patches = []
for i in range(rgb.shape[0]):
rgb_patch, xyz_patch = feature_extractor.get_features_maps(rgb[i].unsqueeze(dim=0), pc[i].unsqueeze(dim=0))
rgb_patches.append(rgb_patch)
xyz_patches.append(xyz_patch)
rgb_patch = torch.stack(rgb_patches, dim=0)
xyz_patch = torch.stack(xyz_patches, dim=0)
# Predictions.
rgb_feat_pred = CFM_3Dto2D(xyz_patch)
xyz_feat_pred = CFM_2Dto3D(rgb_patch)
# Losses.
xyz_mask = (xyz_patch.sum(axis = -1) == 0) # Mask only the feature vectors that are 0 everywhere.
loss_3Dto2D = 1 - metric(xyz_feat_pred[~xyz_mask], xyz_patch[~xyz_mask]).mean()
loss_2Dto3D = 1 - metric(rgb_feat_pred[~xyz_mask], rgb_patch[~xyz_mask]).mean()
cos_sim_3Dto2D, cos_sim_2Dto3D = 1 - loss_3Dto2D.cpu(), 1 - loss_2Dto3D.cpu()
epoch_cos_sim_3Dto2D.append(cos_sim_3Dto2D), epoch_cos_sim_2Dto3D.append(cos_sim_2Dto3D)
# Logging.
wandb.log({
"train/loss_3Dto2D" : loss_3Dto2D,
"train/loss_2Dto3D" : loss_2Dto3D,
"train/cosine_similarity_3Dto2D" : cos_sim_3Dto2D,
"train/cosine_similarity_2Dto3D" : cos_sim_2Dto3D,
})
if torch.isnan(loss_3Dto2D) or torch.isinf(loss_3Dto2D) or torch.isnan(loss_2Dto3D) or torch.isinf(loss_2Dto3D):
exit()
# Optimization.
if not torch.isnan(loss_3Dto2D) and not torch.isinf(loss_3Dto2D) and not torch.isnan(loss_2Dto3D) and not torch.isinf(loss_2Dto3D):
optimizer.zero_grad()
loss_3Dto2D.backward(), loss_2Dto3D.backward()
optimizer.step()
# Global logging.
wandb.log({
"global_train/cos_sim_3Dto2D" : torch.Tensor(epoch_cos_sim_3Dto2D, device = 'cpu').mean(),
"global_train/cos_sim_2Dto3D" : torch.Tensor(epoch_cos_sim_2Dto3D, device = 'cpu').mean()
})
# Model saving.
directory = f'{args.checkpoint_savepath}/{args.class_name}'
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(CFM_2Dto3D.state_dict(), os.path.join(directory, 'CFM_2Dto3D_' + model_name + '.pth'))
torch.save(CFM_3Dto2D.state_dict(), os.path.join(directory, 'CFM_3Dto2D_' + model_name + '.pth'))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'Train Crossmodal Feature Networks (CFMs) on a dataset.')
parser.add_argument('--dataset_path', default = './datasets/mvtec3d', type = str,
help = 'Dataset path.')
parser.add_argument('--checkpoint_savepath', default = './checkpoints/checkpoints_CFM_mvtec', type = str,
help = 'Where to save the model checkpoints.')
parser.add_argument('--class_name', default = None, type = str, choices = ["bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire",
'CandyCane', 'ChocolateCookie', 'ChocolatePraline', 'Confetto', 'GummyBear', 'HazelnutTruffle', 'LicoriceSandwich', 'Lollipop', 'Marshmallow', 'PeppermintCandy'],
help = 'Category name.')
parser.add_argument('--epochs_no', default = 50, type = int,
help = 'Number of epochs to train the CFMs.')
parser.add_argument('--batch_size', default = 4, type = int,
help = 'Batch dimension. Usually 16 is around the max.')
args = parser.parse_args()
train_CFM(args)