-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
728 lines (607 loc) · 22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
import argparse
import json
import logging
import os
import random
from copy import deepcopy
from math import ceil, floor, isnan
from typing import Type
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from detorch import DE, Policy
from detorch.config import Config, default_config
from scipy.stats.contingency import relative_risk
class PullPolicy(Policy):
"""
Pull policy for DE. Utility class to do DE
optimization over dataset feature vectors
"""
bounds = [0, 1]
def __init__(self, eval_fn, p, set_existing_vecs, thresh, n_sigmas):
"""
Args:
eval_fn (function): _description_
p (torch.Tensor): probability vector for initialization. Must be equal in length to `set_existing_vecs`
set_existing_vecs (torch.Tensor): Set of existing vectors to initialize from
thresh (float): Threshold for ci intersection.
"""
super().__init__()
idx = p.multinomial(num_samples=1).item()
self.params = nn.Parameter(set_existing_vecs[idx].clone(), requires_grad=False)
self.eval_fn = eval_fn
self.thresh = thresh
self.n_sigmas = n_sigmas
self.mu = None
self.cb = None
self.sample_r = None
self.activation_grad = None
self.lower_ci_under_thresh = None
def evaluate(self):
"""Evaluate current `self` (a tensor)
to find its value according to `eval_fn`"""
self.transform()
sample_r, activation_grad, mu, cb = self.eval_fn(self.params.data[None])
self.activation_grad = activation_grad
self.sample_r = sample_r.detach().item()
self.mu = mu
self.cb = cb
self.lower_ci_under_thresh = (mu - self.n_sigmas * cb) <= self.thresh
return self.sample_r
def transform(self):
vec = torch.clip(self.params, *PullPolicy.bounds)
self.params = nn.Parameter(vec, requires_grad=False)
class EarlyStopping:
def __init__(self, patience):
self.patience = patience
self.min_loss = float("inf")
self.count = 0
self.best_model = None
def __call__(self, cur_loss, model):
# If no improvement
if cur_loss >= self.min_loss:
self.count += 1
else: # Improvement, store state dict
self.count = 0
self.store(model)
self.min_loss = cur_loss
def store(self, model):
self.best_model = deepcopy(model)
self.best_model.zero_grad(set_to_none=True)
@property
def early_stop(self):
if self.count >= self.patience:
return True
class QuantileLoss(nn.Module):
def __init__(self, quantiles):
super().__init__()
self.quantiles = quantiles
def forward(self, preds, target):
assert not target.requires_grad
assert preds.size(0) == target.size(0)
losses = []
errors = target - preds
for i, q in enumerate(self.quantiles):
losses.append(
torch.max((q - 1) * errors[:, i], q * errors[:, i]).unsqueeze(1)
)
loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
return loss
def sample_from_sql(con, table, num_samples, num_rows, columns):
samples_idx = torch.randint(0, num_rows + 1, size=(num_samples,)).tolist()
statement = f"SELECT {columns} FROM {table} WHERE ROWID IN {*samples_idx,}"
sample = pd.read_sql(statement, con=con).to_numpy()
sample = torch.from_numpy(sample)
combis = sample[:, :-1]
outcomes = sample[:, -1]
return combis, outcomes
def get_table_stats(con, table):
"""Returns num_rows, num_cols, atc_col names for table `table`
Args:
con (sqlite.Connection): Connection object to database
table (str): table name
Returns:
tuple: (number of rows, number of cols, atc column names (list))
"""
cur = con.cursor()
num_rows = cur.execute(f"SELECT COUNT(*) FROM {table};").fetchall()[0][0]
num_cols = cur.execute(
f'SELECT COUNT(*) FROM PRAGMA_TABLE_INFO("{table}")'
).fetchall()[0][0]
cur.row_factory = lambda cursor, row: row[0]
atc_cols_names = cur.execute(
f'SELECT NAME FROM PRAGMA_TABLE_INFO("{table}") WHERE CID BETWEEN {COMBI_INDEXES[0]} AND {COMBI_INDEXES[1]}'
).fetchall()
return num_rows, num_cols, atc_cols_names
def compute_relative_risk(combi, pop_combis, pop_outcomes):
# Determined by polypharmacy definition
if combi.sum() < 5:
return 0
vec_indices = (combi == 1).squeeze(0)
# Get boolean array for exposed and controls
rows_exposed = torch.where(
(pop_combis[:, vec_indices] == 1).all(dim=1), True, False
)
# print(rows_exposed)
rows_control = torch.logical_not(rows_exposed)
n_exposed = rows_exposed.sum()
n_control = rows_control.sum()
n_exposed_cases = pop_outcomes[rows_exposed].sum()
n_control_cases = pop_outcomes[rows_control].sum()
rr = (n_exposed_cases / n_exposed) / (n_control_cases / n_control)
rr = rr.item()
if isnan(rr):
# Interpreted as 0 by experts
return 0
elif isinf(rr):
return 100
# Clip in a realistic range the RR so we don't end up with infinite RR
return rr
def change_to_closest_existing_vector(vec, set_existing_vecs):
"""1-NN search of `vec` in `set_existing_vecs`
Args:
vec (torch.Tensor): base vector
set_existing_vecs (torch.Tensor): neighboring vectors in which to do the search
Returns:
torch.Tensor: nearest neighbor of `vec` in `set_existing_vecs`
"""
dists = torch.norm(vec - set_existing_vecs, dim=1, p=1)
knn_idx = dists.topk(1, largest=False).indices[0]
return set_existing_vecs[knn_idx][None, :], knn_idx
def gen_warmup_vecs_and_rewards(n_warmup, combis, risks, p):
vecs = []
rewards = []
for i in range(n_warmup):
idx = p.multinomial(num_samples=1).item()
vec = combis[idx]
reward = risks[idx]
vecs.append(vec.tolist())
rewards.append([reward])
vecs = torch.tensor(vecs)
rewards = torch.tensor(rewards)
return vecs, rewards
def find_best_member(
agent_eval_fn,
de_config,
proba,
set_init_vecs,
seed,
ci_thresh,
threshold,
n_sigmas_conf,
):
"""
Run DE to find the best vector for the current agent
Args:
agent_eval_fn (function): agent's evaluation function
de_config (DEConfig): DE Configuration
proba (torch.Tensor): initialization probas for vectors in DE
set_init_vecs (torch.Tensor): available vectors to initialize from
seed (int): seed to set up DE
ci_thresh (bool): if True, forces reward sorting to get the best member which has a lower ci intersecting with threshold.
thresh (float): threshold for CI intersection. If ci_thresh = True, then this must be set to a real number.
n_sigmas_conf (float): Number of sigmas to consider for confidence (sigma-rule) around network activation (mu).Used for stopping exploitation of a known good arm.
Returns:
torch.Tensor: Best member from DE's population
"""
# Reseed DE optim to diversify populations across timesteps
de_config.seed = seed
config = Config(default_config)
@config("policy")
class PolicyConfig:
policy: Type[Policy] = PullPolicy
eval_fn: object = agent_eval_fn
p: torch.Tensor = proba
set_existing_vecs: torch.Tensor = set_init_vecs
thresh: float = threshold
n_sigmas: float = n_sigmas_conf
config("de")(de_config)
de = DE(config)
de.train()
# If this is true, then returned best member must
# have its lower CI bound be lower or equal to the
# threshold
best = de.population[de.current_best]
if ci_thresh and not best.lower_ci_under_thresh:
sorted_rewards_idx = np.flip(np.argsort(de.rewards))
sorted_pop = de.population[sorted_rewards_idx]
sorted_pop_inter = [m.lower_ci_under_thresh for m in sorted_pop]
first_occ_idx = np.where(sorted_pop_inter)[0][0]
return sorted_pop[first_occ_idx]
else:
return best
def make_deterministic(seed=42):
# PyTorch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Numpy
np.random.seed(seed)
# Built-in Python
random.seed(seed)
def get_n_inter(found_solution: set, true_solution: set):
intersection = found_solution & true_solution
n_in_inter = len(intersection)
return n_in_inter
def parse_args():
parser = argparse.ArgumentParser(
description="Train a NeuralTS/UCB agent on a given dataset"
)
parser.add_argument(
"-T", "--trials", type=int, required=True, help="Number of trials for the agent"
)
parser.add_argument(
"-d", "--dataset", required=True, help="Name of dataset (located in datasets/*)"
)
parser.add_argument(
"-t",
"--threshold",
type=float,
required=True,
help="Good and bad action threshold",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="Set random seed base for training, only affects network initialization",
)
parser.add_argument(
"-w",
"--width",
type=int,
default=128,
help="Width of the NN (number of neurons)",
)
parser.add_argument(
"-l",
"--layers",
type=int,
default=1,
help="Number of hidden layers",
)
parser.add_argument(
"-r",
"--reg",
type=float,
default=1,
help="Regularization factor for the bandit AND weight decay (lambda)",
)
parser.add_argument(
"-e",
"--exploration",
type=float,
default=1,
help="Exploration multiplier",
)
parser.add_argument(
"--n_epochs",
type=int,
default=100,
help="Number of epochs / gradient steps (if full GD) in NeuralTS/NeuralUCB",
)
parser.add_argument(
"--lr",
default=0.01,
help="Learning rate for SGD / Adam optimizer",
)
parser.add_argument(
"--style",
type=str,
default="ts",
choices=["ts", "ucb"],
help="Choose between NeuralTS and NeuralUCB to train",
)
parser.add_argument(
"--optimizer",
type=str,
default="adam",
choices=["sgd", "adam"],
help="Select SGD or Adam as optimizer for NN",
)
parser.add_argument(
"--warmup",
type=int,
default=100,
help="Number of warmup steps",
)
parser.add_argument(
"-o",
"--output",
type=str,
default="saves/ouput/",
help="Output directory for metrics and agents",
)
parser.add_argument(
"--pop_n_members",
type=int,
default=256,
help="Number of members for the population optimizer",
)
parser.add_argument(
"--pop_n_steps",
type=int,
default=16,
help="Number of step for the population optimizer",
)
parser.add_argument(
"--pop_lr",
type=float,
default=1e-3,
help="Learning rate for the population optimizer (if gradient based)",
)
parser.add_argument(
"--batch_size",
default=32,
type=int,
help="Batch size for learning (specify -1 for full batch)",
)
parser.add_argument(
"--n_sigmas",
type=float,
default=3,
help="Number of sigmas to consider (sigma-rule) for confidence around a network's given activation",
)
parser.add_argument(
"--ci_thresh",
action="store_true",
help="Tells the agent to play the best arm which has an interesecting lower CI with the threshold",
)
parser.add_argument(
"--patience",
type=int,
default=25,
help="Patience for early stopping during training",
)
parser.add_argument(
"--valtype",
type=str,
default="noval",
help="Strategy for validation set selection",
)
parser.add_argument(
"--nobatchnorm",
action="store_true",
help="Use batch norm in neural network",
)
parser.add_argument(
"--lds",
default="True",
choices=["True", "sqrt_inv", "False"],
help="Strategy for label distribution smoothing",
)
parser.add_argument(
"--usedecay",
action="store_true",
help="Use weight decay during training",
)
parser.add_argument(
"--ntrain",
type=int,
default=-1,
help="Number of samples to take during training. Can help if there are enough observations to notice a slow down of the training.",
)
parser.add_argument(
"--train_every",
type=int,
default=1,
help="Number of timesteps to play before retraining",
)
args = parser.parse_args()
return args
def do_gradient_optim(agent, n_steps, existing_vecs, lr, bounds=[0, 1]):
# Generate a random vector to optimize
sample_idx = random.randint(
0, len(existing_vecs) - 1
) # random.randint includes upper bound...
input_vec = existing_vecs[sample_idx][None].clone()
input_vec.requires_grad = True
optimizer = torch.optim.Adam([input_vec], lr=lr)
population = input_vec.detach().clone()
population_values = []
# Do n_steps gradient steps, optimizing a noisy sample from the distribution of the input_vec
for i in range(n_steps):
# Clear gradients for sample
optimizer.zero_grad(set_to_none=True)
agent.net.zero_grad(set_to_none=True)
# Evaluate
sample_r, g_list, mu, cb = agent.get_sample(input_vec)
# Clear gradient from sampling because a backprop happens in there
optimizer.zero_grad(set_to_none=True)
agent.net.zero_grad(set_to_none=True)
# Record input_vecs and values in the population
population_values.append(sample_r.item())
# Backprop
sample_r = -sample_r
sample_r.backward()
optimizer.step()
population = torch.cat((population, input_vec.detach().clone()))
# Clear gradients for sample
optimizer.zero_grad(set_to_none=True)
agent.net.zero_grad(set_to_none=True)
# Record final optimized input_vecs in population since they're the last optimizer steps product
sample_r, g_list, mu, cb = agent.get_sample(input_vec)
population_values.append(sample_r.item())
# Clean up grad before exiting
optimizer.zero_grad(set_to_none=True)
agent.net.zero_grad(set_to_none=True)
population_values = torch.tensor(population_values)
# Find the best generated vector
max_idx = torch.argmax(population_values)
best_vec = population[max_idx]
# Coerce to an existing vector via L1 norm
a_t, idx = change_to_closest_existing_vector(best_vec, existing_vecs)
# Take the wanted vector in binary form, so we can compute a decent gradient on it
clipped_and_rounded_vec = torch.round(torch.clip(best_vec, *bounds))
_, g_list = agent.compute_activation_and_grad(clipped_and_rounded_vec)
return a_t, idx, g_list
def load_dataset(dataset_name, path_to_dataset="datasets"):
dataset = pd.read_csv(f"{path_to_dataset}/combinations/{dataset_name}.csv")
with open(f"{path_to_dataset}/patterns/{dataset_name}.json", "r") as f:
patterns = json.load(f)
# Remove last 3 columns that are risk, inter, dist
combis = dataset.iloc[:, :-3]
# Retrieve risks
risks = dataset.iloc[:, -3]
n_obs, n_dim = combis.shape
pat_vecs = torch.tensor(
[patterns[f"pattern_{i}"]["pattern"] for i in range(len(patterns))]
)
combis, risks = (
torch.tensor(combis.values).float(),
torch.tensor(risks.values).unsqueeze(1).float(),
)
return combis, risks, pat_vecs, n_obs, n_dim
def get_data_splits(combis, risks, valtype="extrema"):
if valtype == "extrema":
# Use extrema for validation set (min and max)
min_idx = torch.argmin(risks)
max_idx = torch.argmax(risks)
ids = torch.tensor([min_idx, max_idx])
X_val = combis[ids]
y_val = risks[ids]
min_of_indexes = min(min_idx, max_idx)
max_of_indexes = max(min_idx, max_idx)
# Remove used indexes from the observations
combis = torch.cat((combis[:max_of_indexes], combis[max_of_indexes + 1 :]))
combis = torch.cat((combis[:min_of_indexes], combis[min_of_indexes + 1 :]))
risks = torch.cat((risks[:max_of_indexes], risks[max_of_indexes + 1 :]))
risks = torch.cat((risks[:min_of_indexes], risks[min_of_indexes + 1 :]))
elif valtype == "bins":
# Use bins for validation set, so we have a wide spread
bin_size = 0.1
min_range = 0
max_range = 4
bins_edges = [
round(min_range + (i * bin_size), 1)
for i in range(int((max_range - min_range) / bin_size) + 1)
]
X_val = []
y_val = []
# Put one observation per bin
for i in range(len(bins_edges) - 1):
lower_bound = bins_edges[i]
upper_bound = bins_edges[i + 1]
bigger_than = risks >= lower_bound
smaller_than = risks < upper_bound
inbound = torch.cat((bigger_than, smaller_than), dim=1).all(dim=1)
idx = torch.where(inbound)[0]
if len(idx) > 0:
idx = idx[0]
X_val.append(combis[idx])
y_val.append(risks[idx])
combis = torch.cat((combis[:idx], combis[idx + 1 :]))
risks = torch.cat((risks[:idx], risks[idx + 1 :]))
X_val = torch.stack(X_val)
y_val = torch.stack(y_val)
if valtype == "noval":
X_val = None
y_val = None
X_train, y_train = combis, risks
return X_train, y_train, X_val, y_val
def compute_metrics(
agent,
combis,
thresh,
pat_vecs,
true_sol_idx,
n_sigmas,
all_flagged_combis_idx,
all_flagged_pats_idx,
):
"""Compute metrics for combination test
Args:
agent (OptimNeuralTS): the bandit agent
combis (torch.Tensor): all possible combinations of Rx in the dataset
thresh (float): threshold of risk
pat_vecs (torch.Tensor): pattern vectors used to generate dataset
true_sol (torch.Tensor): true solution of the dataset
n_sigmas (float): number of sigmas to consider (sigma-rule sense)
all_flagged_combis (torch.Tensor): all previously flagged combinations
all_flagged_pats (torch.Tensor): all previously flagged patterns
Returns:
tuple: tuple of metrics and updated tensors in the following order:
recall for current step,
precision for current step,
percent_found_pat for current step,
n_inter for current step,
recall for all steps so far,
precision for all steps so far,
percent_found_pat for all steps so far,
n_inter for all steps so far,
updated all flagged combis,
updated all flagged pats,
"""
# Parmis tous les vecteurs "existant", lesquels je trouve ? (Jaccard, precision)
sol_idx, _, _ = agent.find_solution_in_vecs(combis, thresh, n_sigmas)
all_flagged_combis_idx.update(sol_idx)
# Parmis les patrons dangereux (ground truth), combien j'en trouve tels quels
sol_pat_idx, _, _ = agent.find_solution_in_vecs(pat_vecs, thresh, n_sigmas)
all_flagged_pats_idx.update(sol_pat_idx)
# À quel point ma solution trouvée parmis les vecteurs du dataset est similaire à la vraie solution
n_inter = get_n_inter(sol_idx, true_sol_idx) # Jaccard for the current step
n_inter_all = get_n_inter(
all_flagged_combis_idx, true_sol_idx
) # Jaccard for all steps before + this one if we keep all previous solutions
# Combien de patrons tels quels j'ai flag ?
percent_found_pat = len(sol_pat_idx) / len(pat_vecs) # For this step
percent_found_pat_all = len(all_flagged_pats_idx) / len(
pat_vecs
) # For all previous steps and this one
# A quel point ma solution trouvee parmis les vecteurs du dataset est dans la vraie solution
if len(sol_idx) == 0:
precision = float("nan")
recall = float("nan")
else:
precision = n_inter / len(sol_idx)
recall = n_inter / len(true_sol_idx)
if len(all_flagged_combis_idx) == 0:
precision_all = float("nan")
recall_all = float("nan")
else:
precision_all = n_inter_all / len(all_flagged_combis_idx)
recall_all = n_inter_all / len(true_sol_idx)
return (
recall,
precision,
percent_found_pat,
n_inter,
recall_all,
precision_all,
percent_found_pat_all,
n_inter_all,
all_flagged_combis_idx,
all_flagged_pats_idx,
)
def discretize_targets(targets, factor):
discrete_risks = torch.floor(targets * factor) / factor
discrete_risks = np.round(discrete_risks.cpu().numpy(), decimals=1)
return discrete_risks
def build_histogram(targets, factor, bin_size):
# Determine the bin edges
min_bin = floor(min(targets) * factor) / factor
max_bin = ceil(max(targets) * factor) / factor
# Handles case where maximum is exactly on the edge of the last bin
if round(max(targets).item(), 1) == max_bin:
max_bin += bin_size
n_bins = round((max_bin - min_bin) / 0.1) # Deal with poor precision in Python...
list_bin_edges = np.around(
[min_bin + (bin_size * i) for i in range(n_bins + 1)], 1
).astype("float32")
bin_edges = torch.from_numpy(list_bin_edges)
# Build discretized distribution with histogram
hist = torch.histogram(targets.cpu(), bin_edges)
return hist, n_bins, list_bin_edges
def gaussian_fn(size, std):
n = torch.arange(0, size) - (size - 1.0) / 2.0
sig2 = 2 * std * std
w = torch.exp(-(n ** 2) / sig2)
return w
def get_model_selection_loss(net, X_val, y_val, loss_fn):
with torch.no_grad():
pred = net(X_val)
loss = loss_fn(pred, y_val)
return loss