Skip to content

Commit

Permalink
Eigengame updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Nov 28, 2022
1 parent c509ad3 commit a55ed2c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cca_zoo/models/_stochastic/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
)
self.track = []
self.weights = [
np.random.rand(view.shape[1], self.latent_dims) for view in views
self.random_state.normal(0, 1, size=(view.shape[1],self.latent_dims)) for view in views
]
# normalize weights
self.weights = [weight / np.linalg.norm(weight, axis=0) for weight in self.weights]
Expand Down
8 changes: 4 additions & 4 deletions cca_zoo/test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,14 @@ def test_stochastic_pls():
from torch import manual_seed
manual_seed(42)
pls = PLS(latent_dims=3).fit((X, Y))
ipls = IncrementalPLS(latent_dims=3, epochs=150, simple=False, batch_size=10).fit(
ipls = IncrementalPLS(latent_dims=3, epochs=150, simple=False, batch_size=10, random_state=1).fit(
(X, Y)
)
spls = PLSStochasticPower(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit(
spls = PLSStochasticPower(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2, random_state=1).fit(
(X, Y)
)
egpls = PLSEigenGame(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit((X, Y))
ghapls = PLSGHAGEP(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit((X, Y))
egpls = PLSEigenGame(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2, random_state=1).fit((X, Y))
ghapls = PLSGHAGEP(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2, random_state=1).fit((X, Y))
pls_score = pls.score((X, Y))
ipls_score = ipls.score((X, Y))
spls_score = spls.score((X, Y))
Expand Down

0 comments on commit a55ed2c

Please sign in to comment.