From 8403d71dc6bffcb954e386d64eb38f21dbddb1ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 29 Aug 2024 10:48:43 +0200 Subject: [PATCH] Make the DecompResult type compatible with python3.9 --- matfree/backend/typing.py | 10 +++++++++- matfree/decomp.py | 14 +++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/matfree/backend/typing.py b/matfree/backend/typing.py index c57b00b..48256c4 100644 --- a/matfree/backend/typing.py +++ b/matfree/backend/typing.py @@ -2,7 +2,15 @@ # fmt: off from collections.abc import Callable # noqa: F401 -from typing import Any, Generic, Iterable, Sequence, Tuple, TypeVar # noqa: F401, UP035 +from typing import ( # noqa: F401, UP035 + Any, + Generic, + Iterable, + Sequence, + Tuple, + TypeVar, + Union, +) from jax import Array # noqa: F401 diff --git a/matfree/decomp.py b/matfree/decomp.py index 8e2613c..f7147f3 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -9,12 +9,20 @@ """ from matfree.backend import containers, control_flow, func, linalg, np, tree_util -from matfree.backend.typing import Array, Callable +from matfree.backend.typing import Array, Callable, Union class _DecompResult(containers.NamedTuple): - Q_tall: Array | tuple[Array, ...] - J_small: Array | tuple[Array, ...] + # If an algorithm returns a single Q, place it here. + # If it returns multiple Qs, stack them + # into a tuple and place them here. + Q_tall: Union[Array, tuple[Array, ...]] + + # If an algorithm returns a materialized matrix, + # place it here. If it returns a sparse representation + # (e.g. two vectors representing diagonals), place it here + J_small: Union[Array, tuple[Array, ...]] + residual: Array init_length_inv: Array