-
Notifications
You must be signed in to change notification settings - Fork 2
/
do_ot.py
70 lines (63 loc) · 4.33 KB
/
do_ot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
'''
we use OT to test the effect of MAPCO+neural_matching_based method.
code from https://github.com/sidak/otfusion
'''
import argparse
import json
import copy
import os
import torch
from wasserstein_ensemble import geometric_ensembling_modularized
from wasserstein_ensemble import get_wasserstein
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu-id', default=1, type=int, help='GPU id to use')
parser.add_argument('--skip-last-layer', default=True, help='skip the last layer in calculating optimal transport')
parser.add_argument('--skip-last-layer-type', type=str, default='average', choices=['second', 'average'],
help='how to average the parameters for the last layer')
parser.add_argument('--debug', default=False, help='print debug statements')
parser.add_argument('--reg', default=1e-2, type=float, help='regularization strength for sinkhorn (default: 1e-2)')
parser.add_argument('--reg-m', default=1e-3, type=float, help='regularization strength for marginals in unbalanced sinkhorn (default: 1e-3)')
parser.add_argument('--ground-metric', type=str, default='euclidean', choices=['euclidean', 'cosine'],
help='ground metric for OT calculations, only works in free support v2 and soon with Ground Metric class in all! .')
parser.add_argument('--ground-metric-normalize', type=str, default='none', choices=['log', 'max', 'none', 'median', 'mean'],
help='ground metric normalization to consider! ')
parser.add_argument('--not-squared', default=True, help='dont square the ground metric')
parser.add_argument('--clip-gm', default=False, help='to clip ground metric')
parser.add_argument('--clip-min', type=float, default=0,
help='Value for clip-min for gm')
parser.add_argument('--clip-max', type=float, default=5,
help='Value for clip-max for gm')
parser.add_argument('--tmap-stats', default=False, help='print tmap stats')
parser.add_argument('--ensemble-step', type=float, default=0.5, action='store', help='rate of adjustment towards the second model')
parser.add_argument('--ground-metric-eff', default=True, help='memory efficient calculation of ground metric')
parser.add_argument('--eval-aligned', action='store_true',
help='evaluate the accuracy of the aligned model 0')
parser.add_argument('--correction',default=True, help='scaling correction for OT')
# parser.add_argument('--weight-stats', default=True, help='log neuron-wise weight vector stats.')
parser.add_argument('--sinkhorn-type', type=str, default='normal', choices=['normal', 'stabilized', 'epsilon', 'gpu'],
help='Type of sinkhorn algorithm to consider.')
parser.add_argument('--geom-ensemble-type', type=str, default='wts', choices=['wts', 'acts'],
help='Ensemble based on weights (wts) or activations (acts).')
parser.add_argument('--normalize-wts', default=True,
help='normalize the vector of weights')
parser.add_argument('--gromov', default=False, help='use gromov wasserstein distance and barycenters')
parser.add_argument('--gromov-loss', type=str, default='square_loss',
choices=['square_loss', 'kl_loss'], help="choice of loss function for gromov wasserstein computations")
parser.add_argument('--past-correction', default=True, help='use the current weights aligned by multiplying with past transport map')
parser.add_argument('--exact', default=True, help='compute exact optimal transport')
parser.add_argument('--proper-marginals', default=False, help='consider the marginals of transport map properly')
parser.add_argument('--print-distances', default=True, help='print OT distances for every layer')
parser.add_argument('--importance', type=str, default=None,
help='importance measure to use for building probab mass! (options, l1, l2, l11, l12)')
parser.add_argument('--unbalanced', default=False, help='use unbalanced OT')
return parser
parser = get_parser()
base_args = parser.parse_known_args()[0]
def doot(net,model_type='mlpnet',args=None):
'''
net=[net1,net2]
'''
base_args.gpu_id = args.gpu_id
modelout = geometric_ensembling_modularized(base_args, net,mtype=model_type)
return modelout