From 807fdc6f1026bfe244c2826ca22646938fa22c52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 21 Nov 2023 13:14:19 +0100 Subject: [PATCH 1/4] Merged montecarlo and hutchinson file --- Makefile | 1 + README.md | 2 +- docs/api/montecarlo.md | 2 - docs/benchmarks/control_variates.py | 4 +- docs/benchmarks/jacobian_squared.py | 4 +- docs/control_variates.md | 2 +- docs/higher_moments.md | 6 +- docs/index.md | 2 +- docs/log_determinants.md | 6 +- docs/pytree_logdeterminants.md | 2 +- docs/vector_calculus.md | 2 +- matfree/hutchinson.py | 213 ++++++++++++++++-- matfree/montecarlo.py | 188 ---------------- matfree/slq.py | 6 +- mkdocs.yml | 2 +- tests/__init__.py | 1 - tests/test_decomp/__init__.py | 1 - tests/test_hutchinson/__init__.py | 1 - tests/test_hutchinson/test_diagonal.py | 4 +- .../test_estimate.py | 6 +- .../test_frobeniusnorm_squared.py | 4 +- .../test_multiestimate.py | 7 +- tests/test_hutchinson/test_trace.py | 4 +- .../test_trace_and_diagonal.py | 4 +- tests/test_hutchinson/test_trace_moments.py | 6 +- .../test_van_der_corput.py | 6 +- tests/test_lanczos/__init__.py | 1 - tests/test_montecarlo/__init__.py | 1 - tests/test_pinv/__init__.py | 1 - tests/test_slq/__init__.py | 1 - tests/test_slq/test_logdet_product.py | 4 +- tests/test_slq/test_logdet_spd.py | 4 +- tests/test_slq/test_logdet_spd_autodiff.py | 4 +- tests/test_slq/test_schatten_norm.py | 4 +- 34 files changed, 246 insertions(+), 260 deletions(-) delete mode 100644 docs/api/montecarlo.md delete mode 100644 matfree/montecarlo.py delete mode 100644 tests/__init__.py delete mode 100644 tests/test_decomp/__init__.py delete mode 100644 tests/test_hutchinson/__init__.py rename tests/{test_montecarlo => test_hutchinson}/test_estimate.py (81%) rename tests/{test_montecarlo => test_hutchinson}/test_multiestimate.py (84%) rename tests/{test_montecarlo => test_hutchinson}/test_van_der_corput.py (74%) delete mode 100644 tests/test_lanczos/__init__.py delete mode 100644 tests/test_montecarlo/__init__.py delete mode 100644 tests/test_pinv/__init__.py delete mode 100644 tests/test_slq/__init__.py diff --git a/Makefile b/Makefile index 914f81b..3048c0b 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,7 @@ clean: rm -rf *.egg-info rm -rf dist site build htmlcov rm -rf *.ipynb_checkpoints + rm matfree/_version.py doc: mkdocs build diff --git a/README.md b/README.md index c786a44..8d1781f 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ Estimate the trace of the matrix: ```python >>> key = jax.random.PRNGKey(1) ->>> normal = montecarlo.normal(shape=(2,)) +>>> normal = hutchinson.normal(shape=(2,)) >>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal) >>> >>> print(jnp.round(trace)) diff --git a/docs/api/montecarlo.md b/docs/api/montecarlo.md deleted file mode 100644 index de77dcf..0000000 --- a/docs/api/montecarlo.md +++ /dev/null @@ -1,2 +0,0 @@ -# Monte-Carlo estimation -:::matfree.montecarlo diff --git a/docs/benchmarks/control_variates.py b/docs/benchmarks/control_variates.py index e8ad1cc..0d3eeb6 100644 --- a/docs/benchmarks/control_variates.py +++ b/docs/benchmarks/control_variates.py @@ -3,7 +3,7 @@ Runtime: ~10 seconds. """ -from matfree import benchmark_util, hutchinson, montecarlo +from matfree import benchmark_util, hutchinson from matfree.backend import func, linalg, np, plt, prng, progressbar @@ -22,7 +22,7 @@ def f(x): _, jvp = func.linearize(f, x0) J = func.jacfwd(f)(x0) trace = linalg.trace(J) - sample_fun = montecarlo.normal(shape=(n,), dtype=float) + sample_fun = hutchinson.normal(shape=(n,), dtype=float) return (jvp, trace, J), (key, sample_fun) diff --git a/docs/benchmarks/jacobian_squared.py b/docs/benchmarks/jacobian_squared.py index 7ae2223..5a253ea 100644 --- a/docs/benchmarks/jacobian_squared.py +++ b/docs/benchmarks/jacobian_squared.py @@ -1,5 +1,5 @@ """What is the fastest way of computing trace(A^5).""" -from matfree import benchmark_util, hutchinson, montecarlo, slq +from matfree import benchmark_util, hutchinson, slq from matfree.backend import func, linalg, np, plt, prng from matfree.backend.progressbar import progressbar @@ -20,7 +20,7 @@ def f(x): J = func.jacfwd(f)(x0) A = J @ J @ J @ J trace = linalg.trace(A) - sample_fun = montecarlo.normal(shape=(n,), dtype=float) + sample_fun = hutchinson.normal(shape=(n,), dtype=float) def Av(v): return jvp(jvp(jvp(jvp(v)))) diff --git a/docs/control_variates.md b/docs/control_variates.md index 095632a..ebbfa5a 100644 --- a/docs/control_variates.md +++ b/docs/control_variates.md @@ -14,7 +14,7 @@ Imports: >>> key = jax.random.PRNGKey(1) >>> matvec = lambda x: a.T @ (a @ x) ->>> sample_fun = montecarlo.normal(shape=(2,)) +>>> sample_fun = hutchinson.normal(shape=(2,)) ``` diff --git a/docs/higher_moments.md b/docs/higher_moments.md index 0ee4f00..b94d4ff 100644 --- a/docs/higher_moments.md +++ b/docs/higher_moments.md @@ -9,7 +9,7 @@ >>> key = jax.random.PRNGKey(1) >>> mvp = lambda x: a.T @ (a @ x) ->>> sample_fun = montecarlo.normal(shape=(2,)) +>>> sample_fun = hutchinson.normal(shape=(2,)) ``` @@ -21,7 +21,7 @@ Compute them as such ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 ->>> normal = montecarlo.normal(shape=(6,)) +>>> normal = hutchinson.normal(shape=(6,)) >>> mvp = lambda x: a.T @ (a @ x) + x >>> first, second = hutchinson.trace_moments(mvp, key=key, sample_fun=normal) >>> print(jnp.round(first, 1)) @@ -53,7 +53,7 @@ Implement this as follows: ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 ->>> sample_fun = montecarlo.normal(shape=(6,)) +>>> sample_fun = hutchinson.normal(shape=(6,)) >>> num_samples = 10_000 >>> mvp = lambda x: a.T @ (a @ x) + x >>> first, second = hutchinson.trace_moments( diff --git a/docs/index.md b/docs/index.md index c786a44..8d1781f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -65,7 +65,7 @@ Estimate the trace of the matrix: ```python >>> key = jax.random.PRNGKey(1) ->>> normal = montecarlo.normal(shape=(2,)) +>>> normal = hutchinson.normal(shape=(2,)) >>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal) >>> >>> print(jnp.round(trace)) diff --git a/docs/log_determinants.md b/docs/log_determinants.md index b234442..c07e858 100644 --- a/docs/log_determinants.md +++ b/docs/log_determinants.md @@ -12,7 +12,7 @@ Imports: >>> key = jax.random.PRNGKey(1) >>> matvec = lambda x: a.T @ (a @ x) ->>> sample_fun = montecarlo.normal(shape=(2,)) +>>> sample_fun = hutchinson.normal(shape=(2,)) ``` @@ -20,7 +20,7 @@ Imports: Estimate log-determinants as such: ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 ->>> sample_fun = montecarlo.normal(shape=(6,)) +>>> sample_fun = hutchinson.normal(shape=(6,)) >>> matvec = lambda x: a.T @ (a @ x) + x >>> order = 3 >>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun) @@ -37,7 +37,7 @@ on arithmetic with $B$; no need to assemble $M$: ```python >>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 + jnp.eye(6) ->>> sample_fun = montecarlo.normal(shape=(6,)) +>>> sample_fun = hutchinson.normal(shape=(6,)) >>> matvec = lambda x: (a @ x) >>> vecmat = lambda x: (a.T @ x) >>> order = 3 diff --git a/docs/pytree_logdeterminants.md b/docs/pytree_logdeterminants.md index f8787ef..f74233d 100644 --- a/docs/pytree_logdeterminants.md +++ b/docs/pytree_logdeterminants.md @@ -70,7 +70,7 @@ Now, we can compute the log-determinant with the flattened inputs as usual: ```python >>> # Compute the log-determinant >>> key = jax.random.PRNGKey(seed=1) ->>> sample_fun = montecarlo.normal(shape=f0_flat.shape) +>>> sample_fun = hutchinson.normal(shape=f0_flat.shape) >>> order = 3 >>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun) diff --git a/docs/vector_calculus.md b/docs/vector_calculus.md index 238af12..021bf40 100644 --- a/docs/vector_calculus.md +++ b/docs/vector_calculus.md @@ -85,7 +85,7 @@ For large-scale problems, it may be the only way of computing Laplacians reliabl ```python >>> laplacian_dense = divergence_dense(gradient) >>> ->>> normal = montecarlo.normal(shape=(3,)) +>>> normal = hutchinson.normal(shape=(3,)) >>> key = jax.random.PRNGKey(1) >>> laplacian_matfree = divergence_matfree(gradient, key=key, sample_fun=normal) >>> diff --git a/matfree/hutchinson.py b/matfree/hutchinson.py index fcb8182..0be0efc 100644 --- a/matfree/hutchinson.py +++ b/matfree/hutchinson.py @@ -1,10 +1,193 @@ """Hutchinson-style trace and diagonal estimation.""" -from matfree import montecarlo from matfree.backend import containers, control_flow, func, linalg, np, prng from matfree.backend.typing import Any, Array, Callable, Sequence +# todo: allow a fun() that returns pytrees instead of arrays. +# why? Because then we rival trace_and_variance as +# trace_and_frobeniusnorm(): y=Ax; return (x@y, y@y) + + +def estimate( + fun: Callable, + /, + *, + key: Array, + sample_fun: Callable, + num_batches: int = 1, + num_samples_per_batch: int = 10_000, + statistic_batch: Callable = np.mean, + statistic_combine: Callable = np.mean, +) -> Array: + """Monte-Carlo estimation: Compute the expected value of a function. + + Parameters + ---------- + fun: + Function whose expected value shall be estimated. + key: + Pseudo-random number generator key. + sample_fun: + Sampling function. + For trace-estimation, use + either [normal(...)][matfree.normal] + or [rademacher(...)][matfree.normal]. + num_batches: + Number of batches when computing arithmetic means. + num_samples_per_batch: + Number of samples per batch. + statistic_batch: + The summary statistic to compute on batch-level. + Usually, this is np.mean. But any other + statistical function with a signature like + [one of these functions](https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html) + would work. + statistic_combine: + The summary statistic to combine batch-results. + Usually, this is np.mean. But any other + statistical function with a signature like + [one of these functions](https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html) + would work. + """ + [result] = multiestimate( + fun, + key=key, + sample_fun=sample_fun, + num_batches=num_batches, + num_samples_per_batch=num_samples_per_batch, + statistics_batch=[statistic_batch], + statistics_combine=[statistic_combine], + ) + return result + + +def multiestimate( + fun: Callable, + /, + *, + key: Array, + sample_fun: Callable, + num_batches: int = 1, + num_samples_per_batch: int = 10_000, + statistics_batch: Sequence[Callable] = (np.mean,), + statistics_combine: Sequence[Callable] = (np.mean,), +) -> Array: + """Compute a Monte-Carlo estimate with multiple summary statistics. + + The signature of this function is almost identical to + [estimate(...)][matfree.estimate]. + The only difference is that statistics_batch and statistics_combine are iterables + of summary statistics (of equal lengths). + + The result of this function is an iterable of matching length. + + Parameters + ---------- + fun: + Same as in [estimate(...)][matfree.estimate]. + key: + Same as in [estimate(...)][matfree.estimate]. + sample_fun: + Same as in [estimate(...)][matfree.estimate]. + num_batches: + Same as in [estimate(...)][matfree.estimate]. + num_samples_per_batch: + Same as in [estimate(...)][matfree.estimate]. + statistics_batch: + List or tuple of summary statistics to compute on batch-level. + statistics_combine: + List or tuple of summary statistics to combine batches. + + """ + assert len(statistics_batch) == len(statistics_combine) + fun_mc = _montecarlo(fun, sample_fun=sample_fun, num_stats=len(statistics_batch)) + fun_single_batch = _stats_via_vmap(fun_mc, num_samples_per_batch, statistics_batch) + fun_batched = _stats_via_map(fun_single_batch, num_batches, statistics_combine) + return fun_batched(key) + + +def _montecarlo(f, /, sample_fun, num_stats): + """Turn a function into a Monte-Carlo problem. + + More specifically, f(x) becomes g(key) = f(h(key)), + using a sample function h: key -> x. + This can then be evaluated and averaged in batches, loops, and compositions thereof. + """ + # todo: what about randomised QMC? How do we best implement this? + + def f_mc(key, /): + sample = sample_fun(key) + return [f(sample)] * num_stats + + return f_mc + + +def _stats_via_vmap(f, num, /, statistics: Sequence[Callable]): + """Compute summary statistics via jax.vmap.""" + + def f_mean(key, /): + subkeys = prng.split(key, num) + fx_values = func.vmap(f)(subkeys) + return [stat(fx, axis=0) for stat, fx in zip(statistics, fx_values)] + + return f_mean + + +def _stats_via_map(f, num, /, statistics: Sequence[Callable]): + """Compute summary statistics via jax.lax.map.""" + + def f_mean(key, /): + subkeys = prng.split(key, num) + fx_values = control_flow.array_map(f, subkeys) + return [stat(fx, axis=0) for stat, fx in zip(statistics, fx_values)] + + return f_mean + + +def normal(*, shape, dtype=float): + """Construct a function that samples from a standard normal distribution.""" + + def fun(key): + return prng.normal(key, shape=shape, dtype=dtype) + + return fun + + +def rademacher(*, shape, dtype=float): + """Construct a function that samples from a Rademacher distribution.""" + + def fun(key): + return prng.rademacher(key, shape=shape, dtype=dtype) + + return fun + + +class _VDCState(containers.NamedTuple): + n: int + vdc: float + denom: int + + +def van_der_corput(n, /, base=2): + """Compute the 'n'th element of the Van-der-Corput sequence.""" + state = _VDCState(n, vdc=0, denom=1) + + vdc_modify = func.partial(_van_der_corput_modify, base=base) + state = control_flow.while_loop(_van_der_corput_cond, vdc_modify, state) + return state.vdc + + +def _van_der_corput_cond(state: _VDCState): + return state.n > 0 + + +def _van_der_corput_modify(state: _VDCState, *, base): + denom = state.denom * base + num, remainder = divmod(state.n, base) + vdc = state.vdc + remainder / denom + return _VDCState(num, vdc, denom) + def trace(Av: Callable, /, **kwargs) -> Array: """Estimate the trace of a matrix stochastically. @@ -15,13 +198,13 @@ def trace(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [montecarlo.estimate()][matfree.montecarlo.estimate]. + [estimate()][matfree.estimate]. """ def quadform(vec): return linalg.vecdot(vec, Av(vec)) - return montecarlo.estimate(quadform, **kwargs) + return estimate(quadform, **kwargs) def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) -> Array: @@ -36,7 +219,7 @@ def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) -> the first and second moment. **kwargs: Keyword-arguments to be passed to - [montecarlo.multiestimate(...)][matfree.montecarlo.multiestimate]. + [multiestimate(...)][matfree.multiestimate]. """ def quadform(vec): @@ -47,7 +230,7 @@ def moment(x, axis, *, power): statistics_batch = [func.partial(moment, power=m) for m in moments] statistics_combine = [np.mean] * len(moments) - return montecarlo.multiestimate( + return multiestimate( quadform, statistics_batch=statistics_batch, statistics_combine=statistics_combine, @@ -72,7 +255,7 @@ def frobeniusnorm_squared(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [montecarlo.estimate()][matfree.montecarlo.estimate]. + [estimate()][matfree.estimate]. """ @@ -80,7 +263,7 @@ def quadform(vec): x = Av(vec) return linalg.vecdot(x, x) - return montecarlo.estimate(quadform, **kwargs) + return estimate(quadform, **kwargs) def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) -> Array: @@ -95,7 +278,7 @@ def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) -> This should be the best-possible estimate of the diagonal of the matrix. **kwargs: Keyword-arguments to be passed to - [montecarlo.estimate()][matfree.montecarlo.estimate]. + [estimate()][matfree.estimate]. """ return diagonal(lambda v: Av(v) - control * v, **kwargs) + control @@ -110,14 +293,14 @@ def diagonal(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [montecarlo.estimate()][matfree.montecarlo.estimate]. + [estimate()][matfree.estimate]. """ def quadform(vec): return vec * Av(vec) - return montecarlo.estimate(quadform, **kwargs) + return estimate(quadform, **kwargs) def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **kwargs): @@ -135,13 +318,13 @@ def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **k Matrix-vector product function. sample_fun: Sampling function. - Usually, either [montecarlo.normal][matfree.montecarlo.normal] - or [montecarlo.rademacher][matfree.montecarlo.normal]. + Usually, either [normal][matfree.normal] + or [rademacher][matfree.normal]. key: Pseudo-random number generator key. **kwargs: Keyword-arguments to be passed to - [diagonal_multilevel()][matfree.hutchinson.diagonal_multilevel]. + [diagonal_multilevel()][matfree.diagonal_multilevel]. See: @@ -185,8 +368,8 @@ def diagonal_multilevel( Pseudo-random number generator key. sample_fun: Sampling function. - Usually, either [montecarlo.normal][matfree.montecarlo.normal] - or [montecarlo.rademacher][matfree.montecarlo.normal]. + Usually, either [normal][matfree.normal] + or [rademacher][matfree.normal]. num_levels: Number of levels. num_batches_per_level: diff --git a/matfree/montecarlo.py b/matfree/montecarlo.py deleted file mode 100644 index 3131a90..0000000 --- a/matfree/montecarlo.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Monte-Carlo estimation.""" - -from matfree.backend import containers, control_flow, func, np, prng -from matfree.backend.typing import Array, Callable, Sequence - -# todo: allow a fun() that returns pytrees instead of arrays. -# why? Because then we rival trace_and_variance as -# trace_and_frobeniusnorm(): y=Ax; return (x@y, y@y) - - -def estimate( - fun: Callable, - /, - *, - key: Array, - sample_fun: Callable, - num_batches: int = 1, - num_samples_per_batch: int = 10_000, - statistic_batch: Callable = np.mean, - statistic_combine: Callable = np.mean, -) -> Array: - """Monte-Carlo estimation: Compute the expected value of a function. - - Parameters - ---------- - fun: - Function whose expected value shall be estimated. - key: - Pseudo-random number generator key. - sample_fun: - Sampling function. - For trace-estimation, use - either [montecarlo.normal(...)][matfree.montecarlo.normal] - or [montecarlo.rademacher(...)][matfree.montecarlo.normal]. - num_batches: - Number of batches when computing arithmetic means. - num_samples_per_batch: - Number of samples per batch. - statistic_batch: - The summary statistic to compute on batch-level. - Usually, this is np.mean. But any other - statistical function with a signature like - [one of these functions](https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html) - would work. - statistic_combine: - The summary statistic to combine batch-results. - Usually, this is np.mean. But any other - statistical function with a signature like - [one of these functions](https://data-apis.org/array-api/2022.12/API_specification/statistical_functions.html) - would work. - """ - [result] = multiestimate( - fun, - key=key, - sample_fun=sample_fun, - num_batches=num_batches, - num_samples_per_batch=num_samples_per_batch, - statistics_batch=[statistic_batch], - statistics_combine=[statistic_combine], - ) - return result - - -def multiestimate( - fun: Callable, - /, - *, - key: Array, - sample_fun: Callable, - num_batches: int = 1, - num_samples_per_batch: int = 10_000, - statistics_batch: Sequence[Callable] = (np.mean,), - statistics_combine: Sequence[Callable] = (np.mean,), -) -> Array: - """Compute a Monte-Carlo estimate with multiple summary statistics. - - The signature of this function is almost identical to - [montecarlo.estimate(...)][matfree.montecarlo.estimate]. - The only difference is that statistics_batch and statistics_combine are iterables - of summary statistics (of equal lengths). - - The result of this function is an iterable of matching length. - - Parameters - ---------- - fun: - Same as in [montecarlo.estimate(...)][matfree.montecarlo.estimate]. - key: - Same as in [montecarlo.estimate(...)][matfree.montecarlo.estimate]. - sample_fun: - Same as in [montecarlo.estimate(...)][matfree.montecarlo.estimate]. - num_batches: - Same as in [montecarlo.estimate(...)][matfree.montecarlo.estimate]. - num_samples_per_batch: - Same as in [montecarlo.estimate(...)][matfree.montecarlo.estimate]. - statistics_batch: - List or tuple of summary statistics to compute on batch-level. - statistics_combine: - List or tuple of summary statistics to combine batches. - - """ - assert len(statistics_batch) == len(statistics_combine) - fun_mc = _montecarlo(fun, sample_fun=sample_fun, num_stats=len(statistics_batch)) - fun_single_batch = _stats_via_vmap(fun_mc, num_samples_per_batch, statistics_batch) - fun_batched = _stats_via_map(fun_single_batch, num_batches, statistics_combine) - return fun_batched(key) - - -def _montecarlo(f, /, sample_fun, num_stats): - """Turn a function into a Monte-Carlo problem. - - More specifically, f(x) becomes g(key) = f(h(key)), - using a sample function h: key -> x. - This can then be evaluated and averaged in batches, loops, and compositions thereof. - """ - # todo: what about randomised QMC? How do we best implement this? - - def f_mc(key, /): - sample = sample_fun(key) - return [f(sample)] * num_stats - - return f_mc - - -def _stats_via_vmap(f, num, /, statistics: Sequence[Callable]): - """Compute summary statistics via jax.vmap.""" - - def f_mean(key, /): - subkeys = prng.split(key, num) - fx_values = func.vmap(f)(subkeys) - return [stat(fx, axis=0) for stat, fx in zip(statistics, fx_values)] - - return f_mean - - -def _stats_via_map(f, num, /, statistics: Sequence[Callable]): - """Compute summary statistics via jax.lax.map.""" - - def f_mean(key, /): - subkeys = prng.split(key, num) - fx_values = control_flow.array_map(f, subkeys) - return [stat(fx, axis=0) for stat, fx in zip(statistics, fx_values)] - - return f_mean - - -def normal(*, shape, dtype=float): - """Construct a function that samples from a standard normal distribution.""" - - def fun(key): - return prng.normal(key, shape=shape, dtype=dtype) - - return fun - - -def rademacher(*, shape, dtype=float): - """Construct a function that samples from a Rademacher distribution.""" - - def fun(key): - return prng.rademacher(key, shape=shape, dtype=dtype) - - return fun - - -class _VDCState(containers.NamedTuple): - n: int - vdc: float - denom: int - - -def van_der_corput(n, /, base=2): - """Compute the 'n'th element of the Van-der-Corput sequence.""" - state = _VDCState(n, vdc=0, denom=1) - - vdc_modify = func.partial(_van_der_corput_modify, base=base) - state = control_flow.while_loop(_van_der_corput_cond, vdc_modify, state) - return state.vdc - - -def _van_der_corput_cond(state: _VDCState): - return state.n > 0 - - -def _van_der_corput_modify(state: _VDCState, *, base): - denom = state.denom * base - num, remainder = divmod(state.n, base) - vdc = state.vdc + remainder / denom - return _VDCState(num, vdc, denom) diff --git a/matfree/slq.py b/matfree/slq.py index f820a1d..ee2d1f2 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -1,6 +1,6 @@ """Stochastic Lanczos quadrature.""" -from matfree import decomp, lanczos, montecarlo +from matfree import decomp, hutchinson, lanczos from matfree.backend import func, linalg, np @@ -12,7 +12,7 @@ def logdet_spd(*args, **kwargs): def trace_of_matfun_spd(matfun, order, Av, /, **kwargs): """Compute the trace of the function of a symmetric matrix.""" quadratic_form = _quadratic_form_slq_spd(matfun, order, Av) - return montecarlo.estimate(quadratic_form, **kwargs) + return hutchinson.estimate(quadratic_form, **kwargs) def _quadratic_form_slq_spd(matfun, order, Av, /): @@ -72,7 +72,7 @@ def trace_of_matfun_product(matfun, order, *matvec_funs, matrix_shape, **kwargs) quadratic_form = _quadratic_form_slq_product( matfun, order, *matvec_funs, matrix_shape=matrix_shape ) - return montecarlo.estimate(quadratic_form, **kwargs) + return hutchinson.estimate(quadratic_form, **kwargs) def _quadratic_form_slq_product(matfun, depth, *matvec_funs, matrix_shape): diff --git a/mkdocs.yml b/mkdocs.yml index 68b147e..9cab81f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,7 +11,7 @@ nav: - vector_calculus.md - pytree_logdeterminants.md - API documentation: - - matfree.montecarlo: api/montecarlo.md + - matfree.montecarlo: api/hutchinson.md - matfree.decomp: api/decomp.md - matfree.lanczos: api/lanczos.md - matfree.hutchinson: api/hutchinson.md diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 2786ce9..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Test configuration.""" diff --git a/tests/test_decomp/__init__.py b/tests/test_decomp/__init__.py deleted file mode 100644 index 5edfa49..0000000 --- a/tests/test_decomp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Decomposition tests.""" diff --git a/tests/test_hutchinson/__init__.py b/tests/test_hutchinson/__init__.py deleted file mode 100644 index 60558fa..0000000 --- a/tests/test_hutchinson/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Trace estimator tests.""" diff --git a/tests/test_hutchinson/test_diagonal.py b/tests/test_hutchinson/test_diagonal.py index 0759fa8..5b39b04 100644 --- a/tests/test_hutchinson/test_diagonal.py +++ b/tests/test_hutchinson/test_diagonal.py @@ -1,6 +1,6 @@ """Tests for basic trace estimators.""" -from matfree import hutchinson, montecarlo +from matfree import hutchinson from matfree.backend import func, linalg, np, prng, testing @@ -23,7 +23,7 @@ def fixture_key(): @testing.parametrize("num_batches", [1_000]) @testing.parametrize("num_samples_per_batch", [1_000]) @testing.parametrize("dim", [1, 10]) -@testing.parametrize("sample_fun", [montecarlo.normal, montecarlo.rademacher]) +@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) def test_diagonal(fun, key, num_batches, num_samples_per_batch, dim, sample_fun): """Assert that the estimated diagonal approximates the true diagonal accurately.""" # Linearise function diff --git a/tests/test_montecarlo/test_estimate.py b/tests/test_hutchinson/test_estimate.py similarity index 81% rename from tests/test_montecarlo/test_estimate.py rename to tests/test_hutchinson/test_estimate.py index 6d2dc67..fe5c915 100644 --- a/tests/test_montecarlo/test_estimate.py +++ b/tests/test_hutchinson/test_estimate.py @@ -1,6 +1,6 @@ """Tests for Monte-Carlo machinery.""" -from matfree import montecarlo +from matfree import hutchinson from matfree.backend import np, prng, testing @@ -12,11 +12,11 @@ def test_mean(key, num_batches, num_samples): def fun(x): return x**2 - received = montecarlo.estimate( + received = hutchinson.estimate( fun, num_batches=num_batches, num_samples_per_batch=num_samples, key=key, - sample_fun=montecarlo.normal(shape=()), + sample_fun=hutchinson.normal(shape=()), ) assert np.allclose(received, 1.0, rtol=1e-1) diff --git a/tests/test_hutchinson/test_frobeniusnorm_squared.py b/tests/test_hutchinson/test_frobeniusnorm_squared.py index cb9d17b..626abe6 100644 --- a/tests/test_hutchinson/test_frobeniusnorm_squared.py +++ b/tests/test_hutchinson/test_frobeniusnorm_squared.py @@ -1,6 +1,6 @@ """Tests for basic trace estimators.""" -from matfree import hutchinson, montecarlo +from matfree import hutchinson from matfree.backend import func, linalg, np, prng, testing @@ -23,7 +23,7 @@ def fixture_key(): @testing.parametrize("num_batches", [1_000]) @testing.parametrize("num_samples_per_batch", [1_000]) @testing.parametrize("dim", [1, 10]) -@testing.parametrize("sample_fun", [montecarlo.normal, montecarlo.rademacher]) +@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) def test_frobeniusnorm_squared( fun, key, num_batches, num_samples_per_batch, dim, sample_fun ): diff --git a/tests/test_montecarlo/test_multiestimate.py b/tests/test_hutchinson/test_multiestimate.py similarity index 84% rename from tests/test_montecarlo/test_multiestimate.py rename to tests/test_hutchinson/test_multiestimate.py index 769facd..b04609f 100644 --- a/tests/test_montecarlo/test_multiestimate.py +++ b/tests/test_hutchinson/test_multiestimate.py @@ -1,6 +1,5 @@ """Tests for Monte-Carlo machinery.""" - -from matfree import montecarlo +from matfree import hutchinson from matfree.backend import np, prng, testing @@ -12,12 +11,12 @@ def test_mean_and_max(key, num_batches, num_samples): def fun(x): return x**2 - mean, amax = montecarlo.multiestimate( + mean, amax = hutchinson.multiestimate( fun, num_batches=num_batches, num_samples_per_batch=num_samples, key=key, - sample_fun=montecarlo.normal(shape=()), + sample_fun=hutchinson.normal(shape=()), statistics_batch=[np.mean, np.array_max], statistics_combine=[np.mean, np.array_max], ) diff --git a/tests/test_hutchinson/test_trace.py b/tests/test_hutchinson/test_trace.py index 690a638..867cbc7 100644 --- a/tests/test_hutchinson/test_trace.py +++ b/tests/test_hutchinson/test_trace.py @@ -1,6 +1,6 @@ """Tests for basic trace estimators.""" -from matfree import hutchinson, montecarlo +from matfree import hutchinson from matfree.backend import func, linalg, np, prng, testing @@ -23,7 +23,7 @@ def fixture_key(): @testing.parametrize("num_batches", [1_000]) @testing.parametrize("num_samples_per_batch", [1_000]) @testing.parametrize("dim", [1, 10]) -@testing.parametrize("sample_fun", [montecarlo.normal, montecarlo.rademacher]) +@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) def test_trace(fun, key, num_batches, num_samples_per_batch, dim, sample_fun): """Assert that the estimated trace approximates the true trace accurately.""" # Linearise function diff --git a/tests/test_hutchinson/test_trace_and_diagonal.py b/tests/test_hutchinson/test_trace_and_diagonal.py index 6176193..f8f0391 100644 --- a/tests/test_hutchinson/test_trace_and_diagonal.py +++ b/tests/test_hutchinson/test_trace_and_diagonal.py @@ -1,6 +1,6 @@ """Tests for basic trace estimators.""" -from matfree import hutchinson, montecarlo +from matfree import hutchinson from matfree.backend import func, linalg, np, prng, testing @@ -22,7 +22,7 @@ def fixture_key(): @testing.parametrize("num_samples", [10_000]) @testing.parametrize("dim", [5]) -@testing.parametrize("sample_fun", [montecarlo.normal, montecarlo.rademacher]) +@testing.parametrize("sample_fun", [hutchinson.normal, hutchinson.rademacher]) def test_trace_and_diagonal(fun, key, num_samples, dim, sample_fun): """Assert that the estimated trace and diagonal approximations are accurate.""" # Linearise function diff --git a/tests/test_hutchinson/test_trace_moments.py b/tests/test_hutchinson/test_trace_moments.py index 22a5508..90ddbf3 100644 --- a/tests/test_hutchinson/test_trace_moments.py +++ b/tests/test_hutchinson/test_trace_moments.py @@ -1,6 +1,6 @@ """Tests for estimating traces.""" -from matfree import hutchinson, montecarlo +from matfree import hutchinson from matfree.backend import func, linalg, np, prng, testing @@ -32,7 +32,7 @@ def test_variance_normal(J_and_jvp, key, num_batches, num_samples_per_batch, dim """Assert that the estimated trace approximates the true trace accurately.""" # Estimate the trace J, jvp = J_and_jvp - fun = montecarlo.normal(shape=(dim,), dtype=float) + fun = hutchinson.normal(shape=(dim,), dtype=float) first, second = hutchinson.trace_moments( jvp, key=key, @@ -58,7 +58,7 @@ def test_variance_rademacher(J_and_jvp, key, num_batches, num_samples_per_batch, """Assert that the estimated trace approximates the true trace accurately.""" # Estimate the trace J, jvp = J_and_jvp - fun = montecarlo.rademacher(shape=(dim,), dtype=float) + fun = hutchinson.rademacher(shape=(dim,), dtype=float) first, second = hutchinson.trace_moments( jvp, key=key, diff --git a/tests/test_montecarlo/test_van_der_corput.py b/tests/test_hutchinson/test_van_der_corput.py similarity index 74% rename from tests/test_montecarlo/test_van_der_corput.py rename to tests/test_hutchinson/test_van_der_corput.py index fed0508..8890d3c 100644 --- a/tests/test_montecarlo/test_van_der_corput.py +++ b/tests/test_hutchinson/test_van_der_corput.py @@ -1,15 +1,15 @@ """Tests for Monte-Carlo machinery.""" -from matfree import montecarlo +from matfree import hutchinson from matfree.backend import np def test_van_der_corput(): """Assert that the van-der-Corput sequence yields values as expected.""" expected = np.asarray([0, 0.5, 0.25, 0.75, 0.125, 0.625, 0.375, 0.875, 0.0625]) - received = np.asarray([montecarlo.van_der_corput(i) for i in range(9)]) + received = np.asarray([hutchinson.van_der_corput(i) for i in range(9)]) assert np.allclose(received, expected) expected = np.asarray([0.0, 1 / 3, 2 / 3, 1 / 9, 4 / 9, 7 / 9, 2 / 9, 5 / 9, 8 / 9]) - received = np.asarray([montecarlo.van_der_corput(i, base=3) for i in range(9)]) + received = np.asarray([hutchinson.van_der_corput(i, base=3) for i in range(9)]) assert np.allclose(received, expected) diff --git a/tests/test_lanczos/__init__.py b/tests/test_lanczos/__init__.py deleted file mode 100644 index 5edfa49..0000000 --- a/tests/test_lanczos/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Decomposition tests.""" diff --git a/tests/test_montecarlo/__init__.py b/tests/test_montecarlo/__init__.py deleted file mode 100644 index 7d7ed7b..0000000 --- a/tests/test_montecarlo/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Monte-Carlo tests.""" diff --git a/tests/test_pinv/__init__.py b/tests/test_pinv/__init__.py deleted file mode 100644 index 0d7748a..0000000 --- a/tests/test_pinv/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for the pinv module.""" diff --git a/tests/test_slq/__init__.py b/tests/test_slq/__init__.py deleted file mode 100644 index 885c2b9..0000000 --- a/tests/test_slq/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for stochastic Lanczos quadrature.""" diff --git a/tests/test_slq/test_logdet_product.py b/tests/test_slq/test_logdet_product.py index 158f4a8..f1c3df2 100644 --- a/tests/test_slq/test_logdet_product.py +++ b/tests/test_slq/test_logdet_product.py @@ -1,6 +1,6 @@ """Test slq.logdet_prod().""" -from matfree import montecarlo, slq, test_util +from matfree import hutchinson, slq, test_util from matfree.backend import linalg, np, prng, testing @@ -22,7 +22,7 @@ def test_logdet_product(A, order): """Assert that logdet_product yields an accurate estimate.""" _, ncols = np.shape(A) key = prng.prng_key(3) - fun = montecarlo.normal(shape=(ncols,)) + fun = hutchinson.normal(shape=(ncols,)) received = slq.logdet_product( order, lambda v: A @ v, diff --git a/tests/test_slq/test_logdet_spd.py b/tests/test_slq/test_logdet_spd.py index c0248f5..cfcfb60 100644 --- a/tests/test_slq/test_logdet_spd.py +++ b/tests/test_slq/test_logdet_spd.py @@ -1,6 +1,6 @@ """Tests for Lanczos functionality.""" -from matfree import montecarlo, slq, test_util +from matfree import hutchinson, slq, test_util from matfree.backend import linalg, np, prng, testing @@ -23,7 +23,7 @@ def test_logdet_spd(A, order): """Assert that the log-determinant estimation matches the true log-determinant.""" n, _ = np.shape(A) key = prng.prng_key(1) - fun = montecarlo.normal(shape=(n,)) + fun = hutchinson.normal(shape=(n,)) received = slq.logdet_spd( order, lambda v: A @ v, diff --git a/tests/test_slq/test_logdet_spd_autodiff.py b/tests/test_slq/test_logdet_spd_autodiff.py index 024736a..ca4b063 100644 --- a/tests/test_slq/test_logdet_spd_autodiff.py +++ b/tests/test_slq/test_logdet_spd_autodiff.py @@ -1,7 +1,7 @@ """Tests for (selected) autodiff functionality.""" -from matfree import montecarlo, slq, test_util +from matfree import hutchinson, slq, test_util from matfree.backend import np, prng, testing @@ -32,7 +32,7 @@ def fun(s): def _logdet(A, order, key): n, _ = np.shape(A) - fun = montecarlo.normal(shape=(n,)) + fun = hutchinson.normal(shape=(n,)) return slq.logdet_spd( order, lambda v: A @ v, diff --git a/tests/test_slq/test_schatten_norm.py b/tests/test_slq/test_schatten_norm.py index f68c6f3..f8e3428 100644 --- a/tests/test_slq/test_schatten_norm.py +++ b/tests/test_slq/test_schatten_norm.py @@ -1,6 +1,6 @@ """Test Schatten norm implementations.""" -from matfree import montecarlo, slq, test_util +from matfree import hutchinson, slq, test_util from matfree.backend import linalg, np, prng, testing @@ -26,7 +26,7 @@ def test_schatten_norm(A, order, power): _, ncols = np.shape(A) key = prng.prng_key(1) - fun = montecarlo.normal(shape=(ncols,)) + fun = hutchinson.normal(shape=(ncols,)) received = slq.schatten_norm( order, lambda v: A @ v, From 916acc1042e4275d6936a5925f09b84ae7d4bd6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 21 Nov 2023 13:19:09 +0100 Subject: [PATCH 2/4] All tests pass again --- README.md | 2 +- docs/control_variates.md | 2 +- docs/higher_moments.md | 2 +- docs/index.md | 2 +- docs/log_determinants.md | 2 +- docs/pytree_logdeterminants.md | 2 +- docs/vector_calculus.md | 2 +- matfree/hutchinson.py | 36 +++++++++++++++++----------------- mkdocs.yml | 5 ++--- 9 files changed, 27 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 8d1781f..600b275 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Import matfree and JAX, and set up a test problem. ```python >>> import jax >>> import jax.numpy as jnp ->>> from matfree import hutchinson, montecarlo, slq +>>> from matfree import hutchinson, slq >>> A = jnp.reshape(jnp.arange(12.0), (6, 2)) >>> diff --git a/docs/control_variates.md b/docs/control_variates.md index ebbfa5a..1f4230f 100644 --- a/docs/control_variates.md +++ b/docs/control_variates.md @@ -8,7 +8,7 @@ Imports: ```python >>> import jax >>> import jax.numpy as jnp ->>> from matfree import hutchinson, montecarlo, slq +>>> from matfree import hutchinson, slq >>> a = jnp.reshape(jnp.arange(12.0), (6, 2)) >>> key = jax.random.PRNGKey(1) diff --git a/docs/higher_moments.md b/docs/higher_moments.md index b94d4ff..9f74aa7 100644 --- a/docs/higher_moments.md +++ b/docs/higher_moments.md @@ -3,7 +3,7 @@ ```python >>> import jax >>> import jax.numpy as jnp ->>> from matfree import hutchinson, montecarlo, slq +>>> from matfree import hutchinson, slq >>> a = jnp.reshape(jnp.arange(12.0), (6, 2)) >>> key = jax.random.PRNGKey(1) diff --git a/docs/index.md b/docs/index.md index 8d1781f..600b275 100644 --- a/docs/index.md +++ b/docs/index.md @@ -51,7 +51,7 @@ Import matfree and JAX, and set up a test problem. ```python >>> import jax >>> import jax.numpy as jnp ->>> from matfree import hutchinson, montecarlo, slq +>>> from matfree import hutchinson, slq >>> A = jnp.reshape(jnp.arange(12.0), (6, 2)) >>> diff --git a/docs/log_determinants.md b/docs/log_determinants.md index c07e858..dd6375c 100644 --- a/docs/log_determinants.md +++ b/docs/log_determinants.md @@ -6,7 +6,7 @@ Imports: ```python >>> import jax >>> import jax.numpy as jnp ->>> from matfree import hutchinson, montecarlo, slq +>>> from matfree import hutchinson, slq >>> a = jnp.reshape(jnp.arange(12.0), (6, 2)) >>> key = jax.random.PRNGKey(1) diff --git a/docs/pytree_logdeterminants.md b/docs/pytree_logdeterminants.md index f74233d..4f09041 100644 --- a/docs/pytree_logdeterminants.md +++ b/docs/pytree_logdeterminants.md @@ -11,7 +11,7 @@ Imports: >>> import jax.flatten_util # this is important! >>> import jax.numpy as jnp >>> ->>> from matfree import slq, montecarlo +>>> from matfree import slq, hutchinson ``` Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple). diff --git a/docs/vector_calculus.md b/docs/vector_calculus.md index 021bf40..558b634 100644 --- a/docs/vector_calculus.md +++ b/docs/vector_calculus.md @@ -11,7 +11,7 @@ Here is how we can implement divergences and Laplacians without forming full Jac ```python >>> import jax >>> import jax.numpy as jnp ->>> from matfree import hutchinson, montecarlo +>>> from matfree import hutchinson ``` diff --git a/matfree/hutchinson.py b/matfree/hutchinson.py index 0be0efc..26f4e6a 100644 --- a/matfree/hutchinson.py +++ b/matfree/hutchinson.py @@ -31,8 +31,8 @@ def estimate( sample_fun: Sampling function. For trace-estimation, use - either [normal(...)][matfree.normal] - or [rademacher(...)][matfree.normal]. + either [normal(...)][matfree.hutchinson.normal] + or [rademacher(...)][matfree.hutchinson.normal]. num_batches: Number of batches when computing arithmetic means. num_samples_per_batch: @@ -76,7 +76,7 @@ def multiestimate( """Compute a Monte-Carlo estimate with multiple summary statistics. The signature of this function is almost identical to - [estimate(...)][matfree.estimate]. + [estimate(...)][matfree.hutchinson.estimate]. The only difference is that statistics_batch and statistics_combine are iterables of summary statistics (of equal lengths). @@ -85,15 +85,15 @@ def multiestimate( Parameters ---------- fun: - Same as in [estimate(...)][matfree.estimate]. + Same as in [estimate(...)][matfree.hutchinson.estimate]. key: - Same as in [estimate(...)][matfree.estimate]. + Same as in [estimate(...)][matfree.hutchinson.estimate]. sample_fun: - Same as in [estimate(...)][matfree.estimate]. + Same as in [estimate(...)][matfree.hutchinson.estimate]. num_batches: - Same as in [estimate(...)][matfree.estimate]. + Same as in [estimate(...)][matfree.hutchinson.estimate]. num_samples_per_batch: - Same as in [estimate(...)][matfree.estimate]. + Same as in [estimate(...)][matfree.hutchinson.estimate]. statistics_batch: List or tuple of summary statistics to compute on batch-level. statistics_combine: @@ -198,7 +198,7 @@ def trace(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.estimate]. + [estimate()][matfree.hutchinson.estimate]. """ def quadform(vec): @@ -219,7 +219,7 @@ def trace_moments(Av: Callable, /, moments: Sequence[int] = (1, 2), **kwargs) -> the first and second moment. **kwargs: Keyword-arguments to be passed to - [multiestimate(...)][matfree.multiestimate]. + [multiestimate(...)][matfree.hutchinson.multiestimate]. """ def quadform(vec): @@ -255,7 +255,7 @@ def frobeniusnorm_squared(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.estimate]. + [estimate()][matfree.hutchinson.estimate]. """ @@ -278,7 +278,7 @@ def diagonal_with_control_variate(Av: Callable, control: Array, /, **kwargs) -> This should be the best-possible estimate of the diagonal of the matrix. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.estimate]. + [estimate()][matfree.hutchinson.estimate]. """ return diagonal(lambda v: Av(v) - control * v, **kwargs) + control @@ -293,7 +293,7 @@ def diagonal(Av: Callable, /, **kwargs) -> Array: Matrix-vector product function. **kwargs: Keyword-arguments to be passed to - [estimate()][matfree.estimate]. + [estimate()][matfree.hutchinson.estimate]. """ @@ -318,13 +318,13 @@ def trace_and_diagonal(Av: Callable, /, *, sample_fun: Callable, key: Array, **k Matrix-vector product function. sample_fun: Sampling function. - Usually, either [normal][matfree.normal] - or [rademacher][matfree.normal]. + Usually, either [normal][matfree.hutchinson.normal] + or [rademacher][matfree.hutchinson.normal]. key: Pseudo-random number generator key. **kwargs: Keyword-arguments to be passed to - [diagonal_multilevel()][matfree.diagonal_multilevel]. + [diagonal_multilevel()][matfree.hutchinson.diagonal_multilevel]. See: @@ -368,8 +368,8 @@ def diagonal_multilevel( Pseudo-random number generator key. sample_fun: Sampling function. - Usually, either [normal][matfree.normal] - or [rademacher][matfree.normal]. + Usually, either [normal][matfree.hutchinson.normal] + or [rademacher][matfree.hutchinson.normal]. num_levels: Number of levels. num_batches_per_level: diff --git a/mkdocs.yml b/mkdocs.yml index 9cab81f..5c97a93 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,7 +11,6 @@ nav: - vector_calculus.md - pytree_logdeterminants.md - API documentation: - - matfree.montecarlo: api/hutchinson.md - matfree.decomp: api/decomp.md - matfree.lanczos: api/lanczos.md - matfree.hutchinson: api/hutchinson.md @@ -47,12 +46,12 @@ plugins: handlers: python: options: - show_root_heading: true + show_root_heading: false show_root_toc_entry: true show_root_full_path: true show_root_members_full_path: true show_object_full_path: false - show_category_heading: true + show_category_heading: false docstring_style: numpy show_if_no_docstring: true members_order: alphabetical From c953e0645d8c5f691d9a74f4e2c39eefe7f9125a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 21 Nov 2023 13:20:37 +0100 Subject: [PATCH 3/4] Do not show source in docs --- mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/mkdocs.yml b/mkdocs.yml index 5c97a93..e3e487d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,6 +60,7 @@ plugins: show_signature_annotations: true separate_signature: false docstring_section_style: list + show_source: false watch: [matfree] extra: social: From eae875e1768380a026f15c120153307687b17b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 21 Nov 2023 13:27:55 +0100 Subject: [PATCH 4/4] Merged decompositions and Lanczos --- docs/api/lanczos.md | 3 - matfree/decomp.py | 164 +++++++++++++++++- matfree/lanczos.py | 161 ----------------- matfree/slq.py | 6 +- mkdocs.yml | 1 - .../test_bidiagonal_full_reortho.py | 10 +- .../test_tridiagonal_full_reortho.py | 6 +- 7 files changed, 171 insertions(+), 180 deletions(-) delete mode 100644 docs/api/lanczos.md delete mode 100644 matfree/lanczos.py rename tests/{test_lanczos => test_decomp}/test_bidiagonal_full_reortho.py (90%) rename tests/{test_lanczos => test_decomp}/test_tridiagonal_full_reortho.py (95%) diff --git a/docs/api/lanczos.md b/docs/api/lanczos.md deleted file mode 100644 index 3f0d80a..0000000 --- a/docs/api/lanczos.md +++ /dev/null @@ -1,3 +0,0 @@ -# Lanczos-style algorithms - -:::matfree.lanczos diff --git a/matfree/decomp.py b/matfree/decomp.py index 97880f4..07f9362 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -1,10 +1,166 @@ """Matrix decomposition algorithms.""" -from matfree import lanczos -from matfree.backend import containers, control_flow, linalg +from matfree.backend import containers, control_flow, linalg, np from matfree.backend.typing import Array, Callable, Tuple +class _Alg(containers.NamedTuple): + """Matrix decomposition algorithm.""" + + init: Callable + """Initialise the state of the algorithm. Usually, this involves pre-allocation.""" + + step: Callable + """Compute the next iteration.""" + + extract: Callable + """Extract the solution from the state of the algorithm.""" + + lower_upper: Tuple[int, int] + """Range of the for-loop used to decompose a matrix.""" + + +def tridiagonal_full_reortho(depth, /): + """Construct an implementation of **tridiagonalisation**. + + Uses pre-allocation. Fully reorthogonalise vectors at every step. + + This algorithm assumes a **symmetric matrix**. + + Decompose a matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices. + Use this algorithm for approximate **eigenvalue** decompositions. + + """ + + class State(containers.NamedTuple): + i: int + basis: Array + tridiag: Tuple[Array, Array] + q: Array + + def init(init_vec: Array) -> State: + (ncols,) = np.shape(init_vec) + if depth >= ncols or depth < 1: + raise ValueError + + diag = np.zeros((depth + 1,)) + offdiag = np.zeros((depth,)) + basis = np.zeros((depth + 1, ncols)) + + return State(0, basis, (diag, offdiag), init_vec) + + def apply(state: State, Av: Callable) -> State: + i, basis, (diag, offdiag), vec = state + + # Re-orthogonalise against ALL basis elements before storing. + # Note: we re-orthogonalise against ALL columns of Q, not just + # the ones we have already computed. This increases the complexity + # of the whole iteration from n(n+1)/2 to n^2, but has the advantage + # that the whole computation has static bounds (thus we can JIT it all). + # Since 'Q' is padded with zeros, the numerical values are identical + # between both modes of computing. + vec, length = _normalise(vec) + vec, _ = _gram_schmidt_orthogonalise_set(vec, basis) + + # I don't know why, but this re-normalisation is soooo crucial + vec, _ = _normalise(vec) + basis = basis.at[i, :].set(vec) + + # When i==0, Q[i-1] is Q[-1] and again, we benefit from the fact + # that Q is initialised with zeros. + vec = Av(vec) + basis_vectors_previous = np.asarray([basis[i], basis[i - 1]]) + vec, (coeff, _) = _gram_schmidt_orthogonalise_set(vec, basis_vectors_previous) + diag = diag.at[i].set(coeff) + offdiag = offdiag.at[i - 1].set(length) + + return State(i + 1, basis, (diag, offdiag), vec) + + def extract(state: State, /): + _, basis, (diag, offdiag), _ = state + return basis, (diag, offdiag) + + return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1)) + + +def bidiagonal_full_reortho(depth, /, matrix_shape): + """Construct an implementation of **bidiagonalisation**. + + Uses pre-allocation. Fully reorthogonalise vectors at every step. + + Works for **arbitrary matrices**. No symmetry required. + + Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices. + Use this algorithm for approximate **singular value** decompositions. + """ + nrows, ncols = matrix_shape + max_depth = min(nrows, ncols) - 1 + if depth > max_depth or depth < 0: + msg1 = f"Depth {depth} exceeds the matrix' dimensions. " + msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} " + msg3 = f"for a matrix with shape {matrix_shape}." + raise ValueError(msg1 + msg2 + msg3) + + class State(containers.NamedTuple): + i: int + Us: Array + Vs: Array + alphas: Array + betas: Array + beta: Array + vk: Array + + def init(init_vec: Array) -> State: + nrows, ncols = matrix_shape + alphas = np.zeros((depth + 1,)) + betas = np.zeros((depth + 1,)) + Us = np.zeros((depth + 1, nrows)) + Vs = np.zeros((depth + 1, ncols)) + v0, _ = _normalise(init_vec) + return State(0, Us, Vs, alphas, betas, 0.0, v0) + + def apply(state: State, Av: Callable, vA: Callable) -> State: + i, Us, Vs, alphas, betas, beta, vk = state + Vs = Vs.at[i].set(vk) + betas = betas.at[i].set(beta) + + uk = Av(vk) - beta * Us[i - 1] + uk, alpha = _normalise(uk) + uk, _ = _gram_schmidt_orthogonalise_set(uk, Us) # full reorthogonalisation + uk, _ = _normalise(uk) + Us = Us.at[i].set(uk) + alphas = alphas.at[i].set(alpha) + + vk = vA(uk) - alpha * vk + vk, beta = _normalise(vk) + vk, _ = _gram_schmidt_orthogonalise_set(vk, Vs) # full reorthogonalisation + vk, _ = _normalise(vk) + + return State(i + 1, Us, Vs, alphas, betas, beta, vk) + + def extract(state: State, /): + _, uk_all, vk_all, alphas, betas, beta, vk = state + return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk) + + return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1)) + + +def _normalise(vec): + length = linalg.vector_norm(vec) + return vec / length, length + + +def _gram_schmidt_orthogonalise_set(vec, vectors): # Gram-Schmidt + vec, coeffs = control_flow.scan(_gram_schmidt_orthogonalise, vec, xs=vectors) + return vec, coeffs + + +def _gram_schmidt_orthogonalise(vec1, vec2): + coeff = linalg.vecdot(vec1, vec2) + vec_ortho = vec1 - coeff * vec2 + return vec_ortho, coeff + + def svd( v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...] ): @@ -29,7 +185,7 @@ def svd( Shape of the matrix involved in matrix-vector and vector-matrix products. """ # Factorise the matrix - algorithm = lanczos.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape) + algorithm = bidiagonal_full_reortho(depth, matrix_shape=matrix_shape) u, (d, e), vt, _ = decompose_fori_loop(v0, Av, vA, algorithm=algorithm) # Compute SVD of factorisation @@ -66,7 +222,7 @@ class _DecompAlg(containers.NamedTuple): """Decomposition algorithm type. For example, the output of -[matfree.lanczos.tridiagonal_full_reortho(...)][matfree.lanczos.tridiagonal_full_reortho]. +[matfree.decomp.tridiagonal_full_reortho(...)][matfree.decomp.tridiagonal_full_reortho]. """ diff --git a/matfree/lanczos.py b/matfree/lanczos.py deleted file mode 100644 index 27968a5..0000000 --- a/matfree/lanczos.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Lanczos-style algorithms.""" - -from matfree.backend import containers, control_flow, linalg, np -from matfree.backend.typing import Array, Callable, Tuple - - -class _Alg(containers.NamedTuple): - """Matrix decomposition algorithm.""" - - init: Callable - """Initialise the state of the algorithm. Usually, this involves pre-allocation.""" - - step: Callable - """Compute the next iteration.""" - - extract: Callable - """Extract the solution from the state of the algorithm.""" - - lower_upper: Tuple[int, int] - """Range of the for-loop used to decompose a matrix.""" - - -def tridiagonal_full_reortho(depth, /): - """Construct an implementation of **tridiagonalisation**. - - Uses pre-allocation. Fully reorthogonalise vectors at every step. - - This algorithm assumes a **symmetric matrix**. - - Decompose a matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices. - Use this algorithm for approximate **eigenvalue** decompositions. - - """ - - class State(containers.NamedTuple): - i: int - basis: Array - tridiag: Tuple[Array, Array] - q: Array - - def init(init_vec: Array) -> State: - (ncols,) = np.shape(init_vec) - if depth >= ncols or depth < 1: - raise ValueError - - diag = np.zeros((depth + 1,)) - offdiag = np.zeros((depth,)) - basis = np.zeros((depth + 1, ncols)) - - return State(0, basis, (diag, offdiag), init_vec) - - def apply(state: State, Av: Callable) -> State: - i, basis, (diag, offdiag), vec = state - - # Re-orthogonalise against ALL basis elements before storing. - # Note: we re-orthogonalise against ALL columns of Q, not just - # the ones we have already computed. This increases the complexity - # of the whole iteration from n(n+1)/2 to n^2, but has the advantage - # that the whole computation has static bounds (thus we can JIT it all). - # Since 'Q' is padded with zeros, the numerical values are identical - # between both modes of computing. - vec, length = _normalise(vec) - vec, _ = _gram_schmidt_orthogonalise_set(vec, basis) - - # I don't know why, but this re-normalisation is soooo crucial - vec, _ = _normalise(vec) - basis = basis.at[i, :].set(vec) - - # When i==0, Q[i-1] is Q[-1] and again, we benefit from the fact - # that Q is initialised with zeros. - vec = Av(vec) - basis_vectors_previous = np.asarray([basis[i], basis[i - 1]]) - vec, (coeff, _) = _gram_schmidt_orthogonalise_set(vec, basis_vectors_previous) - diag = diag.at[i].set(coeff) - offdiag = offdiag.at[i - 1].set(length) - - return State(i + 1, basis, (diag, offdiag), vec) - - def extract(state: State, /): - _, basis, (diag, offdiag), _ = state - return basis, (diag, offdiag) - - return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1)) - - -def bidiagonal_full_reortho(depth, /, matrix_shape): - """Construct an implementation of **bidiagonalisation**. - - Uses pre-allocation. Fully reorthogonalise vectors at every step. - - Works for **arbitrary matrices**. No symmetry required. - - Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices. - Use this algorithm for approximate **singular value** decompositions. - """ - nrows, ncols = matrix_shape - max_depth = min(nrows, ncols) - 1 - if depth > max_depth or depth < 0: - msg1 = f"Depth {depth} exceeds the matrix' dimensions. " - msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} " - msg3 = f"for a matrix with shape {matrix_shape}." - raise ValueError(msg1 + msg2 + msg3) - - class State(containers.NamedTuple): - i: int - Us: Array - Vs: Array - alphas: Array - betas: Array - beta: Array - vk: Array - - def init(init_vec: Array) -> State: - nrows, ncols = matrix_shape - alphas = np.zeros((depth + 1,)) - betas = np.zeros((depth + 1,)) - Us = np.zeros((depth + 1, nrows)) - Vs = np.zeros((depth + 1, ncols)) - v0, _ = _normalise(init_vec) - return State(0, Us, Vs, alphas, betas, 0.0, v0) - - def apply(state: State, Av: Callable, vA: Callable) -> State: - i, Us, Vs, alphas, betas, beta, vk = state - Vs = Vs.at[i].set(vk) - betas = betas.at[i].set(beta) - - uk = Av(vk) - beta * Us[i - 1] - uk, alpha = _normalise(uk) - uk, _ = _gram_schmidt_orthogonalise_set(uk, Us) # full reorthogonalisation - uk, _ = _normalise(uk) - Us = Us.at[i].set(uk) - alphas = alphas.at[i].set(alpha) - - vk = vA(uk) - alpha * vk - vk, beta = _normalise(vk) - vk, _ = _gram_schmidt_orthogonalise_set(vk, Vs) # full reorthogonalisation - vk, _ = _normalise(vk) - - return State(i + 1, Us, Vs, alphas, betas, beta, vk) - - def extract(state: State, /): - _, uk_all, vk_all, alphas, betas, beta, vk = state - return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk) - - return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1)) - - -def _normalise(vec): - length = linalg.vector_norm(vec) - return vec / length, length - - -def _gram_schmidt_orthogonalise_set(vec, vectors): # Gram-Schmidt - vec, coeffs = control_flow.scan(_gram_schmidt_orthogonalise, vec, xs=vectors) - return vec, coeffs - - -def _gram_schmidt_orthogonalise(vec1, vec2): - coeff = linalg.vecdot(vec1, vec2) - vec_ortho = vec1 - coeff * vec2 - return vec_ortho, coeff diff --git a/matfree/slq.py b/matfree/slq.py index ee2d1f2..dea6414 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -1,6 +1,6 @@ """Stochastic Lanczos quadrature.""" -from matfree import decomp, hutchinson, lanczos +from matfree import decomp, hutchinson from matfree.backend import func, linalg, np @@ -22,7 +22,7 @@ def _quadratic_form_slq_spd(matfun, order, Av, /): """ def quadform(v0, /): - algorithm = lanczos.tridiagonal_full_reortho(order) + algorithm = decomp.tridiagonal_full_reortho(order) _, tridiag = decomp.decompose_fori_loop(v0, Av, algorithm=algorithm) (diag, off_diag) = tridiag @@ -85,7 +85,7 @@ def _quadratic_form_slq_product(matfun, depth, *matvec_funs, matrix_shape): def quadform(v0, /): # Decompose into orthogonal-bidiag-orthogonal - algorithm = lanczos.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape) + algorithm = decomp.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape) output = decomp.decompose_fori_loop(v0, *matvec_funs, algorithm=algorithm) u, (d, e), vt, _ = output diff --git a/mkdocs.yml b/mkdocs.yml index e3e487d..8859575 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,7 +12,6 @@ nav: - pytree_logdeterminants.md - API documentation: - matfree.decomp: api/decomp.md - - matfree.lanczos: api/lanczos.md - matfree.hutchinson: api/hutchinson.md - matfree.slq: api/slq.md - matfree.pinv: api/pinv.md diff --git a/tests/test_lanczos/test_bidiagonal_full_reortho.py b/tests/test_decomp/test_bidiagonal_full_reortho.py similarity index 90% rename from tests/test_lanczos/test_bidiagonal_full_reortho.py rename to tests/test_decomp/test_bidiagonal_full_reortho.py index df949a6..76c6cf1 100644 --- a/tests/test_lanczos/test_bidiagonal_full_reortho.py +++ b/tests/test_decomp/test_bidiagonal_full_reortho.py @@ -1,6 +1,6 @@ """Tests for GKL bidiagonalisation.""" -from matfree import decomp, lanczos, test_util +from matfree import decomp, test_util from matfree.backend import linalg, np, prng, testing @@ -23,7 +23,7 @@ def test_bidiagonal_full_reortho(A, order): nrows, ncols = np.shape(A) key = prng.prng_key(1) v0 = prng.normal(key, shape=(ncols,)) - alg = lanczos.bidiagonal_full_reortho(order, matrix_shape=np.shape(A)) + alg = decomp.bidiagonal_full_reortho(order, matrix_shape=np.shape(A)) def Av(v): return A @ v @@ -74,7 +74,7 @@ def test_error_too_high_depth(A): max_depth = min(nrows, ncols) - 1 with testing.raises(ValueError): - _ = lanczos.bidiagonal_full_reortho(max_depth + 1, matrix_shape=np.shape(A)) + _ = decomp.bidiagonal_full_reortho(max_depth + 1, matrix_shape=np.shape(A)) @testing.parametrize("nrows", [5]) @@ -84,7 +84,7 @@ def test_error_too_low_depth(A): """Assert that a ValueError is raised when the depth is negative.""" min_depth = 0 with testing.raises(ValueError): - _ = lanczos.bidiagonal_full_reortho(min_depth - 1, matrix_shape=np.shape(A)) + _ = decomp.bidiagonal_full_reortho(min_depth - 1, matrix_shape=np.shape(A)) @testing.parametrize("nrows", [15]) @@ -93,7 +93,7 @@ def test_error_too_low_depth(A): def test_no_error_zero_depth(A): """Assert the corner case of zero-depth does not raise an error.""" nrows, ncols = np.shape(A) - algorithm = lanczos.bidiagonal_full_reortho(0, matrix_shape=np.shape(A)) + algorithm = decomp.bidiagonal_full_reortho(0, matrix_shape=np.shape(A)) key = prng.prng_key(1) v0 = prng.normal(key, shape=(ncols,)) diff --git a/tests/test_lanczos/test_tridiagonal_full_reortho.py b/tests/test_decomp/test_tridiagonal_full_reortho.py similarity index 95% rename from tests/test_lanczos/test_tridiagonal_full_reortho.py rename to tests/test_decomp/test_tridiagonal_full_reortho.py index 7857023..beb861e 100644 --- a/tests/test_lanczos/test_tridiagonal_full_reortho.py +++ b/tests/test_decomp/test_tridiagonal_full_reortho.py @@ -1,6 +1,6 @@ """Tests for Lanczos functionality.""" -from matfree import decomp, lanczos, test_util +from matfree import decomp, test_util from matfree.backend import linalg, np, prng, testing @@ -22,7 +22,7 @@ def test_max_order(A): order = n - 1 key = prng.prng_key(1) v0 = prng.normal(key, shape=(n,)) - alg = lanczos.tridiagonal_full_reortho(order) + alg = decomp.tridiagonal_full_reortho(order) Q, (d_m, e_m) = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg) # Lanczos is not stable. @@ -63,7 +63,7 @@ def test_identity(A, order): n, _ = np.shape(A) key = prng.prng_key(1) v0 = prng.normal(key, shape=(n,)) - alg = lanczos.tridiagonal_full_reortho(order) + alg = decomp.tridiagonal_full_reortho(order) Q, tridiag = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg) (d_m, e_m) = tridiag