-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
102 lines (84 loc) · 4.33 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# ===========================================================================
# Project: How I Learned to Stop Worrying and Love Retraining - IOL Lab @ ZIB
# File: metrics/metrics.py
# Description: Useful metrics
# ===========================================================================
import math
from typing import Union, Tuple, List
import torch
from metrics import flops
@torch.no_grad()
def get_flops(model, x_input):
return flops.flops(model, x_input)
@torch.no_grad()
def get_theoretical_speedup(n_flops: int, n_nonzero_flops: int) -> dict:
if n_nonzero_flops == 0:
# Would yield infinite speedup
return {}
return float(n_flops) / n_nonzero_flops
def modular_sparsity(parameters_to_prune: List) -> float:
"""Returns the global sparsity out of all prunable parameters"""
n_total, n_zero = 0., 0.
for module, param_type in parameters_to_prune:
if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
param = getattr(module, param_type)
n_param = float(torch.numel(param))
n_zero_param = float(torch.sum(param == 0))
n_total += n_param
n_zero += n_zero_param
return float(n_zero) / n_total if n_total > 0 else 0
def global_sparsity(module: torch.nn.Module, param_type: Union[str, None] = None) -> float:
"""Returns the global sparsity of module (mostly of entire model)"""
n_total, n_zero = 0., 0.
param_list = ['weight', 'bias'] if not param_type else [param_type]
for name, module in module.named_modules():
for param_type in param_list:
if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
param = getattr(module, param_type)
n_param = float(torch.numel(param))
n_zero_param = float(torch.sum(param == 0))
n_total += n_param
n_zero += n_zero_param
return float(n_zero) / n_total
@torch.no_grad()
def get_parameter_count(model: torch.nn.Module) -> Tuple[int, int]:
n_total = 0
n_nonzero = 0
param_list = ['weight', 'bias']
for name, module in model.named_modules():
for param_type in param_list:
if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)):
p = getattr(module, param_type)
n_total += int(p.numel())
n_nonzero += int(torch.sum(p != 0))
return n_total, n_nonzero
@torch.no_grad()
def get_distance_to_pruned(model: torch.nn.Module, sparsity: float) -> Tuple[float, float]:
prune_vector = torch.cat(
[module.weight.flatten() for name, module in model.named_modules() if hasattr(module, 'weight')
and not isinstance(module.weight, type(None)) and not isinstance(module,
torch.nn.BatchNorm2d)])
n_params = float(prune_vector.numel())
k = int((1 - sparsity) * n_params)
total_norm = float(torch.norm(prune_vector, p=2))
pruned_norm = float(torch.norm(torch.topk(torch.abs(prune_vector), k=k).values, p=2))
distance_to_pruned = math.sqrt(abs(total_norm ** 2 - pruned_norm ** 2))
rel_distance_to_pruned = distance_to_pruned / total_norm if total_norm > 0 else 0
return distance_to_pruned, rel_distance_to_pruned
@torch.no_grad()
def get_distance_to_origin(model: torch.nn.Module) -> float:
prune_vector = torch.cat(
[module.weight.flatten() for name, module in model.named_modules() if hasattr(module, 'weight')
and not isinstance(module.weight, type(None)) and not isinstance(module,
torch.nn.BatchNorm2d)])
return float(torch.norm(prune_vector, p=2))
def per_layer_sparsity(model: torch.nn.Module):
"""Returns the per-layer-sparsity of model"""
per_layer_sparsity_dict = dict()
param_type = 'weight' # Only compute for weights, since we do not sparsify biases
for name, submodule in model.named_modules():
if hasattr(submodule, param_type) and not isinstance(getattr(submodule, param_type), type(None)):
if name in per_layer_sparsity_dict:
continue
per_layer_sparsity_dict[name] = global_sparsity(submodule, param_type=param_type)
return per_layer_sparsity_dict