forked from lhoyer/MIC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cdan_mcc_sdat_masking.py
393 lines (347 loc) · 17.7 KB
/
cdan_mcc_sdat_masking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
# Credits: https://github.com/thuml/Transfer-Learning-Library
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import os
import wandb
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
sys.path.append('../')
from dalib.modules.domain_discriminator import DomainDiscriminator
from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
from dalib.adaptation.mcc import MinimumClassConfusionLoss
from dalib.modules.masking import Masking
from dalib.modules.teacher import EMATeacher
from common.utils.data import ForeverDataIterator
from common.utils.metric import accuracy
from common.utils.meter import AverageMeter, ProgressMeter
from common.utils.logger import CompleteLogger
from common.utils.analysis import collect_feature, tsne, a_distance
from common.utils.sam import SAM
sys.path.append('.')
import utils
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.log_results:
wandb.init(
project="MIC",
name=args.log_name)
wandb.config.update(args)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
device = args.device
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source,
args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
print(backbone)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
classifier_feature_dim = classifier.features_dim
if args.randomized:
domain_discri = DomainDiscriminator(
args.randomized_dim, hidden_size=1024).to(device)
else:
domain_discri = DomainDiscriminator(
classifier_feature_dim * num_classes, hidden_size=1024).to(device)
# define optimizer and lr scheduler
base_optimizer = torch.optim.SGD
ad_optimizer = SGD(domain_discri.get_parameters(
), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
optimizer = SAM(classifier.get_parameters(), base_optimizer, rho=args.rho, adaptive=False,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr *
(1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
lr_scheduler_ad = LambdaLR(
ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = ConditionalDomainAdversarialLoss(
domain_discri, entropy_conditioning=args.entropy,
num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
randomized_dim=args.randomized_dim
).to(device)
mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature)
teacher = EMATeacher(classifier, alpha=args.alpha, pseudo_label_weight=args.pseudo_label_weight).to(device)
masking = Masking(
block_size=args.mask_block_size,
ratio=args.mask_ratio,
color_jitter_s=args.mask_color_jitter_s,
color_jitter_p=args.mask_color_jitter_p,
blur=args.mask_blur,
mean=args.norm_mean,
std=args.norm_std)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(
logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(
classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(
train_source_loader, feature_extractor, device)
target_feature = collect_feature(
train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(
source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr_bbone:", lr_scheduler.get_last_lr()[0])
print("lr_btlnck:", lr_scheduler.get_last_lr()[1])
if args.log_results:
wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0],
"lr_btlnck": lr_scheduler.get_last_lr()[1]})
# train for one epoch
train(train_source_iter, train_target_iter, classifier, teacher,
domain_adv, mcc_loss, masking, optimizer, ad_optimizer,
lr_scheduler, lr_scheduler_ad, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
if args.log_results:
wandb.log({'epoch': epoch, 'val_acc': acc1})
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(),
logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'),
logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
if args.log_results:
wandb.log({'epoch': epoch, 'test_acc': acc1})
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, teacher: EMATeacher,
domain_adv: ConditionalDomainAdversarialLoss, mcc, masking, optimizer, ad_optimizer,
lr_scheduler: LambdaLR, lr_scheduler_ad, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, _ = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
x_t_masked = masking(x_t)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
optimizer.zero_grad()
ad_optimizer.zero_grad()
# generate pseudo-label
teacher.update_weights(model, epoch * args.iters_per_epoch + i)
pseudo_label_t, pseudo_prob_t = teacher(x_t)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
mcc_loss_value = mcc(y_t)
y_t_masked, _ = model(x_t_masked)
if teacher.pseudo_label_weight is not None:
ce = F.cross_entropy(y_t_masked, pseudo_label_t, reduction='none')
masking_loss_value = torch.mean(pseudo_prob_t * ce)
else:
masking_loss_value = F.cross_entropy(y_t_masked, pseudo_label_t)
loss = cls_loss + mcc_loss_value + masking_loss_value
loss.backward()
# Calculate ϵ̂ (w) and add it to the weights
optimizer.first_step(zero_grad=True)
# Calculate task loss and domain loss
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
y_t_masked, _ = model(x_t_masked)
transfer_loss = domain_adv(y_s, f_s, y_t, f_t) + mcc(y_t) + \
F.cross_entropy(y_t_masked, pseudo_label_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
if args.log_results:
# masked_img = wandb.Image(x_t_masked, caption="Masked Image")
wandb.log({
'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss,
'transfer_loss': transfer_loss, 'domain_acc': domain_acc, 'pseudo_weight_avg': torch.mean(pseudo_prob_t),
# 'masked_img': masked_img,
})
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc, x_s.size(0))
domain_accs.update(domain_acc, x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
loss.backward()
# Update parameters of domain classifier
ad_optimizer.step()
# Update parameters (Sharpness-Aware update)
optimizer.second_step(zero_grad=True)
lr_scheduler.step()
lr_scheduler_ad.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='CDAN+MCC with SDAT for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true',
help='whether train from scratch.')
parser.add_argument('-r', '--randomized', action='store_true',
help='using randomized multi-linear-map (default: False)')
parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
help='randomized dimension when using randomized multi-linear-map (default: 1024)')
parser.add_argument('--entropy', default=False,
action='store_true', help='use entropy conditioning')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001,
type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75,
type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9,
type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='cdan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
parser.add_argument('--log_results', action='store_true',
help="To log results in wandb")
parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
parser.add_argument('--log_name', type=str,
default="log", help="log name for wandb")
parser.add_argument('--rho', type=float, default=0.05, help="GPU ID")
parser.add_argument('--temperature', default=2.0,
type=float, help='parameter temperature scaling')
# masked image consistency
parser.add_argument('--alpha', default=0.999, type=float)
parser.add_argument('--pseudo_label_weight', default=None)
parser.add_argument('--mask_block_size', default=32, type=int)
parser.add_argument('--mask_ratio', default=0.5, type=float)
parser.add_argument('--mask_color_jitter_s', default=0, type=float)
parser.add_argument('--mask_color_jitter_p', default=0, type=float)
parser.add_argument('--mask_blur', default=False, type=bool)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
main(args)