Skip to content

Commit

Permalink
Implement optional matvec-parameters for Hutchinson's estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Dec 18, 2023
1 parent d2d89f1 commit ba18b49
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
24 changes: 12 additions & 12 deletions matfree/hutchinson.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def integrand_diagonal(matvec, /):
where ``*args_like`` is an argument of the sampler.
"""

def integrand(v, /):
Qv = matvec(v)
def integrand(v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
return unflatten(v_flat * Qv_flat)
Expand All @@ -31,8 +31,8 @@ def integrand(v, /):
def integrand_trace(matvec, /):
"""Construct the integrand for estimating the trace."""

def integrand(v, /):
Qv = matvec(v)
def integrand(v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
return linalg.vecdot(v_flat, Qv_flat)
Expand All @@ -43,8 +43,8 @@ def integrand(v, /):
def integrand_trace_and_diagonal(matvec, /):
"""Construct the integrand for estimating the trace and diagonal jointly."""

def integrand(v, /):
Qv = matvec(v)
def integrand(v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
trace_form = linalg.vecdot(v_flat, Qv_flat)
Expand All @@ -57,8 +57,8 @@ def integrand(v, /):
def integrand_frobeniusnorm_squared(matvec, /):
"""Construct the integrand for estimating the squared Frobenius norm."""

def integrand(vec, /):
x = matvec(vec)
def integrand(vec, *parameters):
x = matvec(vec, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(x)
return linalg.vecdot(v_flat, v_flat)

Expand All @@ -71,8 +71,8 @@ def integrand_trace_moments(matvec, moments, /):
def moment_fun(x):
return tree_util.tree_map(lambda m: x**m, moments)

def integrand(vec, /):
x = matvec(vec)
def integrand(vec, *parameters):
x = matvec(vec, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(vec)
x_flat, _unflatten = tree_util.ravel_pytree(x)
fx = linalg.vecdot(x_flat, v_flat)
Expand Down Expand Up @@ -139,9 +139,9 @@ def hutchinson(integrand_fun, /, sample_fun, stats_fun=np.mean):
"""

def sample(key):
def sample(key, *parameters):
samples = sample_fun(key)
Qs = func.vmap(integrand_fun)(samples)
Qs = func.vmap(lambda vec: integrand_fun(vec, *parameters))(samples)
return tree_util.tree_map(lambda s: stats_fun(s, axis=0), Qs)

return sample
10 changes: 5 additions & 5 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def integrand_slq_spd(matfun, order, matvec, /):
This function assumes a symmetric, positive definite matrix.
"""

def quadform(v0, /):
def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Av = matvec(v)
Av = matvec(v, *parameters)
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

Expand Down Expand Up @@ -83,12 +83,12 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /):
Here, "product" refers to $X = A^\top A$.
"""

def quadform(v0, /):
def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Av = matvec(v)
Av = matvec(v, *parameters)
flat, unflatten = tree_util.ravel_pytree(Av)
return flat, tree_util.partial_pytree(unflatten)

Expand All @@ -97,7 +97,7 @@ def matvec_flat(v_flat):

def vecmat_flat(w_flat):
w = w_unflatten(w_flat)
wA = vecmat(w)
wA = vecmat(w, *parameters)
return tree_util.ravel_pytree(wA)[0]

# Decompose into orthogonal-bidiag-orthogonal
Expand Down

0 comments on commit ba18b49

Please sign in to comment.