-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathbar_distribution.py
147 lines (114 loc) · 7.17 KB
/
bar_distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from torch import nn
class BarDistribution(nn.Module):
def __init__(self, borders: torch.Tensor): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
# sorted list of borders
super().__init__()
assert len(borders.shape) == 1
#self.borders = borders
self.register_buffer('borders', borders)
#self.bucket_widths = self.borders[1:] - self.borders[:-1]
self.register_buffer('bucket_widths', self.borders[1:] - self.borders[:-1])
full_width = self.bucket_widths.sum()
assert (full_width - (self.borders[-1] - self.borders[0])).abs() < 1e-4, f'diff: {full_width - (self.borders[-1] - self.borders[0])}'
assert (torch.argsort(borders) == torch.arange(len(borders))).all(), "Please provide sorted borders!"
self.num_bars = len(borders) - 1
def map_to_bucket_idx(self, y):
target_sample = torch.searchsorted(self.borders, y) - 1
target_sample[y == self.borders[0]] = 0
target_sample[y == self.borders[-1]] = self.num_bars - 1
return target_sample
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
target_sample = self.map_to_bucket_idx(y)
assert (target_sample >= 0).all() and (target_sample < self.num_bars).all(), f'y {y} not in support set for borders (min_y, max_y) {self.borders}'
assert logits.shape[-1] == self.num_bars, f'{logits.shape[-1]} vs {self.num_bars}'
bucket_log_probs = torch.log_softmax(logits, -1)
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
return -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
def mean(self, logits):
bucket_means = self.borders[:-1] + self.bucket_widths/2
p = torch.softmax(logits, -1)
return p @ bucket_means
def quantile(self, logits, center_prob=.682):
logits_shape = logits.shape
logits = logits.view(-1, logits.shape[-1])
side_prob = (1-center_prob)/2
probs = logits.softmax(-1)
flipped_probs = probs.flip(-1)
cumprobs = torch.cumsum(probs, -1)
flipped_cumprobs = torch.cumsum(flipped_probs, -1)
def find_lower_quantile(probs, cumprobs, side_prob, borders):
idx = (torch.searchsorted(cumprobs, side_prob)).clamp(0, len(cumprobs)-1) # this might not do the right for outliers
left_prob = cumprobs[idx-1]
rest_prob = side_prob - left_prob
left_border, right_border = borders[idx:idx+2]
return left_border + (right_border-left_border)*rest_prob/probs[idx]
results = []
for p,cp,f_p,f_cp in zip(probs, cumprobs, flipped_probs, flipped_cumprobs):
r = find_lower_quantile(p, cp, side_prob, self.borders), find_lower_quantile(f_p, f_cp, side_prob, self.borders.flip(0))
results.append(r)
return torch.tensor(results).reshape(*logits_shape[:-1],2)
def mode(self, logits):
mode_inds = logits.argmax(-1)
bucket_means = self.borders[:-1] + self.bucket_widths/2
return bucket_means[mode_inds]
def ei(self, logits, best_f, maximize=True): # logits: evaluation_points x batch x feature_dim
bucket_means = self.borders[:-1] + self.bucket_widths/2
if maximize:
bucket_contributions = torch.tensor(
[max((bucket_max + max(bucket_min, best_f)) / 2 - best_f,0) for
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
else:
bucket_contributions = torch.tensor(
[-min((min(bucket_max,best_f) + bucket_min) / 2 - best_f,0) for # min on max instead of max on min, and compare min < instead of max >
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
p = torch.softmax(logits, -1)
return p @ bucket_contributions
class FullSupportBarDistribution(BarDistribution):
@staticmethod
def halfnormal_with_p_weight_before(range_max,p=.5):
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
return torch.distributions.HalfNormal(s)
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
assert self.num_bars > 1
target_sample = self.map_to_bucket_idx(y)
target_sample.clamp_(0,self.num_bars-1)
assert logits.shape[-1] == self.num_bars
bucket_log_probs = torch.log_softmax(logits, -1)
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
#print(bucket_log_probs, logits.shape)
log_probs = scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]), self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
# TODO look over it again
log_probs[target_sample == 0] += side_normals[0].log_prob((self.borders[1]-y[target_sample == 0]).clamp(min=.00000001)) + torch.log(self.bucket_widths[0])
log_probs[target_sample == self.num_bars-1] += side_normals[1].log_prob(y[target_sample == self.num_bars-1]-self.borders[-2]) + torch.log(self.bucket_widths[-1])
return -log_probs
def mean(self, logits):
bucket_means = self.borders[:-1] + self.bucket_widths / 2
p = torch.softmax(logits, -1)
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
bucket_means[0] = -side_normals[0].mean + self.borders[1]
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
return p @ bucket_means
def get_bucket_limits(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None):
assert (ys is not None) or (full_range is not None)
if ys is not None:
ys = ys.flatten()
if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
ys_per_bucket = len(ys) // num_outputs
if full_range is None:
full_range = (ys.min(), ys.max())
else:
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
full_range = torch.tensor(full_range)
ys_sorted, ys_order = ys.sort(0)
bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
print(full_range)
bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)
else:
class_width = (full_range[1] - full_range[0]) / num_outputs
bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)
assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
return bucket_limits