This repository has been archived by the owner on Oct 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add skeleton lei stuff * Add lei notebook for now and poisson process fun study * Add stuff for ben to see * Add current changes for Ben once again (so needy :P) * Fixed the hang problem. * Commit what I have * Add currently broken code once again * Simplify notebook * Add working version of lewis * Add current progress * Add batching method and current logic of lei * Move lei stuff to its own package * Working on linear interpolation for the Lei problem. * inlaw -> outlaw. * Add lei test n_config * Fix settings.json * Add unit tests and update notebook with correct simulations * Update test, add simulation tests, add point batcher, share RNG * Update comment * JAX implementation of scipy.interpolate.interpn (#47) * JAX Interpolation. * JAX implementation of scipy.interpolate.interpn * Update todo list. * Add current version lol * Fix bugs and integrate good version * Fix small bug in stage 2 and clean up code * Modify interpn to work with multi-dimensional values * Add current version of notebook * WTF * Finish final lei * Fix test in outlaw * Add python notebook (weird vscode lol) * Add lei simulator batching method * Remove unnecessary files cluttering up space * Add current state * Add upper bound logic to lei example * Add ignore to frontend and update lei flow * Clean up lewis code and include some of Ben's changes * Add new script * Add new changes to make memory ok * Add full changes to everything except key * Add checkpointing * Add modified version * First pass at holder-odi bound in binomial.py * Holder-ODI, feeling more confident. * Add analyze lei example * Move lewis into confirm * Fix analyze notebook with new import structure * Add np.isnan check for holder bound and update lei analyze scripts * Moving files, small tweaks. * Pre-commit fixes. * Most tests passing. * Fix test stage1. Co-authored-by: James Yang <[email protected]>
- Loading branch information
1 parent
b294fc9
commit 5881260
Showing
29 changed files
with
4,763 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,126 @@ | ||
{ | ||
"bazel-cpp-tools.compileCommands.targets": [ | ||
"//...", | ||
], | ||
"jupyter.jupyterServerType": "local", | ||
"files.associations": { | ||
"functional": "cpp", | ||
"*.evaluator": "cpp", | ||
"*.traits": "cpp", | ||
"fft": "cpp", | ||
"openglsupport": "cpp", | ||
"regex": "cpp", | ||
"tuple": "cpp", | ||
"type_traits": "cpp", | ||
"any": "cpp", | ||
"array": "cpp", | ||
"atomic": "cpp", | ||
"bit": "cpp", | ||
"*.tcc": "cpp", | ||
"bitset": "cpp", | ||
"cctype": "cpp", | ||
"chrono": "cpp", | ||
"cinttypes": "cpp", | ||
"clocale": "cpp", | ||
"cmath": "cpp", | ||
"codecvt": "cpp", | ||
"complex": "cpp", | ||
"condition_variable": "cpp", | ||
"cstdarg": "cpp", | ||
"cstddef": "cpp", | ||
"cstdint": "cpp", | ||
"cstdio": "cpp", | ||
"cstdlib": "cpp", | ||
"cstring": "cpp", | ||
"ctime": "cpp", | ||
"cwchar": "cpp", | ||
"cwctype": "cpp", | ||
"deque": "cpp", | ||
"forward_list": "cpp", | ||
"list": "cpp", | ||
"map": "cpp", | ||
"set": "cpp", | ||
"unordered_map": "cpp", | ||
"unordered_set": "cpp", | ||
"vector": "cpp", | ||
"exception": "cpp", | ||
"algorithm": "cpp", | ||
"iterator": "cpp", | ||
"memory": "cpp", | ||
"memory_resource": "cpp", | ||
"numeric": "cpp", | ||
"optional": "cpp", | ||
"random": "cpp", | ||
"ratio": "cpp", | ||
"string": "cpp", | ||
"string_view": "cpp", | ||
"system_error": "cpp", | ||
"utility": "cpp", | ||
"hash_map": "cpp", | ||
"fstream": "cpp", | ||
"future": "cpp", | ||
"initializer_list": "cpp", | ||
"iomanip": "cpp", | ||
"iosfwd": "cpp", | ||
"iostream": "cpp", | ||
"istream": "cpp", | ||
"limits": "cpp", | ||
"mutex": "cpp", | ||
"new": "cpp", | ||
"ostream": "cpp", | ||
"shared_mutex": "cpp", | ||
"sstream": "cpp", | ||
"stdexcept": "cpp", | ||
"streambuf": "cpp", | ||
"thread": "cpp", | ||
"typeinfo": "cpp", | ||
"valarray": "cpp", | ||
"variant": "cpp", | ||
"filesystem": "cpp", | ||
"locale": "cpp", | ||
"mprealsupport": "cpp", | ||
"nonlinearoptimization": "cpp", | ||
"dense": "cpp", | ||
"__bit_reference": "cpp", | ||
"__bits": "cpp", | ||
"__config": "cpp", | ||
"__debug": "cpp", | ||
"__errc": "cpp", | ||
"__hash_table": "cpp", | ||
"__locale": "cpp", | ||
"__mutex_base": "cpp", | ||
"__node_handle": "cpp", | ||
"__nullptr": "cpp", | ||
"__split_buffer": "cpp", | ||
"__string": "cpp", | ||
"__threading_support": "cpp", | ||
"__tree": "cpp", | ||
"__tuple": "cpp", | ||
"compare": "cpp", | ||
"concepts": "cpp", | ||
"ios": "cpp", | ||
"queue": "cpp", | ||
"stack": "cpp", | ||
"__functional_base": "cpp", | ||
"alignedvector3": "cpp", | ||
"typeindex": "cpp", | ||
"*.ipp": "cpp", | ||
"*.inc": "cpp", | ||
"core": "cpp", | ||
"geometry": "cpp", | ||
"qtalignedmalloc": "cpp", | ||
"matrixfunctions": "cpp", | ||
"bvh": "cpp" | ||
}, | ||
"C_Cpp.errorSquiggles": "Enabled", | ||
"C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}", | ||
"cmake.configureOnOpen": false, | ||
"python.testing.unittestEnabled": false, | ||
"python.testing.pytestEnabled": true, | ||
"python.linting.pylintEnabled": false, | ||
"python.linting.flake8Enabled": true, | ||
"python.linting.enabled": true, | ||
"python.formatting.provider": "black", | ||
"autoDocstring.docstringFormat": "google-notypes", | ||
"r.bracketedPaste": true, | ||
"r.plot.useHttpgd": true | ||
"bazel-cpp-tools.compileCommands.targets": ["//..."], | ||
"jupyter.jupyterServerType": "local", | ||
"files.associations": { | ||
"functional": "cpp", | ||
"*.evaluator": "cpp", | ||
"*.traits": "cpp", | ||
"fft": "cpp", | ||
"openglsupport": "cpp", | ||
"regex": "cpp", | ||
"tuple": "cpp", | ||
"type_traits": "cpp", | ||
"any": "cpp", | ||
"array": "cpp", | ||
"atomic": "cpp", | ||
"bit": "cpp", | ||
"*.tcc": "cpp", | ||
"bitset": "cpp", | ||
"cctype": "cpp", | ||
"chrono": "cpp", | ||
"cinttypes": "cpp", | ||
"clocale": "cpp", | ||
"cmath": "cpp", | ||
"codecvt": "cpp", | ||
"complex": "cpp", | ||
"condition_variable": "cpp", | ||
"cstdarg": "cpp", | ||
"cstddef": "cpp", | ||
"cstdint": "cpp", | ||
"cstdio": "cpp", | ||
"cstdlib": "cpp", | ||
"cstring": "cpp", | ||
"ctime": "cpp", | ||
"cwchar": "cpp", | ||
"cwctype": "cpp", | ||
"deque": "cpp", | ||
"forward_list": "cpp", | ||
"list": "cpp", | ||
"map": "cpp", | ||
"set": "cpp", | ||
"unordered_map": "cpp", | ||
"unordered_set": "cpp", | ||
"vector": "cpp", | ||
"exception": "cpp", | ||
"algorithm": "cpp", | ||
"iterator": "cpp", | ||
"memory": "cpp", | ||
"memory_resource": "cpp", | ||
"numeric": "cpp", | ||
"optional": "cpp", | ||
"random": "cpp", | ||
"ratio": "cpp", | ||
"string": "cpp", | ||
"string_view": "cpp", | ||
"system_error": "cpp", | ||
"utility": "cpp", | ||
"hash_map": "cpp", | ||
"fstream": "cpp", | ||
"future": "cpp", | ||
"initializer_list": "cpp", | ||
"iomanip": "cpp", | ||
"iosfwd": "cpp", | ||
"iostream": "cpp", | ||
"istream": "cpp", | ||
"limits": "cpp", | ||
"mutex": "cpp", | ||
"new": "cpp", | ||
"ostream": "cpp", | ||
"shared_mutex": "cpp", | ||
"sstream": "cpp", | ||
"stdexcept": "cpp", | ||
"streambuf": "cpp", | ||
"thread": "cpp", | ||
"typeinfo": "cpp", | ||
"valarray": "cpp", | ||
"variant": "cpp", | ||
"filesystem": "cpp", | ||
"locale": "cpp", | ||
"mprealsupport": "cpp", | ||
"nonlinearoptimization": "cpp", | ||
"dense": "cpp", | ||
"__bit_reference": "cpp", | ||
"__bits": "cpp", | ||
"__config": "cpp", | ||
"__debug": "cpp", | ||
"__errc": "cpp", | ||
"__hash_table": "cpp", | ||
"__locale": "cpp", | ||
"__mutex_base": "cpp", | ||
"__node_handle": "cpp", | ||
"__nullptr": "cpp", | ||
"__split_buffer": "cpp", | ||
"__string": "cpp", | ||
"__threading_support": "cpp", | ||
"__tree": "cpp", | ||
"__tuple": "cpp", | ||
"compare": "cpp", | ||
"concepts": "cpp", | ||
"ios": "cpp", | ||
"queue": "cpp", | ||
"stack": "cpp", | ||
"__functional_base": "cpp", | ||
"alignedvector3": "cpp", | ||
"typeindex": "cpp", | ||
"*.ipp": "cpp", | ||
"*.inc": "cpp", | ||
"core": "cpp", | ||
"geometry": "cpp", | ||
"qtalignedmalloc": "cpp", | ||
"matrixfunctions": "cpp", | ||
"bvh": "cpp" | ||
}, | ||
"C_Cpp.errorSquiggles": "Enabled", | ||
"C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}", | ||
"cmake.configureOnOpen": false, | ||
"python.testing.unittestEnabled": false, | ||
"python.testing.pytestEnabled": true, | ||
"python.linting.pylintEnabled": false, | ||
"python.linting.flake8Enabled": true, | ||
"python.linting.enabled": true, | ||
"python.formatting.provider": "black", | ||
"autoDocstring.docstringFormat": "google-notypes", | ||
"r.bracketedPaste": true, | ||
"r.plot.useHttpgd": true, | ||
"python.analysis.extraPaths": ["./outlaw", "./imprint/python"] | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
|
||
|
||
def pad_arg__(a, axis, n_pad: int): | ||
pad_element = np.take(a, indices=0, axis=axis) | ||
pad_element = np.expand_dims(pad_element, axis=axis) | ||
new_shape = tuple(a.shape[i] if i != axis else n_pad for i in range(a.ndim)) | ||
return np.concatenate((a, np.full(new_shape, pad_element)), axis=axis) | ||
|
||
|
||
def create_batched_args__(args, in_axes, start, end, n_pad=None): | ||
def arg_transform(arg, axis): | ||
return pad_arg__(arg, axis, n_pad) if n_pad is not None else arg | ||
|
||
return [ | ||
arg_transform( | ||
np.take(arg, indices=range(start, end), axis=axis), | ||
axis, | ||
) | ||
if axis is not None | ||
else arg | ||
for arg, axis in zip(args, in_axes) | ||
] | ||
|
||
|
||
def batch(f, batch_size: int, in_axes): | ||
def internal(*args): | ||
dims = np.array( | ||
[arg.shape[axis] for arg, axis in zip(args, in_axes) if axis is not None] | ||
) | ||
if len(dims) <= 0: | ||
raise ValueError( | ||
"f must take at least one argument " | ||
"whose corresponding in_axes is not None." | ||
) | ||
|
||
dims_all_equal = np.sum(dims != dims[0]) == 0 | ||
if not dims_all_equal: | ||
raise ValueError( | ||
"All batched arguments must have the same dimension " | ||
"along their corresopnding in_axes." | ||
) | ||
|
||
dim = dims[0] | ||
batch_size_new = min(batch_size, dim) | ||
n_full_batches = dim // batch_size_new | ||
remainder = dim % batch_size_new | ||
n_pad = batch_size_new - remainder | ||
pad_last = remainder > 0 | ||
start = 0 | ||
end = batch_size_new | ||
|
||
for _ in range(n_full_batches): | ||
batched_args = create_batched_args__( | ||
args=args, | ||
in_axes=in_axes, | ||
start=start, | ||
end=end, | ||
) | ||
yield (f(*batched_args), 0) | ||
start += batch_size_new | ||
end += batch_size_new | ||
|
||
if pad_last: | ||
batched_args = create_batched_args__( | ||
args=args, | ||
in_axes=in_axes, | ||
start=start, | ||
end=dim, | ||
n_pad=n_pad, | ||
) | ||
yield (f(*batched_args), n_pad) | ||
|
||
return internal | ||
|
||
|
||
def batch_all(f, batch_size: int, in_axes): | ||
f_batch = batch(f, batch_size, in_axes) | ||
|
||
def internal(*args): | ||
outs = tuple(out for out in f_batch(*args)) | ||
return tuple(out[0] for out in outs), outs[-1][-1] | ||
|
||
return internal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import numpy as np | ||
import pyimprint.grid as pygrid | ||
|
||
|
||
def make_cartesian_grid_range(size, lower, upper): | ||
assert lower.shape[0] == upper.shape[0] | ||
|
||
# make initial 1d grid | ||
center_grids = ( | ||
pygrid.Gridder.make_grid(size, lower[i], upper[i]) for i in range(len(lower)) | ||
) | ||
|
||
# make a grid of centers | ||
coords = np.meshgrid(*center_grids) | ||
centers = np.concatenate([c.flatten().reshape(-1, 1) for c in coords], axis=1) | ||
|
||
# make corresponding radius | ||
radius = np.array( | ||
[pygrid.Gridder.radius(size, lower[i], upper[i]) for i in range(len(lower))] | ||
) | ||
radii = np.full(shape=centers.shape, fill_value=radius) | ||
|
||
return centers, radii |
Oops, something went wrong.