-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain-Gradient-Reversal.py
63 lines (51 loc) · 2.76 KB
/
main-Gradient-Reversal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Main script for ADDA."""
import pretty_errors
import os
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
import params_gradientReversal as params
from core import eval_src, eval_tgt, train_src, train_tgt, train_gradientReversal, eval_gradientReversal
from models import Discriminator, LeNetClassifier, LeNetEncoder, GradientReversal
from utils import get_data_loader, init_model, init_random_seed
if __name__ == '__main__':
# init random seed
init_random_seed(params.manual_seed)
# load dataset
src_data_loader = get_data_loader(params.src_dataset)
src_data_loader_eval = get_data_loader(params.src_dataset, train=False)
tgt_data_loader = get_data_loader(params.tgt_dataset)
tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)
# load models
src_gradientReversal = init_model(net=GradientReversal(),
restore='snapshots//src-gradientReversal-final.pt')
tgt_gradientReversal = init_model(net=GradientReversal(),
restore='snapshots//tgt-gradientReversal-final.pt')
# train source gradientReversal
print("=== Training gradientReversal for source domain ===")
print(">>> Source GradientReversal <<<")
print(src_gradientReversal)
src_gradientReversal = train_gradientReversal(src_gradientReversal, src_data_loader)
# train target gradientReversal
print("=== Training gradientReversal for target domain ===")
print(">>> Target GradientReversal <<<")
print(tgt_gradientReversal)
tgt_gradientReversal = train_gradientReversal(tgt_gradientReversal, tgt_data_loader)
# eval source model on source data
print("=== Evaluating source gradientReversal for source domain ===")
eval_gradientReversal(src_gradientReversal, src_data_loader_eval)
# eval target model on target data
print("=== Evaluating target gradientReversal for target domain ===")
eval_gradientReversal(tgt_gradientReversal, tgt_data_loader_eval)
print('=====================================================')
print('==================== TL/DA Magic ====================')
print('=====================================================')
print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
# eval source model on target data
print('=====================================================')
print("=== Evaluating source gradientReversal for target domain ===")
print("=== This is what happens if no TL/DA is applied ===")
print("=== get source model's classification on target ===")
print('=====================================================')
eval_gradientReversal(src_gradientReversal, tgt_data_loader_eval)