-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_Base.oh_rsut.py
40 lines (32 loc) · 1.38 KB
/
main_Base.oh_rsut.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from trainer.train import train_source_main
import time
timestamp = time.strftime("%Y-%m-%d_%H.%M.%S", time.localtime())
import socket
hostName = socket.gethostname()
domains = ['Product', 'Clipart', 'Real_World']
for src in domains:
for i in [0]:
print('random seed {}'.format(i))
header = '''
++++++++++++++++++++++++++++++++++
{}
++++++++++++++++++++++++++++++++++
@{}
'''.format
args = ['--base_model=Base'
, '--gpu=0'
, '--timestamp={}'.format(timestamp)
, '--random_seed={}'.format(i)
, '--base_net=ResNet50'
, '--class_criterion=CrossEntropyLoss'
, '--dataset=Office-Home'
, '--source_path=data/{}_RS.txt'.format(src)
, '--test_path=[{}]'.format(','.join(['data/{}_UT.txt'.format(tst) for tst in domains if tst!=src]))
, '--train_source_sampler=ClassBalancedBatchSampler'
, '--batch_size=16'
, '--train_steps=10000'
, '--save_interval=5000'
, '--eval_interval=1000'
, '--log_dir=log_base'
, '--use_file_logger=True']
train_source_main(args, header('\n\t'.join(args), hostName))