Skip to content

Commit

Permalink
Merge pull request #1561 from helmholtz-analytics/features/1457-Add_r…
Browse files Browse the repository at this point in the history
…andomized_SVD

Distributed randomized SVD
  • Loading branch information
mrfh92 authored Oct 18, 2024
2 parents 4b3e570 + 18a570d commit 7e44cd3
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 50 deletions.
129 changes: 107 additions & 22 deletions heat/core/linalg/svdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..dndarray import DNDarray
from .. import factories
from .. import types
from ..linalg import matmul, vector_norm
from ..linalg import matmul, vector_norm, qr, svd
from ..indexing import where
from ..random import randn

Expand All @@ -21,7 +21,21 @@
from math import log, ceil, floor, sqrt


__all__ = ["hsvd_rank", "hsvd_rtol", "hsvd"]
__all__ = ["hsvd_rank", "hsvd_rtol", "hsvd", "rsvd"]


def _check_SVD_input(A):
if not isinstance(A, DNDarray):
raise TypeError(f"Argument needs to be a DNDarray but is {type(A)}.")
if not A.ndim == 2:
raise ValueError("A needs to be a 2D matrix")
if not types.heat_type_is_realfloating(A.dtype):
raise TypeError(
"Argument needs to be a DNDarray with datatype float32 or float64, but data type is {}.".format(
A.dtype
)
)
return None


