From 970dc1104f3bdf01e615a37f7eb2a22834458885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Mon, 15 Jan 2024 10:23:35 +0100 Subject: [PATCH] Rename lanczos.integrand_* functions --- matfree/lanczos.py | 16 ++++++++-------- .../test_integrand_logdet_product.py | 4 ++-- tests/test_lanczos/test_integrand_logdet_spd.py | 4 ++-- .../test_lanczos/test_integrand_schatten_norm.py | 2 +- tutorials/1_log_determinants.py | 4 ++-- tutorials/2_pytree_logdeterminants.py | 2 +- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/matfree/lanczos.py b/matfree/lanczos.py index 7d4a69f..a9e3a90 100644 --- a/matfree/lanczos.py +++ b/matfree/lanczos.py @@ -28,15 +28,15 @@ from matfree.backend.typing import Array, Callable, Tuple -def integrand_logdet_spd(order, matvec, /): +def integrand_spd_logdet(order, matvec, /): """Construct the integrand for the log-determinant. This function assumes a symmetric, positive definite matrix. """ - return integrand_slq_spd(np.log, order, matvec) + return integrand_spd(np.log, order, matvec) -def integrand_slq_spd(matfun, order, matvec, /): +def integrand_spd(matfun, order, matvec, /): """Quadratic form for stochastic Lanczos quadrature. This function assumes a symmetric, positive definite matrix. @@ -65,25 +65,25 @@ def matvec_flat(v_flat, *p): return quadform -def integrand_logdet_product(depth, matvec, vecmat, /): +def integrand_product_logdet(depth, matvec, vecmat, /): r"""Construct the integrand for the log-determinant of a matrix-product. Here, "product" refers to $X = A^\top A$. """ - return integrand_slq_product(np.log, depth, matvec, vecmat) + return integrand_product(np.log, depth, matvec, vecmat) -def integrand_schatten_norm(power, depth, matvec, vecmat, /): +def integrand_product_schatten_norm(power, depth, matvec, vecmat, /): r"""Construct the integrand for the p-th power of the Schatten-p norm.""" def matfun(x): """Matrix-function for Schatten-p norms.""" return x ** (power / 2) - return integrand_slq_product(matfun, depth, matvec, vecmat) + return integrand_product(matfun, depth, matvec, vecmat) -def integrand_slq_product(matfun, depth, matvec, vecmat, /): +def integrand_product(matfun, depth, matvec, vecmat, /): r"""Construct the integrand for the trace of a function of a matrix-product. Instead of the trace of a function of a matrix, diff --git a/tests/test_lanczos/test_integrand_logdet_product.py b/tests/test_lanczos/test_integrand_logdet_product.py index e08bd88..402a2a4 100644 --- a/tests/test_lanczos/test_integrand_logdet_product.py +++ b/tests/test_lanczos/test_integrand_logdet_product.py @@ -31,7 +31,7 @@ def vecmat(x): x_like = {"fx": np.ones((ncols,), dtype=float)} fun = hutchinson.sampler_normal(x_like, num=400) - problem = lanczos.integrand_logdet_product(order, matvec, vecmat) + problem = lanczos.integrand_product_logdet(order, matvec, vecmat) estimate = hutchinson.hutchinson(problem, fun) received = estimate(key) @@ -53,7 +53,7 @@ def test_logdet_product_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = lanczos.integrand_logdet_product( + integrand = lanczos.integrand_product_logdet( order, lambda v: A @ v, lambda v: v @ A ) diff --git a/tests/test_lanczos/test_integrand_logdet_spd.py b/tests/test_lanczos/test_integrand_logdet_spd.py index ed9d15d..a5629a0 100644 --- a/tests/test_lanczos/test_integrand_logdet_spd.py +++ b/tests/test_lanczos/test_integrand_logdet_spd.py @@ -29,7 +29,7 @@ def matvec(x): key = prng.prng_key(1) args_like = {"fx": np.ones((n,), dtype=float)} sampler = hutchinson.sampler_normal(args_like, num=10) - integrand = lanczos.integrand_logdet_spd(order, matvec) + integrand = lanczos.integrand_spd_logdet(order, matvec) estimate = hutchinson.hutchinson(integrand, sampler) received = estimate(key) @@ -49,7 +49,7 @@ def test_logdet_spd_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = lanczos.integrand_logdet_spd(order, lambda v: A @ v) + integrand = lanczos.integrand_spd_logdet(order, lambda v: A @ v) # Construct a vector without that does not have expected 2-norm equal to "dim" x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10 diff --git a/tests/test_lanczos/test_integrand_schatten_norm.py b/tests/test_lanczos/test_integrand_schatten_norm.py index ef67700..6516890 100644 --- a/tests/test_lanczos/test_integrand_schatten_norm.py +++ b/tests/test_lanczos/test_integrand_schatten_norm.py @@ -27,7 +27,7 @@ def test_schatten_norm(A, order, power): _, ncols = np.shape(A) args_like = np.ones((ncols,), dtype=float) sampler = hutchinson.sampler_normal(args_like, num=500) - integrand = lanczos.integrand_schatten_norm( + integrand = lanczos.integrand_product_schatten_norm( power, order, lambda v: A @ v, lambda v: A.T @ v ) estimate = hutchinson.hutchinson(integrand, sampler) diff --git a/tutorials/1_log_determinants.py b/tutorials/1_log_determinants.py index 4123b31..c73e8ef 100644 --- a/tutorials/1_log_determinants.py +++ b/tutorials/1_log_determinants.py @@ -27,7 +27,7 @@ def matvec(x): # Estimate log-determinants with stochastic Lanczos quadrature. order = 3 -problem = lanczos.integrand_logdet_spd(order, matvec) +problem = lanczos.integrand_spd_logdet(order, matvec) sampler = hutchinson.sampler_normal(x_like, num=1_000) estimator = hutchinson.hutchinson(problem, sample_fun=sampler) logdet = estimator(jax.random.PRNGKey(1)) @@ -58,7 +58,7 @@ def vecmat_left(x): order = 3 -problem = lanczos.integrand_logdet_product(order, matvec_right, vecmat_left) +problem = lanczos.integrand_product_logdet(order, matvec_right, vecmat_left) sampler = hutchinson.sampler_normal(x_like, num=1_000) estimator = hutchinson.hutchinson(problem, sample_fun=sampler) logdet = estimator(jax.random.PRNGKey(1)) diff --git a/tutorials/2_pytree_logdeterminants.py b/tutorials/2_pytree_logdeterminants.py index 8415e7e..30f7ccc 100644 --- a/tutorials/2_pytree_logdeterminants.py +++ b/tutorials/2_pytree_logdeterminants.py @@ -53,7 +53,7 @@ def fun(fx, /): matvec = make_matvec(alpha=0.1) order = 3 -integrand = lanczos.integrand_logdet_spd(order, matvec) +integrand = lanczos.integrand_spd_logdet(order, matvec) sample_fun = hutchinson.sampler_normal(f0, num=10) estimator = hutchinson.hutchinson(integrand, sample_fun=sample_fun) key = jax.random.PRNGKey(1)