-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbrute_force_greedy_prune.py
96 lines (85 loc) · 4.27 KB
/
brute_force_greedy_prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import math
import numpy as np
from tqdm import tqdm
def calc_approx_errors(orig_wts, approx_wts, embed, orig_proj_embed):
# Direct comparison of original and approximated weights
# recon_dist = np.linalg.norm(approx_wts - orig_wts)
# recon_len_diff = np.absolute(np.linalg.norm(approx_wts) - np.linalg.norm(orig_wts))
recon_embed = np.matmul(embed, approx_wts)
recon_embed_dist = np.linalg.norm(recon_embed - orig_proj_embed)
# recon_embed_len_diff = np.absolute(np.linalg.norm(recon_embed) - np.linalg.norm(orig_proj_embed))
# Correct the magnitude of the approximated weights
approx_corr_wts = approx_wts * (np.linalg.norm(orig_wts) / np.linalg.norm(approx_wts))
# corr_recon_dist = np.linalg.norm(approx_corr_wts - orig_wts)
# corr_recon_len_diff = np.absolute(np.linalg.norm(approx_corr_wts) - np.linalg.norm(orig_wts))
corr_recon_embed = np.matmul(embed, approx_corr_wts)
corr_recon_embed_dist = np.linalg.norm(corr_recon_embed - orig_proj_embed)
# corr_recon_embed_len_diff = np.absolute(np.linalg.norm(corr_recon_embed) - np.linalg.norm(orig_proj_embed))
return recon_embed_dist, corr_recon_embed_dist
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true')
parser.add_argument('--projection', default='k')
parser.add_argument('--layer', default=0, type=int)
parser.add_argument('--step_id', default=90000, type=int)
parser.add_argument('--full_multiply', action='store_true')
args = parser.parse_args()
weights = np.load(f'outputs/running_on_cs1/gpt2_small_msl128_bs144_lr0.00028_gpu-baseline_1/model.ckpt-{args.step_id}.dict.npz')
embed = weights['input_embedding/embedding_weights']
layer_id = ''
if args.layer > 0:
layer_id = f'_{args.layer}'
proj_wts = weights[f'decoder/self_attention{layer_id}/{args.projection}_projection/{args.projection}_projection/kernel']
abs_proj_wts = np.absolute(proj_wts)
argsort_proj_wts = np.argsort(abs_proj_wts, axis=None)
# Process weights we're interested in
# Layer normalize embedding tokens
embsample = np.arange(400) * 125
embed = embed[embsample]
embmean = np.mean(embed, axis=-1, keepdims=True)
embstdev = np.std(embed, axis=-1, keepdims=True)
embed = (embed - embmean) / embstdev
proj_embed = np.matmul(embed, proj_wts)
# TODO: Vectorize the algorithm below!
curr_wts = proj_wts.copy()
# TODO: curr_corr_wts = proj_wts.copy()
top_k = 1000
for iter in tqdm(range(proj_wts.size)):
# for iter in range(proj_wts.size):
partial_result = np.matmul(embed, curr_wts)
emb_proj_errors = np.zeros((top_k))
emb_proj_errors.fill(np.inf)
for argmin_position in range(min(top_k, argsort_proj_wts.size)):
argmin_idx = argsort_proj_wts[argmin_position]
col_id = argmin_idx % proj_wts.shape[1]
row_id = argmin_idx // proj_wts.shape[1]
if curr_wts[row_id, col_id] != 0.0:
# recon_embed = partial_result.copy()
if args.full_multiply:
curr_col = curr_wts[:,col_id].copy()
curr_col[row_id] = 0.0
diff_recon_col = partial_result[:,col_id] - np.matmul(embed, curr_col)
else:
# Even quicker way: This is about 10-25% faster
# The difference above is just the product of the weight being
# zeroed multiplied by the embedding column (at row_id)
diff_recon_col = curr_wts[row_id, col_id] * embed[:,row_id]
emb_proj_error = np.linalg.norm(diff_recon_col)
emb_proj_errors[argmin_position] = emb_proj_error
min_wt = curr_wts[argsort_proj_wts[0] // proj_wts.shape[1], argsort_proj_wts[0] % proj_wts.shape[1]]
min_wt_error = emb_proj_errors[0]
best_error = emb_proj_errors.min()
best_error_argmin_position = emb_proj_errors.argmin()
argmin_idx = argsort_proj_wts[best_error_argmin_position]
col_id = argmin_idx % proj_wts.shape[1]
row_id = argmin_idx // proj_wts.shape[1]
best_error_wt = curr_wts[row_id, col_id]
curr_wts[row_id, col_id] = 0.0
argsort_proj_wts = np.delete(argsort_proj_wts, best_error_argmin_position)
sparsity = float(iter) / proj_wts.size
print(f'{sparsity}\t{best_error_argmin_position}\t{argmin_idx}\t{best_error}\t{min_wt_error}\t{min_wt}\t{best_error_wt}')
# import pdb
# pdb.set_trace()
if args.debug:
import pdb
pdb.set_trace()