diff --git a/matfree/backend/typing.py b/matfree/backend/typing.py index c57b00b..48256c4 100644 --- a/matfree/backend/typing.py +++ b/matfree/backend/typing.py @@ -2,7 +2,15 @@ # fmt: off from collections.abc import Callable # noqa: F401 -from typing import Any, Generic, Iterable, Sequence, Tuple, TypeVar # noqa: F401, UP035 +from typing import ( # noqa: F401, UP035 + Any, + Generic, + Iterable, + Sequence, + Tuple, + TypeVar, + Union, +) from jax import Array # noqa: F401 diff --git a/matfree/decomp.py b/matfree/decomp.py index 8e2613c..f7147f3 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -9,12 +9,20 @@ """ from matfree.backend import containers, control_flow, func, linalg, np, tree_util -from matfree.backend.typing import Array, Callable +from matfree.backend.typing import Array, Callable, Union class _DecompResult(containers.NamedTuple): - Q_tall: Array | tuple[Array, ...] - J_small: Array | tuple[Array, ...] + # If an algorithm returns a single Q, place it here. + # If it returns multiple Qs, stack them + # into a tuple and place them here. + Q_tall: Union[Array, tuple[Array, ...]] + + # If an algorithm returns a materialized matrix, + # place it here. If it returns a sparse representation + # (e.g. two vectors representing diagonals), place it here + J_small: Union[Array, tuple[Array, ...]] + residual: Array init_length_inv: Array