From 588126068b45f55a74a5e6d23d3fef4348322620 Mon Sep 17 00:00:00 2001 From: Ben Thompson Date: Mon, 17 Oct 2022 12:32:46 -0400 Subject: [PATCH] Lewis model (#58) * Add skeleton lei stuff * Add lei notebook for now and poisson process fun study * Add stuff for ben to see * Add current changes for Ben once again (so needy :P) * Fixed the hang problem. * Commit what I have * Add currently broken code once again * Simplify notebook * Add working version of lewis * Add current progress * Add batching method and current logic of lei * Move lei stuff to its own package * Working on linear interpolation for the Lei problem. * inlaw -> outlaw. * Add lei test n_config * Fix settings.json * Add unit tests and update notebook with correct simulations * Update test, add simulation tests, add point batcher, share RNG * Update comment * JAX implementation of scipy.interpolate.interpn (#47) * JAX Interpolation. * JAX implementation of scipy.interpolate.interpn * Update todo list. * Add current version lol * Fix bugs and integrate good version * Fix small bug in stage 2 and clean up code * Modify interpn to work with multi-dimensional values * Add current version of notebook * WTF * Finish final lei * Fix test in outlaw * Add python notebook (weird vscode lol) * Add lei simulator batching method * Remove unnecessary files cluttering up space * Add current state * Add upper bound logic to lei example * Add ignore to frontend and update lei flow * Clean up lewis code and include some of Ben's changes * Add new script * Add new changes to make memory ok * Add full changes to everything except key * Add checkpointing * Add modified version * First pass at holder-odi bound in binomial.py * Holder-ODI, feeling more confident. * Add analyze lei example * Move lewis into confirm * Fix analyze notebook with new import structure * Add np.isnan check for holder bound and update lei analyze scripts * Moving files, small tweaks. * Pre-commit fixes. * Most tests passing. * Fix test stage1. Co-authored-by: James Yang --- .vscode/launch.json | 2 +- .vscode/settings.json | 249 ++-- confirm/confirm/lewislib/__init__.py | 0 confirm/confirm/lewislib/batch.py | 84 ++ confirm/confirm/lewislib/grid.py | 23 + confirm/confirm/lewislib/jax_wrappers.py | 48 + confirm/confirm/lewislib/lewis.py | 993 ++++++++++++++ confirm/confirm/lewislib/table.py | 133 ++ confirm/confirm/outlaw/interp.py | 98 ++ confirm/tests/lewis/test_hash.py | 85 ++ confirm/tests/lewis/test_n_configs.py | 235 ++++ .../tests/lewis/test_permute_invariance.py | 86 ++ .../tests/lewis/test_posterior_difference.py | 33 + confirm/tests/lewis/test_simulation.py | 77 ++ confirm/tests/test_interp.py | 29 + imprint/.vscode/build.sh | 4 +- imprint/frontend/.gitignore | 3 + imprint/frontend/tsconfig.json | 2 +- install.sh | 2 +- research/berry/berry_part1.ipynb | 12 +- research/berry/berry_part1.md | 2 +- research/lei/.gitignore | 1 + research/lei/analyze/analyze.ipynb | 321 +++++ research/lei/analyze/analyze.md | 196 +++ research/lei/analyze/download_data.sh | 10 + research/lei/lei.ipynb | 1177 +++++++++++++++++ research/lei/lei.md | 712 ++++++++++ research/stat/poisson_process.ipynb | 171 +++ research/stat/poisson_process.md | 111 ++ 29 files changed, 4763 insertions(+), 136 deletions(-) create mode 100644 confirm/confirm/lewislib/__init__.py create mode 100644 confirm/confirm/lewislib/batch.py create mode 100644 confirm/confirm/lewislib/grid.py create mode 100644 confirm/confirm/lewislib/jax_wrappers.py create mode 100644 confirm/confirm/lewislib/lewis.py create mode 100644 confirm/confirm/lewislib/table.py create mode 100644 confirm/confirm/outlaw/interp.py create mode 100644 confirm/tests/lewis/test_hash.py create mode 100644 confirm/tests/lewis/test_n_configs.py create mode 100644 confirm/tests/lewis/test_permute_invariance.py create mode 100644 confirm/tests/lewis/test_posterior_difference.py create mode 100644 confirm/tests/lewis/test_simulation.py create mode 100644 confirm/tests/test_interp.py create mode 100644 research/lei/.gitignore create mode 100644 research/lei/analyze/analyze.ipynb create mode 100644 research/lei/analyze/analyze.md create mode 100755 research/lei/analyze/download_data.sh create mode 100644 research/lei/lei.ipynb create mode 100644 research/lei/lei.md create mode 100644 research/stat/poisson_process.ipynb create mode 100644 research/stat/poisson_process.md diff --git a/.vscode/launch.json b/.vscode/launch.json index 306f58eb..0f49b091 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ "request": "launch", "program": "${file}", "console": "integratedTerminal", - "justMyCode": true + "justMyCode": false, } ] } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index b267766f..8c914ab4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,127 +1,126 @@ { - "bazel-cpp-tools.compileCommands.targets": [ - "//...", - ], - "jupyter.jupyterServerType": "local", - "files.associations": { - "functional": "cpp", - "*.evaluator": "cpp", - "*.traits": "cpp", - "fft": "cpp", - "openglsupport": "cpp", - "regex": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "any": "cpp", - "array": "cpp", - "atomic": "cpp", - "bit": "cpp", - "*.tcc": "cpp", - "bitset": "cpp", - "cctype": "cpp", - "chrono": "cpp", - "cinttypes": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "codecvt": "cpp", - "complex": "cpp", - "condition_variable": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "deque": "cpp", - "forward_list": "cpp", - "list": "cpp", - "map": "cpp", - "set": "cpp", - "unordered_map": "cpp", - "unordered_set": "cpp", - "vector": "cpp", - "exception": "cpp", - "algorithm": "cpp", - "iterator": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "numeric": "cpp", - "optional": "cpp", - "random": "cpp", - "ratio": "cpp", - "string": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "utility": "cpp", - "hash_map": "cpp", - "fstream": "cpp", - "future": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "istream": "cpp", - "limits": "cpp", - "mutex": "cpp", - "new": "cpp", - "ostream": "cpp", - "shared_mutex": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "streambuf": "cpp", - "thread": "cpp", - "typeinfo": "cpp", - "valarray": "cpp", - "variant": "cpp", - "filesystem": "cpp", - "locale": "cpp", - "mprealsupport": "cpp", - "nonlinearoptimization": "cpp", - "dense": "cpp", - "__bit_reference": "cpp", - "__bits": "cpp", - "__config": "cpp", - "__debug": "cpp", - "__errc": "cpp", - "__hash_table": "cpp", - "__locale": "cpp", - "__mutex_base": "cpp", - "__node_handle": "cpp", - "__nullptr": "cpp", - "__split_buffer": "cpp", - "__string": "cpp", - "__threading_support": "cpp", - "__tree": "cpp", - "__tuple": "cpp", - "compare": "cpp", - "concepts": "cpp", - "ios": "cpp", - "queue": "cpp", - "stack": "cpp", - "__functional_base": "cpp", - "alignedvector3": "cpp", - "typeindex": "cpp", - "*.ipp": "cpp", - "*.inc": "cpp", - "core": "cpp", - "geometry": "cpp", - "qtalignedmalloc": "cpp", - "matrixfunctions": "cpp", - "bvh": "cpp" - }, - "C_Cpp.errorSquiggles": "Enabled", - "C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}", - "cmake.configureOnOpen": false, - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "python.linting.pylintEnabled": false, - "python.linting.flake8Enabled": true, - "python.linting.enabled": true, - "python.formatting.provider": "black", - "autoDocstring.docstringFormat": "google-notypes", - "r.bracketedPaste": true, - "r.plot.useHttpgd": true + "bazel-cpp-tools.compileCommands.targets": ["//..."], + "jupyter.jupyterServerType": "local", + "files.associations": { + "functional": "cpp", + "*.evaluator": "cpp", + "*.traits": "cpp", + "fft": "cpp", + "openglsupport": "cpp", + "regex": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "any": "cpp", + "array": "cpp", + "atomic": "cpp", + "bit": "cpp", + "*.tcc": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "cinttypes": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "codecvt": "cpp", + "complex": "cpp", + "condition_variable": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "forward_list": "cpp", + "list": "cpp", + "map": "cpp", + "set": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "vector": "cpp", + "exception": "cpp", + "algorithm": "cpp", + "iterator": "cpp", + "memory": "cpp", + "memory_resource": "cpp", + "numeric": "cpp", + "optional": "cpp", + "random": "cpp", + "ratio": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "utility": "cpp", + "hash_map": "cpp", + "fstream": "cpp", + "future": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "limits": "cpp", + "mutex": "cpp", + "new": "cpp", + "ostream": "cpp", + "shared_mutex": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "thread": "cpp", + "typeinfo": "cpp", + "valarray": "cpp", + "variant": "cpp", + "filesystem": "cpp", + "locale": "cpp", + "mprealsupport": "cpp", + "nonlinearoptimization": "cpp", + "dense": "cpp", + "__bit_reference": "cpp", + "__bits": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__nullptr": "cpp", + "__split_buffer": "cpp", + "__string": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__tuple": "cpp", + "compare": "cpp", + "concepts": "cpp", + "ios": "cpp", + "queue": "cpp", + "stack": "cpp", + "__functional_base": "cpp", + "alignedvector3": "cpp", + "typeindex": "cpp", + "*.ipp": "cpp", + "*.inc": "cpp", + "core": "cpp", + "geometry": "cpp", + "qtalignedmalloc": "cpp", + "matrixfunctions": "cpp", + "bvh": "cpp" + }, + "C_Cpp.errorSquiggles": "Enabled", + "C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}", + "cmake.configureOnOpen": false, + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "python.linting.pylintEnabled": false, + "python.linting.flake8Enabled": true, + "python.linting.enabled": true, + "python.formatting.provider": "black", + "autoDocstring.docstringFormat": "google-notypes", + "r.bracketedPaste": true, + "r.plot.useHttpgd": true, + "python.analysis.extraPaths": ["./outlaw", "./imprint/python"] } diff --git a/confirm/confirm/lewislib/__init__.py b/confirm/confirm/lewislib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/confirm/confirm/lewislib/batch.py b/confirm/confirm/lewislib/batch.py new file mode 100644 index 00000000..bf9f0229 --- /dev/null +++ b/confirm/confirm/lewislib/batch.py @@ -0,0 +1,84 @@ +import numpy as np + + +def pad_arg__(a, axis, n_pad: int): + pad_element = np.take(a, indices=0, axis=axis) + pad_element = np.expand_dims(pad_element, axis=axis) + new_shape = tuple(a.shape[i] if i != axis else n_pad for i in range(a.ndim)) + return np.concatenate((a, np.full(new_shape, pad_element)), axis=axis) + + +def create_batched_args__(args, in_axes, start, end, n_pad=None): + def arg_transform(arg, axis): + return pad_arg__(arg, axis, n_pad) if n_pad is not None else arg + + return [ + arg_transform( + np.take(arg, indices=range(start, end), axis=axis), + axis, + ) + if axis is not None + else arg + for arg, axis in zip(args, in_axes) + ] + + +def batch(f, batch_size: int, in_axes): + def internal(*args): + dims = np.array( + [arg.shape[axis] for arg, axis in zip(args, in_axes) if axis is not None] + ) + if len(dims) <= 0: + raise ValueError( + "f must take at least one argument " + "whose corresponding in_axes is not None." + ) + + dims_all_equal = np.sum(dims != dims[0]) == 0 + if not dims_all_equal: + raise ValueError( + "All batched arguments must have the same dimension " + "along their corresopnding in_axes." + ) + + dim = dims[0] + batch_size_new = min(batch_size, dim) + n_full_batches = dim // batch_size_new + remainder = dim % batch_size_new + n_pad = batch_size_new - remainder + pad_last = remainder > 0 + start = 0 + end = batch_size_new + + for _ in range(n_full_batches): + batched_args = create_batched_args__( + args=args, + in_axes=in_axes, + start=start, + end=end, + ) + yield (f(*batched_args), 0) + start += batch_size_new + end += batch_size_new + + if pad_last: + batched_args = create_batched_args__( + args=args, + in_axes=in_axes, + start=start, + end=dim, + n_pad=n_pad, + ) + yield (f(*batched_args), n_pad) + + return internal + + +def batch_all(f, batch_size: int, in_axes): + f_batch = batch(f, batch_size, in_axes) + + def internal(*args): + outs = tuple(out for out in f_batch(*args)) + return tuple(out[0] for out in outs), outs[-1][-1] + + return internal diff --git a/confirm/confirm/lewislib/grid.py b/confirm/confirm/lewislib/grid.py new file mode 100644 index 00000000..0bd3a11e --- /dev/null +++ b/confirm/confirm/lewislib/grid.py @@ -0,0 +1,23 @@ +import numpy as np +import pyimprint.grid as pygrid + + +def make_cartesian_grid_range(size, lower, upper): + assert lower.shape[0] == upper.shape[0] + + # make initial 1d grid + center_grids = ( + pygrid.Gridder.make_grid(size, lower[i], upper[i]) for i in range(len(lower)) + ) + + # make a grid of centers + coords = np.meshgrid(*center_grids) + centers = np.concatenate([c.flatten().reshape(-1, 1) for c in coords], axis=1) + + # make corresponding radius + radius = np.array( + [pygrid.Gridder.radius(size, lower[i], upper[i]) for i in range(len(lower))] + ) + radii = np.full(shape=centers.shape, fill_value=radius) + + return centers, radii diff --git a/confirm/confirm/lewislib/jax_wrappers.py b/confirm/confirm/lewislib/jax_wrappers.py new file mode 100644 index 00000000..91686676 --- /dev/null +++ b/confirm/confirm/lewislib/jax_wrappers.py @@ -0,0 +1,48 @@ +import jax.numpy as jnp + + +class ArraySlice0: + def __init__(self, a, start, end): + self.array = a + self.start = start + self.end = end # TODO: unused + + def __getitem__(self, index): + return self.array[self.start + index] + + +class ArrayReshape0: + def __init__(self, a, shape): + self.array = a + self.shape = shape + self.mask = jnp.flip(jnp.cumprod(jnp.flip(self.shape[1:]))) + + def __getitem__(self, index): + i = index[-1] + jnp.sum(self.mask * index[:-1]) + return self.array[i] + + +def slice0(a, start, end): + """ + Slices an array along axis 0 from start and end. + + Parameters: + ----------- + a: array to slice along axis 0. + start: starting position to slice. + end: ending position to slice (non-inclusive). + """ + return ArraySlice0(a, start, end) + + +def reshape0(a, shape): + """ + Reshapes a given array along the 0th axis + with a new shape. + + Parameters: + ----------- + a: array to reshape along axis 0. + shape: new shape of array along axis 0. + """ + return ArrayReshape0(a, shape) diff --git a/confirm/confirm/lewislib/lewis.py b/confirm/confirm/lewislib/lewis.py new file mode 100644 index 00000000..f257d246 --- /dev/null +++ b/confirm/confirm/lewislib/lewis.py @@ -0,0 +1,993 @@ +import jax +import jax.numpy as jnp +import numpy as np + +import confirm.outlaw.berry as berry +import confirm.outlaw.inla as inla +import confirm.outlaw.quad as quad +from confirm.lewislib import batch +from confirm.lewislib.table import LinearInterpTable +from confirm.lewislib.table import LookupTable + + +""" +The following class implements the Lei example. +See research/lei/lei.ipynb for the description. + +We define concepts used in the code: + +- `data in canonical form`: + `data` is of shape (n_arms, 2) + where n_arms is the number of arms in the trial and each row is a (y, n) + pair for the corresponding arm index. + The first row always corresponds to the control arm. + +- `n configuration`: + A valid sequence of `n` parameters in `data` + that is observable at any point in the trial. + +- `cached table`: + A cached table is assumed to be many table of values + row-stacked in the same order as a list of n configurations. + +- `pd`: + Posterior difference (between treatment and control arms): + P(p_i - p_0 < t | y, n) +- `pr_best`: + Posterior probability of best arm: + P(p_i = max_{j} p_j | y, n) +- `pps`: + Posterior probability of success: + P(Reject at stage 2 with all remaining + patients added to control and selected arm | + y, n, + selected arm = i) +""" + + +class Lewis45: + def __init__( + self, + n_arms: int, + n_stage_1: int, + n_stage_2: int, + n_stage_1_interims: int, + n_stage_1_add_per_interim: int, + n_stage_2_add_per_interim: int, + stage_1_futility_threshold: float, + stage_1_efficacy_threshold: float, + stage_2_futility_threshold: float, + stage_2_efficacy_threshold: float, + inter_stage_futility_threshold: float, + posterior_difference_threshold: float, + rejection_threshold: float, + sig2_int=quad.log_gauss_rule(15, 2e-6, 1e3), + n_sig2_sims: int = 20, + dtype=jnp.float64, + cache_tables=False, + **kwargs, + ): + """ + Constructs an object to run the Lei example. + + Parameters: + ----------- + n_arms: number of arms. + n_stage_1: number of patients to enroll at stage 1 for each arm. + n_stage_2: number of patients to enroll at stage 2 for each arm. + n_stage_1_interims: number of interims in stage 1. + n_stage_1_add_per_interim: number of total patients to + add per interim in stage 1. + n_stage_2_add_per_interim: number of patients to + add in stage 2 interim to control + and the selected treatment arms. + futility_threshold: probability cut-off to decide + futility for treatment arms. + If P(arm_i best | data) < futility_threshold, + declare arm_i as futile. + pps_threshold_lower: threshold for checking futility: + PPS < pps_threshold_lower <=> futility. + pps_threshold_upper: threshold for checking efficacy: + PPS > pps_threshold_upper <=> efficacy. + posterior_difference_threshold: threshold to compute posterior difference + of selected arm p and control arm p. + rejection_threshold: threshold for rejection at the final analysis + (if reached): + P(p_selected_treatment_arm - p_control_arm < + posterior_difference_threshold | data) + < rejection_threshold + <=> rejection. + """ + self.n_arms = n_arms + self.n_stage_1 = n_stage_1 + self.n_stage_2 = n_stage_2 + self.n_stage_1_interims = n_stage_1_interims + self.n_stage_1_add_per_interim = n_stage_1_add_per_interim + self.n_stage_2_add_per_interim = n_stage_2_add_per_interim + self.stage_1_futility_threshold = stage_1_futility_threshold + self.stage_1_efficacy_threshold = stage_1_efficacy_threshold + self.stage_2_futility_threshold = stage_2_futility_threshold + self.stage_2_efficacy_threshold = stage_2_efficacy_threshold + self.inter_stage_futility_threshold = inter_stage_futility_threshold + self.posterior_difference_threshold = posterior_difference_threshold + self.rejection_threshold = rejection_threshold + self.dtype = dtype + + # sig2 for quadrature integration + self.sig2_int = sig2_int + self.sig2_int.pts = self.sig2_int.pts.astype(self.dtype) + self.sig2_int.wts = self.sig2_int.wts.astype(self.dtype) + self.custom_ops_int = berry.optimized(self.sig2_int.pts, n_arms=n_arms).config( + opt_tol=1e-3 + ) + + # sig2 for simulation + self.sig2_sim = 10 ** jnp.linspace(-6, 3, n_sig2_sims, dtype=self.dtype) + self.dsig2_sim = jnp.diff(self.sig2_sim) + self.custom_ops_sim = berry.optimized(self.sig2_sim, n_arms=self.n_arms).config( + opt_tol=1e-3 + ) + + ## cache + # n configuration information + ( + self.n_configs_pr_best_pps_1, + self.n_configs_pps_2, + self.n_configs_pd, + ) = self.make_n_configs__() + + # diff_matrix[i]^T p = p[i+1] - p[0] + self.diff_matrix = np.zeros((self.n_arms - 1, self.n_arms)) + self.diff_matrix[:, 0] = -1 + np.fill_diagonal(self.diff_matrix[:, 1:], 1) + self.diff_matrix = jnp.array(self.diff_matrix) + + # order of arms used for auxiliary computations + self.order = jnp.arange(0, self.n_arms, dtype=int) + + # cache jitted internal functions + self.posterior_difference_table_internal_jit__ = None + self.pr_best_pps_1_internal_jit__ = None + self.pps_2_internal_jit__ = None + + # posterior difference tables for every possible combination of n + if cache_tables: + self.pd_table = self.posterior_difference_table__( + batch_size=kwargs["batch_size"] + ) + self.pr_best_pps_1_table = self.pr_best_pps_1_table__( + key=kwargs["key"], + n_pr_sims=kwargs["n_pr_sims"], + batch_size=kwargs["batch_size"], + ) + _, key = jax.random.split(kwargs["key"]) + self.pps_2_table = self.pps_2_table__( + key=key, + n_pr_sims=kwargs["n_pr_sims"], + batch_size=kwargs["batch_size"], + ) + + # =============================================== + # Table caching logic + # =============================================== + + def make_canonical__(self, data): + # we use the facts that: + # - arms that are not dropped always have + # n value at least as large as those that were dropped. + # - arms that are not dropped all have the same n values. + # This means a stable sort will always: + # - keep the first row in-place + # - only the treatment rows will be sorted + n = data[:, 1] + n_order = jnp.flip(n.shape[0] - 1 - jnp.argsort(jnp.flip(n), kind="stable")) + data = data[n_order] + data = jnp.stack((data[:, 0], data[:, 1] + 1), axis=-1) + n_order_inverse = jnp.argsort(n_order)[1:] - 1 + return data, n_order_inverse + + def make_n_configs__(self): + """ + Creates two 2-D arrays of all possible configurations of the `n` + Binomial parameter configurations throughout the trial. + Each row is a possible `n` configuration. + The first array contains all possible Phase II configurations. + The second array contains all possible Phase III configurations. + """ + + def internal(n_arr, n_add, n_interims, n_drop): + n_arms = n_arr.shape[-1] + out_all_ph2 = np.empty((0, n_arms), dtype=int) + + if n_interims <= 0: + return out_all_ph2 + + n_arr_new = np.copy(n_arr) + for n_drop_new in range(n_drop, n_arms - 1): + n_arr_incr = n_add // (n_arms - n_drop_new) + n_arr_new[n_drop_new:] += n_arr_incr + rest_all_ph2 = internal(n_arr_new, n_add, n_interims - 1, n_drop_new) + out_all_ph2 = np.vstack( + ( + out_all_ph2, + n_arr_new, + rest_all_ph2, + ) + ) + n_arr_new[n_drop_new:] -= n_arr_incr + + return out_all_ph2 + + # make array of all n configurations + n_arr = np.full(self.n_arms, self.n_stage_1, dtype=int) + n_configs_ph2 = internal( + n_arr, self.n_stage_1_add_per_interim, self.n_stage_1_interims, 0 + ) + n_configs_ph2 = np.vstack( + ( + n_arr, + n_configs_ph2, + ) + ) + n_configs_ph2 = np.unique(n_configs_ph2, axis=0) + + n_configs_ph2 = np.fliplr(n_configs_ph2) + + n_configs_ph3 = np.copy(n_configs_ph2) + n_configs_ph3[:, :2] += self.n_stage_2 + + n_configs_pr_best_pps_1 = n_configs_ph2 + n_configs_pps_2 = n_configs_ph3 + n_configs_pd = np.copy(n_configs_ph3) + n_configs_pd[:, :2] += self.n_stage_2_add_per_interim + + return n_configs_pr_best_pps_1, n_configs_pps_2, n_configs_pd + + def table_data__(self, ns, coords): + """ + Creates a data array used to construct internal tables. + + Parameters: + ----------- + ns: n parameter. + coords: result of calling jnp.meshgrid(..., indexing="ij") + + Returns: + -------- + data used for table construction. + """ + data = jnp.concatenate([c.flatten().reshape(-1, 1) for c in coords], axis=1) + n_arr = jnp.full_like(data, ns) + data = jnp.stack((data, n_arr), axis=-1) + return data + + def make_grid__(self, ns, n_points): + """ + Creates a 2-D array of shape (d, n_points) + where d is n.shape[0]. + Each row is a 1-D gridding of points for each entry of n + by creating evenly-spaced gridding from [0, n[i]). + The gridding always includes 0 and n[i]-1. + If n_points is greater than the min(n) it is clipped to be min(n). + """ + + def internal(n): + n_points_clip = jnp.minimum(jnp.min(n), n_points) + steps = (n - 1) // (n_points_clip - 1) + n_no_end = steps * (n_points_clip - 1) + return jnp.array( + [ + jnp.concatenate( + (jnp.arange(n_no_end[idx], step=steps[idx]), n[idx][None] - 1) + ) + for idx in range(len(n)) + ] + ) + + return jnp.array([internal(n) for n in ns]) + + def posterior_difference_table__( + self, + batch_size, + n_points=None, + ): + def internal(data): + return jax.vmap(self.posterior_difference, in_axes=(0,))(data) + + if n_points: + grid = self.make_grid__(self.n_configs_pd, n_points) + + def process_batch__(i, f, batch_size): + f_batched = batch.batch_all( + f, + batch_size, + in_axes=(0,), + ) + + if n_points: + meshgrid = jnp.meshgrid(*grid[i], indexing="ij") + else: + meshgrid = jnp.meshgrid( + *(jnp.arange(0, n + 1) for n in self.n_configs_pd[i]), indexing="ij" + ) + + outs, n_padded = f_batched( + self.table_data__(self.n_configs_pd[i], meshgrid) + ) + out = jnp.row_stack(outs) + return out[:(-n_padded)] if n_padded > 0 else out + + # if called for the first time, register jitted function + if self.posterior_difference_table_internal_jit__ is None: + self.posterior_difference_table_internal_jit__ = jax.jit(internal) + + tup_tables = tuple( + process_batch__( + i, self.posterior_difference_table_internal_jit__, batch_size + ) + for i in range(self.n_configs_pd.shape[0]) + ) + + if n_points: + return LinearInterpTable( + self.n_configs_pd + 1, + grid, + jnp.array(tup_tables), + ) + + else: + return LookupTable(self.n_configs_pd + 1, tup_tables) + + def pr_best_pps_1_table__(self, key, n_pr_sims, batch_size, n_points=None): + unifs = jax.random.uniform( + key=key, + shape=( + n_pr_sims, + self.n_stage_2 + self.n_stage_2_add_per_interim, + self.n_arms, + ), + ) + _, key = jax.random.split(key) + unifs_sig2 = jax.random.uniform( + key=key, + shape=(n_pr_sims,), + ) + _, key = jax.random.split(key) + normals = jax.random.normal(key, shape=(n_pr_sims, self.n_arms)) + + if n_points: + grid = self.make_grid__(self.n_configs_pr_best_pps_1, n_points) + + def internal(data): + return jax.vmap(self.pr_best_pps_1, in_axes=(0, None, None, None))( + data, normals, unifs_sig2, unifs + ) + + def process_batch__(i, f, batch_size): + f_batched = batch.batch_all( + f, + batch_size, + in_axes=(0,), + ) + + if n_points: + meshgrid = jnp.meshgrid(*grid[i], indexing="ij") + else: + meshgrid = jnp.meshgrid( + *(jnp.arange(0, n + 1) for n in self.n_configs_pr_best_pps_1[i]), + indexing="ij", + ) + + outs, n_padded = f_batched( + self.table_data__(self.n_configs_pr_best_pps_1[i], meshgrid) + ) + pr_best_outs = tuple(t[0] for t in outs) + pps_outs = tuple(t[1] for t in outs) + pr_best_out = jnp.row_stack(pr_best_outs) + pps_outs = jnp.row_stack(pps_outs) + return ( + (pr_best_out[:(-n_padded)], pps_outs[:(-n_padded)]) + if n_padded > 0 + else (pr_best_out, pps_outs) + ) + + # if called for the first time, register jitted function + if self.pr_best_pps_1_internal_jit__ is None: + self.pr_best_pps_1_internal_jit__ = jax.jit(internal) + + tup_tables = tuple( + process_batch__(i, self.pr_best_pps_1_internal_jit__, batch_size) + for i in range(self.n_configs_pr_best_pps_1.shape[0]) + ) + pr_best_tables = tuple(t[0] for t in tup_tables) + pps_tables = tuple(t[1] for t in tup_tables) + if n_points: + return LinearInterpTable( + self.n_configs_pr_best_pps_1 + 1, + grid, + (jnp.array(pr_best_tables), jnp.array(pps_tables)), + ) + else: + return LookupTable( + self.n_configs_pr_best_pps_1 + 1, (pr_best_tables, pps_tables) + ) + + def pps_2_table__(self, key, n_pr_sims, batch_size, n_points=None): + unifs = jax.random.uniform( + key=key, + shape=( + n_pr_sims, + self.n_stage_2_add_per_interim, + self.n_arms, + ), + ) + _, key = jax.random.split(key) + unifs_sig2 = jax.random.uniform( + key=key, + shape=(n_pr_sims,), + ) + _, key = jax.random.split(key) + normals = jax.random.normal( + key=key, + shape=(n_pr_sims, self.n_arms), + ) + + if n_points: + grid = self.make_grid__(self.n_configs_pps_2, n_points) + + def internal(data): + return jax.vmap(self.pps_2, in_axes=(0, None, None, None))( + data, normals, unifs_sig2, unifs + ) + + def process_batch__(i, f, batch_size): + f_batched = batch.batch_all( + f, + batch_size, + in_axes=(0,), + ) + + if n_points: + meshgrid = jnp.meshgrid(*grid[i], indexing="ij") + else: + meshgrid = jnp.meshgrid( + *(jnp.arange(0, n + 1) for n in self.n_configs_pps_2[i]), + indexing="ij", + ) + + outs, n_padded = f_batched( + self.table_data__(self.n_configs_pps_2[i], meshgrid) + ) + out = jnp.row_stack(outs) + return out[:(-n_padded)] if n_padded > 0 else out + + # if called for the first time, register jitted function + if self.pps_2_internal_jit__ is None: + self.pps_2_internal_jit__ = jax.jit(internal) + + tup_tables = tuple( + process_batch__(i, self.pps_2_internal_jit__, batch_size) + for i in range(self.n_configs_pps_2.shape[0]) + ) + if n_points: + return LinearInterpTable( + self.n_configs_pps_2 + 1, + grid, + jnp.array(tup_tables), + ) + else: + return LookupTable(self.n_configs_pps_2 + 1, tup_tables) + + def get_posterior_difference__(self, data): + data, n_order_inverse = self.make_canonical__(data) + return self.pd_table.at(data)[0][n_order_inverse] + + def get_pr_best_pps_1__(self, data): + data, n_order_inverse = self.make_canonical__(data) + outs = self.pr_best_pps_1_table.at(data) + return tuple(out[n_order_inverse] for out in outs) + + def get_pps_2__(self, data): + data, n_order_inverse = self.make_canonical__(data) + return self.pps_2_table.at(data)[0][n_order_inverse] + + # =============================================== + # Core routines for computing Bayesian quantities + # =============================================== + + def sample_posterior_sigma_sq(self, post, unifs): + """ + Samples from p(sigma^2 | data) given by the density + (up to a constant), post. + Assumes that post is computed on the grid self.sig2_sim in the same order. + The sampling is approximate as it samples from the discrete + measure defined by normalizing the histogram given by post. + """ + dFx = post[:-1] * self.dsig2_sim + Fx = jnp.cumsum(dFx) + Fx /= Fx[-1] + i_star = jnp.searchsorted(Fx, unifs) + return i_star + 1 + + def hessian_to_covariance(self, hess): + """ + Computes the covariance from the Hessian + (w.r.t. theta) of p(data, theta, sigma^2). + + Parameters: + ----------- + hess: tuple of (H_a, H_b) where H_a is of + shape (..., n) and H_b is of shape (..., 1). + The full Hessian is given by diag(H_a) + H_b 11^T. + + Returns: + -------- + Covariance matrix of shape (..., n, n) by inverting + and negating each term of hess. + """ + _, n_arms = hess[0].shape + hess_fn = jax.vmap( + lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1]) + ) + prec = -hess_fn(hess) # (n_sigs, n_arms, n_arms) + return jnp.linalg.inv(prec) + + def posterior_sigma_sq_int(self, data): + """ + Computes p(sigma^2 | data) using INLA on a grid defined by self.sig2_int.pts. + + Returns: + -------- + post: p(sigma^2 | data) evaluated on self.sig2_int.pts. + x_max: mode of x -> p(data, x, sigma^2) evaluated + for each point in self.sig2_int.pts. + hess: tuple of Hessian information (H_a, H_b) such that the Hessian + is given as in hessian_to_covariance(). + iters: number of iterations. + """ + + n_arms, _ = data.shape + sig2 = self.sig2_int.pts + n_sig2 = sig2.shape[0] + p_pinned = dict(sig2=sig2, theta=None) + f = self.custom_ops_int.laplace_logpost + logpost, x_max, hess, iters = f( + np.zeros((n_sig2, n_arms), dtype=self.dtype), p_pinned, data + ) + post = inla.exp_and_normalize(logpost, self.sig2_int.wts, axis=-1) + return post, x_max, hess, iters + + def posterior_sigma_sq_sim(self, data): + """ + Computes p(sigma^2 | data) using INLA on a grid defined by self.sig2_sim. + + Returns: + -------- + post: p(sigma^2 | data) (up to a constant) evaluated on self.sig2_sim. + x_max: mode of x -> p(data, x, sigma^2) + evaluated for each point in self.sig2_sim. + hess: tuple of Hessian information (H_a, H_b) such that + the Hessian is given as in hessian_to_covariance(). + iters: number of iterations. + """ + n_arms, _ = data.shape + sig2 = self.sig2_sim + n_sig2 = sig2.shape[0] + p_pinned = dict(sig2=sig2, theta=None) + logpost, x_max, hess, iters = self.custom_ops_sim.laplace_logpost( + np.zeros((n_sig2, n_arms), dtype=self.dtype), p_pinned, data + ) + max_logpost = jnp.max(logpost) + max_post = jnp.exp(max_logpost) + post = jnp.exp(logpost - max_logpost) * max_post + return post, x_max, hess, iters + + def posterior_difference(self, data): + """ + Computes p(p_i - p_0 < self.posterior_threshold | data) + for i = 1,..., d-1 where d is the total number of arms. + + Returns: + -------- + 1-D array of length d-1 where the ith component is + p(p_{i+1} - p_0 < self.posterior_threshold | data) + """ + post, x_max, hess, _ = self.posterior_sigma_sq_int(data) + + post_weighted = self.sig2_int.wts * post + cov = self.hessian_to_covariance(hess) + + def post_diff_given_sigma(mean, cov): + loc = self.diff_matrix @ mean + # var = [..., qi^T C qi, ..., ] where qi = self.diff_matrix[i] + var = jnp.sum((self.diff_matrix @ cov) * self.diff_matrix, axis=-1) + scale = jnp.sqrt(jnp.maximum(var, 0)) + normal_term = jax.scipy.stats.norm.cdf( + self.posterior_difference_threshold, loc=loc, scale=scale + ) + return normal_term + + normal_term = jax.vmap(post_diff_given_sigma, in_axes=(0, 0))(x_max, cov) + return post_weighted @ normal_term + + def pr_best(self, x): + """ + Computes P[X_i > max_{j != i} X_j] for each i = 0,..., d-1 + where x is of shape (..., d). + """ + x_argmax = jnp.argmax(x, axis=-1) + compute_best = jax.vmap(lambda i: self.order == i) + return jnp.mean(compute_best(x_argmax), axis=0) + + def pps(self, data, thetas, unifs): + # estimate P(A_i | y, n, theta_0, theta_i) + def simulate_Ai(data, arm, new_data): + new_data = jnp.where( + self.diff_matrix[arm].reshape((new_data.shape[0], -1)), new_data, 0 + ) + # pool outcomes for each arm + data = data + new_data + + return self.get_posterior_difference__(data)[arm] < self.rejection_threshold + + # compute p from logit space + p_samples = jax.scipy.special.expit(thetas) + berns = unifs < p_samples[:, None] + binoms = jnp.sum(berns, axis=1) + n_arr = jnp.full_like(binoms, unifs.shape[1]) + new_data = jnp.stack((binoms, n_arr), axis=-1) + + simulate_Ai_vmapped = jax.vmap( + jax.vmap(simulate_Ai, in_axes=(None, 0, None)), + in_axes=(None, None, 0), + ) + Ai_indicators = simulate_Ai_vmapped( + data, + self.order[:-1], + new_data, + ) + out = jnp.mean(Ai_indicators, axis=0) + return out + + def pr_best_pps_common(self, data, normals, unifs): + # compute p(sigma^2 | y, n), mode, hessian for simulation + # p(sigma^2 | y, n) is up to a constant + post, x_max, hess, _ = self.posterior_sigma_sq_sim(data) + + # compute covariance of theta | data, sigma^2 for each value of self.sig2_sim. + cov = self.hessian_to_covariance(hess) + chol = jnp.linalg.cholesky(cov) + + # sample from p(sigma^2 | data) by getting the indices of self.sig2_sim. + i_star = self.sample_posterior_sigma_sq(post, unifs) + + # sample theta from p(theta | data, sigma^2) given each sigma^2 from i_star. + mean_sub = x_max[i_star] + chol_sub = chol[i_star] + thetas = ( + jax.vmap(lambda chol, n: chol @ n, in_axes=(0, 0))(chol_sub, normals) + + mean_sub + ) + + return thetas + + def pr_best_pps_1(self, data, normals, unifs_sig2, unifs): + thetas = self.pr_best_pps_common(data, normals, unifs_sig2) + pr_best_out = self.pr_best(thetas)[1:] + pps_out = self.pps(data, thetas, unifs) + return pr_best_out, pps_out + + def pps_2(self, data, normals, unifs_sig2, unifs): + thetas = self.pr_best_pps_common(data, normals, unifs_sig2) + pps_out = self.pps(data, thetas, unifs) + return pps_out + + # =========== + # Trial Logic + # =========== + + def unifs_shape(self): + """ + Helper function that returns the necessary shape of + uniform draws for a single simulation to ensure enough Binomial + samples are guaranteed. + """ + # the n-configs used to compute posterior difference + # means that we've reached the very end of the simulation + # so it's sufficient to find the max n among these n-configs. + n_max = jnp.max(self.n_configs_pd) + return (n_max, self.n_arms) + + def sample(self, berns, berns_order, berns_start, n_new_per_arm): + berns_end = berns_start + n_new_per_arm + berns_subset = jnp.where( + ((berns_order >= berns_start) & (berns_order < berns_end))[:, None], + berns, + 0, + ) + n_new = jnp.full(shape=self.n_arms, fill_value=n_new_per_arm) + y_new = jnp.sum(berns_subset, axis=0) + data_new = jnp.stack((y_new, n_new), axis=-1) + return ( + data_new, + berns_end, + ) + + def score(self, data, p): + return data[:, 0] - data[:, 1] * p + + def stage_1(self, berns, berns_order, berns_start=0): + """ + Runs a single simulation of Stage 1 of the Lei example. + + Parameters: + ----------- + berns: a 2-D array of Bernoulli(p) draws of shape (n, d) where + n is the max number of patients to enroll + and d is the total number of arms. + berns_order: result of calling jnp.arange(0, berns.shape[0]). + It is made an argument to be able to reuse this array. + berns_start: starting row position into berns to begin accumulation. + + Returns: + -------- + data, n_non_futile, non_futile_idx, pr_best, berns_start + + data: (number of arms, 2) where column 0 is the + simulated binomial data for each arm + and column 1 is the corresponding value + for the Binomial n parameter. + n_non_futile: number of non-futile treatment arms. + non_futile_idx: vector of booleans indicating whether each arm is non-futile. + pr_best: vector containing probability of + being the best arm for each arm. + It is set to jnp.nan if the arm was dropped for + futility or if the arm is control (index 0). + berns_start: the next starting position to accumulate berns. + """ + + # aliases + n_arms = berns.shape[1] + n_stage_1 = self.n_stage_1 + n_interims = self.n_stage_1_interims + n_add_per_interim = self.n_stage_1_add_per_interim + futility_threshold = self.stage_1_futility_threshold + efficacy_threshold = self.stage_1_efficacy_threshold + + # create initial data + data, berns_start = self.sample( + berns=berns, + berns_order=berns_order, + berns_start=berns_start, + n_new_per_arm=n_stage_1, + ) + + # auxiliary variables + non_dropped_idx = jnp.ones(n_arms - 1, dtype=bool) + pr_best, pps = self.get_pr_best_pps_1__(data) + + # Stage 1: + def body_func(args): + ( + i, + _, + _, + data, + _, + non_dropped_idx, + pr_best, + pps, + berns_start, + ) = args + + # get next non-dropped indices + non_dropped_idx = (pr_best >= futility_threshold) * non_dropped_idx + n_non_dropped = jnp.sum(non_dropped_idx) + + # check for futility + early_exit_futility = n_non_dropped == 0 + + # check for efficacy + n_effective = jnp.sum(pps > efficacy_threshold) + early_exit_efficacy = n_effective > 0 + + # evenly distribute the next patients across non-dropped arms + # only if we are not early stopping stage 1. + # Note: for simplicity, we remove the remainder patients. + do_add = jnp.logical_not(early_exit_futility | early_exit_efficacy) + add_idx = jnp.concatenate( + (jnp.array(True)[None], non_dropped_idx), dtype=bool + ) + add_idx = add_idx * do_add + n_new_per_arm = n_add_per_interim // (n_non_dropped + 1) + data_new, berns_start = self.sample( + berns=berns, + berns_order=berns_order, + berns_start=berns_start, + n_new_per_arm=n_new_per_arm, + ) + data_new = jnp.where(add_idx[:, None], data_new, 0) + data = data + data_new + + pr_best, pps = self.get_pr_best_pps_1__(data) + + return ( + i + 1, + early_exit_futility, + early_exit_efficacy, + data, + n_non_dropped, + non_dropped_idx, + pr_best, + pps, + berns_start, + ) + + ( + _, + early_exit_futility, + _, + data, + _, + non_dropped_idx, + _, + pps, + berns_start, + ) = jax.lax.while_loop( + lambda tup: (tup[0] < n_interims) & jnp.logical_not(tup[1] | tup[2]), + body_func, + ( + 0, + False, + False, + data, + non_dropped_idx.shape[0], + non_dropped_idx, + pr_best, + pps, + berns_start, + ), + ) + + return ( + early_exit_futility, + data, + non_dropped_idx, + pps, + berns_start, + ) + + def stage_2( + self, + data, + best_arm, + berns, + berns_order, + berns_start, + ): + """ + Runs a single simulation of stage 2 of the Lei example. + + Parameters: + ----------- + data: data in canonical form. + best_arm: treatment arm index that is chosen for stage 2. + berns: a 2-D array of Bernoulli(p) draws of shape (n, d) + where n is the max number of patients and + d is the number of arms. + berns_order: result of calling jnp.arange(0, berns.shape[0]). + It is made an argument to be able to reuse this array. + berns_start: start row position into berns to start accumulation. + + Returns: + -------- + 0 if no rejection, otherwise 1. + """ + n_stage_2 = self.n_stage_2 + n_stage_2_add_per_interim = self.n_stage_2_add_per_interim + pps_threshold_lower = self.stage_2_futility_threshold + pps_threshold_upper = self.stage_2_efficacy_threshold + rejection_threshold = self.rejection_threshold + + non_dropped_idx = (self.order == 0) | (self.order == best_arm) + + # add n_stage_2 number of patients to each + # of the control and selected treatment arms. + data_new, berns_start = self.sample( + berns=berns, + berns_order=berns_order, + berns_start=berns_start, + n_new_per_arm=n_stage_2, + ) + data_new = jnp.where(non_dropped_idx[:, None], data_new, 0) + data = data + data_new + + pps = self.get_pps_2__(data)[best_arm - 1] + + # interim: check early-stop based on futility (lower) or efficacy (upper) + early_exit_futility = pps < pps_threshold_lower + early_exit_efficacy = pps > pps_threshold_upper + early_exit = early_exit_futility | early_exit_efficacy + early_exit_out = jnp.logical_not(early_exit_futility) | early_exit_efficacy + + def final_analysis(data, berns_start): + data_new, berns_start = self.sample( + berns=berns, + berns_order=berns_order, + berns_start=berns_start, + n_new_per_arm=n_stage_2_add_per_interim, + ) + data_new = jnp.where(non_dropped_idx[:, None], data_new, 0) + data = data + data_new + rej = ( + self.get_posterior_difference__(data)[best_arm - 1] + < rejection_threshold + ) + return (rej, data) + + return jax.lax.cond( + early_exit, + lambda: (early_exit_out, data), + lambda: final_analysis(data, berns_start), + ) + + def simulate(self, p, null_truths, unifs, unifs_order): + """ + Runs a single simulation of both stage 1 and stage 2. + + Parameters: + ----------- + p: simulation grid-point. + unifs: a 2-D array of uniform draws of shape (n, d) where + n is the max number of patients to enroll + and d is the total number of arms. + unifs_order: result of calling jnp.arange(0, unifs.shape[0]). + It is made an argument to be able to reuse this array. + """ + # construct bernoulli draws + berns = unifs < p[None] + + # Stage 1: + (early_exit_futility, data, non_dropped_idx, pps, berns_start) = self.stage_1( + berns=berns, + berns_order=unifs_order, + ) + + # if early-exited because of efficacy, + # pick the best arm based on PPS along with control. + # otherwise, pick the best arm based on pr_best along with control. + best_arm_info = jnp.where(non_dropped_idx, pps, -1) + best_arm = jnp.argmax(best_arm_info) + 1 + + early_exit = early_exit_futility | ( + pps[best_arm - 1] < self.inter_stage_futility_threshold + ) + + def stage_2_wrap( + null_truths, data, p, best_arm, berns, unifs_order, berns_start + ): + rej, data = self.stage_2( + data=data, + best_arm=best_arm, + berns=berns, + berns_order=unifs_order, + berns_start=berns_start, + ) + false_rej = rej * null_truths[best_arm - 1] + score = self.score(data, p) * false_rej + return (false_rej, score) + + # Stage 2 only if no early termination based on futility + return jax.lax.cond( + early_exit, + lambda: (False, jnp.zeros(self.n_arms)), + lambda: stage_2_wrap( + null_truths=null_truths, + data=data, + p=p, + best_arm=best_arm, + berns=berns, + unifs_order=unifs_order, + berns_start=berns_start, + ), + ) diff --git a/confirm/confirm/lewislib/table.py b/confirm/confirm/lewislib/table.py new file mode 100644 index 00000000..83f7d88f --- /dev/null +++ b/confirm/confirm/lewislib/table.py @@ -0,0 +1,133 @@ +import jax.numpy as jnp + +import confirm.lewislib.jax_wrappers as jwp +from confirm.outlaw.interp import interpn + + +class BaseTable: + def __init__(self, n_sizes): + # compute mask to hash n_sizes + n_arms = n_sizes.shape[-1] + n_sizes_max = jnp.max(n_sizes) + 1 + n_sizes_max_mask = n_sizes_max ** jnp.arange(0, n_arms) + self.n_sizes_max_mask = n_sizes_max_mask.astype(int) + + # create hashes + hashes = jnp.array([self.hash_n__(ns) for ns in n_sizes]) + + # reorder data based on increasing order of hashes + self.hashes_order = jnp.argsort(hashes) + self.hashes = hashes[self.hashes_order] + + def hash_n__(self, n): + """ + Hashes the n configuration with a given mask. + + Parameters: + ----------- + n: n configuration sorted in decreasing order. + """ + return jnp.sum(n * self.n_sizes_max_mask) + + def search(self, n): + n_hash = self.hash_n__(n) + idx = jnp.searchsorted(self.hashes, n_hash) + return idx + + def hash_ordered(self, seq): + return tuple(seq[i] for i in self.hashes_order) + + +class LinearInterpTable(BaseTable): + def __init__(self, n_sizes, grids, tables): + """ + Parameters: + ----------- + n_sizes: a 2-D array of shape (n, d). + grid: a 3-D array of shape (n, d, a). + tables: a sequence of N-D arrays + each of shape (n, a^d, ...) where each slice + (a^d, ...) corresponds to values in a cartesian + product of points defined by the same slice of grid. + """ + super().__init__(n_sizes) + if not isinstance(tables, tuple): + tables = (tables,) + + self.grids = grids[self.hashes_order] + + self.tables = tuple(sub_tables[self.hashes_order] for sub_tables in tables) + + n_arms, n_points = self.grids.shape[-2:] + self.shape = tuple(n_points for _ in range(n_arms)) + + def at(self, data): + y = data[:, 0] + n = data[:, 1] + idx = self.search(n) + grid = self.grids[idx] + return tuple( + interpn(grid, values[idx].reshape(self.shape + values[idx].shape[1:]), y) + for values in self.tables + ) + + +class LookupTable(BaseTable): + def __init__( + self, + n_sizes, + tables, + ): + """ + Constructs a lookup table given a list of n sizes + and their corresponding table of values corresponding to + all enumerations of the sizes. + + Parameters: + ----------- + n_sizes: a 2-D array of shape (n, d) where n is the number + of configurations and d is the number of arms. + tables: a list of list of/list of/table of values. + If it is not a list of list of tables, + it will be converted in such a form. + In that form, tables[i] corresponds to the ith table + where tables[i][j] is a sub-table of values + corresponding to the configuration n_sizes[j]. + tables[i][j] is assumed to be of shape + (jnp.prod(n_sizes[j]), ...). + Each row is a value corresponding to a row of + a d-dimensional possible configuration y, where + 0 <= y < n_sizes, where the first index increments slowest + and the last index increments fastest. + """ + super().__init__(n_sizes) + + # force tables to be a tuple of (tuple of sub-tables) + if not isinstance(tables, tuple): + tables = ((tables,),) + if not isinstance(tables[0], tuple): + tables = (tables,) + + tables_reordered = tuple(self.hash_ordered(sub_tables) for sub_tables in tables) + self.tables = tuple( + jnp.row_stack(sub_tables) for sub_tables in tables_reordered + ) + + # reorder based on hash order + n_sizes = n_sizes[self.hashes_order] + + # compute offsets corresponding to each n_size + sizes = jnp.array([0] + [jnp.prod(ns) for ns in n_sizes]) + sizes_cumsum = jnp.cumsum(sizes) + self.offsets = sizes_cumsum[:-1] + self.sizes = sizes[1:] + + def at(self, data): + index = data[:, 0] + n = data[:, 1] + idx = self.search(n) + offset = self.offsets[idx] + size = self.sizes[idx] + slices = tuple(jwp.slice0(t, offset, offset + size) for t in self.tables) + slices_reshaped = tuple(jwp.reshape0(a, n) for a in slices) + return tuple(a[index] for a in slices_reshaped) diff --git a/confirm/confirm/outlaw/interp.py b/confirm/confirm/outlaw/interp.py new file mode 100644 index 00000000..abfc94c3 --- /dev/null +++ b/confirm/confirm/outlaw/interp.py @@ -0,0 +1,98 @@ +import jax.numpy as jnp + + +def interpn(points, values, xi): + """ + A JAX reimplementation of scipy.interpolate.interpn. Most of the input + validity checks have been removed, so make sure your inputs are correct or + go implement those checks yourself. + + In addition, the keyword arguments are: + - `method="linear"` + - `bounds_error=False` + - `fill_value=None` + + The scipy source is here: + https://github.com/scipy/scipy/blob/651a9b717deb68adde9416072c1e1d5aa14a58a1/scipy/interpolate/_rgi.py#L445-L614 + + The original docstring from scipy: + Multidimensional interpolation on regular or rectilinear grids. + + Strictly speaking, not all regular grids are supported - this function + works on *rectilinear* grids, that is, a rectangular grid with even or + uneven spacing. + + Args: + points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, ) + The points defining the regular grid in n dimensions. The points in + each dimension (i.e. every elements of the points tuple) must be + strictly ascending or descending. + values : array_like, shape (m1, ..., mn, ...) + The data on the regular grid in n dimensions. Complex data can be + acceptable. + xi : ndarray of shape (..., ndim) + The coordinates to sample the gridded data at + + Returns: + values_x : ndarray, shape xi.shape[:-1] + values.shape[ndim:] + Interpolated values at input coordinates. + """ + + grid = tuple([jnp.asarray(p) for p in points]) + indices, norm_distances = _find_indices(grid, xi) + return _evaluate_linear(grid, values, indices, norm_distances) + + +# This code is copied from scipy.interpolate.interpn and modified for working with JAX. +def _find_indices(grid, xi): + + # find relevant edges between which xi are situated + indices = [] + # compute distance to lower edge in unity units + norm_distances = [] + + for i in range(len(grid)): + g = grid[i] + idx = jnp.searchsorted(g, xi[i]) - 1 + idx = jnp.where(idx > 0, idx, 0) + idx = jnp.where(idx > g.size - 2, g.size - 2, idx) + indices = indices + [idx] + denom = g[idx + 1] - g[idx] + norm_distances = norm_distances + [ + jnp.where(denom != 0, (xi[i] - g[idx]) / denom, 0) + ] + + indices = jnp.array(indices) + norm_distances = jnp.array(norm_distances) + return indices, norm_distances + + +def _evaluate_linear(grid, values, indices, norm_distances): + d = len(grid) + # Construct the unit d-dimensional cube. + unit_cube = jnp.meshgrid(*[jnp.array([0, 1]) for i in range(d)], indexing="ij") + + # Choose the left or right index for each corner of the hypercube. these + # are 1D indices which get used in will later be used to construct the ND + # indices of each corner. + hypercube_dim_indices = [ + jnp.array([indices[i], indices[i] + 1])[unit_cube[i]] for i in range(d) + ] + # the final indices will be the unraveled ND indices produced from the 1D + # indices above. + hypercube_indices = tuple(hypercube_dim_indices[i].flatten() for i in range(d)) + + # the weights for the left and right sides of each 1D interval. + # norm_distance is the normalized distance from the left edge so the weight + # will be (1 - norm_distance) for the left edge + hypercube_dim_weights = jnp.array( + [ + jnp.array([1 - norm_distances[i], norm_distances[i]])[unit_cube[i]] + for i in range(d) + ] + ) + # the final weights will be the product of the weights for each dimension + hypercube_weights = jnp.prod(hypercube_dim_weights, axis=0).ravel() + + # finally, select the values to interpolate and multiply by the weights. + return hypercube_weights @ values[hypercube_indices] diff --git a/confirm/tests/lewis/test_hash.py b/confirm/tests/lewis/test_hash.py new file mode 100644 index 00000000..8694a944 --- /dev/null +++ b/confirm/tests/lewis/test_hash.py @@ -0,0 +1,85 @@ +import jax +import jax.numpy as jnp +import numpy as np + +from confirm.lewislib.table import LookupTable + + +default_params = { + "n_arms": 3, + "n_stage_1": 3, + "n_stage_2": 3, + "n_stage_1_interims": 1, + "n_stage_1_add_per_interim": 10, + "n_stage_2_add_per_interim": 4, + "stage_1_futility_threshold": 0.1, + "stage_2_futility_threshold": 0.1, + "stage_1_efficacy_threshold": 0.1, + "stage_2_efficacy_threshold": 0.9, + "inter_stage_futility_threshold": 0.8, + "posterior_difference_threshold": 0.05, + "rejection_threshold": 0.05, +} + + +def test_stable_sort_1(): + n = jnp.array([20, 20, 10, 20]) + order = jnp.flip(n.shape[0] - 1 - jnp.argsort(jnp.flip(n), kind="stable")) + expected = jnp.array([0, 1, 3, 2]) + assert jnp.array_equal(order, expected) + + +def test_stable_sort_2(): + n = jnp.array([30, 10, 20, 30]) + order = jnp.flip(n.shape[0] - 1 - jnp.argsort(jnp.flip(n), kind="stable")) + expected = jnp.array([0, 3, 2, 1]) + assert jnp.array_equal(order, expected) + + +def test_hash_undo(): + n = jnp.array([12, 5, 5, 12]) + n_order = jnp.flip(n.shape[0] - 1 - jnp.argsort(jnp.flip(n), kind="stable")) + + # test if this piece of code gives us the correct undoing + actual = jnp.argsort(n_order) + expected = jnp.array([0, 2, 3, 1]) + + assert jnp.array_equal(actual, expected) + + +def test_y_to_index(): + n = np.array([5, 5, 5, 2]) + max_idx = np.prod(n + 1) + y = np.zeros(n.shape[0]) + + def increment(y): + carry = 1 + for i in range(n.shape[-1] - 1, -1, -1): + y[i] += carry + carry = y[i] // (n[i] + 1) + y[i] -= carry * (n[i] + 1) + return y + + for i in range(max_idx): + actual = y[-1] + jnp.sum(jnp.flip(y[:-1]) * jnp.cumprod(jnp.flip(n[1:] + 1))) + expected = i + assert jnp.array_equal(actual, expected) + y = increment(y) + + +def test_hash(): + n = jnp.array([12, 5, 5, 12]) + dims = n + 1 + values = jnp.arange(0, jnp.prod(dims))[:, None] + table = LookupTable(dims[None], values) + y = jnp.array([5, 1, 0, 10]) + + # tests if the at function is jit-able also. + @jax.jit + def internal(): + data = jnp.stack((y, dims), axis=-1) + return table.at(data)[0].squeeze() + + actual = internal() + expected = 2428 + assert jnp.array_equal(actual, expected) diff --git a/confirm/tests/lewis/test_n_configs.py b/confirm/tests/lewis/test_n_configs.py new file mode 100644 index 00000000..26690d79 --- /dev/null +++ b/confirm/tests/lewis/test_n_configs.py @@ -0,0 +1,235 @@ +import numpy as np + +from confirm.lewislib import lewis + + +default_params = { + "n_arms": 3, + "n_stage_1": 10, + "n_stage_2": 10, + "n_stage_1_interims": 3, + "n_stage_1_add_per_interim": 4, + "n_stage_2_add_per_interim": 4, + "stage_1_futility_threshold": 0.1, + "stage_2_futility_threshold": 0.1, + "stage_1_efficacy_threshold": 0.9, + "stage_2_efficacy_threshold": 0.9, + "inter_stage_futility_threshold": 0.8, + "posterior_difference_threshold": 0.05, + "rejection_threshold": 0.05, +} + + +def run_n_configs_test(actual, expected): + def check_equality(actual, expected): + # lexicographical sort to order the rows consistently + actual_sorted = np.lexsort(actual) + expected_sorted = np.lexsort(expected) + assert np.array_equal(actual_sorted, expected_sorted) + + for n_configs, n_configs_expected in zip(actual, expected): + check_equality(n_configs, n_configs_expected) + + +def make_expected( + n_configs_pr_best_pps_1_expected, + n_stage_2, + n_stage_2_add_per_interim, +): + n_configs_pps_2_expected = np.copy(n_configs_pr_best_pps_1_expected) + n_configs_pps_2_expected[:, :2] += n_stage_2 + n_configs_pd_expected = np.copy(n_configs_pps_2_expected) + n_configs_pd_expected[:, :2] += n_stage_2_add_per_interim + expected = ( + n_configs_pr_best_pps_1_expected, + n_configs_pps_2_expected, + n_configs_pd_expected, + ) + return expected + + +def test_3_arms_0_interim(): + # re-setting parameters to make it clear which parameters affect this function. + default_params["n_arms"] = 3 + default_params["n_stage_1"] = 10 + default_params["n_stage_2"] = 10 + default_params["n_stage_1_interims"] = 0 + default_params["n_stage_1_add_per_interim"] = 4 + default_params["n_stage_2_add_per_interim"] = 4 + + lewis_obj = lewis.Lewis45(**default_params) + actual = lewis_obj.make_n_configs__() + + # expected values + n_configs_pr_best_pps_1_expected = np.array( + [ + [10, 10, 10], + ] + ) + expected = make_expected( + n_configs_pr_best_pps_1_expected, + default_params["n_stage_2"], + default_params["n_stage_2_add_per_interim"], + ) + + # tests + run_n_configs_test(actual, expected) + + +def test_3_arms_1_interim(): + # re-setting parameters to make it clear which parameters affect this function. + default_params["n_arms"] = 3 + default_params["n_stage_1"] = 10 + default_params["n_stage_2"] = 10 + default_params["n_stage_1_interims"] = 1 + default_params["n_stage_1_add_per_interim"] = 4 + default_params["n_stage_2_add_per_interim"] = 4 + + lewis_obj = lewis.Lewis45(**default_params) + actual = lewis_obj.make_n_configs__() + + # expected values + n_configs_pr_best_pps_1_expected = np.array( + [ + [10, 10, 10], + [11, 11, 11], + [12, 12, 10], + ] + ) + expected = make_expected( + n_configs_pr_best_pps_1_expected, + default_params["n_stage_2"], + default_params["n_stage_2_add_per_interim"], + ) + + # tests + run_n_configs_test(actual, expected) + + +def test_3_arms_2_interim(): + # re-setting parameters to make it clear which parameters affect this function. + default_params["n_arms"] = 3 + default_params["n_stage_1"] = 5 + default_params["n_stage_2"] = 15 + default_params["n_stage_1_interims"] = 2 + default_params["n_stage_1_add_per_interim"] = 7 + default_params["n_stage_2_add_per_interim"] = 4 + + lewis_obj = lewis.Lewis45(**default_params) + actual = lewis_obj.make_n_configs__() + + # expected values + n_configs_pr_best_pps_1_expected = np.array( + [ + [5, 5, 5], + [7, 7, 7], + [8, 8, 5], + [9, 9, 9], + [10, 10, 7], + [11, 11, 5], + ] + ) + expected = make_expected( + n_configs_pr_best_pps_1_expected, + default_params["n_stage_2"], + default_params["n_stage_2_add_per_interim"], + ) + + # tests + run_n_configs_test(actual, expected) + + +def test_4_arms_0_interim(): + # re-setting parameters to make it clear which parameters affect this function. + default_params["n_arms"] = 4 + default_params["n_stage_1"] = 10 + default_params["n_stage_2"] = 10 + default_params["n_stage_1_interims"] = 0 + default_params["n_stage_1_add_per_interim"] = 4 + default_params["n_stage_1_add_per_interim"] = 10 + + lewis_obj = lewis.Lewis45(**default_params) + actual = lewis_obj.make_n_configs__() + + # expected values + n_configs_pr_best_pps_1_expected = np.array( + [ + [10, 10, 10, 10], + ] + ) + expected = make_expected( + n_configs_pr_best_pps_1_expected, + default_params["n_stage_2"], + default_params["n_stage_2_add_per_interim"], + ) + + # tests + run_n_configs_test(actual, expected) + + +def test_4_arms_1_interim(): + # re-setting parameters to make it clear which parameters affect this function. + default_params["n_arms"] = 4 + default_params["n_stage_1"] = 10 + default_params["n_stage_2"] = 10 + default_params["n_stage_1_interims"] = 1 + default_params["n_stage_1_add_per_interim"] = 4 + default_params["n_stage_2_add_per_interim"] = 20 + + lewis_obj = lewis.Lewis45(**default_params) + actual = lewis_obj.make_n_configs__() + + # expected values + n_configs_pr_best_pps_1_expected = np.array( + [ + [10, 10, 10, 10], + [11, 11, 11, 11], + [11, 11, 11, 10], + [12, 12, 10, 10], + ] + ) + expected = make_expected( + n_configs_pr_best_pps_1_expected, + default_params["n_stage_2"], + default_params["n_stage_2_add_per_interim"], + ) + + # tests + run_n_configs_test(actual, expected) + + +def test_4_arms_2_interim(): + # re-setting parameters to make it clear which parameters affect this function. + default_params["n_arms"] = 4 + default_params["n_stage_1"] = 10 + default_params["n_stage_2"] = 10 + default_params["n_stage_1_interims"] = 2 + default_params["n_stage_1_add_per_interim"] = 4 + default_params["n_stage_2_add_per_interim"] = 1 + + lewis_obj = lewis.Lewis45(**default_params) + actual = lewis_obj.make_n_configs__() + + # expected values + n_configs_pr_best_pps_1_expected = np.array( + [ + [10, 10, 10, 10], + [11, 11, 11, 11], + [11, 11, 11, 10], + [12, 12, 10, 10], + [12, 12, 12, 12], + [12, 12, 12, 11], + [13, 13, 11, 11], + [12, 12, 12, 10], + [13, 13, 11, 10], + [14, 14, 10, 10], + ] + ) + expected = make_expected( + n_configs_pr_best_pps_1_expected, + default_params["n_stage_2"], + default_params["n_stage_2_add_per_interim"], + ) + + # tests + run_n_configs_test(actual, expected) diff --git a/confirm/tests/lewis/test_permute_invariance.py b/confirm/tests/lewis/test_permute_invariance.py new file mode 100644 index 00000000..ccb8e6f1 --- /dev/null +++ b/confirm/tests/lewis/test_permute_invariance.py @@ -0,0 +1,86 @@ +import jax +import jax.numpy as jnp + +from confirm.lewislib import lewis + + +default_params = { + "n_arms": 3, + "n_stage_1": 3, + "n_stage_2": 3, + "n_stage_1_interims": 1, + "n_stage_1_add_per_interim": 10, + "n_stage_2_add_per_interim": 4, + "stage_1_futility_threshold": 0.1, + "stage_2_futility_threshold": 0.1, + "stage_1_efficacy_threshold": 0.1, + "stage_2_efficacy_threshold": 0.9, + "inter_stage_futility_threshold": 0.8, + "posterior_difference_threshold": 0.05, + "rejection_threshold": 0.05, +} + + +def test_posterior_difference_permute(): + lewis_obj = lewis.Lewis45(**default_params) + n = default_params["n_stage_1"] + + data_1 = jnp.array( + [ + [0, n], + [1, n], + [2, n], + ] + ) + out_1 = lewis_obj.posterior_difference(data_1) + + data_2 = jnp.array( + [ + [0, n], + [2, n], + [1, n], + ] + ) + out_2 = lewis_obj.posterior_difference(data_2) + + permute = jnp.array([1, 0]) + + assert jnp.allclose(out_1, out_2[permute]) + + +def test_pr_best_permute(): + lewis_obj = lewis.Lewis45(**default_params) + key = jax.random.PRNGKey(10) + + thetas_1 = jax.random.normal(key=key, shape=(100, 3)) + out_1 = lewis_obj.pr_best(thetas_1) + + permute = jnp.array([1, 0, 2]) + + thetas_2 = thetas_1[:, permute] + out_2 = lewis_obj.pr_best(thetas_2) + + assert jnp.allclose(out_1, out_2[permute]) + + +def test_pps_permute(): + lewis_obj = lewis.Lewis45(**default_params) + lewis_obj.pd_table = lewis_obj.posterior_difference_table__(batch_size=int(2**16)) + + n = default_params["n_stage_1"] + key = jax.random.PRNGKey(10) + + data_1 = jnp.array([[0, n], [1, n], [2, n]]) + thetas_1 = jax.random.normal(key=key, shape=(100, 3)) + _, key = jax.random.split(key) + unifs_1 = jax.random.uniform(key=key, shape=(100, 10, 3)) + out_1 = lewis_obj.pps(data_1, thetas_1, unifs_1) + + permute = jnp.array([0, 2, 1]) + + data_2 = data_1[permute] + thetas_2 = thetas_1[:, permute] + unifs_2 = unifs_1[..., permute] + out_2 = lewis_obj.pps(data_2, thetas_2, unifs_2) + + assert jnp.allclose(out_1, out_2[permute[1:] - 1]) diff --git a/confirm/tests/lewis/test_posterior_difference.py b/confirm/tests/lewis/test_posterior_difference.py new file mode 100644 index 00000000..7936d02d --- /dev/null +++ b/confirm/tests/lewis/test_posterior_difference.py @@ -0,0 +1,33 @@ +import jax.numpy as jnp + +from confirm.lewislib import lewis + + +default_params = { + "n_arms": 3, + "n_stage_1": 3, + "n_stage_2": 3, + "n_stage_1_interims": 1, + "n_stage_1_add_per_interim": 10, + "n_stage_2_add_per_interim": 4, + "stage_1_futility_threshold": 0.1, + "stage_2_futility_threshold": 0.1, + "stage_1_efficacy_threshold": 0.1, + "stage_2_efficacy_threshold": 0.9, + "inter_stage_futility_threshold": 0.8, + "posterior_difference_threshold": 0.05, + "rejection_threshold": 0.05, +} + + +def test_get_posterior_difference(): + lewis_obj = lewis.Lewis45(**default_params) + lewis_obj.pd_table = lewis_obj.posterior_difference_table__(batch_size=int(2**16)) + n = lewis_obj.n_configs_pd[2] + y = jnp.array([5, 1, 2]) + data = jnp.stack((y, n), axis=-1) + out_1 = lewis_obj.get_posterior_difference__(data) + permute = jnp.array([0, 2, 1]) + data_2 = data[permute] + out_2 = lewis_obj.get_posterior_difference__(data_2) + assert jnp.array_equal(out_1, out_2[permute[1:] - 1]) diff --git a/confirm/tests/lewis/test_simulation.py b/confirm/tests/lewis/test_simulation.py new file mode 100644 index 00000000..a5b541bb --- /dev/null +++ b/confirm/tests/lewis/test_simulation.py @@ -0,0 +1,77 @@ +import jax +import jax.numpy as jnp + +from confirm.lewislib import lewis + + +default_params = { + "n_arms": 3, + "n_stage_1": 1, + "n_stage_2": 1, + "n_stage_1_interims": 2, + "n_stage_1_add_per_interim": 3, + "n_stage_2_add_per_interim": 1, + "stage_1_futility_threshold": 0.2, + "stage_2_futility_threshold": 0.1, + "stage_1_efficacy_threshold": 0.9, + "stage_2_efficacy_threshold": 0.9, + "inter_stage_futility_threshold": 0.8, + "posterior_difference_threshold": 0.05, + "rejection_threshold": 0.05, + "batch_size": 2**16, + "key": jax.random.PRNGKey(1), + "n_pr_sims": 100, + "n_sig2_sims": 20, + "cache_tables": True, +} + +key = jax.random.PRNGKey(0) +lewis_obj = lewis.Lewis45(**default_params) +unifs = jax.random.uniform(key=key, shape=lewis_obj.unifs_shape()) +p = jnp.array([0.25, 0.5, 0.75]) +berns = unifs < p[None] +berns_order = jnp.arange(0, berns.shape[0]) + + +def test_stage_1(): + # actual + ( + early_exit_futility, + data, + non_dropped_idx, + pps, + berns_start, + ) = lewis_obj.stage_1(berns, berns_order) + + # expected + early_exit_futility_expected = False + data_expected = jnp.array([[0, 3], [0, 1], [2, 3]], dtype=int) + non_dropped_idx_expected = jnp.array([False, True]) + _, pps_expected = lewis_obj.get_pr_best_pps_1__(data_expected) + berns_start_expected = 3 + + # test + assert jnp.array_equal(early_exit_futility, early_exit_futility_expected) + assert jnp.array_equal(data, data_expected) + assert jnp.array_equal(non_dropped_idx, non_dropped_idx_expected) + assert jnp.array_equal(pps, pps_expected) + assert jnp.array_equal(berns_start, berns_start_expected) + + +def test_stage_2(): + # expected stage 1 + data = jnp.array([[1, 3], [0, 1], [2, 3]], dtype=int) + best_arm = 2 + berns_start = 3 + + # actual stage 2 + rej, _ = lewis_obj.stage_2(data, best_arm, berns, berns_order, berns_start) + + # test + assert jnp.array_equal(rej, False) + + +def test_inter_stage(): + null_truths = jnp.zeros(default_params["n_arms"] - 1, dtype=bool) + rej, _ = lewis_obj.simulate(p, null_truths, unifs, berns_order) + assert jnp.array_equal(rej, False) diff --git a/confirm/tests/test_interp.py b/confirm/tests/test_interp.py new file mode 100644 index 00000000..7d5ad48e --- /dev/null +++ b/confirm/tests/test_interp.py @@ -0,0 +1,29 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import scipy.interpolate + +from confirm.outlaw.interp import interpn + + +def test_interpn(): + grid = jnp.array([[0, 1], [0, 1]]) + values = jnp.array([[0, 1], [2, 3]]) + xi = jnp.array([[0.5, 0.5], [1.0, 0.0]]) + result = jax.vmap(interpn, in_axes=(None, None, 0))(grid, values, xi) + np.testing.assert_allclose(result, [1.5, 2.0]) + + +@pytest.mark.parametrize("dim", [1, 3]) +def test_against_scipy_multi_value(dim): + for i in range(3): + np.random.seed(10) + grid = [np.sort(np.random.uniform(size=10)) for _ in range(2)] + values = jnp.array(np.random.uniform(size=(10, 10, dim)).squeeze()) + xi = np.random.uniform(size=(10, 2)) + result = jax.vmap(interpn, in_axes=(None, None, 0))(grid, values, xi) + scipy_result = scipy.interpolate.interpn( + grid, values, xi, method="linear", bounds_error=False, fill_value=None + ) + np.testing.assert_allclose(result, scipy_result) diff --git a/imprint/.vscode/build.sh b/imprint/.vscode/build.sh index dbb89470..43dadd87 100755 --- a/imprint/.vscode/build.sh +++ b/imprint/.vscode/build.sh @@ -1,5 +1,5 @@ #!/bin/bash eval "$(conda shell.bash hook)" conda activate imprint -bazel build //python:pyimprint/core.so -ln -sf ./bazel-bin/python/pyimprint/core.so python/pyimprint/core.so \ No newline at end of file +bazel build -c opt --config gcc //python:pyimprint/core.so +cp -f ./bazel-bin/python/pyimprint/core.so python/pyimprint/ diff --git a/imprint/frontend/.gitignore b/imprint/frontend/.gitignore index 95004389..d4bfd63c 100644 --- a/imprint/frontend/.gitignore +++ b/imprint/frontend/.gitignore @@ -44,3 +44,6 @@ yarn-error.log* npm-debug.log* yarn-debug.log* yarn-error.log* + +# local folders +my-app/ diff --git a/imprint/frontend/tsconfig.json b/imprint/frontend/tsconfig.json index a273b0cf..f199ca8f 100644 --- a/imprint/frontend/tsconfig.json +++ b/imprint/frontend/tsconfig.json @@ -18,7 +18,7 @@ "resolveJsonModule": true, "isolatedModules": true, "noEmit": true, - "jsx": "react-jsx" + "jsx": "preserve" }, "include": [ "src" diff --git a/install.sh b/install.sh index 9bc63458..04508cce 100755 --- a/install.sh +++ b/install.sh @@ -27,4 +27,4 @@ fi if [[ -n "$CONFIRM_IMPRINT_SSH" ]]; then git remote add -f imprint git@github.com:Confirm-Solutions/imprint.git -fi \ No newline at end of file +fi diff --git a/research/berry/berry_part1.ipynb b/research/berry/berry_part1.ipynb index 4cc613e4..7ae8629a 100644 --- a/research/berry/berry_part1.ipynb +++ b/research/berry/berry_part1.ipynb @@ -1243,14 +1243,11 @@ } ], "metadata": { - "interpreter": { - "hash": "a9637099bd81b2ef0895c64d539356b45819bc945d59d426757b1f51ae370d50" - }, "jupytext": { "formats": "ipynb,md" }, "kernelspec": { - "display_name": "Python 3.10.2 ('imprint')", + "display_name": "Python 3.10.5 ('confirm')", "language": "python", "name": "python3" }, @@ -1264,7 +1261,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.10.5" + }, + "vscode": { + "interpreter": { + "hash": "b4c6ec5b2d6c7b38df115d547b82cd53ca25eea58d87299956d35a9dc79f19f1" + } } }, "nbformat": 4, diff --git a/research/berry/berry_part1.md b/research/berry/berry_part1.md index ac0ce915..b70a312c 100644 --- a/research/berry/berry_part1.md +++ b/research/berry/berry_part1.md @@ -8,7 +8,7 @@ jupyter: format_version: '1.3' jupytext_version: 1.13.8 kernelspec: - display_name: Python 3.10.2 ('imprint') + display_name: Python 3.10.5 ('confirm') language: python name: python3 --- diff --git a/research/lei/.gitignore b/research/lei/.gitignore new file mode 100644 index 00000000..afed0735 --- /dev/null +++ b/research/lei/.gitignore @@ -0,0 +1 @@ +*.csv diff --git a/research/lei/analyze/analyze.ipynb b/research/lei/analyze/analyze.ipynb new file mode 100644 index 00000000..b725efb8 --- /dev/null +++ b/research/lei/analyze/analyze.ipynb @@ -0,0 +1,321 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Analyze Upper Bound of Type I Error for Lei Example" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import os\n", + "import numpy as np\n", + "from confirm.mini_imprint import grid\n", + "from confirm.lewislib import grid as lewgrid\n", + "from confirm.lewislib import lewis\n", + "from confirm.mini_imprint import binomial" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration used during simulation\n", + "params = {\n", + " \"n_arms\" : 4,\n", + " \"n_stage_1\" : 50,\n", + " \"n_stage_2\" : 100,\n", + " \"n_stage_1_interims\" : 2,\n", + " \"n_stage_1_add_per_interim\" : 100,\n", + " \"n_stage_2_add_per_interim\" : 100,\n", + " \"stage_1_futility_threshold\" : 0.15,\n", + " \"stage_1_efficacy_threshold\" : 0.7,\n", + " \"stage_2_futility_threshold\" : 0.2,\n", + " \"stage_2_efficacy_threshold\" : 0.95,\n", + " \"inter_stage_futility_threshold\" : 0.6,\n", + " \"posterior_difference_threshold\" : 0,\n", + " \"rejection_threshold\" : 0.05,\n", + " \"key\" : jax.random.PRNGKey(0),\n", + " \"n_pr_sims\" : 100,\n", + " \"n_sig2_sims\" : 20,\n", + " \"batch_size\" : int(2**20),\n", + " \"cache_tables\" : False,\n", + "}\n", + "size = 52\n", + "n_sim_batches = 500\n", + "sim_batch_size = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# construct Lei object\n", + "lei_obj = lewis.Lewis45(**params)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# construct the same grid used during simulation\n", + "n_arms = params['n_arms']\n", + "lower = np.full(n_arms, -1)\n", + "upper = np.full(n_arms, 1)\n", + "thetas, radii = lewgrid.make_cartesian_grid_range(\n", + " size=size,\n", + " lower=lower,\n", + " upper=upper,\n", + ")\n", + "ns = np.concatenate(\n", + " [np.ones(n_arms-1)[:, None], -np.eye(n_arms-1)],\n", + " axis=-1,\n", + ")\n", + "null_hypos = [\n", + " grid.HyperPlane(n, 0)\n", + " for n in ns\n", + "]\n", + "gr = grid.build_grid(\n", + " thetas=thetas,\n", + " radii=radii,\n", + " null_hypos=null_hypos,\n", + ")\n", + "gr = grid.prune(gr)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# construct tile informations used during simulation\n", + "theta_tiles = gr.thetas[gr.grid_pt_idx]\n", + "p_tiles = jax.scipy.special.expit(theta_tiles)\n", + "tile_radii = gr.radii[gr.grid_pt_idx]\n", + "null_truths = gr.null_truth.astype(bool)\n", + "sim_size = 2 * n_sim_batches * sim_batch_size # 2 instances parallelized\n", + "sim_sizes = np.full(gr.n_tiles, sim_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# get type I sum and score\n", + "cwd = '.'\n", + "data_dir = os.path.join(cwd, '../data')\n", + "output_dir = os.path.join(data_dir, 'output_1')\n", + "typeI_sum = np.loadtxt(os.path.join(output_dir, 'typeI_sum.csv'), delimiter=',')\n", + "typeI_score = np.loadtxt(os.path.join(output_dir, 'typeI_score.csv'), delimiter=',')\n", + "output_dir = os.path.join(data_dir, 'output_2')\n", + "typeI_sum += np.loadtxt(os.path.join(output_dir, 'typeI_sum.csv'), delimiter=',')\n", + "typeI_score += np.loadtxt(os.path.join(output_dir, 'typeI_score.csv'), delimiter=',')" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "delta = 0.025\n", + "n_arm_samples = int(lei_obj.unifs_shape()[0])\n", + "tile_corners = gr.vertices" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "# construct Holder upper bound\n", + "d0, d0u = binomial.zero_order_bound(\n", + " typeI_sum=typeI_sum, \n", + " sim_sizes=sim_sizes, \n", + " delta=delta, \n", + " delta_prop_0to1=1,\n", + ")\n", + "typeI_bound = d0 + d0u\n", + "\n", + "total_holder = binomial.holder_odi_bound(\n", + " typeI_bound=typeI_bound, \n", + " theta_tiles=theta_tiles,\n", + " tile_corners=tile_corners,\n", + " n_arm_samples=n_arm_samples, \n", + " holderq=16,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# construct classical upper bound\n", + "total, d0, d0u, d1w, d1uw, d2uw = binomial.upper_bound(\n", + " theta_tiles,\n", + " tile_radii,\n", + " gr.vertices,\n", + " sim_sizes,\n", + " n_arm_samples,\n", + " typeI_sum,\n", + " typeI_score,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare bound components\n", + "\n", + "# classical\n", + "bound_components = np.array([\n", + " d0,\n", + " d0u,\n", + " d1w,\n", + " d1uw,\n", + " d2uw,\n", + " total,\n", + "]).T\n", + "\n", + "# holder\n", + "dummy = np.zeros_like(d0)\n", + "bound_components_holder = np.array([\n", + " d0,\n", + " d0u,\n", + " dummy,\n", + " dummy,\n", + " dummy,\n", + " total_holder,\n", + "]).T" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([-0.98076923, -0.94230769, -0.90384615, -0.86538462, -0.82692308,\n", + " -0.78846154, -0.75 , -0.71153846, -0.67307692, -0.63461538,\n", + " -0.59615385, -0.55769231, -0.51923077, -0.48076923, -0.44230769,\n", + " -0.40384615, -0.36538462, -0.32692308, -0.28846154, -0.25 ,\n", + " -0.21153846, -0.17307692, -0.13461538, -0.09615385, -0.05769231,\n", + " -0.01923077, 0.01923077, 0.05769231, 0.09615385, 0.13461538,\n", + " 0.17307692, 0.21153846, 0.25 , 0.28846154, 0.32692308,\n", + " 0.36538462, 0.40384615, 0.44230769, 0.48076923, 0.51923077,\n", + " 0.55769231, 0.59615385, 0.63461538, 0.67307692, 0.71153846,\n", + " 0.75 , 0.78846154, 0.82692308, 0.86538462, 0.90384615,\n", + " 0.94230769, 0.98076923]),\n", + " array([-0.98076923, -0.94230769, -0.90384615, -0.86538462, -0.82692308,\n", + " -0.78846154, -0.75 , -0.71153846, -0.67307692, -0.63461538,\n", + " -0.59615385, -0.55769231, -0.51923077, -0.48076923, -0.44230769,\n", + " -0.40384615, -0.36538462, -0.32692308, -0.28846154, -0.25 ,\n", + " -0.21153846, -0.17307692, -0.13461538, -0.09615385, -0.05769231,\n", + " -0.01923077, 0.01923077, 0.05769231, 0.09615385, 0.13461538,\n", + " 0.17307692, 0.21153846, 0.25 , 0.28846154, 0.32692308,\n", + " 0.36538462, 0.40384615, 0.44230769, 0.48076923, 0.51923077,\n", + " 0.55769231, 0.59615385, 0.63461538, 0.67307692, 0.71153846,\n", + " 0.75 , 0.78846154, 0.82692308, 0.86538462, 0.90384615,\n", + " 0.94230769, 0.98076923]))" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t2_uniques = np.unique(theta_tiles[:, 2])\n", + "t3_uniques = np.unique(theta_tiles[:, 3])\n", + "t2_uniques, t3_uniques" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "# slice and save P, B\n", + "t2 = t2_uniques[25]\n", + "t3 = t3_uniques[20]\n", + "selection = (theta_tiles[:, 2] == t2) & (theta_tiles[:, 3] == t3)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "bound_dir = os.path.join(data_dir, 'bound')\n", + "if not os.path.exists(bound_dir):\n", + " os.makedirs(bound_dir)\n", + "\n", + "np.savetxt(f'{bound_dir}/P_lei.csv', theta_tiles[selection, :].T, fmt=\"%s\", delimiter=\",\")\n", + "np.savetxt(f'{bound_dir}/B_lei.csv', bound_components[selection, :], fmt=\"%s\", delimiter=\",\")\n", + "np.savetxt(f'{bound_dir}/B_lei_holder.csv', bound_components_holder[selection, :], fmt=\"%s\", delimiter=\",\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.5 ('confirm')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d8e1ca1b3fede25e3995e2b26ea544fa1b75b9a17984e6284a43c1dc286640dd" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/research/lei/analyze/analyze.md b/research/lei/analyze/analyze.md new file mode 100644 index 00000000..7e1a307b --- /dev/null +++ b/research/lei/analyze/analyze.md @@ -0,0 +1,196 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 + kernelspec: + display_name: Python 3.10.5 ('confirm') + language: python + name: python3 +--- + +# Analyze Upper Bound of Type I Error for Lei Example + +```python +%load_ext autoreload +%autoreload 2 +``` + +```python +import jax +import os +import numpy as np +from confirm.mini_imprint import grid +from confirm.lewislib import grid as lewgrid +from confirm.lewislib import lewis +from confirm.mini_imprint import binomial +``` + +```python +# Configuration used during simulation +params = { + "n_arms" : 4, + "n_stage_1" : 50, + "n_stage_2" : 100, + "n_stage_1_interims" : 2, + "n_stage_1_add_per_interim" : 100, + "n_stage_2_add_per_interim" : 100, + "stage_1_futility_threshold" : 0.15, + "stage_1_efficacy_threshold" : 0.7, + "stage_2_futility_threshold" : 0.2, + "stage_2_efficacy_threshold" : 0.95, + "inter_stage_futility_threshold" : 0.6, + "posterior_difference_threshold" : 0, + "rejection_threshold" : 0.05, + "key" : jax.random.PRNGKey(0), + "n_pr_sims" : 100, + "n_sig2_sims" : 20, + "batch_size" : int(2**20), + "cache_tables" : False, +} +size = 52 +n_sim_batches = 500 +sim_batch_size = 100 +``` + +```python +# construct Lei object +lei_obj = lewis.Lewis45(**params) +``` + +```python +# construct the same grid used during simulation +n_arms = params['n_arms'] +lower = np.full(n_arms, -1) +upper = np.full(n_arms, 1) +thetas, radii = lewgrid.make_cartesian_grid_range( + size=size, + lower=lower, + upper=upper, +) +ns = np.concatenate( + [np.ones(n_arms-1)[:, None], -np.eye(n_arms-1)], + axis=-1, +) +null_hypos = [ + grid.HyperPlane(n, 0) + for n in ns +] +gr = grid.build_grid( + thetas=thetas, + radii=radii, + null_hypos=null_hypos, +) +gr = grid.prune(gr) +``` + +```python +# construct tile informations used during simulation +theta_tiles = gr.thetas[gr.grid_pt_idx] +p_tiles = jax.scipy.special.expit(theta_tiles) +tile_radii = gr.radii[gr.grid_pt_idx] +null_truths = gr.null_truth.astype(bool) +sim_size = 2 * n_sim_batches * sim_batch_size # 2 instances parallelized +sim_sizes = np.full(gr.n_tiles, sim_size) +``` + +```python +# get type I sum and score +cwd = '.' +data_dir = os.path.join(cwd, '../data') +output_dir = os.path.join(data_dir, 'output_1') +typeI_sum = np.loadtxt(os.path.join(output_dir, 'typeI_sum.csv'), delimiter=',') +typeI_score = np.loadtxt(os.path.join(output_dir, 'typeI_score.csv'), delimiter=',') +output_dir = os.path.join(data_dir, 'output_2') +typeI_sum += np.loadtxt(os.path.join(output_dir, 'typeI_sum.csv'), delimiter=',') +typeI_score += np.loadtxt(os.path.join(output_dir, 'typeI_score.csv'), delimiter=',') +``` + +```python +delta = 0.025 +n_arm_samples = int(lei_obj.unifs_shape()[0]) +tile_corners = gr.vertices +``` + +```python +# construct Holder upper bound +d0, d0u = binomial.zero_order_bound( + typeI_sum=typeI_sum, + sim_sizes=sim_sizes, + delta=delta, + delta_prop_0to1=1, +) +typeI_bound = d0 + d0u + +total_holder = binomial.holder_odi_bound( + typeI_bound=typeI_bound, + theta_tiles=theta_tiles, + tile_corners=tile_corners, + n_arm_samples=n_arm_samples, + holderq=16, +) +``` + +```python +# construct classical upper bound +total, d0, d0u, d1w, d1uw, d2uw = binomial.upper_bound( + theta_tiles, + tile_radii, + gr.vertices, + sim_sizes, + n_arm_samples, + typeI_sum, + typeI_score, +) +``` + +```python +# prepare bound components + +# classical +bound_components = np.array([ + d0, + d0u, + d1w, + d1uw, + d2uw, + total, +]).T + +# holder +dummy = np.zeros_like(d0) +bound_components_holder = np.array([ + d0, + d0u, + dummy, + dummy, + dummy, + total_holder, +]).T +``` + +```python +t2_uniques = np.unique(theta_tiles[:, 2]) +t3_uniques = np.unique(theta_tiles[:, 3]) +t2_uniques, t3_uniques +``` + +```python +# slice and save P, B +t2 = t2_uniques[25] +t3 = t3_uniques[20] +selection = (theta_tiles[:, 2] == t2) & (theta_tiles[:, 3] == t3) +``` + +```python +bound_dir = os.path.join(data_dir, 'bound') +if not os.path.exists(bound_dir): + os.makedirs(bound_dir) + +np.savetxt(f'{bound_dir}/P_lei.csv', theta_tiles[selection, :].T, fmt="%s", delimiter=",") +np.savetxt(f'{bound_dir}/B_lei.csv', bound_components[selection, :], fmt="%s", delimiter=",") +np.savetxt(f'{bound_dir}/B_lei_holder.csv', bound_components_holder[selection, :], fmt="%s", delimiter=",") +``` diff --git a/research/lei/analyze/download_data.sh b/research/lei/analyze/download_data.sh new file mode 100755 index 00000000..62caf74e --- /dev/null +++ b/research/lei/analyze/download_data.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# directory where current shell script resides +PROJECTDIR=$(dirname "$BASH_SOURCE") +cd "$PROJECTDIR" +cd .. +mkdir -p data +cd data +aws s3 cp s3://imprint-dump/output_lei4d/ output_1/ --recursive +aws s3 cp s3://imprint-dump/output_lei4d2/ output_2/ --recursive \ No newline at end of file diff --git a/research/lei/lei.ipynb b/research/lei/lei.ipynb new file mode 100644 index 00000000..3bb3f140 --- /dev/null +++ b/research/lei/lei.ipynb @@ -0,0 +1,1177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import confirm.outlaw\n", + "import confirm.outlaw.berry as berry\n", + "import confirm.outlaw.quad as quad\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "import jax\n", + "import time\n", + "import confirm.outlaw.inla as inla\n", + "import matplotlib.pyplot as plt\n", + "import numpyro.distributions as dist\n", + "from functools import partial\n", + "from itertools import combinations\n", + "\n", + "from confirm.lewislib import lewis\n", + "from confirm.lewislib import batch\n", + "from confirm.mini_imprint import grid\n", + "from confirm.lewislib import grid as lewgrid\n", + "from confirm.mini_imprint import binomial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Lei Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following description is a clinical trial design using a Bayesian model with early-stopping rules for futility or efficacy of a drug.\n", + "This design was explicitly requested to be studied by an FDA member (Lei) in the CID team." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> The following is a randomized, double-blind, placebo-controlled two-stage adaptive design intended to identify an optimal treatment regimen \n", + "> from three possible regimens (for example, different dosages or different combinations of agents) and \n", + "> to assess the efficacy of that regimen with respect to a primary binary response endpoint measured at month 6.\n", + "> \n", + "> In Stage 1, one of four experimental regimens will be selected, or the trial will stop for futility. \n", + "> In this stage, a minimum of 200 and a maximum of 400 will be randomized 1:1:1:1 to one of the three experimental arms or one placebo arm. \n", + "> Interim analyses will be conducted after 200, 300 and 400 subjects have been enrolled to select the best experimental regimen and to potentially stop \n", + "> the trial for futility. \n", + "> If an experimental regimen is dropped for futility at an interim analysis, \n", + "> the next 100 subjects to be randomized will be allocated equally among the remaining arms in the study. \n", + "> At each of these three analysis time points (N = 200, 300, 400), \n", + "> the probabilities of being the best regimen (PrBest) and predictive probability of success (PPS) \n", + "> are calculated for each experimental regimen using a Bayesian approach, \n", + "> and the trial will either stop for futility, \n", + "> continue to the next interim analysis, \n", + "> or proceed to Stage 2 depending on the results of these PrBest and PPS calculations.\n", + "> \n", + "> In Stage 2, a minimum of 200 and a maximum of 400 additional subjects will be randomized 1:1 to the chosen regimen or placebo. \n", + "> The two groups (pooling both Stage 1 and Stage 2 subjects) will be compared for efficacy and futility assessment at an interim analysis \n", + "> after 200 subjects have been enrolled in Stage 2, \n", + "> and for efficacy at a final analysis after 400 subjects have been enrolled in Stage 2 and fully followed-up for response. \n", + "> The study may be stopped for futility or efficacy based on PPS at the interim analysis. \n", + "> If the study continues to the final analysis, \n", + "> the posterior distribution of the difference in response rates between placebo and the chosen experimental arm \n", + "> will be evaluated against a pre-specified decision criterion.\n", + "> \n", + "> - Scenario 1: interim analyses are based on available data on the primary endpoint (measured at month 6)\n", + "> - Scenario 2: interim analyses are based on available data on a secondary endpoint (measured at month 3) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook breaks down and discusses the components of the trial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The notation is as follows:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- $y \\in \\mathbb{N}^d$: Binomial responses.\n", + "- $p \\in [0,1]^d$: probability parameter to the Binomial distribution.\n", + "- $n \\in \\mathbb{N}^d$: size parameter to the Binomial distribution.\n", + "- $q \\in [0,1]^d$: base probability value to offset $p$.\n", + "- $\\theta \\in \\R^d$: logit parameter that determines $p$.\n", + "- $\\mu \\in \\mathbb{R}$: shared mean parameter among $\\theta_i$.\n", + "- $\\sigma^2 \\in \\mathbb{R}_+$: shared variance parameter among $\\theta_i$.\n", + "- $\\mu_0, \\sigma_0^2, \\alpha_0, \\beta_0 \\in \\mathbb{R}$: hyper-parameters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Bayesian model is described below:\n", + "\\begin{align*}\n", + "y_i | p_i &\\sim \\mathrm{Binom}(n_i, p_i) \\quad i = 1,\\ldots, d \\\\\n", + "p_i &= {\\sf expit}(\\theta_i + \\mathrm{logit}(q_i) ) \\quad i = 1,\\ldots, d \\\\\n", + "\\theta | \\mu, \\sigma^2 &\\sim \\mathcal{N}(\\mu \\mathbb{1}, \\sigma^2 I) \\\\\n", + "\\mu &\\sim \\mathcal{N}(\\mu_0, \\sigma_0^2) \\\\\n", + "\\sigma^2 &\\sim \\Gamma^{-1}(\\alpha_0, \\beta_0) \\\\\n", + "\\end{align*}\n", + "\n", + "We note in passing that the model can be collapsed along $\\mu$ to get:\n", + "\\begin{align*}\n", + "y_i | p_i &\\sim \\mathrm{Binom}(n_i, p_i) \\quad i = 1,\\ldots, d \\\\\n", + "p_i &= {\\sf expit}(\\theta_i + \\mathrm{logit}(q_i) ) \\quad i = 1,\\ldots, d \\\\\n", + "\\theta | \\sigma^2 &\\sim \\mathcal{N}(\\mu_0 \\mathbb{1}, \\sigma^2 I + \\sigma_0^2 \\mathbb{1} \\mathbb{1}^\\top) \\\\\n", + "\\sigma^2 &\\sim \\Gamma^{-1}(\\alpha_0, \\beta_0) \\\\\n", + "\\end{align*}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Probability of Best Arm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first quantity of interest is probability of best (treatment) arm.\n", + "Concretely, letting $i = 1$ denote the control arm, we wish to compute for each $1 < i \\leq d$:\n", + "\\begin{align*}\n", + "\\mathbb{P}(p_i > \\max\\limits_{j \\neq i} p_j | y, n)\n", + "&=\n", + "\\int \\mathbb{P}(p_i > \\max\\limits_{j \\neq i} p_j | y, n, \\sigma^2) p(\\sigma^2 | y, n) \\, d\\sigma^2\n", + "\\\\&=\n", + "\\int \\mathbb{P}(\\theta_i + c_i > \\max\\limits_{j \\neq i} (\\theta_j + c_j) | y, n, \\sigma^2) p(\\sigma^2 | y, n) \\, d\\sigma^2\n", + "\\end{align*}\n", + "where $c = \\mathrm{logit}(q)$.\n", + "We can approximate this quantity by estimating the two integrand terms separately. \n", + "By approximating $\\theta_i | y, n$ as normal, the first integrand term can be estimated by Monte Carlo.\n", + "The second term can be estimated by INLA." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def pr_normal_best(mean, cov, key, n_sims):\n", + " '''\n", + " Estimates P[X_i > max_{j != i} X_j] where X ~ N(mean, cov) via sampling.\n", + " '''\n", + " out_shape = (n_sims, *mean.shape[:-1])\n", + " sims = jax.random.multivariate_normal(key, mean, cov, shape=out_shape)\n", + " order = jnp.arange(1, mean.shape[-1])\n", + " compute_pr_best_all = jax.vmap(lambda i: jnp.mean(jnp.argmax(sims, axis=-1) == i, axis=0))\n", + " return compute_pr_best_all(order)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "d = 4\n", + "mean = jnp.array([2, 2, 2, 5])\n", + "cov = jnp.eye(d)\n", + "key = jax.random.PRNGKey(0)\n", + "n_sims = 100000\n", + "jax.jit(pr_normal_best, static_argnums=(3,))(mean, cov, key, n_sims=n_sims)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we perform INLA to estimate $p(\\sigma^2 | y, n)$ on a grid of values for $\\sigma^2$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sig2_rule = quad.log_gauss_rule(15, 1e-6, 1e3)\n", + "sig2_rule_ops = berry.optimized(sig2_rule.pts, n_arms=4).config(\n", + " opt_tol=1e-3\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def posterior_sigma_sq(data, sig2_rule, sig2_rule_ops):\n", + " n_arms, _ = data.shape\n", + " sig2 = sig2_rule.pts\n", + " n_sig2 = sig2.shape[0]\n", + " p_pinned = dict(sig2=sig2, theta=None)\n", + "\n", + " f = sig2_rule_ops.laplace_logpost\n", + " logpost, x_max, hess, iters = f(\n", + " np.zeros((n_sig2, n_arms)), p_pinned, data\n", + " )\n", + " post = inla.exp_and_normalize(\n", + " logpost, sig2_rule.wts, axis=-1)\n", + "\n", + " return post, x_max, hess, iters " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dtype = jnp.float64\n", + "N = 1\n", + "data = berry.figure2_data(N).astype(dtype)[0]\n", + "n_arms, _ = data.shape\n", + "posterior_sigma_sq_jit = jax.jit(lambda data: posterior_sigma_sq(data, sig2_rule, sig2_rule_ops))\n", + "post, _, hess, _ = posterior_sigma_sq_jit(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Putting the two pieces together, we have the following function to compute the probability of best treatment arm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def pr_best(data, sig2_rule, sig2_rule_ops, key, n_sims):\n", + " n_arms, _ = data.shape\n", + " post, x_max, hess, _ = posterior_sigma_sq(data, sig2_rule, sig2_rule_ops) \n", + " mean = x_max\n", + " hess_fn = jax.vmap(lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1]))\n", + " prec = -hess_fn(hess) # (n_sigs, n_arms, n_arms)\n", + " cov = jnp.linalg.inv(prec)\n", + " pr_normal_best_out = pr_normal_best(mean, cov, key=key, n_sims=n_sims)\n", + " return jnp.matmul(pr_normal_best_out, post * sig2_rule.wts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_sims = 13\n", + "out = pr_best(data, sig2_rule, sig2_rule_ops, key, n_sims)\n", + "out" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Phase III Final Analysis\n", + "\n", + "\\begin{align*}\n", + "\\mathbb{P}(\\theta_i - \\theta_0 < t | y, n) < 0.1\n", + "\\end{align*}\n", + "\n", + "\\begin{align*}\n", + "\\mathbb{P}(\\theta_i - \\theta_0 < t | y, n)\n", + "&=\n", + "\\mathbb{P}(q_1^\\top \\theta < t | y,n)\n", + "=\n", + "\\int \\mathbb{P}(q_1^\\top \\theta < t | y, n, \\sigma^2) p(\\sigma^2 | y, n) \\, d\\sigma^2\n", + "\\\\&=\n", + "\\int \\mathbb{P}(q_1^\\top \\theta < t | y, n, \\sigma^2) p(\\sigma^2 | y, n) \\, d\\sigma^2\n", + "\\\\\n", + "q_1^\\top \\theta | y, n, \\sigma^2 &\\sim \\mathcal{N}(q_1^\\top \\theta^*, -q_1^\\top (H\\log p(\\theta^*, y, \\sigma^2))^{-1} q_1)\n", + "\\end{align*}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "posterior_difference_threshold = 0.2\n", + "rejection_threshold = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def posterior_difference(data, arm, sig2_rule, sig2_rule_ops, thresh):\n", + " n_arms, _ = data.shape\n", + " post, x_max, hess, _ = posterior_sigma_sq(data, sig2_rule, sig2_rule_ops)\n", + " hess_fn = jax.vmap(lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1]))\n", + " prec = -hess_fn(hess) # (n_sigs, n_arms, n_arms)\n", + " order = jnp.arange(0, n_arms)\n", + " q1 = jnp.where(order == 0, -1, 0)\n", + " q1 = jnp.where(order == arm, 1, q1)\n", + " loc = x_max @ q1\n", + " scale = jnp.linalg.solve(prec, q1[None,:]) @ q1\n", + " normal_term = jax.scipy.stats.norm.cdf(thresh, loc=loc, scale=scale)\n", + " post_weighted = sig2_rule.wts * post\n", + " out = normal_term @ post_weighted\n", + " return out\n", + "\n", + "posterior_difference(data, 1, sig2_rule, sig2_rule_ops, posterior_difference_threshold)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Posterior Probability of Success\n", + "\n", + "The next quantity we need to compute is the posterior probability of success (PPS).\n", + "For convenience of implementation, we will take this to mean the following:\n", + "let $y, n$ denote the currently observed data\n", + "and $A_i = \\{ \\text{Phase III rejects using treatment arm i} \\}$.\n", + "Then, we wish to compute\n", + "\\begin{align*}\n", + "\\mathbb{P}(A_i | y, n)\n", + "\\end{align*}\n", + "Expanding the quantity,\n", + "\\begin{align*}\n", + "\\mathbb{P}(A_i | y, n) &=\n", + "\\int \\mathbb{P}(A_i | y, n, \\theta_i, \\theta_0) p(\\theta_0, \\theta_i | y, n) \\, d\\theta_i d\\theta_0 \\\\&=\n", + "\\int \\mathbb{P}(A_i | y, n, \\theta_i, \\theta_0) p(\\theta_0, \\theta_i | y, n) \\, d\\theta_i d\\theta_0\n", + "\\end{align*}\n", + "\n", + "Once we have an estimate for $p(\\theta_0, \\theta_i | y, n)$, \n", + "we can use 2-D quadrature to numerically integrate the integrand.\n", + "Similar to computing the probability of best arm,\n", + "\\begin{align*}\n", + "p(\\theta_0, \\theta_i | y, n)\n", + "&=\n", + "\\int p(\\theta_0, \\theta_i | y, n, \\sigma^2) p(\\sigma^2 | y, n) \\, d\\sigma^2\n", + "\\end{align*}\n", + "We will use the Gaussian approximation for $p(\\theta_0, \\theta_i | y, n, \\sigma^2)$\n", + "and use INLA to estimate $p(\\sigma^2 | y, n)$.\n", + "\n", + "\\begin{align*}\n", + "p(\\theta | y, n, \\sigma^2)\n", + "\\approx\n", + "\\mathcal{N}(\\theta^*, -(H\\log p(\\theta^*, y, \\sigma^2))^{-1})\n", + "\\\\\n", + "\\implies\n", + "p(\\theta_0, \\theta_i | y, n, \\sigma^2)\n", + "\\approx\n", + "\\mathcal{N}(\\theta^*_{[0,i]}, -(H\\log p(\\theta^*, y, \\sigma^2))^{-1}_{[0,i], [0,i]})\n", + "\\end{align*}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# input parameters\n", + "n_Ai_sims = 1000\n", + "p = jnp.full(n_arms, 0.5)\n", + "n_stage_2 = 100\n", + "pps_threshold_lower = 0.1\n", + "pps_threshold_upper = 0.9\n", + "posterior_difference_threshold = 0.1\n", + "rejection_threshold = 0.1\n", + "\n", + "subset = jnp.array([0, 1])\n", + "non_futile_idx = np.zeros(n_arms)\n", + "non_futile_idx[subset] = 1\n", + "non_futile_idx = jnp.array(non_futile_idx)\n", + "\n", + "# create a dense grid of sig2 values\n", + "n_sig2 = 100\n", + "sig2_grid = 10**jnp.linspace(-6, 3, n_sig2)\n", + "dsig2_grid = jnp.diff(sig2_grid)\n", + "sig2_grid_ops = berry.optimized(sig2_grid, n_arms=data.shape[0]).config(\n", + " opt_tol=1e-3\n", + ")\n", + "\n", + "_, key = jax.random.split(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def pr_Ai(\n", + " data, p, key, best_arm, non_futile_idx, \n", + " sig2_rule, sig2_rule_ops,\n", + " sig2_grid, sig2_grid_ops, dsig2_grid,\n", + "):\n", + " n_arms, _ = data.shape\n", + "\n", + " # compute p(sig2 | y, n), mode, hessian\n", + " p_pinned = dict(sig2=sig2_grid, theta=None)\n", + " logpost, x_max, hess, _ = jax.jit(sig2_grid_ops.laplace_logpost)(\n", + " np.zeros((len(sig2_grid), n_arms)), p_pinned, data\n", + " )\n", + " max_logpost = jnp.max(logpost)\n", + " max_post = jnp.exp(max_logpost)\n", + " post = jnp.exp(logpost - max_logpost) * max_post\n", + "\n", + " # sample sigma^2 | y, n\n", + " dFx = post[:-1] * dsig2_grid\n", + " Fx = jnp.cumsum(dFx)\n", + " Fx /= Fx[-1]\n", + " _, key = jax.random.split(key)\n", + " unifs = jax.random.uniform(key=key, shape=(n_Ai_sims,))\n", + " i_star = jnp.searchsorted(Fx, unifs)\n", + "\n", + " # sample theta | y, n, sigma^2\n", + " mean = x_max[i_star+1]\n", + " hess_fn = jax.vmap(\n", + " lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1])\n", + " )\n", + " prec = -hess_fn(hess)\n", + " cov = jnp.linalg.inv(prec)[i_star+1]\n", + " _, key = jax.random.split(key)\n", + " theta = jax.random.multivariate_normal(\n", + " key=key, mean=mean, cov=cov,\n", + " )\n", + " p_samples = jax.scipy.special.expit(theta)\n", + "\n", + " # estimate P(A_i | y, n, theta_0, theta_i)\n", + "\n", + " def simulate_Ai(data, best_arm, diff_thresh, rej_thresh, non_futile_idx, key, p):\n", + " # add n_stage_2 number of patients to each\n", + " # of the control and selected treatment arms.\n", + " n_new = jnp.where(non_futile_idx, n_stage_2, 0)\n", + " y_new = dist.Binomial(total_count=n_new, probs=p).sample(key)\n", + "\n", + " # pool outcomes for each arm\n", + " data = data + jnp.stack((y_new, n_new), axis=-1)\n", + "\n", + " return posterior_difference(data, best_arm, sig2_rule, sig2_rule_ops, diff_thresh) < rej_thresh\n", + "\n", + " simulate_Ai_vmapped = jax.vmap(\n", + " simulate_Ai, in_axes=(None, None, None, None, None, 0, 0)\n", + " )\n", + " keys = jax.random.split(key, num=p_samples.shape[0])\n", + " Ai_indicators = simulate_Ai_vmapped(\n", + " data,\n", + " best_arm,\n", + " posterior_difference_threshold,\n", + " rejection_threshold,\n", + " non_futile_idx,\n", + " keys,\n", + " p_samples,\n", + " )\n", + " out = jnp.mean(Ai_indicators)\n", + " return out\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "jax.jit(lambda data, p, key, best_arm, non_futile_idx:\n", + " pr_Ai(data, p, key, best_arm, non_futile_idx, \n", + " sig2_rule, sig2_rule_ops, sig2_grid, sig2_grid_ops, dsig2_grid),\n", + " static_argnums=(3,))(data, p, key, 1, non_futile_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sampling based on pdf values and linearly interpolating\n", + "n_sims = 1000\n", + "n_unifs = 1000\n", + "key = jax.random.PRNGKey(2)\n", + "\n", + "#x = jnp.linspace(-3, 3, num=n_sims)\n", + "#px_orig = 0.5*jax.scipy.stats.norm.pdf(x, -1, 0.5) + 0.5*jax.scipy.stats.norm.pdf(x, 1, 0.5)\n", + "\n", + "#x = jnp.linspace(0, 10, num=n_sims)\n", + "#px_orig = jax.scipy.stats.gamma.pdf(x, 10)\n", + "\n", + "x = jnp.linspace(0, 1, num=n_sims)\n", + "px_orig = jax.scipy.stats.beta.pdf(x, 4, 2)\n", + "\n", + "px = 2 * px_orig\n", + "dx = jnp.diff(x)\n", + "dFx = px[:-1] * dx\n", + "Fx = jnp.cumsum(dFx)\n", + "Fx /= Fx[-1]\n", + "_, key = jax.random.split(key)\n", + "unifs = jax.random.uniform(key=key, shape=(n_unifs,))\n", + "i_star = jnp.searchsorted(Fx, unifs)\n", + "\n", + "# point mass approx\n", + "#samples = x[i_star+1]\n", + "\n", + "# constant approx\n", + "#samples = x[i_star+1] - (Fx[i_star] - unifs) / px[i_star]\n", + "\n", + "# linear approx\n", + "a = 0.5 * (px[i_star+1] - px[i_star]) / dx[i_star]\n", + "b = px[i_star]\n", + "c = Fx[i_star] - unifs - px[i_star] * dx[i_star] - a * dx[i_star]**2\n", + "discr = jnp.sqrt(jnp.maximum(b**2 - 4*a*c, 0))\n", + "quad_solve = jnp.where(jnp.abs(a) < 1e-8, -c/b, (-b + discr) / (2*a))\n", + "samples = x[i_star] + quad_solve\n", + "\n", + "#plt.plot(x[1:], Fx)\n", + "#plt.plot(x[1:], jax.scipy.stats.norm.cdf(x[1:]))\n", + "plt.hist(x[i_star+1], density=True, alpha=0.5)\n", + "plt.hist(samples, density = True, alpha=0.5)\n", + "plt.plot(x, px_orig)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Design Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.2 s, sys: 1.19 s, total: 3.39 s\n", + "Wall time: 4.23 s\n" + ] + } + ], + "source": [ + "%%time\n", + "params = {\n", + " \"n_arms\" : 4,\n", + " \"n_stage_1\" : 50,\n", + " \"n_stage_2\" : 100,\n", + " \"n_stage_1_interims\" : 2,\n", + " \"n_stage_1_add_per_interim\" : 100,\n", + " \"n_stage_2_add_per_interim\" : 100,\n", + " \"stage_1_futility_threshold\" : 0.15,\n", + " \"stage_1_efficacy_threshold\" : 0.7,\n", + " \"stage_2_futility_threshold\" : 0.2,\n", + " \"stage_2_efficacy_threshold\" : 0.95,\n", + " \"inter_stage_futility_threshold\" : 0.6,\n", + " \"posterior_difference_threshold\" : 0,\n", + " \"rejection_threshold\" : 0.05,\n", + " \"key\" : jax.random.PRNGKey(0),\n", + " \"n_pr_sims\" : 100,\n", + " \"n_sig2_sims\" : 20,\n", + " \"batch_size\" : int(2**20),\n", + " \"cache_tables\" : False,\n", + "}\n", + "lei_obj = lewis.Lewis45(**params)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 2**12\n", + "key = jax.random.PRNGKey(0)\n", + "n_points = 20\n", + "n_pr_sims = 100\n", + "n_sig2_sim = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5.58 s, sys: 453 ms, total: 6.04 s\n", + "Wall time: 6.46 s\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "lei_obj.pd_table = lei_obj.posterior_difference_table__(\n", + " batch_size=batch_size,\n", + " n_points=n_points, \n", + ")\n", + "lei_obj.pd_table" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 17.7 s, sys: 4.76 s, total: 22.4 s\n", + "Wall time: 20.3 s\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "lei_obj.pr_best_pps_1_table = lei_obj.pr_best_pps_1_table__(\n", + " key, \n", + " n_pr_sims,\n", + " batch_size=batch_size,\n", + " n_points=n_points,\n", + ")\n", + "lei_obj.pr_best_pps_1_table" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 13.3 s, sys: 3.29 s, total: 16.6 s\n", + "Wall time: 14 s\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "_, key = jax.random.split(key)\n", + "lei_obj.pps_2_table = lei_obj.pps_2_table__(\n", + " key, \n", + " n_pr_sims,\n", + " batch_size=batch_size,\n", + " n_points=n_points,\n", + ")\n", + "lei_obj.pps_2_table" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "n_arms = params['n_arms']\n", + "size = 52\n", + "lower = np.full(n_arms, -1)\n", + "upper = np.full(n_arms, 1)\n", + "thetas, radii = lewgrid.make_cartesian_grid_range(\n", + " size=size,\n", + " lower=lower,\n", + " upper=upper,\n", + ") \n", + "ns = np.concatenate(\n", + " [np.ones(n_arms-1)[:, None], -np.eye(n_arms-1)],\n", + " axis=-1,\n", + ")\n", + "null_hypos = [\n", + " grid.HyperPlane(n, 0)\n", + " for n in ns\n", + "]\n", + "gr = grid.build_grid(\n", + " thetas=thetas,\n", + " radii=radii,\n", + " null_hypos=null_hypos,\n", + ")\n", + "gr = grid.prune(gr)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "theta_tiles = gr.thetas[gr.grid_pt_idx]\n", + "null_truths = gr.null_truth.astype(bool)\n", + "grid_batch_size = int(2**12)\n", + "n_sim_batches = 500\n", + "sim_batch_size = 50\n", + "\n", + "p_tiles = jax.scipy.special.expit(theta_tiles)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class LeiSimulator:\n", + " def __init__(\n", + " self,\n", + " lei_obj,\n", + " p_tiles,\n", + " null_truths,\n", + " grid_batch_size,\n", + " reduce_func=None,\n", + " ):\n", + " self.lei_obj = lei_obj\n", + " self.unifs_shape = self.lei_obj.unifs_shape()\n", + " self.unifs_order = np.arange(0, self.unifs_shape[0])\n", + " self.p_tiles = p_tiles\n", + " self.null_truths = null_truths\n", + " self.grid_batch_size = grid_batch_size\n", + "\n", + " self.reduce_func = (\n", + " lambda x: np.sum(x, axis=0) if not reduce_func else reduce_func\n", + " )\n", + "\n", + " self.f_batch_sim_batch_grid_jit = jax.jit(self.f_batch_sim_batch_grid)\n", + " self.batch_all = batch.batch_all(\n", + " self.f_batch_sim_batch_grid_jit,\n", + " batch_size=self.grid_batch_size,\n", + " in_axes=(0, 0, None, None),\n", + " )\n", + "\n", + " self.typeI_sum = None\n", + " self.typeI_score = None\n", + "\n", + " def f_batch_sim_batch_grid(self, p_batch, null_batch, unifs_batch, unifs_order):\n", + " return jax.vmap(\n", + " jax.vmap(\n", + " self.lei_obj.simulate,\n", + " in_axes=(0, 0, None, None),\n", + " ),\n", + " in_axes=(None, None, 0, None),\n", + " )(p_batch, null_batch, unifs_batch, unifs_order)\n", + "\n", + " def simulate_batch_sim(self, sim_batch_size, i, key):\n", + " start = time.perf_counter()\n", + "\n", + " unifs = jax.random.uniform(key=key, shape=(sim_batch_size,) + self.unifs_shape)\n", + " rejs_scores, n_padded = self.batch_all(\n", + " self.p_tiles, self.null_truths, unifs, self.unifs_order\n", + " )\n", + " rejs, scores = tuple(\n", + " np.concatenate(\n", + " tuple(x[i] for x in rejs_scores),\n", + " axis=1,\n", + " )\n", + " for i in range(2)\n", + " )\n", + " rejs, scores = (\n", + " (rejs[:, :-n_padded], scores[:, :-n_padded, :])\n", + " if n_padded\n", + " else (rejs, scores)\n", + " )\n", + " rejs_reduced = self.reduce_func(rejs)\n", + " scores_reduced = self.reduce_func(scores)\n", + "\n", + " end = time.perf_counter()\n", + " elapsed_time = (end-start)\n", + " print(f\"Batch {i}: {elapsed_time:.03f}s\")\n", + " return rejs_reduced, scores_reduced\n", + "\n", + " def simulate(\n", + " self,\n", + " key,\n", + " n_sim_batches,\n", + " sim_batch_size,\n", + " ):\n", + " keys = jax.random.split(key, num=n_sim_batches)\n", + " self.typeI_sum = np.zeros(self.p_tiles.shape[0])\n", + " self.typeI_score = np.zeros(self.p_tiles.shape)\n", + " for i, key in enumerate(keys):\n", + " out = self.simulate_batch_sim(sim_batch_size, i, key)\n", + " self.typeI_sum += out[0]\n", + " self.typeI_score += out[1]\n", + " return self.typeI_sum, self.typeI_score\n", + "\n", + "\n", + "simulator = LeiSimulator(\n", + " lei_obj=lei_obj,\n", + " p_tiles=p_tiles,\n", + " null_truths=null_truths,\n", + " grid_batch_size=grid_batch_size,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-09-12 13:40:16.727209: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.09GiB (rounded to 1173514496)requested by op \n", + "2022-09-12 13:40:16.729045: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] *********************************************************************************************_______\n", + "2022-09-12 13:40:16.730879: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1173514448 bytes.\n", + "BufferAssignment OOM Debugging.\n", + "BufferAssignment stats:\n", + " parameter allocation: 1.21MiB\n", + " constant allocation: 128.20MiB\n", + " maybe_live_out allocation: 12.89MiB\n", + " preallocated temp allocation: 1.09GiB\n", + " preallocated temp fragmentation: 0B (0.00%)\n", + " total allocation: 1.23GiB\n", + " total fragmentation: 128.21MiB (10.16%)\n", + "Peak buffers:\n", + "\tBuffer 1:\n", + "\t\tSize: 546.88MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/reduce_sum[axes=(2,)]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=704\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: pred[100,4096,350,4]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 2:\n", + "\t\tSize: 250.00MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 4, 20) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/table.py\" source_line=66\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096,4,20]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 3:\n", + "\t\tSize: 62.50MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096,20]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 4:\n", + "\t\tSize: 62.50MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096,20]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 5:\n", + "\t\tSize: 62.50MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096,20]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 6:\n", + "\t\tSize: 62.50MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096,20]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 7:\n", + "\t\tSize: 36.62MiB\n", + "\t\tXLA Label: constant\n", + "\t\tShape: f64[10,160000,3]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 8:\n", + "\t\tSize: 25.00MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/concatenate[dimension=3]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=703\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096,4,2]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 9:\n", + "\t\tSize: 18.31MiB\n", + "\t\tXLA Label: constant\n", + "\t\tShape: f32[10,160000,3]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 10:\n", + "\t\tSize: 18.31MiB\n", + "\t\tXLA Label: constant\n", + "\t\tShape: f32[10,160000,3]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 11:\n", + "\t\tSize: 18.31MiB\n", + "\t\tXLA Label: constant\n", + "\t\tShape: f32[10,160000,3]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 12:\n", + "\t\tSize: 18.31MiB\n", + "\t\tXLA Label: constant\n", + "\t\tShape: f32[10,160000,3]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 13:\n", + "\t\tSize: 18.31MiB\n", + "\t\tXLA Label: constant\n", + "\t\tShape: f32[10,160000,3]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 14:\n", + "\t\tSize: 12.50MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/select_n\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=963\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: f64[100,4096,4]\n", + "\t\t==========================\n", + "\n", + "\tBuffer 15:\n", + "\t\tSize: 3.12MiB\n", + "\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=54\n", + "\t\tXLA Label: fusion\n", + "\t\tShape: s64[100,4096]\n", + "\t\t==========================\n", + "\n", + "\n" + ] + }, + { + "ename": "ValueError", + "evalue": "RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1173514448 bytes.\nBufferAssignment OOM Debugging.\nBufferAssignment stats:\n parameter allocation: 1.21MiB\n constant allocation: 128.20MiB\n maybe_live_out allocation: 12.89MiB\n preallocated temp allocation: 1.09GiB\n preallocated temp fragmentation: 0B (0.00%)\n total allocation: 1.23GiB\n total fragmentation: 128.21MiB (10.16%)\nPeak buffers:\n\tBuffer 1:\n\t\tSize: 546.88MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/reduce_sum[axes=(2,)]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=704\n\t\tXLA Label: fusion\n\t\tShape: pred[100,4096,350,4]\n\t\t==========================\n\n\tBuffer 2:\n\t\tSize: 250.00MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 4, 20) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/table.py\" source_line=66\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,4,20]\n\t\t==========================\n\n\tBuffer 3:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 4:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 5:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 6:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 7:\n\t\tSize: 36.62MiB\n\t\tXLA Label: constant\n\t\tShape: f64[10,160000,3]\n\t\t==========================\n\n\tBuffer 8:\n\t\tSize: 25.00MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/concatenate[dimension=3]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=703\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,4,2]\n\t\t==========================\n\n\tBuffer 9:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 10:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 11:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 12:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 13:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 14:\n\t\tSize: 12.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/select_n\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=963\n\t\tXLA Label: fusion\n\t\tShape: f64[100,4096,4]\n\t\t==========================\n\n\tBuffer 15:\n\t\tSize: 3.12MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=54\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096]\n\t\t==========================\n\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m:2\u001b[0m\n", + "Cell \u001b[0;32mIn [10], line 77\u001b[0m, in \u001b[0;36mLeiSimulator.simulate\u001b[0;34m(self, key, n_sim_batches, sim_batch_size)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtypeI_score \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mp_tiles\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(keys):\n\u001b[0;32m---> 77\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msimulate_batch_sim\u001b[49m\u001b[43m(\u001b[49m\u001b[43msim_batch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtypeI_sum \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m out[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtypeI_score \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m out[\u001b[38;5;241m1\u001b[39m]\n", + "Cell \u001b[0;32mIn [10], line 44\u001b[0m, in \u001b[0;36mLeiSimulator.simulate_batch_sim\u001b[0;34m(self, sim_batch_size, i, key)\u001b[0m\n\u001b[1;32m 41\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mperf_counter()\n\u001b[1;32m 43\u001b[0m unifs \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39muniform(key\u001b[38;5;241m=\u001b[39mkey, shape\u001b[38;5;241m=\u001b[39m(sim_batch_size,) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munifs_shape)\n\u001b[0;32m---> 44\u001b[0m rejs_scores, n_padded \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_all\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mp_tiles\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnull_truths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munifs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munifs_order\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m rejs, scores \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m 48\u001b[0m np\u001b[38;5;241m.\u001b[39mconcatenate(\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28mtuple\u001b[39m(x[i] \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m rejs_scores),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 53\u001b[0m )\n\u001b[1;32m 54\u001b[0m rejs, scores \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 55\u001b[0m (rejs[:, :\u001b[38;5;241m-\u001b[39mn_padded], scores[:, :\u001b[38;5;241m-\u001b[39mn_padded, :])\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_padded\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m (rejs, scores)\n\u001b[1;32m 58\u001b[0m )\n", + "File \u001b[0;32m/workspaces/confirmasaurus/research/lei/lewis/batch.py:81\u001b[0m, in \u001b[0;36mbatch_all..internal\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minternal\u001b[39m(\u001b[39m*\u001b[39margs):\n\u001b[0;32m---> 81\u001b[0m outs \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39;49m(out \u001b[39mfor\u001b[39;49;00m out \u001b[39min\u001b[39;49;00m f_batch(\u001b[39m*\u001b[39;49margs))\n\u001b[1;32m 82\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mtuple\u001b[39m(out[\u001b[39m0\u001b[39m] \u001b[39mfor\u001b[39;00m out \u001b[39min\u001b[39;00m outs), outs[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m][\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m]\n", + "File \u001b[0;32m/workspaces/confirmasaurus/research/lei/lewis/batch.py:81\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39minternal\u001b[39m(\u001b[39m*\u001b[39margs):\n\u001b[0;32m---> 81\u001b[0m outs \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(out \u001b[39mfor\u001b[39;00m out \u001b[39min\u001b[39;00m f_batch(\u001b[39m*\u001b[39margs))\n\u001b[1;32m 82\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mtuple\u001b[39m(out[\u001b[39m0\u001b[39m] \u001b[39mfor\u001b[39;00m out \u001b[39min\u001b[39;00m outs), outs[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m][\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m]\n", + "File \u001b[0;32m/workspaces/confirmasaurus/research/lei/lewis/batch.py:60\u001b[0m, in \u001b[0;36mbatch..internal\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(n_full_batches):\n\u001b[1;32m 54\u001b[0m batched_args \u001b[39m=\u001b[39m create_batched_args__(\n\u001b[1;32m 55\u001b[0m args\u001b[39m=\u001b[39margs,\n\u001b[1;32m 56\u001b[0m in_axes\u001b[39m=\u001b[39min_axes,\n\u001b[1;32m 57\u001b[0m start\u001b[39m=\u001b[39mstart,\n\u001b[1;32m 58\u001b[0m end\u001b[39m=\u001b[39mend,\n\u001b[1;32m 59\u001b[0m )\n\u001b[0;32m---> 60\u001b[0m \u001b[39myield\u001b[39;00m (f(\u001b[39m*\u001b[39;49mbatched_args), \u001b[39m0\u001b[39m)\n\u001b[1;32m 61\u001b[0m start \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m batch_size_new\n\u001b[1;32m 62\u001b[0m end \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m batch_size_new\n", + "\u001b[0;31mValueError\u001b[0m: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1173514448 bytes.\nBufferAssignment OOM Debugging.\nBufferAssignment stats:\n parameter allocation: 1.21MiB\n constant allocation: 128.20MiB\n maybe_live_out allocation: 12.89MiB\n preallocated temp allocation: 1.09GiB\n preallocated temp fragmentation: 0B (0.00%)\n total allocation: 1.23GiB\n total fragmentation: 128.21MiB (10.16%)\nPeak buffers:\n\tBuffer 1:\n\t\tSize: 546.88MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/reduce_sum[axes=(2,)]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=704\n\t\tXLA Label: fusion\n\t\tShape: pred[100,4096,350,4]\n\t\t==========================\n\n\tBuffer 2:\n\t\tSize: 250.00MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(2, 3), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 4, 20) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/table.py\" source_line=66\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,4,20]\n\t\t==========================\n\n\tBuffer 3:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 4:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 5:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 6:\n\t\tSize: 62.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=41\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,20]\n\t\t==========================\n\n\tBuffer 7:\n\t\tSize: 36.62MiB\n\t\tXLA Label: constant\n\t\tShape: f64[10,160000,3]\n\t\t==========================\n\n\tBuffer 8:\n\t\tSize: 25.00MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/concatenate[dimension=3]\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=703\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096,4,2]\n\t\t==========================\n\n\tBuffer 9:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 10:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 11:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 12:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 13:\n\t\tSize: 18.31MiB\n\t\tXLA Label: constant\n\t\tShape: f32[10,160000,3]\n\t\t==========================\n\n\tBuffer 14:\n\t\tSize: 12.50MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/select_n\" source_file=\"/workspaces/confirmasaurus/research/lei/lewis/lewis.py\" source_line=963\n\t\tXLA Label: fusion\n\t\tShape: f64[100,4096,4]\n\t\t==========================\n\n\tBuffer 15:\n\t\tSize: 3.12MiB\n\t\tOperator: op_name=\"jit(f_batch_sim_batch_grid)/jit(main)/squeeze[dimensions=(2,)]\" source_file=\"/workspaces/confirmasaurus/outlaw/outlaw/interp.py\" source_line=54\n\t\tXLA Label: fusion\n\t\tShape: s64[100,4096]\n\t\t==========================\n\n" + ] + } + ], + "source": [ + "%%time\n", + "key = jax.random.PRNGKey(3)\n", + "typeI_sum, typeI_score = simulator.simulate(\n", + " key=key,\n", + " n_sim_batches=n_sim_batches,\n", + " sim_batch_size=sim_batch_size,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(\"output_lei4d2\", exist_ok=True)\n", + "np.savetxt(\"output_lei4d2/typeI_sum.csv\", typeI_sum, fmt=\"%s\", delimiter=\",\")\n", + "np.savetxt(\"output_lei4d2/typeI_score.csv\", typeI_score, fmt=\"%s\", delimiter=\",\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sim_size = sim_batch_size * n_sim_batches\n", + "\n", + "plt.figure(figsize=(8,4), constrained_layout=True)\n", + "for i, t2_idx in enumerate([4, 8]):\n", + " t2 = np.unique(theta_tiles[:, 2])[t2_idx]\n", + " selection = (theta_tiles[:,2] == t2)\n", + "\n", + " plt.subplot(1,2,i+1)\n", + " plt.title(f'slice: $\\\\theta_2 \\\\approx$ {t2:.1f}')\n", + " plt.scatter(theta_tiles[selection,0], theta_tiles[selection,1], c=typeI_sum[selection]/sim_size, s=90)\n", + " cbar = plt.colorbar()\n", + " plt.xlabel(r'$\\theta_0$')\n", + " plt.ylabel(r'$\\theta_1$')\n", + " cbar.set_label('Simulated fraction of Type I errors')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "tile_radii = gr.radii[gr.grid_pt_idx]\n", + "sim_sizes = np.full(gr.n_tiles, sim_size)\n", + "n_arm_samples = (\n", + " params['n_stage_1'] +\n", + " params['n_stage_1_add_per_interim'] // 2 * params['n_stage_1_interims'] + \n", + " params['n_stage_2'] +\n", + " params['n_stage_2_add_per_interim'] // 2\n", + ")\n", + "total, d0, d0u, d1w, d1uw, d2uw = binomial.upper_bound(\n", + " theta_tiles,\n", + " tile_radii,\n", + " gr.vertices,\n", + " sim_sizes,\n", + " n_arm_samples,\n", + " typeI_sum,\n", + " typeI_score,\n", + ")\n", + "bound_components = np.array([\n", + " d0,\n", + " d0u,\n", + " d1w,\n", + " d1uw,\n", + " d2uw,\n", + " total,\n", + "]).T" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "t2_uniques = np.unique(theta_tiles[:, 2])\n", + "t3_uniques = np.unique(theta_tiles[:, 3])\n", + "t2 = t2_uniques[8]\n", + "t3 = t3_uniques[8]\n", + "selection = (theta_tiles[:, 2] == t2) & (theta_tiles[:, 3] == t3)\n", + "\n", + "np.savetxt('output_lei4d/P_lei.csv', theta_tiles[selection, :].T, fmt=\"%s\", delimiter=\",\")\n", + "np.savetxt('output_lei4d/B_lei.csv', bound_components[selection, :], fmt=\"%s\", delimiter=\",\")" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.0625, 0.0625)" + ] + }, + "execution_count": 80, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t2_uniques[8], t3_uniques[8]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sandbox" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.5 ('confirm')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d8e1ca1b3fede25e3995e2b26ea544fa1b75b9a17984e6284a43c1dc286640dd" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/research/lei/lei.md b/research/lei/lei.md new file mode 100644 index 00000000..985fed83 --- /dev/null +++ b/research/lei/lei.md @@ -0,0 +1,712 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 + kernelspec: + display_name: Python 3.10.5 ('confirm') + language: python + name: python3 +--- + +```python +%load_ext autoreload +%autoreload 2 +``` + +```python +import os +import confirm.outlaw +import confirm.outlaw.berry as berry +import confirm.outlaw.quad as quad +import numpy as np +import jax.numpy as jnp +import jax +import time +import confirm.outlaw.inla as inla +import matplotlib.pyplot as plt +import numpyro.distributions as dist +from functools import partial +from itertools import combinations + +from confirm.lewislib import lewis +from confirm.lewislib import batch +from confirm.mini_imprint import grid +from confirm.lewislib import grid as lewgrid +from confirm.mini_imprint import binomial +``` + +# Lei Example + + +The following description is a clinical trial design using a Bayesian model with early-stopping rules for futility or efficacy of a drug. +This design was explicitly requested to be studied by an FDA member (Lei) in the CID team. + + +> The following is a randomized, double-blind, placebo-controlled two-stage adaptive design intended to identify an optimal treatment regimen +> from three possible regimens (for example, different dosages or different combinations of agents) and +> to assess the efficacy of that regimen with respect to a primary binary response endpoint measured at month 6. +> +> In Stage 1, one of four experimental regimens will be selected, or the trial will stop for futility. +> In this stage, a minimum of 200 and a maximum of 400 will be randomized 1:1:1:1 to one of the three experimental arms or one placebo arm. +> Interim analyses will be conducted after 200, 300 and 400 subjects have been enrolled to select the best experimental regimen and to potentially stop +> the trial for futility. +> If an experimental regimen is dropped for futility at an interim analysis, +> the next 100 subjects to be randomized will be allocated equally among the remaining arms in the study. +> At each of these three analysis time points (N = 200, 300, 400), +> the probabilities of being the best regimen (PrBest) and predictive probability of success (PPS) +> are calculated for each experimental regimen using a Bayesian approach, +> and the trial will either stop for futility, +> continue to the next interim analysis, +> or proceed to Stage 2 depending on the results of these PrBest and PPS calculations. +> +> In Stage 2, a minimum of 200 and a maximum of 400 additional subjects will be randomized 1:1 to the chosen regimen or placebo. +> The two groups (pooling both Stage 1 and Stage 2 subjects) will be compared for efficacy and futility assessment at an interim analysis +> after 200 subjects have been enrolled in Stage 2, +> and for efficacy at a final analysis after 400 subjects have been enrolled in Stage 2 and fully followed-up for response. +> The study may be stopped for futility or efficacy based on PPS at the interim analysis. +> If the study continues to the final analysis, +> the posterior distribution of the difference in response rates between placebo and the chosen experimental arm +> will be evaluated against a pre-specified decision criterion. +> +> - Scenario 1: interim analyses are based on available data on the primary endpoint (measured at month 6) +> - Scenario 2: interim analyses are based on available data on a secondary endpoint (measured at month 3) + + +This notebook breaks down and discusses the components of the trial. + + +## Model + + +The notation is as follows: + + +- $y \in \mathbb{N}^d$: Binomial responses. +- $p \in [0,1]^d$: probability parameter to the Binomial distribution. +- $n \in \mathbb{N}^d$: size parameter to the Binomial distribution. +- $q \in [0,1]^d$: base probability value to offset $p$. +- $\theta \in \R^d$: logit parameter that determines $p$. +- $\mu \in \mathbb{R}$: shared mean parameter among $\theta_i$. +- $\sigma^2 \in \mathbb{R}_+$: shared variance parameter among $\theta_i$. +- $\mu_0, \sigma_0^2, \alpha_0, \beta_0 \in \mathbb{R}$: hyper-parameters. + + +The Bayesian model is described below: +\begin{align*} +y_i | p_i &\sim \mathrm{Binom}(n_i, p_i) \quad i = 1,\ldots, d \\ +p_i &= {\sf expit}(\theta_i + \mathrm{logit}(q_i) ) \quad i = 1,\ldots, d \\ +\theta | \mu, \sigma^2 &\sim \mathcal{N}(\mu \mathbb{1}, \sigma^2 I) \\ +\mu &\sim \mathcal{N}(\mu_0, \sigma_0^2) \\ +\sigma^2 &\sim \Gamma^{-1}(\alpha_0, \beta_0) \\ +\end{align*} + +We note in passing that the model can be collapsed along $\mu$ to get: +\begin{align*} +y_i | p_i &\sim \mathrm{Binom}(n_i, p_i) \quad i = 1,\ldots, d \\ +p_i &= {\sf expit}(\theta_i + \mathrm{logit}(q_i) ) \quad i = 1,\ldots, d \\ +\theta | \sigma^2 &\sim \mathcal{N}(\mu_0 \mathbb{1}, \sigma^2 I + \sigma_0^2 \mathbb{1} \mathbb{1}^\top) \\ +\sigma^2 &\sim \Gamma^{-1}(\alpha_0, \beta_0) \\ +\end{align*} + + + +## Probability of Best Arm + + +The first quantity of interest is probability of best (treatment) arm. +Concretely, letting $i = 1$ denote the control arm, we wish to compute for each $1 < i \leq d$: +\begin{align*} +\mathbb{P}(p_i > \max\limits_{j \neq i} p_j | y, n) +&= +\int \mathbb{P}(p_i > \max\limits_{j \neq i} p_j | y, n, \sigma^2) p(\sigma^2 | y, n) \, d\sigma^2 +\\&= +\int \mathbb{P}(\theta_i + c_i > \max\limits_{j \neq i} (\theta_j + c_j) | y, n, \sigma^2) p(\sigma^2 | y, n) \, d\sigma^2 +\end{align*} +where $c = \mathrm{logit}(q)$. +We can approximate this quantity by estimating the two integrand terms separately. +By approximating $\theta_i | y, n$ as normal, the first integrand term can be estimated by Monte Carlo. +The second term can be estimated by INLA. + +```python +def pr_normal_best(mean, cov, key, n_sims): + ''' + Estimates P[X_i > max_{j != i} X_j] where X ~ N(mean, cov) via sampling. + ''' + out_shape = (n_sims, *mean.shape[:-1]) + sims = jax.random.multivariate_normal(key, mean, cov, shape=out_shape) + order = jnp.arange(1, mean.shape[-1]) + compute_pr_best_all = jax.vmap(lambda i: jnp.mean(jnp.argmax(sims, axis=-1) == i, axis=0)) + return compute_pr_best_all(order) +``` + +```python +d = 4 +mean = jnp.array([2, 2, 2, 5]) +cov = jnp.eye(d) +key = jax.random.PRNGKey(0) +n_sims = 100000 +jax.jit(pr_normal_best, static_argnums=(3,))(mean, cov, key, n_sims=n_sims) +``` + +Next, we perform INLA to estimate $p(\sigma^2 | y, n)$ on a grid of values for $\sigma^2$. + +```python +sig2_rule = quad.log_gauss_rule(15, 1e-6, 1e3) +sig2_rule_ops = berry.optimized(sig2_rule.pts, n_arms=4).config( + opt_tol=1e-3 +) +``` + +```python +def posterior_sigma_sq(data, sig2_rule, sig2_rule_ops): + n_arms, _ = data.shape + sig2 = sig2_rule.pts + n_sig2 = sig2.shape[0] + p_pinned = dict(sig2=sig2, theta=None) + + f = sig2_rule_ops.laplace_logpost + logpost, x_max, hess, iters = f( + np.zeros((n_sig2, n_arms)), p_pinned, data + ) + post = inla.exp_and_normalize( + logpost, sig2_rule.wts, axis=-1) + + return post, x_max, hess, iters +``` + +```python +dtype = jnp.float64 +N = 1 +data = berry.figure2_data(N).astype(dtype)[0] +n_arms, _ = data.shape +posterior_sigma_sq_jit = jax.jit(lambda data: posterior_sigma_sq(data, sig2_rule, sig2_rule_ops)) +post, _, hess, _ = posterior_sigma_sq_jit(data) +``` + +Putting the two pieces together, we have the following function to compute the probability of best treatment arm. + +```python +def pr_best(data, sig2_rule, sig2_rule_ops, key, n_sims): + n_arms, _ = data.shape + post, x_max, hess, _ = posterior_sigma_sq(data, sig2_rule, sig2_rule_ops) + mean = x_max + hess_fn = jax.vmap(lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1])) + prec = -hess_fn(hess) # (n_sigs, n_arms, n_arms) + cov = jnp.linalg.inv(prec) + pr_normal_best_out = pr_normal_best(mean, cov, key=key, n_sims=n_sims) + return jnp.matmul(pr_normal_best_out, post * sig2_rule.wts) +``` + +```python +n_sims = 13 +out = pr_best(data, sig2_rule, sig2_rule_ops, key, n_sims) +out +``` + +## Phase III Final Analysis + +\begin{align*} +\mathbb{P}(\theta_i - \theta_0 < t | y, n) < 0.1 +\end{align*} + +\begin{align*} +\mathbb{P}(\theta_i - \theta_0 < t | y, n) +&= +\mathbb{P}(q_1^\top \theta < t | y,n) += +\int \mathbb{P}(q_1^\top \theta < t | y, n, \sigma^2) p(\sigma^2 | y, n) \, d\sigma^2 +\\&= +\int \mathbb{P}(q_1^\top \theta < t | y, n, \sigma^2) p(\sigma^2 | y, n) \, d\sigma^2 +\\ +q_1^\top \theta | y, n, \sigma^2 &\sim \mathcal{N}(q_1^\top \theta^*, -q_1^\top (H\log p(\theta^*, y, \sigma^2))^{-1} q_1) +\end{align*} + +```python +posterior_difference_threshold = 0.2 +rejection_threshold = 0.1 +``` + +```python +def posterior_difference(data, arm, sig2_rule, sig2_rule_ops, thresh): + n_arms, _ = data.shape + post, x_max, hess, _ = posterior_sigma_sq(data, sig2_rule, sig2_rule_ops) + hess_fn = jax.vmap(lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1])) + prec = -hess_fn(hess) # (n_sigs, n_arms, n_arms) + order = jnp.arange(0, n_arms) + q1 = jnp.where(order == 0, -1, 0) + q1 = jnp.where(order == arm, 1, q1) + loc = x_max @ q1 + scale = jnp.linalg.solve(prec, q1[None,:]) @ q1 + normal_term = jax.scipy.stats.norm.cdf(thresh, loc=loc, scale=scale) + post_weighted = sig2_rule.wts * post + out = normal_term @ post_weighted + return out + +posterior_difference(data, 1, sig2_rule, sig2_rule_ops, posterior_difference_threshold) +``` + +## Posterior Probability of Success + +The next quantity we need to compute is the posterior probability of success (PPS). +For convenience of implementation, we will take this to mean the following: +let $y, n$ denote the currently observed data +and $A_i = \{ \text{Phase III rejects using treatment arm i} \}$. +Then, we wish to compute +\begin{align*} +\mathbb{P}(A_i | y, n) +\end{align*} +Expanding the quantity, +\begin{align*} +\mathbb{P}(A_i | y, n) &= +\int \mathbb{P}(A_i | y, n, \theta_i, \theta_0) p(\theta_0, \theta_i | y, n) \, d\theta_i d\theta_0 \\&= +\int \mathbb{P}(A_i | y, n, \theta_i, \theta_0) p(\theta_0, \theta_i | y, n) \, d\theta_i d\theta_0 +\end{align*} + +Once we have an estimate for $p(\theta_0, \theta_i | y, n)$, +we can use 2-D quadrature to numerically integrate the integrand. +Similar to computing the probability of best arm, +\begin{align*} +p(\theta_0, \theta_i | y, n) +&= +\int p(\theta_0, \theta_i | y, n, \sigma^2) p(\sigma^2 | y, n) \, d\sigma^2 +\end{align*} +We will use the Gaussian approximation for $p(\theta_0, \theta_i | y, n, \sigma^2)$ +and use INLA to estimate $p(\sigma^2 | y, n)$. + +\begin{align*} +p(\theta | y, n, \sigma^2) +\approx +\mathcal{N}(\theta^*, -(H\log p(\theta^*, y, \sigma^2))^{-1}) +\\ +\implies +p(\theta_0, \theta_i | y, n, \sigma^2) +\approx +\mathcal{N}(\theta^*_{[0,i]}, -(H\log p(\theta^*, y, \sigma^2))^{-1}_{[0,i], [0,i]}) +\end{align*} + +```python +# input parameters +n_Ai_sims = 1000 +p = jnp.full(n_arms, 0.5) +n_stage_2 = 100 +pps_threshold_lower = 0.1 +pps_threshold_upper = 0.9 +posterior_difference_threshold = 0.1 +rejection_threshold = 0.1 + +subset = jnp.array([0, 1]) +non_futile_idx = np.zeros(n_arms) +non_futile_idx[subset] = 1 +non_futile_idx = jnp.array(non_futile_idx) + +# create a dense grid of sig2 values +n_sig2 = 100 +sig2_grid = 10**jnp.linspace(-6, 3, n_sig2) +dsig2_grid = jnp.diff(sig2_grid) +sig2_grid_ops = berry.optimized(sig2_grid, n_arms=data.shape[0]).config( + opt_tol=1e-3 +) + +_, key = jax.random.split(key) +``` + +```python +def pr_Ai( + data, p, key, best_arm, non_futile_idx, + sig2_rule, sig2_rule_ops, + sig2_grid, sig2_grid_ops, dsig2_grid, +): + n_arms, _ = data.shape + + # compute p(sig2 | y, n), mode, hessian + p_pinned = dict(sig2=sig2_grid, theta=None) + logpost, x_max, hess, _ = jax.jit(sig2_grid_ops.laplace_logpost)( + np.zeros((len(sig2_grid), n_arms)), p_pinned, data + ) + max_logpost = jnp.max(logpost) + max_post = jnp.exp(max_logpost) + post = jnp.exp(logpost - max_logpost) * max_post + + # sample sigma^2 | y, n + dFx = post[:-1] * dsig2_grid + Fx = jnp.cumsum(dFx) + Fx /= Fx[-1] + _, key = jax.random.split(key) + unifs = jax.random.uniform(key=key, shape=(n_Ai_sims,)) + i_star = jnp.searchsorted(Fx, unifs) + + # sample theta | y, n, sigma^2 + mean = x_max[i_star+1] + hess_fn = jax.vmap( + lambda h: jnp.diag(h[0]) + jnp.full(shape=(n_arms, n_arms), fill_value=h[1]) + ) + prec = -hess_fn(hess) + cov = jnp.linalg.inv(prec)[i_star+1] + _, key = jax.random.split(key) + theta = jax.random.multivariate_normal( + key=key, mean=mean, cov=cov, + ) + p_samples = jax.scipy.special.expit(theta) + + # estimate P(A_i | y, n, theta_0, theta_i) + + def simulate_Ai(data, best_arm, diff_thresh, rej_thresh, non_futile_idx, key, p): + # add n_stage_2 number of patients to each + # of the control and selected treatment arms. + n_new = jnp.where(non_futile_idx, n_stage_2, 0) + y_new = dist.Binomial(total_count=n_new, probs=p).sample(key) + + # pool outcomes for each arm + data = data + jnp.stack((y_new, n_new), axis=-1) + + return posterior_difference(data, best_arm, sig2_rule, sig2_rule_ops, diff_thresh) < rej_thresh + + simulate_Ai_vmapped = jax.vmap( + simulate_Ai, in_axes=(None, None, None, None, None, 0, 0) + ) + keys = jax.random.split(key, num=p_samples.shape[0]) + Ai_indicators = simulate_Ai_vmapped( + data, + best_arm, + posterior_difference_threshold, + rejection_threshold, + non_futile_idx, + keys, + p_samples, + ) + out = jnp.mean(Ai_indicators) + return out + +``` + +```python +%%time +jax.jit(lambda data, p, key, best_arm, non_futile_idx: + pr_Ai(data, p, key, best_arm, non_futile_idx, + sig2_rule, sig2_rule_ops, sig2_grid, sig2_grid_ops, dsig2_grid), + static_argnums=(3,))(data, p, key, 1, non_futile_idx) +``` + +```python +# Sampling based on pdf values and linearly interpolating +n_sims = 1000 +n_unifs = 1000 +key = jax.random.PRNGKey(2) + +#x = jnp.linspace(-3, 3, num=n_sims) +#px_orig = 0.5*jax.scipy.stats.norm.pdf(x, -1, 0.5) + 0.5*jax.scipy.stats.norm.pdf(x, 1, 0.5) + +#x = jnp.linspace(0, 10, num=n_sims) +#px_orig = jax.scipy.stats.gamma.pdf(x, 10) + +x = jnp.linspace(0, 1, num=n_sims) +px_orig = jax.scipy.stats.beta.pdf(x, 4, 2) + +px = 2 * px_orig +dx = jnp.diff(x) +dFx = px[:-1] * dx +Fx = jnp.cumsum(dFx) +Fx /= Fx[-1] +_, key = jax.random.split(key) +unifs = jax.random.uniform(key=key, shape=(n_unifs,)) +i_star = jnp.searchsorted(Fx, unifs) + +# point mass approx +#samples = x[i_star+1] + +# constant approx +#samples = x[i_star+1] - (Fx[i_star] - unifs) / px[i_star] + +# linear approx +a = 0.5 * (px[i_star+1] - px[i_star]) / dx[i_star] +b = px[i_star] +c = Fx[i_star] - unifs - px[i_star] * dx[i_star] - a * dx[i_star]**2 +discr = jnp.sqrt(jnp.maximum(b**2 - 4*a*c, 0)) +quad_solve = jnp.where(jnp.abs(a) < 1e-8, -c/b, (-b + discr) / (2*a)) +samples = x[i_star] + quad_solve + +#plt.plot(x[1:], Fx) +#plt.plot(x[1:], jax.scipy.stats.norm.cdf(x[1:])) +plt.hist(x[i_star+1], density=True, alpha=0.5) +plt.hist(samples, density = True, alpha=0.5) +plt.plot(x, px_orig) +plt.show() +``` + +## Design Implementation + +```python +%%time +params = { + "n_arms" : 4, + "n_stage_1" : 50, + "n_stage_2" : 100, + "n_stage_1_interims" : 2, + "n_stage_1_add_per_interim" : 100, + "n_stage_2_add_per_interim" : 100, + "stage_1_futility_threshold" : 0.15, + "stage_1_efficacy_threshold" : 0.7, + "stage_2_futility_threshold" : 0.2, + "stage_2_efficacy_threshold" : 0.95, + "inter_stage_futility_threshold" : 0.6, + "posterior_difference_threshold" : 0, + "rejection_threshold" : 0.05, + "key" : jax.random.PRNGKey(0), + "n_pr_sims" : 100, + "n_sig2_sims" : 20, + "batch_size" : int(2**20), + "cache_tables" : False, +} +lei_obj = lewis.Lewis45(**params) +``` + +```python +batch_size = 2**12 +key = jax.random.PRNGKey(0) +n_points = 20 +n_pr_sims = 100 +n_sig2_sim = 20 +``` + +```python +%%time +lei_obj.pd_table = lei_obj.posterior_difference_table__( + batch_size=batch_size, + n_points=n_points, +) +lei_obj.pd_table +``` + +```python +%%time +lei_obj.pr_best_pps_1_table = lei_obj.pr_best_pps_1_table__( + key, + n_pr_sims, + batch_size=batch_size, + n_points=n_points, +) +lei_obj.pr_best_pps_1_table +``` + +```python +%%time +_, key = jax.random.split(key) +lei_obj.pps_2_table = lei_obj.pps_2_table__( + key, + n_pr_sims, + batch_size=batch_size, + n_points=n_points, +) +lei_obj.pps_2_table +``` + +```python +n_arms = params['n_arms'] +size = 52 +lower = np.full(n_arms, -1) +upper = np.full(n_arms, 1) +thetas, radii = lewgrid.make_cartesian_grid_range( + size=size, + lower=lower, + upper=upper, +) +ns = np.concatenate( + [np.ones(n_arms-1)[:, None], -np.eye(n_arms-1)], + axis=-1, +) +null_hypos = [ + grid.HyperPlane(n, 0) + for n in ns +] +gr = grid.build_grid( + thetas=thetas, + radii=radii, + null_hypos=null_hypos, +) +gr = grid.prune(gr) +``` + +```python +theta_tiles = gr.thetas[gr.grid_pt_idx] +null_truths = gr.null_truth.astype(bool) +grid_batch_size = int(2**12) +n_sim_batches = 500 +sim_batch_size = 50 + +p_tiles = jax.scipy.special.expit(theta_tiles) +``` + +```python +class LeiSimulator: + def __init__( + self, + lei_obj, + p_tiles, + null_truths, + grid_batch_size, + reduce_func=None, + ): + self.lei_obj = lei_obj + self.unifs_shape = self.lei_obj.unifs_shape() + self.unifs_order = np.arange(0, self.unifs_shape[0]) + self.p_tiles = p_tiles + self.null_truths = null_truths + self.grid_batch_size = grid_batch_size + + self.reduce_func = ( + lambda x: np.sum(x, axis=0) if not reduce_func else reduce_func + ) + + self.f_batch_sim_batch_grid_jit = jax.jit(self.f_batch_sim_batch_grid) + self.batch_all = batch.batch_all( + self.f_batch_sim_batch_grid_jit, + batch_size=self.grid_batch_size, + in_axes=(0, 0, None, None), + ) + + self.typeI_sum = None + self.typeI_score = None + + def f_batch_sim_batch_grid(self, p_batch, null_batch, unifs_batch, unifs_order): + return jax.vmap( + jax.vmap( + self.lei_obj.simulate, + in_axes=(0, 0, None, None), + ), + in_axes=(None, None, 0, None), + )(p_batch, null_batch, unifs_batch, unifs_order) + + def simulate_batch_sim(self, sim_batch_size, i, key): + start = time.perf_counter() + + unifs = jax.random.uniform(key=key, shape=(sim_batch_size,) + self.unifs_shape) + rejs_scores, n_padded = self.batch_all( + self.p_tiles, self.null_truths, unifs, self.unifs_order + ) + rejs, scores = tuple( + np.concatenate( + tuple(x[i] for x in rejs_scores), + axis=1, + ) + for i in range(2) + ) + rejs, scores = ( + (rejs[:, :-n_padded], scores[:, :-n_padded, :]) + if n_padded + else (rejs, scores) + ) + rejs_reduced = self.reduce_func(rejs) + scores_reduced = self.reduce_func(scores) + + end = time.perf_counter() + elapsed_time = (end-start) + print(f"Batch {i}: {elapsed_time:.03f}s") + return rejs_reduced, scores_reduced + + def simulate( + self, + key, + n_sim_batches, + sim_batch_size, + ): + keys = jax.random.split(key, num=n_sim_batches) + self.typeI_sum = np.zeros(self.p_tiles.shape[0]) + self.typeI_score = np.zeros(self.p_tiles.shape) + for i, key in enumerate(keys): + out = self.simulate_batch_sim(sim_batch_size, i, key) + self.typeI_sum += out[0] + self.typeI_score += out[1] + return self.typeI_sum, self.typeI_score + + +simulator = LeiSimulator( + lei_obj=lei_obj, + p_tiles=p_tiles, + null_truths=null_truths, + grid_batch_size=grid_batch_size, +) + +``` + +```python +%%time +key = jax.random.PRNGKey(3) +typeI_sum, typeI_score = simulator.simulate( + key=key, + n_sim_batches=n_sim_batches, + sim_batch_size=sim_batch_size, +) +``` + +```python +os.makedirs("output_lei4d2", exist_ok=True) +np.savetxt("output_lei4d2/typeI_sum.csv", typeI_sum, fmt="%s", delimiter=",") +np.savetxt("output_lei4d2/typeI_score.csv", typeI_score, fmt="%s", delimiter=",") +``` + +```python +sim_size = sim_batch_size * n_sim_batches + +plt.figure(figsize=(8,4), constrained_layout=True) +for i, t2_idx in enumerate([4, 8]): + t2 = np.unique(theta_tiles[:, 2])[t2_idx] + selection = (theta_tiles[:,2] == t2) + + plt.subplot(1,2,i+1) + plt.title(f'slice: $\\theta_2 \\approx$ {t2:.1f}') + plt.scatter(theta_tiles[selection,0], theta_tiles[selection,1], c=typeI_sum[selection]/sim_size, s=90) + cbar = plt.colorbar() + plt.xlabel(r'$\theta_0$') + plt.ylabel(r'$\theta_1$') + cbar.set_label('Simulated fraction of Type I errors') +plt.show() +``` + +```python +tile_radii = gr.radii[gr.grid_pt_idx] +sim_sizes = np.full(gr.n_tiles, sim_size) +n_arm_samples = ( + params['n_stage_1'] + + params['n_stage_1_add_per_interim'] // 2 * params['n_stage_1_interims'] + + params['n_stage_2'] + + params['n_stage_2_add_per_interim'] // 2 +) +total, d0, d0u, d1w, d1uw, d2uw = binomial.upper_bound( + theta_tiles, + tile_radii, + gr.vertices, + sim_sizes, + n_arm_samples, + typeI_sum, + typeI_score, +) +bound_components = np.array([ + d0, + d0u, + d1w, + d1uw, + d2uw, + total, +]).T +``` + +```python +t2_uniques = np.unique(theta_tiles[:, 2]) +t3_uniques = np.unique(theta_tiles[:, 3]) +t2 = t2_uniques[8] +t3 = t3_uniques[8] +selection = (theta_tiles[:, 2] == t2) & (theta_tiles[:, 3] == t3) + +np.savetxt('output_lei4d/P_lei.csv', theta_tiles[selection, :].T, fmt="%s", delimiter=",") +np.savetxt('output_lei4d/B_lei.csv', bound_components[selection, :], fmt="%s", delimiter=",") +``` + +```python +t2_uniques[8], t3_uniques[8] +``` + +# Sandbox diff --git a/research/stat/poisson_process.ipynb b/research/stat/poisson_process.ipynb new file mode 100644 index 00000000..81e0c0fd --- /dev/null +++ b/research/stat/poisson_process.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Poisson Process Fun Time" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax.numpy as jnp\n", + "import jax\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [], + "source": [ + "def h(p):\n", + " return p\n", + "\n", + "def method_1(lam, n_sims, t, seed):\n", + " key = jax.random.PRNGKey(seed)\n", + "\n", + " Ns = jax.random.poisson(key=key, lam=lam, shape=(n_sims,))\n", + " max_Ns = jnp.max(Ns)\n", + " order = jnp.arange(0, max_Ns)\n", + " \n", + " def stat(N, key): \n", + " p = jax.random.uniform(key=key, shape=(max_Ns,))\n", + " p_sub = jnp.where(order < N, p, jnp.nan)\n", + " return jnp.sum(h(p_sub) * (p_sub < t))\n", + " \n", + " keys = jax.random.split(key, num=n_sims)\n", + " \n", + " stat_vmapped = jax.vmap(stat, in_axes=(0,0))\n", + " stat_vmapped_jit = jax.jit(stat_vmapped)\n", + " out = stat_vmapped_jit(Ns, keys)\n", + " return out\n", + "\n", + "def method_2(lam, n_sims, t, seed, n_begin=10):\n", + " key = jax.random.PRNGKey(seed)\n", + "\n", + " # sample Exp(lam) until the running sum is >= 1, then take everything before that point.\n", + " # If X_1,..., X_n ~ Exp(lam) and T_i = sum_{j=1}^i X_j,\n", + " # then (T_1,..., T_{n-1}) | T_n = t ~ (U_{(1)}, ..., U_{(n-1)}) where each U_i ~ Unif(0, t)\n", + " # \n", + " # Sampling procedure:\n", + " # - Increase n until T_n >= 1\n", + " # - Sample (T_1,..., T_{n-1}) | T_n via formula above.\n", + " # - Sum over h(T_i) 1{T_i < t}\n", + "\n", + " def find_n_T_n(n_begin, key):\n", + " n = n_begin\n", + " T = 0\n", + " def body_fun(tup, key):\n", + " n, _ = tup\n", + " n = n + n_begin\n", + " _, key = jax.random.split(key)\n", + " return (n, jax.random.gamma(key=key, a=n) / lam)\n", + " out = jax.lax.while_loop(\n", + " lambda tup: tup[1] < 1, \n", + " lambda tup: body_fun(tup, key), \n", + " (n, T))\n", + " return jnp.array(out)\n", + "\n", + " keys = jax.random.split(key, num=n_sims)\n", + " NT = jax.jit(jax.vmap(find_n_T_n, in_axes=(None, 0)))(n_begin, keys)\n", + " \n", + " N_max = int(jnp.max(NT[:,0]))\n", + " order = jnp.arange(0, N_max)\n", + " def stat(nt, key):\n", + " n, t_n = nt\n", + " unifs = jax.random.uniform(key=key, shape=(N_max,))\n", + " unifs = jnp.where(order < (n-1), unifs, jnp.inf)\n", + " unifs_sorted = jnp.sort(unifs)\n", + " Ts = t_n * unifs_sorted\n", + " return jnp.sum(h(Ts) * (Ts < t))\n", + "\n", + " stat_vmapped = jax.vmap(stat, in_axes=(0,0))\n", + " stat_vmapped_jit = jax.jit(stat_vmapped)\n", + "\n", + " keys = jax.random.split(keys[-1], num=n_sims)\n", + " return stat_vmapped_jit(NT, keys)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 158, + "metadata": {}, + "outputs": [], + "source": [ + "lam = 100\n", + "n_sims = 100000\n", + "t = 0.2\n", + "seed = 69" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": {}, + "outputs": [], + "source": [ + "out_1 = method_1(lam=lam, n_sims=n_sims, t=t, seed=seed)\n", + "out_2 = method_2(lam=lam, n_sims=n_sims, t=t, seed=seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(out_1, bins=50, alpha=0.5, label='method_1')\n", + "plt.hist(out_2, bins=50, alpha=0.5, label='method_2')\n", + "plt.legend()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.5 ('confirm')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d8e1ca1b3fede25e3995e2b26ea544fa1b75b9a17984e6284a43c1dc286640dd" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/research/stat/poisson_process.md b/research/stat/poisson_process.md new file mode 100644 index 00000000..223b8793 --- /dev/null +++ b/research/stat/poisson_process.md @@ -0,0 +1,111 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.13.8 + kernelspec: + display_name: Python 3.10.5 ('confirm') + language: python + name: python3 +--- + +# Poisson Process Fun Time + +```python +import numpy as np +import jax.numpy as jnp +import jax +import matplotlib.pyplot as plt +``` + +```python +def h(p): + return p + +def method_1(lam, n_sims, t, seed): + key = jax.random.PRNGKey(seed) + + Ns = jax.random.poisson(key=key, lam=lam, shape=(n_sims,)) + max_Ns = jnp.max(Ns) + order = jnp.arange(0, max_Ns) + + def stat(N, key): + p = jax.random.uniform(key=key, shape=(max_Ns,)) + p_sub = jnp.where(order < N, p, jnp.nan) + return jnp.sum(h(p_sub) * (p_sub < t)) + + keys = jax.random.split(key, num=n_sims) + + stat_vmapped = jax.vmap(stat, in_axes=(0,0)) + stat_vmapped_jit = jax.jit(stat_vmapped) + out = stat_vmapped_jit(Ns, keys) + return out + +def method_2(lam, n_sims, t, seed, n_begin=10): + key = jax.random.PRNGKey(seed) + + # sample Exp(lam) until the running sum is >= 1, then take everything before that point. + # If X_1,..., X_n ~ Exp(lam) and T_i = sum_{j=1}^i X_j, + # then (T_1,..., T_{n-1}) | T_n = t ~ (U_{(1)}, ..., U_{(n-1)}) where each U_i ~ Unif(0, t) + # + # Sampling procedure: + # - Increase n until T_n >= 1 + # - Sample (T_1,..., T_{n-1}) | T_n via formula above. + # - Sum over h(T_i) 1{T_i < t} + + def find_n_T_n(n_begin, key): + n = n_begin + T = 0 + def body_fun(tup, key): + n, _ = tup + n = n + n_begin + _, key = jax.random.split(key) + return (n, jax.random.gamma(key=key, a=n) / lam) + out = jax.lax.while_loop( + lambda tup: tup[1] < 1, + lambda tup: body_fun(tup, key), + (n, T)) + return jnp.array(out) + + keys = jax.random.split(key, num=n_sims) + NT = jax.jit(jax.vmap(find_n_T_n, in_axes=(None, 0)))(n_begin, keys) + + N_max = int(jnp.max(NT[:,0])) + order = jnp.arange(0, N_max) + def stat(nt, key): + n, t_n = nt + unifs = jax.random.uniform(key=key, shape=(N_max,)) + unifs = jnp.where(order < (n-1), unifs, jnp.inf) + unifs_sorted = jnp.sort(unifs) + Ts = t_n * unifs_sorted + return jnp.sum(h(Ts) * (Ts < t)) + + stat_vmapped = jax.vmap(stat, in_axes=(0,0)) + stat_vmapped_jit = jax.jit(stat_vmapped) + + keys = jax.random.split(keys[-1], num=n_sims) + return stat_vmapped_jit(NT, keys) + +``` + +```python +lam = 100 +n_sims = 100000 +t = 0.2 +seed = 69 +``` + +```python +out_1 = method_1(lam=lam, n_sims=n_sims, t=t, seed=seed) +out_2 = method_2(lam=lam, n_sims=n_sims, t=t, seed=seed) +``` + +```python +plt.hist(out_1, bins=50, alpha=0.5, label='method_1') +plt.hist(out_2, bins=50, alpha=0.5, label='method_2') +plt.legend() +plt.show() +```