Skip to content
This repository has been archived by the owner on Oct 26, 2024. It is now read-only.

Commit

Permalink
Lewis model (#58)
Browse files Browse the repository at this point in the history
* Add skeleton lei stuff

* Add lei notebook for now and poisson process fun study

* Add stuff for ben to see

* Add current changes for Ben once again (so needy :P)

* Fixed the hang problem.

* Commit what I have

* Add currently broken code once again

* Simplify notebook

* Add working version of lewis

* Add current progress

* Add batching method and current logic of lei

* Move lei stuff to its own package

* Working on linear interpolation for the Lei problem.

* inlaw -> outlaw.

* Add lei test n_config

* Fix settings.json

* Add unit tests and update notebook with correct simulations

* Update test, add simulation tests, add point batcher, share RNG

* Update comment

* JAX implementation of scipy.interpolate.interpn (#47)

* JAX Interpolation.

* JAX implementation of scipy.interpolate.interpn

* Update todo list.

* Add current version lol

* Fix bugs and integrate good version

* Fix small bug in stage 2 and clean up code

* Modify interpn to work with multi-dimensional values

* Add current version of notebook

* WTF

* Finish final lei

* Fix test in outlaw

* Add python notebook (weird vscode lol)

* Add lei simulator batching method

* Remove unnecessary files cluttering up space

* Add current state

* Add upper bound logic to lei example

* Add ignore to frontend and update lei flow

* Clean up lewis code and include some of Ben's changes

* Add new script

* Add new changes to make memory ok

* Add full changes to everything except key

* Add checkpointing

* Add modified version

* First pass at holder-odi bound in binomial.py

* Holder-ODI, feeling more confident.

* Add analyze lei example

* Move lewis into confirm

* Fix analyze notebook with new import structure

* Add np.isnan check for holder bound and update lei analyze scripts

* Moving files, small tweaks.

* Pre-commit fixes.

* Most tests passing.

* Fix test stage1.

Co-authored-by: James Yang <[email protected]>
  • Loading branch information
tbenthompson and JamesYang007 authored Oct 17, 2022
1 parent b294fc9 commit 5881260
Show file tree
Hide file tree
Showing 29 changed files with 4,763 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
"justMyCode": false,
}
]
}
249 changes: 124 additions & 125 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,127 +1,126 @@
{
"bazel-cpp-tools.compileCommands.targets": [
"//...",
],
"jupyter.jupyterServerType": "local",
"files.associations": {
"functional": "cpp",
"*.evaluator": "cpp",
"*.traits": "cpp",
"fft": "cpp",
"openglsupport": "cpp",
"regex": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"any": "cpp",
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"cinttypes": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"codecvt": "cpp",
"complex": "cpp",
"condition_variable": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"forward_list": "cpp",
"list": "cpp",
"map": "cpp",
"set": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"optional": "cpp",
"random": "cpp",
"ratio": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"utility": "cpp",
"hash_map": "cpp",
"fstream": "cpp",
"future": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"ostream": "cpp",
"shared_mutex": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"thread": "cpp",
"typeinfo": "cpp",
"valarray": "cpp",
"variant": "cpp",
"filesystem": "cpp",
"locale": "cpp",
"mprealsupport": "cpp",
"nonlinearoptimization": "cpp",
"dense": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__nullptr": "cpp",
"__split_buffer": "cpp",
"__string": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__tuple": "cpp",
"compare": "cpp",
"concepts": "cpp",
"ios": "cpp",
"queue": "cpp",
"stack": "cpp",
"__functional_base": "cpp",
"alignedvector3": "cpp",
"typeindex": "cpp",
"*.ipp": "cpp",
"*.inc": "cpp",
"core": "cpp",
"geometry": "cpp",
"qtalignedmalloc": "cpp",
"matrixfunctions": "cpp",
"bvh": "cpp"
},
"C_Cpp.errorSquiggles": "Enabled",
"C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}",
"cmake.configureOnOpen": false,
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
"python.formatting.provider": "black",
"autoDocstring.docstringFormat": "google-notypes",
"r.bracketedPaste": true,
"r.plot.useHttpgd": true
"bazel-cpp-tools.compileCommands.targets": ["//..."],
"jupyter.jupyterServerType": "local",
"files.associations": {
"functional": "cpp",
"*.evaluator": "cpp",
"*.traits": "cpp",
"fft": "cpp",
"openglsupport": "cpp",
"regex": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"any": "cpp",
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"cinttypes": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"codecvt": "cpp",
"complex": "cpp",
"condition_variable": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"forward_list": "cpp",
"list": "cpp",
"map": "cpp",
"set": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"optional": "cpp",
"random": "cpp",
"ratio": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"utility": "cpp",
"hash_map": "cpp",
"fstream": "cpp",
"future": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"mutex": "cpp",
"new": "cpp",
"ostream": "cpp",
"shared_mutex": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"thread": "cpp",
"typeinfo": "cpp",
"valarray": "cpp",
"variant": "cpp",
"filesystem": "cpp",
"locale": "cpp",
"mprealsupport": "cpp",
"nonlinearoptimization": "cpp",
"dense": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__nullptr": "cpp",
"__split_buffer": "cpp",
"__string": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__tuple": "cpp",
"compare": "cpp",
"concepts": "cpp",
"ios": "cpp",
"queue": "cpp",
"stack": "cpp",
"__functional_base": "cpp",
"alignedvector3": "cpp",
"typeindex": "cpp",
"*.ipp": "cpp",
"*.inc": "cpp",
"core": "cpp",
"geometry": "cpp",
"qtalignedmalloc": "cpp",
"matrixfunctions": "cpp",
"bvh": "cpp"
},
"C_Cpp.errorSquiggles": "Enabled",
"C_Cpp.clang_format_fallbackStyle": "{ BasedOnStyle: LLVM, UseTab: Never, IndentWidth: 4, TabWidth: 4, AllowShortIfStatementsOnASingleLine: false, IndentCaseLabels: false, ColumnLimit: 100, AccessModifierOffset: -4, NamespaceIndentation: All, FixNamespaceComments: false, PointerAlignment: Left}",
"cmake.configureOnOpen": false,
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.linting.pylintEnabled": false,
"python.linting.flake8Enabled": true,
"python.linting.enabled": true,
"python.formatting.provider": "black",
"autoDocstring.docstringFormat": "google-notypes",
"r.bracketedPaste": true,
"r.plot.useHttpgd": true,
"python.analysis.extraPaths": ["./outlaw", "./imprint/python"]
}
Empty file.
84 changes: 84 additions & 0 deletions confirm/confirm/lewislib/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np


