Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename lanczos.integrand_* functions #178

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions matfree/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
from matfree.backend.typing import Array, Callable, Tuple


def integrand_logdet_spd(order, matvec, /):
def integrand_spd_logdet(order, matvec, /):
"""Construct the integrand for the log-determinant.

This function assumes a symmetric, positive definite matrix.
"""
return integrand_slq_spd(np.log, order, matvec)
return integrand_spd(np.log, order, matvec)


def integrand_slq_spd(matfun, order, matvec, /):
def integrand_spd(matfun, order, matvec, /):
"""Quadratic form for stochastic Lanczos quadrature.

This function assumes a symmetric, positive definite matrix.
Expand Down Expand Up @@ -65,25 +65,25 @@ def matvec_flat(v_flat, *p):
return quadform


def integrand_logdet_product(depth, matvec, vecmat, /):
def integrand_product_logdet(depth, matvec, vecmat, /):
r"""Construct the integrand for the log-determinant of a matrix-product.

Here, "product" refers to $X = A^\top A$.
"""
return integrand_slq_product(np.log, depth, matvec, vecmat)
return integrand_product(np.log, depth, matvec, vecmat)


def integrand_schatten_norm(power, depth, matvec, vecmat, /):
def integrand_product_schatten_norm(power, depth, matvec, vecmat, /):
r"""Construct the integrand for the p-th power of the Schatten-p norm."""

def matfun(x):
"""Matrix-function for Schatten-p norms."""
return x ** (power / 2)

return integrand_slq_product(matfun, depth, matvec, vecmat)
return integrand_product(matfun, depth, matvec, vecmat)


def integrand_slq_product(matfun, depth, matvec, vecmat, /):
def integrand_product(matfun, depth, matvec, vecmat, /):
r"""Construct the integrand for the trace of a function of a matrix-product.

Instead of the trace of a function of a matrix,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lanczos/test_integrand_logdet_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def vecmat(x):

x_like = {"fx": np.ones((ncols,), dtype=float)}
fun = hutchinson.sampler_normal(x_like, num=400)
problem = lanczos.integrand_logdet_product(order, matvec, vecmat)
problem = lanczos.integrand_product_logdet(order, matvec, vecmat)
estimate = hutchinson.hutchinson(problem, fun)
received = estimate(key)

Expand All @@ -53,7 +53,7 @@ def test_logdet_product_exact_for_full_order_lanczos(n):

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = lanczos.integrand_logdet_product(
integrand = lanczos.integrand_product_logdet(
order, lambda v: A @ v, lambda v: v @ A
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_lanczos/test_integrand_logdet_spd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def matvec(x):
key = prng.prng_key(1)
args_like = {"fx": np.ones((n,), dtype=float)}
sampler = hutchinson.sampler_normal(args_like, num=10)
integrand = lanczos.integrand_logdet_spd(order, matvec)
integrand = lanczos.integrand_spd_logdet(order, matvec)
estimate = hutchinson.hutchinson(integrand, sampler)
received = estimate(key)

Expand All @@ -49,7 +49,7 @@ def test_logdet_spd_exact_for_full_order_lanczos(n):

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = lanczos.integrand_logdet_spd(order, lambda v: A @ v)
integrand = lanczos.integrand_spd_logdet(order, lambda v: A @ v)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lanczos/test_integrand_schatten_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_schatten_norm(A, order, power):
_, ncols = np.shape(A)
args_like = np.ones((ncols,), dtype=float)
sampler = hutchinson.sampler_normal(args_like, num=500)
integrand = lanczos.integrand_schatten_norm(
integrand = lanczos.integrand_product_schatten_norm(
power, order, lambda v: A @ v, lambda v: A.T @ v
)
estimate = hutchinson.hutchinson(integrand, sampler)
Expand Down
4 changes: 2 additions & 2 deletions tutorials/1_log_determinants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def matvec(x):
# Estimate log-determinants with stochastic Lanczos quadrature.

order = 3
problem = lanczos.integrand_logdet_spd(order, matvec)
problem = lanczos.integrand_spd_logdet(order, matvec)
sampler = hutchinson.sampler_normal(x_like, num=1_000)
estimator = hutchinson.hutchinson(problem, sample_fun=sampler)
logdet = estimator(jax.random.PRNGKey(1))
Expand Down Expand Up @@ -58,7 +58,7 @@ def vecmat_left(x):


order = 3
problem = lanczos.integrand_logdet_product(order, matvec_right, vecmat_left)
problem = lanczos.integrand_product_logdet(order, matvec_right, vecmat_left)
sampler = hutchinson.sampler_normal(x_like, num=1_000)
estimator = hutchinson.hutchinson(problem, sample_fun=sampler)
logdet = estimator(jax.random.PRNGKey(1))
Expand Down
2 changes: 1 addition & 1 deletion tutorials/2_pytree_logdeterminants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def fun(fx, /):

matvec = make_matvec(alpha=0.1)
order = 3
integrand = lanczos.integrand_logdet_spd(order, matvec)
integrand = lanczos.integrand_spd_logdet(order, matvec)
sample_fun = hutchinson.sampler_normal(f0, num=10)
estimator = hutchinson.hutchinson(integrand, sample_fun=sample_fun)
key = jax.random.PRNGKey(1)
Expand Down
Loading