-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate_shift_lookup_table.py
54 lines (48 loc) · 1.66 KB
/
generate_shift_lookup_table.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import pandas as pd
from pyampute.ampute import MultivariateAmputation
from pyampute.utils import LOOKUP_TABLE_PATH
def generate_shift_lookup_table(
lookup_table_path: str = LOOKUP_TABLE_PATH,
n_samples: int = int(1e6),
lower_range: float = MultivariateAmputation.DEFAULTS["lower_range"],
upper_range: float = MultivariateAmputation.DEFAULTS["upper_range"],
max_iter: int = MultivariateAmputation.DEFAULTS["max_iter"],
max_diff_with_target: float = MultivariateAmputation.DEFAULTS[
"max_diff_with_target"
],
):
"""
Note: This should be run from the root folder so it is properly stored in "/data".
"""
normal_sample = np.random.standard_normal(size=n_samples)
percent_missing = np.arange(0.01, 1.01, 0.01)
score_to_prob_func_names = [
"SIGMOID-RIGHT",
"SIGMOID-LEFT",
"SIGMOID-TAIL",
"SIGMOID-MID",
]
shifts = []
for func in score_to_prob_func_names:
shifts.append(
[
MultivariateAmputation._binary_search(
normal_sample,
func,
percent,
lower_range,
upper_range,
max_iter,
max_diff_with_target,
)[0]
for percent in percent_missing
]
)
percent_missing_2_decimal = ["{:.2f}".format(p) for p in percent_missing]
lookup_table = pd.DataFrame(
shifts, index=score_to_prob_func_names, columns=percent_missing_2_decimal,
)
lookup_table.to_csv(lookup_table_path)
if __name__ == "__main__":
generate_shift_lookup_table()