#######################################################################################
Expand Down Expand Up @@ -85,16 +99,7 @@ def hsvd_rank(
[1] Iwen, Ong. A distributed and incremental SVD algorithm for agglomerative data analysis on large networks. SIAM J. Matrix Anal. Appl., 37(4), 2016.
[2] Himpe, Leibner, Rave. Hierarchical approximate proper orthogonal decomposition. SIAM J. Sci. Comput., 40 (5), 2018.
"""
if not isinstance(A, DNDarray):
raise TypeError(f"Argument needs to be a DNDarray but is {type(A)}.")
if not A.ndim == 2:
raise ValueError("A needs to be a 2D matrix")
if not A.dtype == types.float32 and not A.dtype == types.float64:
raise TypeError(
"Argument needs to be a DNDarray with datatype float32 or float64, but data type is {}.".format(
A.dtype
)
)
_check_SVD_input(A) # check if A is suitable input
A_local_size = max(A.lshape_map[:, 1])

if maxmergedim is not None and maxmergedim < 2 * (maxrank + safetyshift) + 1:
Expand Down Expand Up @@ -197,16 +202,7 @@ def hsvd_rtol(
[1] Iwen, Ong. A distributed and incremental SVD algorithm for agglomerative data analysis on large networks. SIAM J. Matrix Anal. Appl., 37(4), 2016.
[2] Himpe, Leibner, Rave. Hierarchical approximate proper orthogonal decomposition. SIAM J. Sci. Comput., 40 (5), 2018.
"""
if not isinstance(A, DNDarray):
raise TypeError(f"Argument needs to be a DNDarray but is {type(A)}.")
if not A.ndim == 2:
raise ValueError("A needs to be a 2D matrix")
if not A.dtype == types.float32 and not A.dtype == types.float64:
raise TypeError(
"Argument needs to be a DNDarray with datatype float32 or float64, but data type is {}.".format(
A.dtype
)
)
_check_SVD_input(A) # check if A is suitable input
A_local_size = max(A.lshape_map[:, 1])

if maxmergedim is not None and maxrank is None:
Expand Down Expand Up @@ -529,3 +525,92 @@ def compute_local_truncated_svd(
sigma_loc = torch.zeros(1, dtype=U_loc.dtype, device=U_loc.device)
U_loc = torch.zeros(U_loc.shape[0], 1, dtype=U_loc.dtype, device=U_loc.device)
return U_loc, sigma_loc, err_squared_loc


##############################################################################################
# Randomized SVD
##############################################################################################


def rsvd(
A: DNDarray,
rank: int,
n_oversamples: int = 10,
power_iter: int = 0,
qr_procs_to_merge: int = 2,
) -> Union[Tuple[DNDarray, DNDarray, DNDarray], Tuple[DNDarray, DNDarray]]:
r"""
Randomized SVD (rSVD) with prescribed truncation rank `rank`.
If :math:`A = U \operatorname{diag}(S) V^T` is the true SVD of A, this routine computes an approximation for U[:,:rank] (and S[:rank], V[:,:rank]).
The accuracy of this approximation depends on the structure of A ("low-rank" is best) and appropriate choice of parameters.
Parameters
----------
A : DNDarray
2D-array (float32/64) of which the rSVD has to be computed.
rank : int
truncation rank. (This parameter corresponds to `n_components` in scikit-learn's TruncatedSVD.)
n_oversamples : int, optional
number of oversamples. The default is 10.
power_iter : int, optional
number of power iterations. The default is 0.
Choosing `power_iter > 0` can improve the accuracy of the SVD approximation in the case of slowly decaying singular values, but increases the computational cost.
qr_procs_to_merge : int, optional
number of processes to merge at each step of QR decomposition in the power iteration (if power_iter > 0). The default is 2. See the corresponding remarks for :func:`heat.linalg.qr() <heat.core.linalg.qr.qr()>` for more details.
Notes
------
Memory requirements: the SVD computation of a matrix of size (rank + n_oversamples) x (rank + n_oversamples) must fit into the memory of a single process.
The implementation follows Algorithm 4.4 (randomized range finder) and Algorithm 5.1 (direct SVD) in [1].
References
-----------
[1] Halko, N., Martinsson, P. G., & Tropp, J. A. (2011). Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions. SIAM review, 53(2), 217-288.
"""
_check_SVD_input(A) # check if A is suitable input
if not isinstance(rank, int):
raise TypeError(f"rank must be an integer, but is {type(rank)}.")
if rank < 1:
raise ValueError(f"rank must be positive, but is {rank}.")
if not isinstance(n_oversamples, int):
raise TypeError(
f"if provided, n_oversamples must be an integer, but is {type(n_oversamples)}."
)
if n_oversamples < 0:
raise ValueError(f"n_oversamples must be non-negative, but is {n_oversamples}.")
if not isinstance(power_iter, int):
raise TypeError(f"if provided, power_iter must be an integer, but is {type(power_iter)}.")
if power_iter < 0:
raise ValueError(f"power_iter must be non-negative, but is {power_iter}.")

ell = rank + n_oversamples
q = power_iter

# random matrix
splitOmega = 1 if A.split == 0 else 0
Omega = randn(A.shape[1], ell, dtype=A.dtype, device=A.device, split=splitOmega)

# compute the range of A
Y = matmul(A, Omega)
Q, _ = qr(Y, procs_to_merge=qr_procs_to_merge)

# power iterations
for _ in range(q):
Y = matmul(A.T, Q)
Q, _ = qr(Y, procs_to_merge=qr_procs_to_merge)
Y = matmul(A, Q)
Q, _ = qr(Y, procs_to_merge=qr_procs_to_merge)

# compute the SVD of the projected matrix
B = matmul(Q.T, A)
B.resplit_(
None
) # B will be of size ell x ell and thus small enough to fit into memory of a single process
U, sigma, V = svd.svd(B) # actually just torch svd as input is not split anymore
U = matmul(Q, U)[:, :rank]
U.balance_()
S = sigma[:rank]
V = V[:, :rank]
V.balance_()
return U, S, V
55 changes: 55 additions & 0 deletions heat/core/linalg/tests/test_svdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,58 @@ def test_hsvd_rank_part2(self):
self.assertTrue(U_orth_err <= dtype_tol)
self.assertTrue(V_orth_err <= dtype_tol)
self.assertTrue(true_rel_err <= dtype_tol)


class TestRSVD(TestCase):
def test_rsvd(self):
for dtype in [ht.float32, ht.float64]:
dtype_tol = 1e-4 if dtype == ht.float32 else 1e-10
for split in [0, 1, None]:
X = ht.random.randn(200, 200, dtype=dtype, split=split)
for rank in [ht.MPI_WORLD.size, 10]:
for n_oversamples in [5, 10]:
for power_iter in [0, 1, 2, 3]:
U, S, V = ht.linalg.rsvd(
X, rank, n_oversamples=n_oversamples, power_iter=power_iter
)
self.assertEqual(U.shape, (X.shape[0], rank))
self.assertEqual(S.shape, (rank,))
self.assertEqual(V.shape, (X.shape[1], rank))
self.assertTrue(ht.all(S >= 0))
self.assertTrue(
ht.allclose(
U.T @ U,
ht.eye(rank, dtype=U.dtype, split=U.split),
rtol=dtype_tol,
atol=dtype_tol,
)
)
self.assertTrue(
ht.allclose(
V.T @ V,
ht.eye(rank, dtype=V.dtype, split=V.split),
rtol=dtype_tol,
atol=dtype_tol,
)
)

def test_rsvd_catch_wrong_inputs(self):
X = ht.random.randn(10, 10)
# wrong dtype for rank
with self.assertRaises(TypeError):
ht.linalg.rsvd(X, "a")
# rank zero
with self.assertRaises(ValueError):
ht.linalg.rsvd(X, 0)
# wrong dtype for n_oversamples
with self.assertRaises(TypeError):
ht.linalg.rsvd(X, 10, n_oversamples="a")
# n_oversamples negative
with self.assertRaises(ValueError):
ht.linalg.rsvd(X, 10, n_oversamples=-1)
# wrong dtype for power_iter
with self.assertRaises(TypeError):
ht.linalg.rsvd(X, 10, power_iter="a")
# power_iter negative
with self.assertRaises(ValueError):
ht.linalg.rsvd(X, 10, power_iter=-1)
1 change: 0 additions & 1 deletion heat/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ def randint(
x_0, x_1 = __threefry32(x_0, x_1, seed=__seed)
else: # torch.int64
x_0, x_1 = __threefry64(x_0, x_1, seed=__seed)

# stack the resulting sequence and normalize to given range
values = torch.stack([x_0, x_1], dim=1).flatten()[lslice].reshape(lshape)
# ATTENTION: this is biased and known, bias-free rejection sampling is difficult to do in parallel
Expand Down
21 changes: 20 additions & 1 deletion heat/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"canonical_heat_type",
"heat_type_is_exact",
"heat_type_is_inexact",
"heat_type_is_realfloating",
"heat_type_is_complexfloating",
"iscomplex",
"isreal",
"issubdtype",
Expand Down Expand Up @@ -547,9 +549,26 @@ def heat_type_is_inexact(ht_dtype: Type[datatype]) -> bool:
return ht_dtype in _inexact


def heat_type_is_realfloating(ht_dtype: Type[datatype]) -> bool:
"""
Check if Heat type is a real floating point number, i.e float32 or float64
Parameters
----------
ht_dtype: Type[datatype]
Heat type to check
Returns
-------
out: bool
True if ht_dtype is a real float, False otherwise
"""
return ht_dtype in (float32, float64)


def heat_type_is_complexfloating(ht_dtype: Type[datatype]) -> bool:
"""
Check if HeAT type is a complex floating point number, i.e complex64
Check if Heat type is a complex floating point number, i.e complex64
Parameters
----------
Expand Down
54 changes: 33 additions & 21 deletions heat/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,36 @@ class PCA(ht.TransformMixin, ht.BaseEstimator):
svd_solver : {'full', 'hierarchical'}, default='hierarchical'
'full' : Full SVD is performed. In general, this is more accurate, but also slower. So far, this is only supported for tall-skinny or short-fat data.
'hierarchical' : Hierarchical SVD, i.e., an algorithm for computing an approximate, truncated SVD, is performed. Only available for data split along axis no. 0.
'randomized' : Randomized SVD is performed.
tol : float, default=None
Not yet necessary as iterative methods for PCA are not yet implemented.
iterated_power : {'auto', int}, default='auto'
if svd_solver='randomized', ... (not yet supported)
iterated_power : int, default=0
if svd_solver='randomized', this parameter is the number of iterations for the power method.
Choosing `iterated_power > 0` can lead to better results in the case of slowly decaying singular values but is computationally more expensive.
n_oversamples : int, default=10
if svd_solver='randomized', ... (not yet supported)
if svd_solver='randomized', this parameter is the number of additional random vectors to sample the range of X so that the range of X can be approximated more accurately.
power_iteration_normalizer : {'qr'}, default='qr'
if svd_solver='randomized', ... (not yet supported)
if svd_solver='randomized', this parameter is the normalization form of the iterated power method. So far, only QR is supported.
random_state : int, default=None
if svd_solver='randomized', ... (not yet supported)
if svd_solver='randomized', this parameter allows to set the seed for the random number generator.
Attributes
----------
components_ : DNDarray of shape (n_components, n_features)
Principal axes in feature space, representing the directions of maximum variance in the data. The components are sorted by explained_variance_.
explained_variance_ : DNDarray of shape (n_components,)
The amount of variance explained by each of the selected components.
Not supported by svd_solver='hierarchical'.
Not supported by svd_solver='hierarchical' and svd_solver='randomized'.
explained_variance_ratio_ : DNDarray of shape (n_components,)
Percentage of variance explained by each of the selected components.
Not supported by svd_solver='hierarchical'.
Not supported by svd_solver='hierarchical' and svd_solver='randomized'.
total_explained_variance_ratio_ : float
The percentage of total variance explained by the selected components together.
For svd_solver='hierarchical', an lower estimate for this quantity is provided; see :func:`ht.linalg.hsvd_rtol` and :func:`ht.linalg.hsvd_rank` for details.
Not supported by svd_solver='randomized'.
singular_values_ : DNDarray of shape (n_components,)
The singular values corresponding to each of the selected components.
Not supported by svd_solver='hierarchical'.
Not supported by svd_solver='hierarchical' and svd_solver='randomized'.
mean_ : DNDarray of shape (n_features,)
Per-feature empirical mean, estimated from the training set.
n_components_ : int
Expand All @@ -74,8 +77,9 @@ class PCA(ht.TransformMixin, ht.BaseEstimator):
Notes
------------
Hieararchical SVD (`svd_solver = "hierarchical"`) computes and approximate, truncated SVD. Thus, the results are not exact, in general, unless the
truncation rank chose is larger than the actual rank (matrix rank) of the underlying data; see :func:`ht.linalg.hsvd_rank` and :func:`ht.linalg.hsvd_rtol` for details.
Hierarchical SVD (`svd_solver = "hierarchical"`) computes an approximate, truncated SVD. Thus, the results are not exact, in general, unless the
truncation rank chosen is larger than the actual rank (matrix rank) of the underlying data; see :func:`ht.linalg.hsvd_rank` and :func:`ht.linalg.hsvd_rtol` for details.
Randomized SVD (`svd_solver = "randomized"`) is a stochastic algorithm that computes an approximate, truncated SVD.
"""

def __init__(
Expand All @@ -85,7 +89,7 @@ def __init__(
whiten: bool = False,
svd_solver: str = "hierarchical",
tol: Optional[float] = None,
iterated_power: Union[str, int] = "auto",
iterated_power: Union[str, int] = 0,
n_oversamples: int = 10,
power_iteration_normalizer: str = "qr",
random_state: Optional[int] = None,
Expand All @@ -99,10 +103,12 @@ def __init__(
raise NotImplementedError("Whitening is not yet supported. Please set whiten=False.")
if not (svd_solver == "full" or svd_solver == "hierarchical" or svd_solver == "randomized"):
raise ValueError(
"At the moment, only svd_solver='full' (for tall-skinny or short-fat data) and svd_solver='hierarchical' are supported. \n An implementation of the 'full' option for arbitrarily shaped data as well as the option 'randomized' are already planned."
"At the moment, only svd_solver='full' (for tall-skinny or short-fat data), svd_solver='hierarchical', and svd_solver='randomized' are supported. \n An implementation of the 'full' option for arbitrarily shaped data is already planned."
)
if not isinstance(iterated_power, int):
raise TypeError(
"iterated_power must be an integer. The option 'auto' is not yet supported."
)
if iterated_power != "auto" and not isinstance(iterated_power, int):
raise TypeError("iterated_power must be 'auto' or an integer.")
if isinstance(iterated_power, int) and iterated_power < 0:
raise ValueError("if an integer, iterated_power must be greater or equal to 0.")
if power_iteration_normalizer != "qr":
Expand All @@ -113,10 +119,8 @@ def __init__(
raise ValueError(
"Argument tol is not yet necessary as iterative methods for PCA are not yet implemented. Please set tol=None."
)
if random_state is None:
random_state = 0
if not isinstance(random_state, int):
raise ValueError("random_state must be None or an integer.")
if random_state is not None and not isinstance(random_state, int):
raise ValueError(f"random_state must be None or an integer, was {type(random_state)}.")
if (
n_components is not None
and not (isinstance(n_components, int) and n_components >= 1)
Expand All @@ -135,6 +139,9 @@ def __init__(
self.n_oversamples = n_oversamples
self.power_iteration_normalizer = power_iteration_normalizer
self.random_state = random_state
if self.random_state is not None:
# set random seed accordingly
ht.random.seed(self.random_state)

# set future attributes to None to initialize those that will not be computed later on with None (e.g., explained_variance_ for svd_solver='hierarchical')
self.components_ = None
Expand Down Expand Up @@ -220,10 +227,15 @@ def fit(self, X: ht.DNDarray, y=None) -> Self:
self.total_explained_variance_ratio_ = 1 - info.larray.item() ** 2

else:
# here one could add other computational backends
raise NotImplementedError(
f"The chosen svd_solver {self.svd_solver} is not yet implemented."
# compute SVD via "randomized" SVD
_, S, V = ht.linalg.rsvd(
X_centered,
self.n_components_,
n_oversamples=self.n_oversamples,
power_iter=self.iterated_power,
)
self.components_ = V.T
self.n_components_ = V.shape[1]

self.n_samples_ = X.shape[0]
self.noise_variance_ = None # not yet implemented
Expand Down
Loading

0 comments on commit 7e44cd3

Please sign in to comment.