-
Notifications
You must be signed in to change notification settings - Fork 21
/
weight_matching.py
150 lines (120 loc) · 5.41 KB
/
weight_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from scipy.optimize import linear_sum_assignment
import time
import random
from merge_PermSpec_ResNet import mlp_permutation_spec
from PermSpec_Base import PermutationSpec
from tqdm import tqdm
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
"""Get parameter `k` from `params`, with the permutations applied."""
w = params[k]
# Printing on screen will make the process very slow. Don't leave it on in final version
#print(k)
# I will remove the try block also. Rewrite it when needed.
for axis, p in enumerate(ps.axes_to_perm[k]):
# Skip the axis we're trying to permute.
if axis == except_axis:
continue
# None indicates that there is no permutation relevant to that axis.
if p is not None:
w = torch.index_select(w, axis, perm[p].int())
return w
def apply_permutation(ps: PermutationSpec, perm, params):
"""Apply a `perm` to `params`."""
return {k: get_permuted_param(ps, perm, k, params) for k in params.keys() if "model_" not in k}
def weight_matching(ps: PermutationSpec, params_a, params_b, special_layers=None, device="cpu", max_iter=3, init_perm=None, usefp16=False):
"""Find a permutation of `params_b` to make them match `params_a`."""
# tqdm layer will start from 1.
perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items() if axes[0][0] in params_b}
#print(perm_sizes)
perm = dict()
perm = {p: torch.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
special_layers = special_layers if special_layers and len(special_layers) > 0 else sorted(list(perm.keys()))
#print(special_layers)
sum = 0
number = 0
if usefp16:
for _ in tqdm(range(max_iter), desc="weight_matching in fp16", position=1):
progress = False
random.shuffle(special_layers)
for p_ix in tqdm(special_layers, desc="weight_matching for special_layers", position=2):
p = p_ix
if p in special_layers:
n = perm_sizes[p]
A = torch.zeros((n, n), dtype=torch.float16).to(device)
for wk, axis in ps.perm_to_axes[p]:
w_a = params_a[wk]
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)
A += torch.matmul(w_a.half(), w_b.half())
A = A.cpu()
ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
assert (torch.tensor(ri) == torch.arange(len(ri))).all()
oldL = torch.vdot(torch.flatten(A).float(), torch.flatten(torch.eye(n)[perm[p].long()]).float()).half()
newL = torch.vdot(torch.flatten(A).float(), torch.flatten(torch.eye(n)[ci, :]).float()).half()
if newL - oldL != 0:
sum += abs((newL-oldL).item())
number += 1
#print(f"{p}: {newL - oldL}")
progress = progress or newL > oldL + 1e-12
perm[p] = torch.Tensor(ci)
if not progress:
break
if number > 0:
average = sum / number
else:
average = 0
return (perm, average)
else:
for _ in tqdm(range(max_iter), desc="weight_matching in fp32", position=1):
progress = False
random.shuffle(special_layers)
for p_ix in tqdm(special_layers, desc="weight_matching for special_layers", position=2):
p = p_ix
if p in special_layers:
n = perm_sizes[p]
A = torch.zeros((n, n), dtype=torch.float32).to(device="cpu")
for wk, axis in ps.perm_to_axes[p]:
w_a = params_a[wk]
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)
A += torch.matmul(w_a.float(), w_b.float()).cpu()
ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
assert (torch.tensor(ri) == torch.arange(len(ri))).all()
oldL = torch.vdot(torch.flatten(A), torch.flatten(torch.eye(n)[perm[p].long()]).float())
newL = torch.vdot(torch.flatten(A), torch.flatten(torch.eye(n)[ci, :]).float())
if newL - oldL != 0:
sum += abs((newL-oldL).item())
number += 1
#print(f"{p}: {newL - oldL}")
progress = progress or newL > oldL + 1e-12
perm[p] = torch.Tensor(ci)
if not progress:
break
if number > 0:
average = sum / number
else:
average = 0
return (perm, average)
def test_weight_matching():
"""If we just have a single hidden layer then it should converge after just one step."""
ps = mlp_permutation_spec(num_hidden_layers=3)
#print(ps.axes_to_perm)
rng = torch.Generator()
rng.manual_seed(13)
num_hidden = 10
shapes = {
"layer0.weight": (2, num_hidden),
"layer0.bias": (num_hidden, ),
"layer1.weight": (num_hidden, 3),
"layer1.bias": (3, )
}
rngmix = lambda rng, x: random.fold_in(rng, hash(x))
params_a = {k: random.normal(rngmix(rng, f"a-{k}"), shape) for k, shape in shapes.items()}
params_b = {k: random.normal(rngmix(rng, f"b-{k}"), shape) for k, shape in shapes.items()}
perm = weight_matching(rng, ps, params_a, params_b)
print(perm)
if __name__ == "__main__":
test_weight_matching()