From ae0496dd027b19df1e15b9c3e6aa22f3e8ca0ff4 Mon Sep 17 00:00:00 2001 From: <> Date: Wed, 8 Jan 2025 11:04:16 +0000 Subject: [PATCH] Deployed 524474c with MkDocs version: 1.6.1 --- .nojekyll | 0 404.html | 876 +++ API_documentation/bounds/index.html | 1192 +++ API_documentation/decomp/index.html | 1925 +++++ API_documentation/eig/index.html | 1437 ++++ API_documentation/funm/index.html | 2260 ++++++ API_documentation/low_rank/index.html | 1408 ++++ API_documentation/stochtrace/index.html | 1757 +++++ API_documentation/test_util/index.html | 1450 ++++ .../index.html | 977 +++ .../2_contribute_to_matfree/index.html | 975 +++ .../index.html | 973 +++ .../index.html | 966 +++ .../index.html | 1855 +++++ .../index.html | 1854 +++++ .../index.html | 1779 +++++ .../index.html | 1847 +++++ .../index.html | 1850 +++++ .../index.html | 1869 +++++ .../index.html | 1831 +++++ assets/_mkdocstrings.css | 143 + assets/images/favicon.png | Bin 0 -> 1870 bytes assets/javascripts/bundle.88dd0f4e.min.js | 16 + assets/javascripts/bundle.88dd0f4e.min.js.map | 7 + assets/javascripts/lunr/min/lunr.ar.min.js | 1 + assets/javascripts/lunr/min/lunr.da.min.js | 18 + assets/javascripts/lunr/min/lunr.de.min.js | 18 + assets/javascripts/lunr/min/lunr.du.min.js | 18 + assets/javascripts/lunr/min/lunr.el.min.js | 1 + assets/javascripts/lunr/min/lunr.es.min.js | 18 + assets/javascripts/lunr/min/lunr.fi.min.js | 18 + assets/javascripts/lunr/min/lunr.fr.min.js | 18 + assets/javascripts/lunr/min/lunr.he.min.js | 1 + assets/javascripts/lunr/min/lunr.hi.min.js | 1 + assets/javascripts/lunr/min/lunr.hu.min.js | 18 + assets/javascripts/lunr/min/lunr.hy.min.js | 1 + assets/javascripts/lunr/min/lunr.it.min.js | 18 + assets/javascripts/lunr/min/lunr.ja.min.js | 1 + assets/javascripts/lunr/min/lunr.jp.min.js | 1 + assets/javascripts/lunr/min/lunr.kn.min.js | 1 + assets/javascripts/lunr/min/lunr.ko.min.js | 1 + assets/javascripts/lunr/min/lunr.multi.min.js | 1 + assets/javascripts/lunr/min/lunr.nl.min.js | 18 + assets/javascripts/lunr/min/lunr.no.min.js | 18 + assets/javascripts/lunr/min/lunr.pt.min.js | 18 + assets/javascripts/lunr/min/lunr.ro.min.js | 18 + assets/javascripts/lunr/min/lunr.ru.min.js | 18 + assets/javascripts/lunr/min/lunr.sa.min.js | 1 + .../lunr/min/lunr.stemmer.support.min.js | 1 + assets/javascripts/lunr/min/lunr.sv.min.js | 18 + assets/javascripts/lunr/min/lunr.ta.min.js | 1 + assets/javascripts/lunr/min/lunr.te.min.js | 1 + assets/javascripts/lunr/min/lunr.th.min.js | 1 + assets/javascripts/lunr/min/lunr.tr.min.js | 18 + assets/javascripts/lunr/min/lunr.vi.min.js | 1 + assets/javascripts/lunr/min/lunr.zh.min.js | 1 + assets/javascripts/lunr/tinyseg.js | 206 + assets/javascripts/lunr/wordcut.js | 6708 +++++++++++++++++ .../workers/search.6ce7567c.min.js | 42 + .../workers/search.6ce7567c.min.js.map | 7 + assets/stylesheets/main.6f8fc17f.min.css | 1 + assets/stylesheets/main.6f8fc17f.min.css.map | 1 + assets/stylesheets/palette.06af60db.min.css | 1 + .../stylesheets/palette.06af60db.min.css.map | 1 + index.html | 1026 +++ javascripts/mathjax.js | 18 + objects.inv | Bin 0 -> 636 bytes search/search_index.json | 1 + sitemap.xml | 3 + sitemap.xml.gz | Bin 0 -> 127 bytes 70 files changed, 37549 insertions(+) create mode 100644 .nojekyll create mode 100644 404.html create mode 100644 API_documentation/bounds/index.html create mode 100644 API_documentation/decomp/index.html create mode 100644 API_documentation/eig/index.html create mode 100644 API_documentation/funm/index.html create mode 100644 API_documentation/low_rank/index.html create mode 100644 API_documentation/stochtrace/index.html create mode 100644 API_documentation/test_util/index.html create mode 100644 Developer_documentation/1_use_matfree's_continuous_integration/index.html create mode 100644 Developer_documentation/2_contribute_to_matfree/index.html create mode 100644 Developer_documentation/3_extend_matfree's_documentation/index.html create mode 100644 Developer_documentation/4_understand_matfree's_api_policy/index.html create mode 100644 Tutorials/1_compute_log_determinants_with_stochastic_lanczos_quadrature/index.html create mode 100644 Tutorials/2_estimate_log_determinants_of_py_tree_valued_functions/index.html create mode 100644 Tutorials/3_implement_uncertainty_quantification_for_trace_estimation/index.html create mode 100644 Tutorials/4_combine_trace_estimation_with_control_variates/index.html create mode 100644 Tutorials/5_implement_vector_calculus_in_linear_complexity/index.html create mode 100644 Tutorials/6_carry_out_stochastic_trace_estimation_with_minimal_memory/index.html create mode 100644 Tutorials/7_compute_matrix_functions_without_materializing_large_matrices/index.html create mode 100644 assets/_mkdocstrings.css create mode 100644 assets/images/favicon.png create mode 100644 assets/javascripts/bundle.88dd0f4e.min.js create mode 100644 assets/javascripts/bundle.88dd0f4e.min.js.map create mode 100644 assets/javascripts/lunr/min/lunr.ar.min.js create mode 100644 assets/javascripts/lunr/min/lunr.da.min.js create mode 100644 assets/javascripts/lunr/min/lunr.de.min.js create mode 100644 assets/javascripts/lunr/min/lunr.du.min.js create mode 100644 assets/javascripts/lunr/min/lunr.el.min.js create mode 100644 assets/javascripts/lunr/min/lunr.es.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.he.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hu.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hy.min.js create mode 100644 assets/javascripts/lunr/min/lunr.it.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ja.min.js create mode 100644 assets/javascripts/lunr/min/lunr.jp.min.js create mode 100644 assets/javascripts/lunr/min/lunr.kn.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ko.min.js create mode 100644 assets/javascripts/lunr/min/lunr.multi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.nl.min.js create mode 100644 assets/javascripts/lunr/min/lunr.no.min.js create mode 100644 assets/javascripts/lunr/min/lunr.pt.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ro.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ru.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sa.min.js create mode 100644 assets/javascripts/lunr/min/lunr.stemmer.support.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sv.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ta.min.js create mode 100644 assets/javascripts/lunr/min/lunr.te.min.js create mode 100644 assets/javascripts/lunr/min/lunr.th.min.js create mode 100644 assets/javascripts/lunr/min/lunr.tr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.vi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.zh.min.js create mode 100644 assets/javascripts/lunr/tinyseg.js create mode 100644 assets/javascripts/lunr/wordcut.js create mode 100644 assets/javascripts/workers/search.6ce7567c.min.js create mode 100644 assets/javascripts/workers/search.6ce7567c.min.js.map create mode 100644 assets/stylesheets/main.6f8fc17f.min.css create mode 100644 assets/stylesheets/main.6f8fc17f.min.css.map create mode 100644 assets/stylesheets/palette.06af60db.min.css create mode 100644 assets/stylesheets/palette.06af60db.min.css.map create mode 100644 index.html create mode 100644 javascripts/mathjax.js create mode 100644 objects.inv create mode 100644 search/search_index.json create mode 100644 sitemap.xml create mode 100644 sitemap.xml.gz diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/404.html b/404.html new file mode 100644 index 0000000..b8b4569 --- /dev/null +++ b/404.html @@ -0,0 +1,876 @@ + + + +
+ + + + + + + + + + + + + + +matfree.bounds
+
+
+Matrix-free bounds on functions of matrices.
+ + + + + + + + +matfree.bounds.baigolub96_logdet_spd(bound_spectrum, /, nrows, trace, norm_frobenius_squared)
+
+Bound the log-determinant of a symmetric, positive definite matrix.
+This function implements Theorem 2 in the paper by Bai and Golub (1996).
+bound_spectrum
is either an upper or a lower bound
+on the spectrum of the matrix.
+If it is an upper bound,
+the function returns an upper bound of the log-determinant.
+If it is a lower bound,
+the function returns a lower bound of the log-determinant.
@article{bai1996bounds,
+ title={Bounds for the trace of the inverse and the
+ determinant of symmetric positive definite matrices},
+ author={Bai, Zhaojun and Golub, Gene H},
+ journal={Annals of Numerical Mathematics},
+ volume={4},
+ pages={29--38},
+ year={1996},
+ publisher={Citeseer}
+}
+
matfree/bounds.py
6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 |
|
matfree.decomp
+
+
+Matrix-free matrix decompositions.
+This module includes various Lanczos-decompositions of matrices +(tri-diagonal, bi-diagonal, etc.).
+For stochastic Lanczos quadrature and +matrix-function-vector products, see +matfree.funm.
+ + + + + + + + +matfree.decomp.bidiag(num_matvecs: int, /, materialize: bool = True)
+
+Construct an implementation of bidiagonalisation.
+Uses pre-allocation and full reorthogonalisation.
+Works for arbitrary matrices. No symmetry required.
+Decompose a matrix into a product of orthogonal-bidiagonal-orthogonal matrices. +Use this algorithm for approximate singular value decompositions.
+Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products.
+Unlike tridiag_sym or +hessenberg, this function's reverse-mode +derivatives are very efficient. Custom gradients for bidiagonalisation +are a work in progress, and if you need to differentiate the decompositions, +consider using tridiag_sym for the time being.
+matfree/decomp.py
584 +585 +586 +587 +588 +589 +590 +591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 +604 +605 +606 +607 +608 +609 +610 +611 +612 +613 +614 +615 +616 +617 +618 +619 +620 +621 +622 +623 +624 +625 +626 +627 +628 +629 +630 +631 +632 +633 +634 +635 +636 +637 +638 +639 +640 +641 +642 +643 +644 +645 +646 +647 +648 +649 +650 +651 +652 +653 +654 +655 +656 +657 +658 +659 +660 +661 +662 +663 +664 +665 +666 +667 +668 +669 +670 +671 +672 +673 +674 +675 +676 +677 +678 +679 +680 +681 +682 +683 +684 +685 +686 +687 +688 +689 +690 +691 +692 +693 +694 +695 +696 +697 +698 +699 +700 +701 +702 +703 +704 +705 +706 +707 +708 |
|
matfree.decomp.hessenberg(num_matvecs, /, *, reortho: str, custom_vjp: bool = True, reortho_vjp: str = 'match')
+
+Construct a Hessenberg-factorisation via the Arnoldi iteration.
+Uses pre-allocation, and full reorthogonalisation if reortho
is set to "full"
.
+It tends to be a good idea to use full reorthogonalisation.
This algorithm works for arbitrary matrices.
+Setting custom_vjp
to True
implies using efficient, numerically stable
+gradients of the Arnoldi iteration according to what has been proposed by
+Krämer et al. (2024).
+These gradients are exact, so there is little reason not to use them.
+If you use this configuration,
+please consider citing Krämer et al. (2024; bibtex below).
@article{kraemer2024gradients,
+ title={Gradients of functions of large matrices},
+ author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
+ Roy, Hrittik and Hauberg, S{\o}ren},
+ journal={arXiv preprint arXiv:2405.17277},
+ year={2024}
+}
+
matfree/decomp.py
337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 |
|
matfree.decomp.tridiag_sym(num_matvecs: int, /, *, materialize: bool = True, reortho: str = 'full', custom_vjp: bool = True)
+
+Construct an implementation of tridiagonalisation.
+Uses pre-allocation, and full reorthogonalisation if reortho
is set to "full"
.
+It tends to be a good idea to use full reorthogonalisation.
This algorithm assumes a symmetric matrix.
+Decompose a matrix into a product of orthogonal-tridiagonal-orthogonal matrices. +Use this algorithm for approximate eigenvalue decompositions.
+Setting custom_vjp
to True
implies using efficient, numerically stable
+gradients of the Lanczos iteration according to what has been proposed by
+Krämer et al. (2024).
+These gradients are exact, so there is little reason not to use them.
+If you use this configuration, please consider
+citing Krämer et al. (2024; bibtex below).
@article{kraemer2024gradients,
+ title={Gradients of functions of large matrices},
+ author={Kr\"amer, Nicholas and Moreno-Mu\~noz, Pablo and
+ Roy, Hrittik and Hauberg, S{\o}ren},
+ journal={arXiv preprint arXiv:2405.17277},
+ year={2024}
+}
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ num_matvecs
+ |
+
+ int
+ |
+
+
+
+ The number of matrix-vector products aka the depth of the Krylov space. +The deeper the Krylov space, the more accurate the factorisation tends to be. +However, the computational complexity increases linearly +with the number of matrix-vector products. + |
+ + required + | +
+ materialize
+ |
+
+ bool
+ |
+
+
+
+ The value of this flag indicates whether the tridiagonal matrix +should be returned in a sparse format (which means, as a tuple of diagonas) +or as a dense matrix. +The dense matrix is helpful if different decompositions should be used +interchangeably. The sparse representation requires less memory. + |
+
+ True
+ |
+
+ reortho
+ |
+
+ str
+ |
+
+
+
+ The value of this parameter indicates whether to reorthogonalise the +basis vectors during the forward pass. +Reorthogonalisation makes the forward pass more expensive, but helps +(significantly) with numerical stability. + |
+
+ 'full'
+ |
+
+ custom_vjp
+ |
+
+ bool
+ |
+
+
+
+ The value of this flag indicates whether to use a custom vector-Jacobian +product as proposed by Krämer et al. (2024; bibtex above). +Generally, using a custom VJP tends to be a good idea. +However, due to JAX's mechanics, a custom VJP precludes the use of forward-mode +differentiation +(see here), +so don't use a custom VJP if you need forward-mode differentiation. + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ decompose
+ |
+
+
+
+ A decomposition function that maps
+ |
+
matfree/decomp.py
30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 |
|
matfree.eig
+
+
+Matrix-free eigenvalue and singular-value analysis.
+ + + + + + + + +matfree.eig.eig_partial(hessenberg: Callable) -> Callable
+
+Partial eigenvalue decomposition.
+Combines Hessenberg factorisation with a decomposition +of the (small) Hessenberg matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ hessenberg
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of Hessenberg factorisation. +For example, the output of +decomp.hessenberg. + |
+ + required + | +
matfree/eig.py
63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 |
|
matfree.eig.eigh_partial(tridiag_sym: Callable) -> Callable
+
+Partial symmetric/Hermitian eigenvalue decomposition.
+Combines tridiagonalization with a decomposition +of the (small) tridiagonal matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ tridiag_sym
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of tridiagonalization. +For example, the output of +decomp.tridiag_sym. + |
+ + required + | +
matfree/eig.py
36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 |
|
matfree.eig.svd_partial(bidiag: Callable) -> Callable
+
+Partial singular value decomposition.
+Combines bidiagonalisation with a full SVD of the (small) bidiagonal matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ bidiag
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of bidiagonalisation. +For example, the output of +decomp.bidiag. +Note how this function assumes that the bidiagonalisation +materialises the bidiagonal matrix. + |
+ + required + | +
matfree/eig.py
7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 |
|
matfree.funm
+
+
+Matrix-free implementations of functions of matrices.
+This includes matrix-function-vector products
+as well as matrix-function extensions for stochastic trace estimation, +which provide
+Plug these integrands into +matfree.stochtrace.estimator.
+ + +Examples:
+>>> import jax.random
+>>> import jax.numpy as jnp
+>>> from matfree import decomp
+>>>
+>>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10))
+>>> A = M.T @ M
+>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,))
+>>>
+>>> # Compute a matrix-logarithm with Lanczos' algorithm
+>>> matfun = dense_funm_sym_eigh(jnp.log)
+>>> tridiag = decomp.tridiag_sym(4)
+>>> matfun_vec = funm_lanczos_sym(matfun, tridiag)
+>>> fAx = matfun_vec(lambda s: A @ s, v)
+>>> print(fAx.shape)
+(10,)
+
matfree.funm.dense_funm_pade_exp()
+
+Implement dense matrix-exponentials using a Pade approximation.
+Use it to construct one of the matrix-free matrix-function implementations, +e.g. matfree.funm.funm_arnoldi.
+ +matfree/funm.py
328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 |
|
matfree.funm.dense_funm_product_svd(matfun)
+
+Implement dense matrix-functions of a product of matrices via SVDs.
+ +matfree/funm.py
284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 |
|
matfree.funm.dense_funm_schur(matfun)
+
+Implement dense matrix-functions via symmetric Schur decompositions.
+Use it to construct one of the matrix-free matrix-function implementations, +e.g. matfree.funm.funm_lanczos_sym.
+ +matfree/funm.py
315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 |
|
matfree.funm.dense_funm_sym_eigh(matfun)
+
+Implement dense matrix-functions via symmetric eigendecompositions.
+Use it to construct one of the matrix-free matrix-function implementations, +e.g. matfree.funm.funm_lanczos_sym.
+ +matfree/funm.py
300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 |
|
matfree.funm.funm_arnoldi(dense_funm: Callable, hessenberg: Callable) -> Callable
+
+Implement a matrix-function-vector product via the Arnoldi iteration.
+This algorithm uses the Arnoldi iteration +and therefore applies only to all square matrices.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dense_funm
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of a function of a dense matrix. +For example, the output of +funm.dense_funm_sym_eigh +funm.dense_funm_schur + |
+ + required + | +
+ hessenberg
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of Hessenberg-factorisation. +E.g., the output of +decomp.hessenberg. + |
+ + required + | +
matfree/funm.py
147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 |
|
matfree.funm.funm_chebyshev(matfun: Callable, num_matvecs: int, matvec: Callable) -> Callable
+
+Compute a matrix-function-vector product via Chebyshev's algorithm.
+This function assumes that the spectrum of the matrix-vector product +is contained in the interval (-1, 1), and that the matrix-function +is analytic on this interval. If this is not the case, +transform the matrix-vector product and the matrix-function accordingly.
+ +matfree/funm.py
44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 |
|
matfree.funm.funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable) -> Callable
+
+Implement a matrix-function-vector product via Lanczos' tridiagonalisation.
+This algorithm uses Lanczos' tridiagonalisation +and therefore applies only to symmetric matrices.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dense_funm
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of a function of a dense matrix. +For example, the output of +funm.dense_funm_sym_eigh +funm.dense_funm_schur + |
+ + required + | +
+ tridiag_sym
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of tridiagonalisation. +E.g., the output of +decomp.tridiag_sym. + |
+ + required + | +
matfree/funm.py
116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 |
|
matfree.funm.integrand_funm_product(dense_funm, algorithm)
+
+Construct the integrand for matrix-function-trace estimation.
+Instead of the trace of a function of a matrix, +compute the trace of a function of the product of matrices. +Here, "product" refers to \(X = A^\top A\).
+ +matfree/funm.py
254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 |
|
matfree.funm.integrand_funm_product_logdet(bidiag: Callable)
+
+Construct the integrand for the log-determinant of a matrix-product.
+Here, "product" refers to \(X = A^\top A\).
+ +matfree/funm.py
234 +235 +236 +237 +238 +239 +240 |
|
matfree.funm.integrand_funm_product_schatten_norm(power, bidiag: Callable)
+
+Construct the integrand for the \(p\)-th power of the Schatten-p norm.
+ +matfree/funm.py
243 +244 +245 +246 +247 +248 +249 +250 +251 |
|
matfree.funm.integrand_funm_sym(dense_funm, tridiag_sym)
+
+Construct the integrand for matrix-function-trace estimation.
+This function assumes a symmetric matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dense_funm
+ |
+ + | +
+
+
+ An implementation of a function of a dense matrix. +For example, the output of +funm.dense_funm_sym_eigh +funm.dense_funm_schur + |
+ + required + | +
+ tridiag_sym
+ |
+ + | +
+
+
+ An implementation of tridiagonalisation. +E.g., the output of +decomp.tridiag_sym. + |
+ + required + | +
matfree/funm.py
195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 |
|
matfree.funm.integrand_funm_sym_logdet(tridiag_sym: Callable)
+
+Construct the integrand for the log-determinant.
+This function assumes a symmetric, positive definite matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ tridiag_sym
+ |
+
+ Callable
+ |
+
+
+
+ An implementation of tridiagonalisation. +E.g., the output of +decomp.tridiag_sym. + |
+ + required + | +
matfree/funm.py
178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 |
|
matfree.low_rank
+
+
+Low-rank approximations (like partial Cholesky decompositions) of matrices.
+ + + + + + + + +matfree.low_rank.cholesky_partial(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable
+
+Compute a partial Cholesky factorisation.
+ +matfree/low_rank.py
56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 |
|
matfree.low_rank.cholesky_partial_pivot(mat_el: Callable, /, *, nrows: int, rank: int) -> Callable
+
+Compute a partial Cholesky factorisation with pivoting.
+ +matfree/low_rank.py
97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 |
|
matfree.low_rank.preconditioner(cholesky: Callable) -> Callable
+
+Turn a low-rank approximation into a preconditioner.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ cholesky
+ |
+
+ Callable
+ |
+
+
+
+ (Partial) Cholesky decomposition. +Usually, the result of either +cholesky_partial +or +cholesky_partial_pivot. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ solve
+ |
+
+
+
+ A function that computes +\[
+(v, s, *p) \mapsto (sI + L(*p) L(*p)^\top)^{-1} v,
+\]
+where \(K = [k(i,j,*p)]_{ij} \approx L(*p) L(*p)^\top\) +and \(L\) comes from the low-rank approximation. + |
+
matfree/low_rank.py
7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 |
|
matfree.stochtrace
+
+
+Stochastic estimation of traces, diagonals, and more.
+ + + + + + + + +matfree.stochtrace.estimator(integrand: Callable, /, sampler: Callable) -> Callable
+
+Construct a stochastic trace-/diagonal-estimator.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ integrand
+ |
+
+ Callable
+ |
+
+
+
+ The integrand function. For example, the return-value of +integrand_trace. +But any other integrand works, too. + |
+ + required + | +
+ sampler
+ |
+
+ Callable
+ |
+
+
+
+ The sample function. Usually, either +sampler_normal or +sampler_rademacher. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ estimate
+ |
+
+
+
+ A function that maps a random key to an estimate. +This function can be compiled, vectorised, differentiated, +or looped over as the user desires. + |
+
matfree/stochtrace.py
7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 |
|
matfree.stochtrace.integrand_diagonal()
+
+Construct the integrand for estimating the diagonal.
+When plugged into the Monte-Carlo estimator,
+the result will be an Array or PyTree of Arrays with the
+same tree-structure as
+matvec(*args_like)
+where *args_like
is an argument of the sampler.
matfree/stochtrace.py
38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 |
|
matfree.stochtrace.integrand_frobeniusnorm_squared()
+
+Construct the integrand for estimating the squared Frobenius norm.
+ +matfree/stochtrace.py
85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
matfree.stochtrace.integrand_trace()
+
+Construct the integrand for estimating the trace.
+ +matfree/stochtrace.py
59 +60 +61 +62 +63 +64 +65 +66 +67 +68 |
|
matfree.stochtrace.integrand_trace_and_diagonal()
+
+Construct the integrand for estimating the trace and diagonal jointly.
+ +matfree/stochtrace.py
71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 |
|
matfree.stochtrace.integrand_wrap_moments(integrand, /, moments)
+
+Wrap an integrand into another integrand that computes moments.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ integrand
+ |
+ + | +
+
+
+ Any integrand function compatible with Hutchinson-style estimation. + |
+ + required + | +
+ moments
+ |
+ + | +
+
+
+ Any Pytree (tuples, lists, dictionaries) whose leafs that are
+valid inputs to |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ integrand
+ |
+
+
+
+ An integrand function compatible with Hutchinson-style estimation whose
+output has a PyTree-structure that mirrors the structure of the |
+
matfree/stochtrace.py
96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
matfree.stochtrace.sampler_normal(*args_like, num)
+
+Construct a function that samples from a standard-normal distribution.
+ +matfree/stochtrace.py
130 +131 +132 |
|
matfree.stochtrace.sampler_rademacher(*args_like, num)
+
+Construct a function that samples from a Rademacher distribution.
+ +matfree/stochtrace.py
135 +136 +137 |
|
matfree.test_util
+
+
+Test utilities.
+ + + + + + + + +matfree.test_util.assert_allclose(a, b)
+
+Assert that two arrays are close.
+This function uses a different default tolerance to +jax.numpy.allclose. Instead of fixing values, the tolerance +depends on the floating-point precision of the input variables.
+ +matfree/test_util.py
62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 |
|
matfree.test_util.assert_columns_orthonormal(Q)
+
+Assert that the columns in a matrix are orthonormal.
+ +matfree/test_util.py
55 +56 +57 +58 +59 |
|
matfree.test_util.asymmetric_matrix_from_singular_values(vals, /, nrows, ncols)
+
+Generate an asymmetric matrix with specific singular values.
+ +matfree/test_util.py
25 +26 +27 +28 +29 +30 |
|
matfree.test_util.symmetric_matrix_from_eigenvalues(eigvals)
+
+Generate a symmetric matrix with prescribed eigenvalues.
+ +matfree/test_util.py
6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 |
|
matfree.test_util.to_dense_bidiag(d, e, /, offset=1)
+
+Materialize a bidiagonal matrix.
+ +matfree/test_util.py
33 +34 +35 +36 +37 |
|
matfree.test_util.to_dense_tridiag_sym(d, e)
+
+Materialize a symmetric tridiagonal matrix.
+ +matfree/test_util.py
40 +41 +42 +43 +44 +45 |
|
matfree.test_util.tree_random_like(key, pytree, *, generate_func=prng.normal)
+
+Fill a tree with random values.
+ +matfree/test_util.py
48 +49 +50 +51 +52 |
|
To install all test-related dependencies (assuming JAX is installed; if not, run pip install .[cpu]
), execute
+
pip install .[test]
+
make test
+
Install all formatting-related dependencies via +
pip install .[format-and-lint]
+pre-commit install
+
make format-and-lint
+
Install the documentation-related dependencies as
+pip install .[doc]
+
make doc-preview
+
Check whether the docs can be built correctly via
+make doc-build
+
Contributions are welcome!
+Issues:
+Most contributions start with an issue. +Please don't hesitate to create issues in which you ask for features, give performance feedback, or simply want to reach out.
+Pull requests:
+To make a pull request, proceed as follows:
+pip install .[full]
or pip install -e .[full]
.make test
. Check out make format-and-lint
as well. Use the pre-commit hook if you like.When making a pull request, keep in mind the following (rough) guidelines:
+Here is what GitHub considers important for informative pull requests.
+ + + + + + + + + + + + + +Write a new tutorial:
+To add a new tutorial, create a Python file in tutorials/
and fill it with content.
+Use docstrings (mirror the style in the existing tutorials).
+Make sure to satisfy the formatting- and linting-requirements.
+That's all.
Then, the documentation pipeline will automatically convert those into a format compatible with Jupytext, which is subsequently included in the documentation. +If you do not want to make the tutorial part of the documentation, make the filename +have a leading underscore.
+Extend the developer documentation:
+To extend the developer documentation, create a new section in the README. +Use a second-level header (which is a header starting with "##") and fill the section +with content. +Then, the documentation pipeline will turn this section into a page in the developer documentation.
+Create a new module:
+To make a new module appear in the documentation, create the new module in matfree/
,
+and fill it with content.
+Unless the module name starts with an underscore or is placed in the backend,
+the documentation pipeline will take care of the rest.
Matfree is a research project, and parts of its API may change frequently and without warning.
+This stage of development aligns with its current (0.y.z) version. +To quote from semantic versioning:
+++Major version zero (0.y.z) is for initial development. Anything MAY change at any time. The public API SHOULD NOT be considered stable.
+
Matfree does not implement an official deprecation policy (just yet) but handles all API change communication via version increments:
+v0.1.5
to v0.1.6
.v0.1.6
to v0.2.0
.Log-determinant estimation can be implemented with stochastic Lanczos quadrature, +which can be loosely interpreted as an extension of Hutchinson's trace estimator.
+import jax
+import jax.numpy as jnp
+
from matfree import decomp, funm, stochtrace
+
Set up a matrix.
+nhidden, nrows = 6, 5
+A = jnp.reshape(jnp.arange(1.0, 1.0 + nhidden * nrows), (nhidden, nrows))
+A /= nhidden * nrows
+
def matvec(x):
+ """Compute a matrix-vector product."""
+ return A.T @ (A @ x) + x
+
x_like = jnp.ones((nrows,), dtype=float) # use to determine shapes of vectors
+
Estimate log-determinants with stochastic Lanczos quadrature.
+num_matvecs = 3
+tridiag_sym = decomp.tridiag_sym(num_matvecs)
+problem = funm.integrand_funm_sym_logdet(tridiag_sym)
+sampler = stochtrace.sampler_normal(x_like, num=1_000)
+estimator = stochtrace.estimator(problem, sampler=sampler)
+logdet = estimator(matvec, jax.random.PRNGKey(1))
+print(logdet)
+
2.3622565 ++
For comparison:
+print(jnp.linalg.slogdet(A.T @ A + jnp.eye(nrows)))
+
SlogdetResult(sign=Array(1., dtype=float32), logabsdet=Array(2.4568148, dtype=float32)) ++
We can compute the log determinant of a matrix +of the form $M = B^\top B$, purely based +on arithmetic with $B$; no need to assemble $M$:
+A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows**2), (nrows, nrows))
+A += jnp.eye(nrows)
+A /= nrows**2
+
def matvec_half(x):
+ """Compute a matrix-vector product."""
+ return A @ x
+
num_matvecs = 3
+bidiag = decomp.bidiag(num_matvecs)
+problem = funm.integrand_funm_product_logdet(bidiag)
+sampler = stochtrace.sampler_normal(x_like, num=1_000)
+estimator = stochtrace.estimator(problem, sampler=sampler)
+logdet = estimator(matvec_half, jax.random.PRNGKey(1))
+print(logdet)
+
-22.779821 ++
Internally, Matfree uses JAX's vector-Jacobian products to +turn the matrix-vector product into a vector-matrix product.
+For comparison:
+print(jnp.linalg.slogdet(A.T @ A))
+
SlogdetResult(sign=Array(1., dtype=float32), logabsdet=Array(-21.758816, dtype=float32)) ++
Can we compute log-determinants if the matrix-vector +products are pytree-valued? +Yes, we can. Matfree natively supports PyTrees.
+import jax
+import jax.numpy as jnp
+
from matfree import decomp, funm, stochtrace
+
Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple). +Its (regularised) Gauss--Newton Hessian shall be the matrix-vector product +whose log-determinant we estimate.
+def testfunc(x):
+ """Map a dictionary to a tuple with some arbitrary values."""
+ return jnp.linalg.norm(x["weights"]), x["bias"]
+
Create a test-input
+b = jnp.arange(1.0, 40.0)
+W = jnp.stack([b + 1.0, b + 2.0])
+x0 = {"weights": W, "bias": b}
+
Linearise the functions
+f0, jvp = jax.linearize(testfunc, x0)
+_f0, vjp = jax.vjp(testfunc, x0)
+print(jax.tree.map(jnp.shape, f0))
+print(jax.tree.map(jnp.shape, jvp(x0)))
+print(jax.tree.map(jnp.shape, vjp(f0)))
+
((), (39,)) +((), (39,)) +({'bias': (39,), 'weights': (2, 39)},) ++
Use the same API as if the matrix-vector product were array-valued. +Matfree flattens all trees internally.
+def make_matvec(alpha):
+ """Create a matrix-vector product function."""
+
+ def fun(fx, /):
+ r"""Matrix-vector product with $J J^\top + \alpha I$."""
+ vjp_eval = vjp(fx)
+ matvec_eval = jvp(*vjp_eval)
+ return jax.tree.map(lambda x, y: x + alpha * y, matvec_eval, fx)
+
+ return fun
+
matvec = make_matvec(alpha=0.1)
+num_matvecs = 3
+tridiag_sym = decomp.tridiag_sym(num_matvecs)
+integrand = funm.integrand_funm_sym_logdet(tridiag_sym)
+sample_fun = stochtrace.sampler_normal(f0, num=10)
+estimator = stochtrace.estimator(integrand, sampler=sample_fun)
+key = jax.random.PRNGKey(1)
+logdet = estimator(matvec, key)
+print(logdet)
+
3.9901187 ++
For reference: flatten all arguments +and compute the dense log-determinant:
+f0_flat, unravel_func_f = jax.flatten_util.ravel_pytree(f0)
+
def make_matvec_flat(alpha):
+ """Create a flattened matrix-vector-product function."""
+
+ def fun(f_flat):
+ """Evaluate a flattened matrix-vector product."""
+ f_unravelled = unravel_func_f(f_flat)
+ vjp_eval = vjp(f_unravelled)
+ matvec_eval = jvp(*vjp_eval)
+ f_eval, _unravel_func = jax.flatten_util.ravel_pytree(matvec_eval)
+ return f_eval + alpha * f_flat
+
+ return fun
+
matvec_flat = make_matvec_flat(alpha=0.1)
+M = jax.jacfwd(matvec_flat)(f0_flat)
+print(jnp.linalg.slogdet(M))
+
SlogdetResult(sign=Array(1., dtype=float32), logabsdet=Array(3.812408, dtype=float32)) ++
Computing higher moments of trace-estimates can easily +be turned into uncertainty quantification.
+import jax
+import jax.numpy as jnp
+
from matfree import stochtrace
+
A = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
+
def matvec(x):
+ """Evaluate a matrix-vector product."""
+ return A.T @ (A @ x) + x
+
x_like = jnp.ones((6,))
+num_samples = 10_000
+
Trace estimation involves estimating expected values of random variables. +Sometimes, second and higher moments of a random variable are interesting.
+normal = stochtrace.sampler_normal(x_like, num=num_samples)
+integrand = stochtrace.integrand_trace()
+integrand = stochtrace.integrand_wrap_moments(integrand, [1, 2])
+estimator = stochtrace.estimator(integrand, sampler=normal)
+first, second = estimator(matvec, jax.random.PRNGKey(1))
+
For normally-distributed base-samples, +we know that the variance is twice the squared Frobenius norm.
+print(second - first**2)
+print(2 * jnp.linalg.norm(A.T @ A + jnp.eye(6), ord="fro") ** 2)
+
322.09515 +321.78638 ++
Variance estimation leads to uncertainty quantification: +The variance of the estimator is equal to the variance of the random variable +divided by the number of samples.
+variance = (second - first**2) / num_samples
+print(variance)
+
0.032209516 ++
Here is how to implement control variates.
+import jax
+import jax.numpy as jnp
+
from matfree import stochtrace
+
Create a matrix to whose trace/diagonal to approximate.
+nrows, ncols = 4, 4
+A = jnp.reshape(jnp.arange(1.0, 1.0 + nrows * ncols), (nrows, ncols))
+
Set up the sampler.
+x_like = jnp.ones((ncols,), dtype=float)
+sample_fun = stochtrace.sampler_normal(x_like, num=10_000)
+
First, compute the diagonal.
+problem = stochtrace.integrand_diagonal()
+estimate = stochtrace.estimator(problem, sample_fun)
+diagonal_ctrl = estimate(lambda v: A @ v, jax.random.PRNGKey(1))
+
Then, compute trace and diagonal jointly +using the estimate of the diagonal as a control variate.
+def matvec_ctrl(v):
+ """Evaluate a matrix-vector product with a control variate."""
+ return A @ v - diagonal_ctrl * v
+
problem = stochtrace.integrand_trace_and_diagonal()
+estimate = stochtrace.estimator(problem, sample_fun)
+trace_and_diagonal = estimate(matvec_ctrl, jax.random.PRNGKey(2))
+trace, diagonal = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
+
We can, of course, compute it without a control variate as well.
+problem = stochtrace.integrand_trace_and_diagonal()
+estimate = stochtrace.estimator(problem, sample_fun)
+trace_and_diagonal = estimate(lambda v: A @ v, jax.random.PRNGKey(2))
+trace_ref, diagonal_ref = trace_and_diagonal["trace"], trace_and_diagonal["diagonal"]
+
Compare the results. +First, the diagonal.
+print("True value:", jnp.diag(A))
+print("Control variate:", diagonal_ctrl, jnp.linalg.norm(jnp.diag(A) - diagonal_ctrl))
+print("Approximation:", diagonal_ref, jnp.linalg.norm(jnp.diag(A) - diagonal_ref))
+print(
+ "Control-variate approximation:",
+ diagonal + diagonal_ctrl,
+ jnp.linalg.norm(jnp.diag(A) - diagonal - diagonal_ctrl),
+)
+
True value: [ 1. 6. 11. 16.] +Control variate: [ 1.0441655 5.7610655 10.704792 15.640959 ] 0.52449846 +Approximation: [ 1.0695375 5.773019 11.284297 15.853068 ] 0.39845905 +Control-variate approximation: [ 1.0738611 5.8862205 11.102717 15.987889 ] 0.17058367 ++
Then, the trace.
+print("True value:", jnp.trace(A))
+print(
+ "Control variate:",
+ jnp.sum(diagonal_ctrl),
+ jnp.abs(jnp.trace(A) - jnp.sum(diagonal_ctrl)),
+)
+print("Approximation:", trace_ref, jnp.abs(jnp.trace(A) - trace_ref))
+print(
+ "Control variate approximation:",
+ trace + jnp.sum(diagonal_ctrl),
+ jnp.abs(jnp.trace(A) - trace - jnp.sum(diagonal_ctrl)),
+)
+
True value: 34.0 +Control variate: 33.15098 0.8490181 +Approximation: 33.97992 0.020080566 +Control variate approximation: 34.05069 0.050689697 ++
Implementing vector calculus with conventional +algorithmic differentiation can be inefficient. +For example, computing the divergence of a +vector field requires computing the trace of a Jacobian. +The divergence of a vector field is +important when evaluating Laplacians of scalar functions.
+Here is how we can implement divergences and +Laplacians without forming full Jacobian matrices:
+import jax
+import jax.numpy as jnp
+
from matfree import stochtrace
+
The divergence of a vector field is the trace of its Jacobian. +The conventional implementation would look like this:
+def divergence_dense(vf):
+ """Compute the divergence of a vector field."""
+
+ def div_fn(x):
+ J = jax.jacfwd(vf)
+ return jnp.trace(J(x))
+
+ return div_fn
+
This implementation computes the divergence of a vector field:
+def fun(x):
+ """Evaluate a scalar valued function."""
+ return jnp.dot(x, x) ** 2
+
x0 = jnp.arange(1.0, 4.0)
+gradient = jax.grad(fun)
+laplacian = divergence_dense(gradient)
+print(jax.hessian(fun)(x0))
+print(laplacian(x0))
+
[[ 64. 16. 24.] + [ 16. 88. 48.] + [ 24. 48. 128.]] +280.0 ++
But the implementation above requires $O(d^2)$ storage +because it evaluates the dense Jacobian. +This is problematic for high-dimensional problems.
+If we have access to Jacobian-vector products (which we usually do), +we can use matrix-free trace estimation +to approximate divergences and Laplacians without forming full Jacobians:
+def divergence_matfree(vf, /, *, num):
+ """Compute the divergence with Hutchinson's estimator."""
+
+ def divergence(k, x):
+ _fx, jvp = jax.linearize(vf, x)
+ integrand_laplacian = stochtrace.integrand_trace()
+ normal = stochtrace.sampler_normal(x, num=num)
+ estimator = stochtrace.estimator(integrand_laplacian, sampler=normal)
+ return estimator(jvp, k)
+
+ return divergence
+
laplacian_matfree = divergence_matfree(gradient, num=10_000)
+print(laplacian(x0))
+print(laplacian_matfree(jax.random.PRNGKey(1), x0))
+
280.0 ++
281.30106 ++
In summary, compute matrix-free linear algebra +and algorithmic differentiation to implement vector calculus.
+If we replace trace estimation with diagonal estimation, +we can compute the diagonal of Jacobian matrices in +$O(d)$ memory and $O(dN)$ operations.
+Matfree's implementation of stochastic trace estimation +via Hutchinson's method defaults to computing all +Monte-Carlo samples at once, because this is the fastest +implementation as long as all samples fit into memory.
+Some matrix-vector products, however, are so large that +we can only store a single sample in memory at once. +Here is how we can wrap calls around the trace estimators +in such a scenario to save memory.
+import functools
+
import jax
+import jax.numpy as jnp
+
from matfree import stochtrace
+
The conventional setup for estimating the trace of a large matrix +would look like this.
+nrows = 100 # but imagine nrows=100,000,000,000 instead
+nsamples = 1_000
+
def large_matvec(v):
+ """Evaluate a (dummy for a) large matrix-vector product."""
+ return 1.2345 * v
+
integrand = stochtrace.integrand_trace()
+x0 = jnp.ones((nrows,))
+sampler = stochtrace.sampler_rademacher(x0, num=nsamples)
+estimate = stochtrace.estimator(integrand, sampler)
+
key = jax.random.PRNGKey(1)
+trace = estimate(large_matvec, key)
+print(trace)
+
123.4499 ++
The above code requires nrows $\times$ nsamples storage, which +is prohibitive for extremely large matrices. +Instead, we can loop around estimate() to do the following: +The below code requires nrows $\times$ 1 storage:
+sampler = stochtrace.sampler_rademacher(x0, num=1)
+estimate = stochtrace.estimator(integrand, sampler)
+estimate = functools.partial(estimate, large_matvec)
+
key = jax.random.PRNGKey(2)
+keys = jax.random.split(key, num=nsamples)
+traces = jax.lax.map(estimate, keys)
+trace = jnp.mean(traces)
+print(trace)
+
123.4499 ++
In practice, we often combine both approaches by choosing +the largest nsamples (in the first implementation) so that +nrows $\times$ nsamples fits into memory, and handle all samples beyond +that via the split-and-map combination.
+If we reverse-mode differentiate through the sampler, we have to +be careful because by default, reverse-mode differentiation +stores all intermediate results (and the memory-efficiency of using +jax.lax.map is void). +To solve this problem, place a jax.checkpoint around the estimator:
+traces = jax.lax.map(jax.checkpoint(estimate), keys)
+trace = jnp.mean(traces)
+print(trace)
+
123.4499 ++
This implementation recomputes the forward pass for each key during the +backward pass, but preserves the memory-efficiency on the backward pass.
+In summary, memory efficiency can be achieved by calling estimators +inside jax.lax.map (with or without checkpoints).
+Sometimes, we need to compute matrix exponentials, log-determinants, +or similar functions of matrices, but our matrices are too big to +use functions from +scipy.linalg +or +jax.scipy.linalg. +However, matrix-free linear algebra scales to even the largest of matrices. +Here is how to use Matfree to compute functions of large matrices.
+import functools
+
import jax
+
from matfree import decomp, funm
+
n = 7 # imagine n = 10^5 or larger
+
key = jax.random.PRNGKey(1)
+key, subkey = jax.random.split(key, num=2)
+large_matrix = jax.random.normal(subkey, shape=(n, n))
+
The expected value is computed with jax.scipy.linalg.
+key, subkey = jax.random.split(key, num=2)
+vector = jax.random.normal(subkey, shape=(n,))
+expected = jax.scipy.linalg.expm(large_matrix) @ vector
+print(expected)
+
[ 0.5121861 1.0731273 -1.1475035 -1.6931866 0.06646963 -1.1467085 + 0.66265297] ++
Instead of using jax.scipy.linalg, we can use matrix-vector products +in combination with the Arnoldi iteration to approximate the +matrix-function-vector product.
+def large_matvec(v):
+ """Evaluate a matrix-vector product."""
+ return large_matrix @ v
+
num_matvecs = 5
+arnoldi = decomp.hessenberg(num_matvecs, reortho="full")
+dense_funm = funm.dense_funm_pade_exp()
+matfun_vec = funm.funm_arnoldi(dense_funm, arnoldi)
+received = matfun_vec(large_matvec, vector)
+print(received)
+
[ 0.5136445 1.0897965 -1.1209555 -1.7069302 0.03098169 -1.1719893 + 0.67968863] ++
The matrix-function vector product can be combined with all usual +JAX transformations. For example, after fixing the matvec-function +as the first argument, we can vectorize the matrix function with jax.vmap +and compile it with jax.jit.
+matfun_vec = functools.partial(matfun_vec, large_matvec)
+key, subkey = jax.random.split(key, num=2)
+vector_batch = jax.random.normal(subkey, shape=(5, n)) # a batch of 5 vectors
+received = jax.jit(jax.vmap(matfun_vec))(vector_batch)
+print(received.shape)
+
(5, 7) ++
Talking about function transformations: we can also +reverse-mode-differentiate the matrix functions efficiently.
+jac = jax.jacrev(matfun_vec)(vector)
+print(jac)
+
[[ 3.68775666e-01 3.48348975e-01 -1.14449523e-01 -3.22446883e-01 + 3.28712702e-01 -6.60334349e-01 3.08125526e-01] + [ 8.88347626e-04 9.77235258e-01 2.68623352e+00 -5.51655173e-01 + -1.45154142e+00 -1.11724639e+00 7.45091677e-01] + [ 4.17882234e-01 -9.98488367e-01 -3.91192406e-01 8.76782537e-01 + -9.65307474e-01 5.19365370e-01 -6.68987870e-01] + [ 2.65466452e-01 -8.89071941e-01 -2.17203140e+00 7.52809644e-01 + 4.79240775e-01 8.03415000e-01 -8.45992625e-01] + [-4.26323414e-01 -8.46019328e-01 -2.89584970e+00 1.10395364e-01 + 2.57722950e+00 1.75358319e+00 -3.07614803e-01] + [-1.35615468e-01 -5.94067991e-01 -1.90474641e+00 1.77025393e-01 + 1.02040839e+00 7.22389579e-01 -3.67944658e-01] + [-3.23790073e-01 1.21016252e+00 1.78035736e+00 -1.12524259e+00 + -1.80692703e-01 -1.32690465e+00 1.32771575e+00]] ++
Under the hood, reverse-mode derivatives of Arnoldi- and Lanczos-based +matrix functions use the fast algorithm for gradients of the +Lanczos and Arnoldi iterations from +this paper. +Please consider citing it if you use reverse-mode derivatives +functions of matrices +(a BibTex is here).
+