diff --git a/benchmarks/cb/decomposition.py b/benchmarks/cb/decomposition.py new file mode 100644 index 000000000..44d9cf1c4 --- /dev/null +++ b/benchmarks/cb/decomposition.py @@ -0,0 +1,17 @@ +# flake8: noqa +import heat as ht +from mpi4py import MPI +from perun import monitor +from heat.decomposition import IncrementalPCA + + +@monitor() +def incremental_pca_split0(list_of_X, n_components): + ipca = IncrementalPCA(n_components=n_components) + for X in list_of_X: + ipca.partial_fit(X) + + +def run_decomposition_benchmarks(): + list_of_X = [ht.random.rand(50000, 500, split=0) for _ in range(10)] + incremental_pca_split0(list_of_X, 50) diff --git a/benchmarks/cb/main.py b/benchmarks/cb/main.py index 52cd18d76..a4decaaa1 100644 --- a/benchmarks/cb/main.py +++ b/benchmarks/cb/main.py @@ -10,8 +10,10 @@ from cluster import run_cluster_benchmarks from manipulations import run_manipulation_benchmarks from preprocessing import run_preprocessing_benchmarks +from decomposition import run_decomposition_benchmarks run_linalg_benchmarks() run_cluster_benchmarks() run_manipulation_benchmarks() run_preprocessing_benchmarks() +run_decomposition_benchmarks() diff --git a/heat/core/linalg/qr.py b/heat/core/linalg/qr.py index f3cc5afe5..a0d9559cb 100644 --- a/heat/core/linalg/qr.py +++ b/heat/core/linalg/qr.py @@ -154,6 +154,7 @@ def qr( for i in range(last_row_reached + 1): # this loop goes through all the column-blocks (i.e. local arrays) of the matrix # this corresponds to the loop over all columns in classical Gram-Schmidt + if i < nprocs - 1: k_loc_i = min(A.shape[0], A.lshape_map[i, 1]) Q_buf = torch.zeros( @@ -172,8 +173,7 @@ def qr( if i < nprocs - 1: # broadcast the orthogonalized block of columns to all other processes - req = A.comm.Ibcast(Q_buf, root=i) - req.Wait() + A.comm.Bcast(Q_buf, root=i) if A.comm.rank > i: # subtract the contribution of the current block of columns from the remaining columns diff --git a/heat/core/linalg/svdtools.py b/heat/core/linalg/svdtools.py index 70a4c838c..f217291a4 100644 --- a/heat/core/linalg/svdtools.py +++ b/heat/core/linalg/svdtools.py @@ -14,32 +14,18 @@ from ..linalg import matmul, vector_norm, qr, svd from ..indexing import where from ..random import randn - +from ..sanitation import sanitize_in_nd_realfloating from ..manipulations import vstack, hstack, diag, balance from .. import statistics from math import log, ceil, floor, sqrt -__all__ = ["hsvd_rank", "hsvd_rtol", "hsvd", "rsvd"] - - -def _check_SVD_input(A): - if not isinstance(A, DNDarray): - raise TypeError(f"Argument needs to be a DNDarray but is {type(A)}.") - if not A.ndim == 2: - raise ValueError("A needs to be a 2D matrix") - if not types.heat_type_is_realfloating(A.dtype): - raise TypeError( - "Argument needs to be a DNDarray with datatype float32 or float64, but data type is {}.".format( - A.dtype - ) - ) - return None +__all__ = ["hsvd_rank", "hsvd_rtol", "hsvd", "rsvd", "isvd"] ####################################################################################### -# user-friendly versions of hSVD +# hierachical SVD "hSVD" ####################################################################################### @@ -99,7 +85,7 @@ def hsvd_rank( [1] Iwen, Ong. A distributed and incremental SVD algorithm for agglomerative data analysis on large networks. SIAM J. Matrix Anal. Appl., 37(4), 2016. [2] Himpe, Leibner, Rave. Hierarchical approximate proper orthogonal decomposition. SIAM J. Sci. Comput., 40 (5), 2018. """ - _check_SVD_input(A) # check if A is suitable input + sanitize_in_nd_realfloating(A, "A", [2]) A_local_size = max(A.lshape_map[:, 1]) if maxmergedim is not None and maxmergedim < 2 * (maxrank + safetyshift) + 1: @@ -202,7 +188,7 @@ def hsvd_rtol( [1] Iwen, Ong. A distributed and incremental SVD algorithm for agglomerative data analysis on large networks. SIAM J. Matrix Anal. Appl., 37(4), 2016. [2] Himpe, Leibner, Rave. Hierarchical approximate proper orthogonal decomposition. SIAM J. Sci. Comput., 40 (5), 2018. """ - _check_SVD_input(A) # check if A is suitable input + sanitize_in_nd_realfloating(A, "A", [2]) A_local_size = max(A.lshape_map[:, 1]) if maxmergedim is not None and maxrank is None: @@ -248,11 +234,6 @@ def hsvd_rtol( ) -################################################################################################ -# hSVD - "full" routine for the experts -################################################################################################ - - def hsvd( A: DNDarray, maxrank: Optional[int] = None, @@ -334,7 +315,7 @@ def hsvd( "\t\t".join(["%d" % an for an in active_nodes]), ) - U_loc, sigma_loc, err_squared_loc = compute_local_truncated_svd( + U_loc, sigma_loc, err_squared_loc = _compute_local_truncated_svd( level, A.comm.rank, A.larray, maxrank, loc_atol, safetyshift ) U_loc = torch.matmul(U_loc, torch.diag(sigma_loc)) @@ -412,7 +393,7 @@ def hsvd( if len(future_nodes) == 1: safetyshift = 0 - U_loc, sigma_loc, err_squared_loc_new = compute_local_truncated_svd( + U_loc, sigma_loc, err_squared_loc_new = _compute_local_truncated_svd( level, A.comm.rank, U_loc, maxrank, loc_atol, safetyshift ) @@ -466,12 +447,7 @@ def hsvd( return U, rel_error_estimate -############################################################################################## -# AUXILIARY ROUTINES -############################################################################################## - - -def compute_local_truncated_svd( +def _compute_local_truncated_svd( level: int, proc_id: int, U_loc: torch.Tensor, @@ -528,7 +504,7 @@ def compute_local_truncated_svd( ############################################################################################## -# Randomized SVD +# Randomized SVD "rSVD" ############################################################################################## @@ -568,7 +544,7 @@ def rsvd( ----------- [1] Halko, N., Martinsson, P. G., & Tropp, J. A. (2011). Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions. SIAM review, 53(2), 217-288. """ - _check_SVD_input(A) # check if A is suitable input + sanitize_in_nd_realfloating(A, "A", [2]) if not isinstance(rank, int): raise TypeError(f"rank must be an integer, but is {type(rank)}.") if rank < 1: @@ -614,3 +590,199 @@ def rsvd( V = V[:, :rank] V.balance_() return U, S, V + + +############################################################################################## +# Incremental SVD "iSVD" +############################################################################################## + + +def _isvd( + new_data: DNDarray, + U_old: DNDarray, + S_old: DNDarray, + V_old: Optional[DNDarray] = None, + maxrank: Optional[int] = None, + old_matrix_size: Optional[int] = None, + old_rowwise_mean: Optional[DNDarray] = None, +) -> Union[Tuple[DNDarray, DNDarray, DNDarray], Tuple[DNDarray, DNDarray, DNDarray, DNDarray]]: + """ + Helper function for iSVD and iPCA; follows roughly the "incremental PCA with mean update", Fig.1 in: + David A. Ross, Jongwoo Lim, Ruei-Sung Lin, Ming-Hsuan Yang. Incremental Learning for Robust Visual Tracking. IJCV, 2008. + + Either incremental SVD / PCA or incremental SVD / PCA with mean subtraction is performed. + + Parameters + ----------- + new_data: DNDarray + new data as DNDarray + U_old, S_old, V_old: DNDarrays + "old" SVD-factors + if no V_old is provided, only U and S are computed (PCA) + maxrank: int, optional + rank to which new SVD should be truncated + old_matrix_size: int, optional + size of the old matrix; this does not need to be identical to V_old.shape[0] as "old" SVD might have been truncated + old_rowwise_mean: int, optional + row-wise mean of the old matrix; if not provided, no mean subtraction is performed + """ + # old SVD is SVD of a matrix of dimension m x n as has rank r + # new data have shape m x d + d = new_data.shape[1] + n = V_old.shape[0] if V_old is not None else old_matrix_size + r = S_old.shape[0] + if maxrank is None: + maxrank = min(n + d, U_old.shape[0]) + else: + maxrank = min(maxrank, min(n + d, U_old.shape[0])) + + if old_rowwise_mean is not None: + new_data_rowwise_mean = statistics.mean(new_data, axis=1) + new_rowwise_mean = (old_matrix_size * old_rowwise_mean + d * new_data_rowwise_mean) / ( + old_matrix_size + d + ) + new_data -= new_data_rowwise_mean.reshape(-1, 1) + new_data = hstack( + [ + new_data, + (new_data_rowwise_mean - old_rowwise_mean) + * (d * old_matrix_size / (d + old_matrix_size)) ** 0.5, + ] + ) + d += 1 + + # orthogonalize and decompose new_data + UtC = U_old.T @ new_data + if U_old.split is not None: + new_data = new_data.resplit_(U_old.split) - U_old @ UtC + else: + new_data = new_data - (U_old @ UtC).resplit_(new_data.split) + P, Rc = qr(new_data) + + # prepare one component of "new" V-factor + if V_old is not None: + V_new = vstack( + [ + V_old, + factories.zeros( + (d, r), + device=V_old.device, + dtype=V_old.dtype, + split=V_old.split, + comm=V_old.comm, + ), + ] + ) + helper = vstack( + [ + factories.zeros( + (n, d), + device=V_old.device, + dtype=V_old.dtype, + split=V_old.split, + comm=V_old.comm, + ), + factories.eye( + d, device=V_old.device, dtype=V_old.dtype, split=V_old.split, comm=V_old.comm + ), + ] + ) + V_new = hstack([V_new, helper]) + del helper + + # prepare one component of "new" U-factor + U_new = hstack([U_old, P]) + + # prepare "inner" matrix that needs to be decomposed, decompose it + helper1 = vstack( + [ + diag(S_old), + factories.zeros( + (Rc.shape[0] + UtC.shape[0] - r, r), + device=S_old.device, + dtype=S_old.dtype, + split=S_old.split, + comm=S_old.comm, + ), + ] + ) + if r > d: + Rc = Rc.resplit_(UtC.split) + else: + UtC = UtC.resplit_(Rc.split) + helper2 = vstack([UtC, Rc]) + innermat = hstack([helper1, helper2]) + del (helper1, helper2) + # as innermat is small enough to fit into memory of a single process, we can use torch svd + u, s, v = svd.svd(innermat.resplit_(None)) + del innermat + + # truncate if desired + if maxrank < s.shape[0]: + u = u[:, :maxrank] + s = s[:maxrank] + v = v[:, :maxrank] + + U_new = U_new @ u + if V_old is not None: + V_new = V_new @ v + + if V_old is not None: # use-case: SVD + return U_new, s, V_new + if old_rowwise_mean is not None: # use-case PCA + return U_new, s, new_rowwise_mean + + +def isvd( + new_data: DNDarray, + U_old: DNDarray, + S_old: DNDarray, + V_old: DNDarray, + maxrank: Optional[int] = None, +) -> Tuple[DNDarray, DNDarray, DNDarray]: + r"""Incremental SVD (iSVD) for the addition of new data to an existing SVD. + Given the the SVD of an "old" matrix, :math:`X_\textnormal{old} = `U_\textnormal{old} \cdot S_\textnormal{old} \cdot V_\textnormal{old}^T`, and additional columns :math:`N` (\"`new_data`\"), this routine computes + (a possibly approximate) SVD of the extended matrix :math:`X_\textnormal{new} = [ X_\textnormal{old} | N]`. + + Parameters + ---------- + new_data : DNDarray + 2D-array (float32/64) of columns that are added to the "old" SVD. It must hold `new_data.split != 1` if `U_old.split = 0`. + U_old : DNDarray + U-factor of the SVD of the "old" matrix, 2D-array (float32/64). It must hold `U_old.split != 0` if `new_data.split = 1`. + S_old : DNDarray + Sigma-factor of the SVD of the "old" matrix, 1D-array (float32/64) + V_old : DNDarray + V-factor of the SVD of the "old" matrix, 2D-array (float32/64) + maxrank : int, optional + truncation rank of the SVD of the extended matrix. The default is None, i.e., no bound on the maximal rank is imposed. + + Notes + ----------- + Inexactness may arise due to truncation to maximal rank `maxrank` if rank of the data to be processed exceeds this rank. + If you set `maxrank` to a high number (or None) in order to avoid inexactness, you may encounter memory issues. + The implementation follows the approach described in Ref. [1], Sect. 2. + + References + ------------ + [1] Brand, M. (2006). Fast low-rank modifications of the thin singular value decomposition. Linear algebra and its applications, 415(1), 20-30. + """ + # check if new_data, U_old, V_old are 2D DNDarrays and float32/64 + sanitize_in_nd_realfloating(new_data, "new_data", [2]) + sanitize_in_nd_realfloating(U_old, "U_old", [2]) + sanitize_in_nd_realfloating(S_old, "S_old", [1]) + sanitize_in_nd_realfloating(V_old, "V_old", [2]) + # check if number of columns of U_old and V_old match the number of elements in S_old + if U_old.shape[1] != S_old.shape[0]: + raise ValueError( + "The number of columns of U_old must match the number of elements in S_old." + ) + if V_old.shape[1] != S_old.shape[0]: + raise ValueError( + "The number of columns of V_old must match the number of elements in S_old." + ) + # check if the number of columns of new_data matches the number of rows of U_old and V_old + if new_data.shape[0] != U_old.shape[0]: + raise ValueError("The number of rows of new_data must match the number of rows of U_old.") + + return _isvd(new_data, U_old, S_old, V_old, maxrank) diff --git a/heat/core/linalg/tests/test_svdtools.py b/heat/core/linalg/tests/test_svdtools.py index 4f8728bb6..4d3640955 100644 --- a/heat/core/linalg/tests/test_svdtools.py +++ b/heat/core/linalg/tests/test_svdtools.py @@ -248,3 +248,57 @@ def test_rsvd_catch_wrong_inputs(self): # power_iter negative with self.assertRaises(ValueError): ht.linalg.rsvd(X, 10, power_iter=-1) + + +class TestISVD(TestCase): + def test_isvd(self): + ht.random.seed(27183) + for dtype in [ht.float32, ht.float64]: + dtypetol = 1e-5 if dtype == ht.float32 else 1e-10 + for old_split in [0, 1, None]: + X_old, SVD_old = ht.utils.data.matrixgallery.random_known_rank( + 250, 25, 3 * ht.MPI_WORLD.size, split=old_split, dtype=dtype + ) + U_old, S_old, V_old = SVD_old + for new_split in [0, 1, None]: + new_data = ht.random.randn( + 250, 2 * ht.MPI_WORLD.size, split=new_split, dtype=dtype + ) + U_new, S_new, V_new = ht.linalg.isvd(new_data, U_old, S_old, V_old) + # check if U_new, V_new are orthogonal + self.assertTrue( + ht.allclose( + U_new.T @ U_new, + ht.eye(U_new.shape[1], dtype=U_new.dtype, split=U_new.split), + atol=dtypetol, + rtol=dtypetol, + ) + ) + self.assertTrue( + ht.allclose( + V_new.T @ V_new, + ht.eye(V_new.shape[1], dtype=V_new.dtype, split=V_new.split), + atol=dtypetol, + rtol=dtypetol, + ) + ) + # check if entries of S_new are positive + self.assertTrue(ht.all(S_new >= 0)) + # check if the reconstruction error is small + X_new = ht.hstack([X_old, new_data.resplit_(X_old.split)]) + X_rec = U_new @ ht.diag(S_new) @ V_new.T + self.assertTrue(ht.allclose(X_rec, X_new, atol=dtypetol, rtol=dtypetol)) + + def test_isvd_catch_wrong_inputs(self): + u_old = ht.zeros((10, 2)) + s_old = ht.zeros((3,)) + v_old = ht.zeros((5, 3)) + new_data = ht.zeros((11, 5)) + with self.assertRaises(ValueError): + ht.linalg.isvd(new_data, u_old, s_old, v_old) + s_old = ht.zeros((2,)) + with self.assertRaises(ValueError): + ht.linalg.isvd(new_data, u_old, s_old, v_old) + v_old = ht.zeros((5, 2)) + with self.assertRaises(ValueError): + ht.linalg.isvd(new_data, u_old, s_old, v_old) diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 5a5134401..9a776df5c 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -174,6 +174,23 @@ def sanitize_in(x: Any): raise TypeError(f"Input must be a DNDarray, is {type(x)}") +def sanitize_in_nd_realfloating(input: Any, inputname: str, allowed_ns: List[int]) -> None: + """ + Verify that input object ``input`` is a real floating point ``DNDarray`` with number of dimensions contained in ``allowed_ns``. + The argument ``inputname`` is used for error messages. + """ + if not isinstance(input, DNDarray): + raise TypeError(f"Argument {inputname} needs to be a DNDarray but is {type(input)}.") + if input.ndim not in allowed_ns: + raise ValueError( + f"Argument {inputname} needs to be a {allowed_ns}-dimensional, but is {input.ndim}-dimensional." + ) + if not types.heat_type_is_realfloating(input.dtype): + raise TypeError( + f"Argument {inputname} needs to be a DNDarray with datatype float32 or float64, but data type is {input.dtype}." + ) + + def sanitize_infinity(x: Union[DNDarray, torch.Tensor]) -> Union[int, float]: """ Returns largest possible value for the ``dtype`` of the input array. diff --git a/heat/core/tests/test_sanitation.py b/heat/core/tests/test_sanitation.py index 2e79a0810..fd08a1401 100644 --- a/heat/core/tests/test_sanitation.py +++ b/heat/core/tests/test_sanitation.py @@ -14,6 +14,17 @@ def test_sanitize_in(self): with self.assertRaises(TypeError): ht.sanitize_in(np_x) + def sanitize_in_nd_realfloating(self): + x = "this is not a DNDarray" + with self.assertRaises(TypeError): + ht.sanitize_in_nd_realfloating(x, "x", [2]) + x = ht.zeros(10, 10, 10, dtype=ht.float32, split=0) + with self.assertRaises(ValueError): + ht.sanitize_in_nd_realfloating(x, "x", [1, 2]) + x = ht.zeros(10, 10, dtype=ht.int32, split=None) + with self.assertRaises(ValueError): + ht.sanitize_in_nd_realfloating(x, "x", [1, 2]) + def test_sanitize_out(self): output_shape = (4, 5, 6) output_split = 1 diff --git a/heat/decomposition/pca.py b/heat/decomposition/pca.py index 63d4b99dd..bfaad1416 100644 --- a/heat/decomposition/pca.py +++ b/heat/decomposition/pca.py @@ -4,6 +4,7 @@ import heat as ht from typing import Optional, Tuple, Union +from ..core.linalg.svdtools import _isvd try: from typing import Self @@ -277,3 +278,156 @@ def inverse_transform(self, X: ht.DNDarray) -> ht.DNDarray: ) return X @ self.components_ + self.mean_ + + +class IncrementalPCA(ht.TransformMixin, ht.BaseEstimator): + """ + Incremental Principal Component Analysis (PCA). + + This class allows for incremental updates of the PCA model. This is especially useful for large data sets that do not fit into memory. + + An example how to apply this class is given in, e.g., `benchmarks/cb/decomposition.py`. + + Parameters + ---------- + n_components : int, optional + Number of components to keep. If `n_components` is not set all components are kept (default). + copy : bool, default=True + In-place operations are not yet supported. Please set `copy=True`. + whiten : bool, default=False + Not yet supported. + batch_size : int, optional + Currently not needed and only added for API consistency and possible future extensions. + + Attributes + ---------- + components_ : DNDarray of shape (n_components, n_features) + Principal axes in feature space, representing the directions of maximum variance in the data. The components are sorted by `explained_variance_. + singular_values_ : DNDarray of shape (n_components,) + The singular values corresponding to each of the selected components. + mean_ : DNDarray of shape (n_features,) + Per-feature empirical mean, estimated from the training set. + n_components_ : int + The estimated number of components. + n_samples_seen_ : int + Number of samples processed so far. + """ + + def __init__( + self, + n_components: Optional[int] = None, + copy: bool = True, + whiten: bool = False, + batch_size: Optional[int] = None, + ): + if not copy: + raise NotImplementedError( + "In-place operations for PCA are not supported at the moment. Please set copy=True." + ) + if whiten: + raise NotImplementedError("Whitening is not yet supported. Please set whiten=False.") + if n_components is not None: + if not isinstance(n_components, int): + raise TypeError( + f"n_components must be None or an integer, but is {type(n_components)}." + ) + else: + if n_components < 1: + raise ValueError("if an integer, n_components must be greater or equal to 1.") + self.whiten = whiten + self.n_components = n_components + self.batch_size = batch_size + self.components_ = None + # self.explained_variance_ = None # not yet supported + # self.explained_variance_ratio_ = None # not yet supported + self.singular_values_ = None + self.mean_ = None + self.n_components_ = None + self.batch_size_ = None + self.n_samples_seen_ = 0 + + def fit(self, X, y=None) -> Self: + """ + Not yet implemented; please use `.partial_fit` instead. + Please open an issue on GitHub if you would like to see this method implemented and make a suggestion on how you would like to see it implemented. + """ + raise NotImplementedError( + f"You have called IncrementalPCA's `.fit`-method with an argument of type {type(X)}. \n So far, we have only implemented the method `.partial_fit` which performs a single-step update of incremental PCA. \n Please consider using `.partial_fit` for the moment, and open an issue on GitHub in which we can discuss what you would like to see implemented for the `.fit`-method." + ) + + def partial_fit(self, X: ht.DNDarray, y=None): + """ + One single step of incrementally building up the PCA. + Input X is the current batch of data that needs to be added to the existing PCA. + """ + ht.sanitize_in(X) + if y is not None: + raise ValueError( + "Argument y is ignored and just present for API consistency by convention." + ) + if self.n_samples_seen_ == 0: + # this is the first batch of data, hence we need to initialize everything + if self.n_components is None: + self.n_components_ = min(X.shape) + else: + self.n_components_ = min(X.shape[0], X.shape[1], self.n_components) + + self.mean_ = X.mean(axis=0) + X_centered = X - self.mean_ + _, S, V = ht.linalg.svd(X_centered) + self.components_ = V[:, : self.n_components_].T + self.singular_values_ = S[: self.n_components_] + self.n_samples_seen_ = X.shape[0] + + else: + # if already batches of data have been seen before, only an update is necessary + U, S, mean = _isvd( + X.T, + self.components_.T, + self.singular_values_, + V_old=None, + maxrank=self.n_components, + old_matrix_size=self.n_samples_seen_, + old_rowwise_mean=self.mean_, + ) + self.components_ = U.T + self.singular_values_ = S + self.mean_ = mean + self.n_samples_seen_ += X.shape[0] + self.n_components_ = self.components_.shape[0] + + def transform(self, X: ht.DNDarray) -> ht.DNDarray: + """ + Apply dimensionality based on PCA to X. + + Parameters + ---------- + X : DNDarray of shape (n_samples, n_features) + Data set to be transformed. + """ + ht.sanitize_in(X) + if X.shape[1] != self.mean_.shape[0]: + raise ValueError( + f"X must have the same number of features as the training data. Expected {self.mean_.shape[0]} but got {X.shape[1]}." + ) + + # center data and apply PCA + X_centered = X - self.mean_ + return X_centered @ self.components_.T + + def inverse_transform(self, X: ht.DNDarray) -> ht.DNDarray: + """ + Transform data back to its original space. + + Parameters + ---------- + X : DNDarray of shape (n_samples, n_components) + Data set to be transformed back. + """ + ht.sanitize_in(X) + if X.shape[1] != self.n_components_: + raise ValueError( + f"Dimension mismatch. Expected input of shape n_points x {self.n_components_} but got {X.shape}." + ) + + return X @ self.components_ + self.mean_ diff --git a/heat/decomposition/tests/test_pca.py b/heat/decomposition/tests/test_pca.py index 58fc361ce..118be6fc2 100644 --- a/heat/decomposition/tests/test_pca.py +++ b/heat/decomposition/tests/test_pca.py @@ -115,7 +115,6 @@ def test_pca_with_hiearchical_rtol(self): and pca.total_explained_variance_ratio_ >= 0.0 and pca.total_explained_variance_ratio_ <= 1.0 ) - print(pca.total_explained_variance_ratio_) self.assertTrue(pca.total_explained_variance_ratio_ >= ratio) if ht.MPI_WORLD.size > 1: self.assertEqual(pca.explained_variance_, None) @@ -210,3 +209,127 @@ def test_pca_randomized(self): pca = ht.decomposition.PCA(n_components=None, svd_solver="randomized", random_state=1234) self.assertEqual(ht.random.get_state()[1], 1234) + + +class TestIncrementalPCA(TestCase): + def test_incrementalpca_setup(self): + pca = ht.decomposition.IncrementalPCA(n_components=2) + + # check correct base classes + self.assertTrue(ht.is_estimator(pca)) + self.assertTrue(ht.is_transformer(pca)) + + # check correct default values + self.assertEqual(pca.n_components, 2) + self.assertEqual(pca.whiten, False) + self.assertEqual(pca.batch_size, None) + self.assertEqual(pca.components_, None) + self.assertEqual(pca.singular_values_, None) + self.assertEqual(pca.mean_, None) + self.assertEqual(pca.n_components_, None) + self.assertEqual(pca.batch_size_, None) + self.assertEqual(pca.n_samples_seen_, 0) + + # check catching of invalid parameters + # whitening and in-place are not yet supported + with self.assertRaises(NotImplementedError): + ht.decomposition.IncrementalPCA(whiten=True) + with self.assertRaises(NotImplementedError): + ht.decomposition.IncrementalPCA(copy=False) + # wrong n_components + with self.assertRaises(TypeError): + ht.decomposition.IncrementalPCA(n_components=0.9) + with self.assertRaises(ValueError): + ht.decomposition.IncrementalPCA(n_components=0) + + def test_incrementalpca_full_rank_reached_split0(self): + # full rank is reached, split = 0 + # dtype float32 + pca = ht.decomposition.IncrementalPCA() + data0 = ht.random.randn(150 * ht.MPI_WORLD.size, 2 * ht.MPI_WORLD.size + 1, split=0) + data1 = 1.0 + ht.random.rand(50 * ht.MPI_WORLD.size, 2 * ht.MPI_WORLD.size + 1, split=0) + data = ht.vstack([data0, data1]) + data0_np = data0.numpy() + data_np = data.numpy() + + # test partial_fit, step 0 + pca.partial_fit(data0) + self.assertEqual( + pca.components_.shape, (2 * ht.MPI_WORLD.size + 1, 2 * ht.MPI_WORLD.size + 1) + ) + self.assertEqual(pca.n_components_, 2 * ht.MPI_WORLD.size + 1) + self.assertEqual(pca.mean_.shape, (2 * ht.MPI_WORLD.size + 1,)) + self.assertEqual(pca.singular_values_.shape, (2 * ht.MPI_WORLD.size + 1,)) + self.assertEqual(pca.n_samples_seen_, 150 * ht.MPI_WORLD.size) + s0_np = np.linalg.svd(data0_np - data0_np.mean(axis=0), compute_uv=False, hermitian=False) + self.assertTrue(np.allclose(s0_np, pca.singular_values_.numpy())) + + # test partial_fit, step 1 + pca.partial_fit(data1) + self.assertEqual( + pca.components_.shape, (2 * ht.MPI_WORLD.size + 1, 2 * ht.MPI_WORLD.size + 1) + ) + self.assertEqual(pca.n_components_, 2 * ht.MPI_WORLD.size + 1) + self.assertTrue(ht.allclose(pca.mean_, ht.mean(data, axis=0))) + self.assertEqual(pca.singular_values_.shape, (2 * ht.MPI_WORLD.size + 1,)) + self.assertEqual(pca.n_samples_seen_, 200 * ht.MPI_WORLD.size) + s_np = np.linalg.svd(data_np - data_np.mean(axis=0), compute_uv=False, hermitian=False) + self.assertTrue(np.allclose(s_np, pca.singular_values_.numpy())) + + # test transform (only possible here, as in the next test truncation happens) + new_data = ht.random.rand(100, 2 * ht.MPI_WORLD.size + 1, split=1) + Y = pca.transform(new_data) + Z = pca.inverse_transform(Y) + self.assertTrue(ht.allclose(new_data, Z, atol=1e-4, rtol=1e-4)) + + def test_incrementalpca_truncation_happens_split1(self): + # full rank not reached, but truncation happens, split = 1 + # dtype float64 + pca = ht.decomposition.IncrementalPCA(n_components=15) + data0 = ht.random.randn(9, 100 * ht.MPI_WORLD.size + 1, split=1, dtype=ht.float64) + data1 = 1.0 + ht.random.rand(11, 100 * ht.MPI_WORLD.size + 1, split=1, dtype=ht.float64) + data = ht.vstack([data0, data1]) + data0_np = data0.numpy() + data_np = data.numpy() + + # test partial_fit, step 0 + pca.partial_fit(data0) + self.assertEqual(pca.components_.shape, (9, 100 * ht.MPI_WORLD.size + 1)) + self.assertEqual(pca.components_.dtype, ht.float64) + self.assertEqual(pca.n_components_, 9) + self.assertEqual(pca.mean_.shape, (100 * ht.MPI_WORLD.size + 1,)) + self.assertEqual(pca.mean_.dtype, ht.float64) + self.assertEqual(pca.singular_values_.shape, (9,)) + self.assertEqual(pca.singular_values_.dtype, ht.float64) + self.assertEqual(pca.n_samples_seen_, 9) + s0_np = np.linalg.svd(data0_np - data0_np.mean(axis=0), compute_uv=False, hermitian=False) + self.assertTrue(np.allclose(s0_np, pca.singular_values_.numpy(), atol=1e-12)) + + # test partial_fit, step 1 + # here actually truncation happens as we have rank 20 but n_components=15 + pca.partial_fit(data1) + self.assertEqual(pca.components_.shape, (15, 100 * ht.MPI_WORLD.size + 1)) + self.assertEqual(pca.n_components_, 15) + self.assertEqual(pca.mean_.shape, (100 * ht.MPI_WORLD.size + 1,)) + self.assertEqual(pca.singular_values_.shape, (15,)) + self.assertEqual(pca.n_samples_seen_, 20) + s_np = np.linalg.svd(data_np - data_np.mean(axis=0), compute_uv=False, hermitian=False) + self.assertTrue(np.allclose(s_np[:15], pca.singular_values_.numpy())) + + def test_incrementalpca_catch_wrong_inputs(self): + pca = ht.decomposition.IncrementalPCA(n_components=1) + data0 = ht.random.randn(15, 15, split=None) + + # fit is not yet implemented + with self.assertRaises(NotImplementedError): + pca.fit(data0) + # wrong input for partial_fit + with self.assertRaises(ValueError): + pca.partial_fit(data0, y="Why can't we get rid of this argument?") + + pca.partial_fit(data0) + # wrong inputs for transform and inverse transform + with self.assertRaises(ValueError): + pca.transform(ht.zeros((15, 16), split=None)) + with self.assertRaises(ValueError): + pca.inverse_transform(ht.zeros((17, 2), split=None))