diff --git a/matfree/hutchinson.py b/matfree/hutchinson.py index 1908bd2..aea93c2 100644 --- a/matfree/hutchinson.py +++ b/matfree/hutchinson.py @@ -19,8 +19,8 @@ def integrand_diagonal(matvec, /): where ``*args_like`` is an argument of the sampler. """ - def integrand(v, /): - Qv = matvec(v) + def integrand(v, *parameters): + Qv = matvec(v, *parameters) v_flat, unflatten = tree_util.ravel_pytree(v) Qv_flat, _unflatten = tree_util.ravel_pytree(Qv) return unflatten(v_flat * Qv_flat) @@ -31,8 +31,8 @@ def integrand(v, /): def integrand_trace(matvec, /): """Construct the integrand for estimating the trace.""" - def integrand(v, /): - Qv = matvec(v) + def integrand(v, *parameters): + Qv = matvec(v, *parameters) v_flat, unflatten = tree_util.ravel_pytree(v) Qv_flat, _unflatten = tree_util.ravel_pytree(Qv) return linalg.vecdot(v_flat, Qv_flat) @@ -43,8 +43,8 @@ def integrand(v, /): def integrand_trace_and_diagonal(matvec, /): """Construct the integrand for estimating the trace and diagonal jointly.""" - def integrand(v, /): - Qv = matvec(v) + def integrand(v, *parameters): + Qv = matvec(v, *parameters) v_flat, unflatten = tree_util.ravel_pytree(v) Qv_flat, _unflatten = tree_util.ravel_pytree(Qv) trace_form = linalg.vecdot(v_flat, Qv_flat) @@ -57,8 +57,8 @@ def integrand(v, /): def integrand_frobeniusnorm_squared(matvec, /): """Construct the integrand for estimating the squared Frobenius norm.""" - def integrand(vec, /): - x = matvec(vec) + def integrand(vec, *parameters): + x = matvec(vec, *parameters) v_flat, unflatten = tree_util.ravel_pytree(x) return linalg.vecdot(v_flat, v_flat) @@ -71,8 +71,8 @@ def integrand_trace_moments(matvec, moments, /): def moment_fun(x): return tree_util.tree_map(lambda m: x**m, moments) - def integrand(vec, /): - x = matvec(vec) + def integrand(vec, *parameters): + x = matvec(vec, *parameters) v_flat, unflatten = tree_util.ravel_pytree(vec) x_flat, _unflatten = tree_util.ravel_pytree(x) fx = linalg.vecdot(x_flat, v_flat) @@ -139,9 +139,9 @@ def hutchinson(integrand_fun, /, sample_fun, stats_fun=np.mean): """ - def sample(key): + def sample(key, *parameters): samples = sample_fun(key) - Qs = func.vmap(integrand_fun)(samples) + Qs = func.vmap(lambda vec: integrand_fun(vec, *parameters))(samples) return tree_util.tree_map(lambda s: stats_fun(s, axis=0), Qs) return sample diff --git a/matfree/slq.py b/matfree/slq.py index e9d6ba3..5303a9a 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -23,12 +23,12 @@ def integrand_slq_spd(matfun, order, matvec, /): This function assumes a symmetric, positive definite matrix. """ - def quadform(v0, /): + def quadform(v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) def matvec_flat(v_flat): v = v_unflatten(v_flat) - Av = matvec(v) + Av = matvec(v, *parameters) flat, unflatten = tree_util.ravel_pytree(Av) return flat @@ -83,12 +83,12 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /): Here, "product" refers to $X = A^\top A$. """ - def quadform(v0, /): + def quadform(v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) def matvec_flat(v_flat): v = v_unflatten(v_flat) - Av = matvec(v) + Av = matvec(v, *parameters) flat, unflatten = tree_util.ravel_pytree(Av) return flat, tree_util.partial_pytree(unflatten) @@ -97,7 +97,7 @@ def matvec_flat(v_flat): def vecmat_flat(w_flat): w = w_unflatten(w_flat) - wA = vecmat(w) + wA = vecmat(w, *parameters) return tree_util.ravel_pytree(wA)[0] # Decompose into orthogonal-bidiag-orthogonal