Skip to content

Commit

Permalink
Update the decomposition calls
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Aug 29, 2024
1 parent 9033220 commit 3dd9075
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable
def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
(basis, matrix), _ = tridiag_sym(matvec, vec, *parameters)
Q, matrix, *_ = tridiag_sym(matvec, vec, *parameters)
# matrix = _todense_tridiag_sym(diag, off_diag)

funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
return length * (basis.T @ funm @ e1)
return length * (Q @ funm @ e1)

return estimate

Expand Down Expand Up @@ -205,7 +205,7 @@ def matvec_flat(v_flat, *p):
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

(_, dense), _ = algorithm(matvec_flat, v0_flat, *parameters)
_, dense, *_ = algorithm(matvec_flat, v0_flat, *parameters)

fA = dense_funm(dense)
e1 = np.eye(len(fA))[0, :]
Expand Down Expand Up @@ -264,7 +264,7 @@ def vecmat_flat(w_flat):
# Decompose into orthogonal-bidiag-orthogonal
matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731
output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters)
u, B, vt, *_ = output
_u, B, *_ = output

# Compute SVD of factorisation
# todo: turn the following lines into dense_funm_svd()
Expand Down

0 comments on commit 3dd9075

Please sign in to comment.