diff --git a/cca_zoo/models/_stochastic/_base.py b/cca_zoo/models/_stochastic/_base.py index 76422727..33c0d546 100644 --- a/cca_zoo/models/_stochastic/_base.py +++ b/cca_zoo/models/_stochastic/_base.py @@ -86,7 +86,7 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): ) self.track = [] self.weights = [ - np.random.rand(view.shape[1], self.latent_dims) for view in views + self.random_state.normal(0, 1, size=(view.shape[1],self.latent_dims)) for view in views ] # normalize weights self.weights = [weight / np.linalg.norm(weight, axis=0) for weight in self.weights] diff --git a/cca_zoo/test/test_models.py b/cca_zoo/test/test_models.py index 98b08f4f..4ae60a97 100644 --- a/cca_zoo/test/test_models.py +++ b/cca_zoo/test/test_models.py @@ -226,14 +226,14 @@ def test_stochastic_pls(): from torch import manual_seed manual_seed(42) pls = PLS(latent_dims=3).fit((X, Y)) - ipls = IncrementalPLS(latent_dims=3, epochs=150, simple=False, batch_size=10).fit( + ipls = IncrementalPLS(latent_dims=3, epochs=150, simple=False, batch_size=10, random_state=1).fit( (X, Y) ) - spls = PLSStochasticPower(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit( + spls = PLSStochasticPower(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2, random_state=1).fit( (X, Y) ) - egpls = PLSEigenGame(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit((X, Y)) - ghapls = PLSGHAGEP(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit((X, Y)) + egpls = PLSEigenGame(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2, random_state=1).fit((X, Y)) + ghapls = PLSGHAGEP(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2, random_state=1).fit((X, Y)) pls_score = pls.score((X, Y)) ipls_score = ipls.score((X, Y)) spls_score = spls.score((X, Y))