def pad_arg__(a, axis, n_pad: int):
pad_element = np.take(a, indices=0, axis=axis)
pad_element = np.expand_dims(pad_element, axis=axis)
new_shape = tuple(a.shape[i] if i != axis else n_pad for i in range(a.ndim))
return np.concatenate((a, np.full(new_shape, pad_element)), axis=axis)


def create_batched_args__(args, in_axes, start, end, n_pad=None):
def arg_transform(arg, axis):
return pad_arg__(arg, axis, n_pad) if n_pad is not None else arg

return [
arg_transform(
np.take(arg, indices=range(start, end), axis=axis),
axis,
)
if axis is not None
else arg
for arg, axis in zip(args, in_axes)
]


def batch(f, batch_size: int, in_axes):
def internal(*args):
dims = np.array(
[arg.shape[axis] for arg, axis in zip(args, in_axes) if axis is not None]
)
if len(dims) <= 0:
raise ValueError(
"f must take at least one argument "
"whose corresponding in_axes is not None."
)

dims_all_equal = np.sum(dims != dims[0]) == 0
if not dims_all_equal:
raise ValueError(
"All batched arguments must have the same dimension "
"along their corresopnding in_axes."
)

dim = dims[0]
batch_size_new = min(batch_size, dim)
n_full_batches = dim // batch_size_new
remainder = dim % batch_size_new
n_pad = batch_size_new - remainder
pad_last = remainder > 0
start = 0
end = batch_size_new

for _ in range(n_full_batches):
batched_args = create_batched_args__(
args=args,
in_axes=in_axes,
start=start,
end=end,
)
yield (f(*batched_args), 0)
start += batch_size_new
end += batch_size_new

if pad_last:
batched_args = create_batched_args__(
args=args,
in_axes=in_axes,
start=start,
end=dim,
n_pad=n_pad,
)
yield (f(*batched_args), n_pad)

return internal


def batch_all(f, batch_size: int, in_axes):
f_batch = batch(f, batch_size, in_axes)

def internal(*args):
outs = tuple(out for out in f_batch(*args))
return tuple(out[0] for out in outs), outs[-1][-1]

return internal
23 changes: 23 additions & 0 deletions confirm/confirm/lewislib/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import pyimprint.grid as pygrid


def make_cartesian_grid_range(size, lower, upper):
assert lower.shape[0] == upper.shape[0]

# make initial 1d grid
center_grids = (
pygrid.Gridder.make_grid(size, lower[i], upper[i]) for i in range(len(lower))
)

# make a grid of centers
coords = np.meshgrid(*center_grids)
centers = np.concatenate([c.flatten().reshape(-1, 1) for c in coords], axis=1)

# make corresponding radius
radius = np.array(
[pygrid.Gridder.radius(size, lower[i], upper[i]) for i in range(len(lower))]
)
radii = np.full(shape=centers.shape, fill_value=radius)

return centers, radii
Loading

0 comments on commit 5881260

Please sign in to comment.