Skip to content

Commit

Permalink
Eigengame updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Nov 28, 2022
1 parent 004e30e commit c509ad3
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 16 deletions.
1 change: 0 additions & 1 deletion cca_zoo/models/_rcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ class CCA(rCCA):
accept_sparse : Union[bool, str], optional
Whether to accept sparse data, by default None
References
--------
Expand Down
37 changes: 33 additions & 4 deletions cca_zoo/models/_stochastic/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Iterable

import numpy as np
from scipy.linalg import block_diag
from torch.utils import data

from cca_zoo.data.deep import NumpyDataset
Expand Down Expand Up @@ -62,6 +63,10 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
views = self._validate_inputs(views)
self._check_params()
dataset = NumpyDataset(views)
if self.val_split is not None:
train_size = int((1 - self.val_split) * len(dataset))
val_size = len(dataset) - train_size
dataset, val_dataset = data.random_split(dataset, [train_size, val_size])
dataloader = data.DataLoader(
dataset,
batch_size=self.batch_size,
Expand All @@ -74,6 +79,11 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
timeout=self.timeout,
worker_init_fn=self.worker_init_fn,
)
if self.val_split is not None:
val_dataloader = data.DataLoader(
val_dataset,
batch_size=len(val_dataset),
)
self.track = []
self.weights = [
np.random.rand(view.shape[1], self.latent_dims) for view in views
Expand All @@ -83,7 +93,9 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):
for _ in range(self.epochs):
for i, sample in enumerate(dataloader):
self.update([view.numpy() for view in sample["views"]])
self.track.append(self.objective(sample["views"]))
if self.val_split is not None:
for i, sample in enumerate(val_dataloader):
self.track.append(self.objective(sample["views"]))
return self

@abstractmethod
Expand All @@ -95,9 +107,26 @@ def objective(self, views, **kwargs):
return self.tcc(views)

def tv(self, views):
z = self.transform(views)
return PLS(self.latent_dims).fit(z).score(z).sum()
#q from qr decomposition of weights
q = [np.linalg.qr(weight)[0] for weight in self.weights]
views = self._centre_scale_transform(views)
transformed_views = []
for i, (view) in enumerate(views):
transformed_view = view @ q[i]
transformed_views.append(transformed_view)
return tv(transformed_views)


def tcc(self, views):
z = self.transform(views)
return CCA(self.latent_dims).fit(z).score(z).sum()
return tcc(z)

def tv(z):
all_z = np.hstack(z)
C = np.cov(all_z, rowvar=False)
C -= block_diag(*[np.cov(z_, rowvar=False) for z_ in z])
C /= z[0].shape[0]
return np.linalg.svd(C, compute_uv=False).sum()

def tcc(z):
return CCA(z[0].shape[1]).fit(z).score(z).sum()
8 changes: 7 additions & 1 deletion cca_zoo/models/_stochastic/_eigengame.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
epochs=1,
learning_rate=0.01,
c=0,
**kwargs
):
super().__init__(
latent_dims=latent_dims,
Expand All @@ -92,6 +93,7 @@ def __init__(
worker_init_fn=worker_init_fn,
epochs=epochs,
learning_rate=learning_rate,
**kwargs
)
self.c = c

Expand All @@ -108,7 +110,11 @@ def update(self, views):
Aw = self._Aw(view, projections.sum(axis=0).filled())
projections.mask[i] = False
Bw = self._Bw(view, projections[i].filled(), self.weights[i], self.c[i])
grads = 2 * Aw - (Aw @ np.triu(self.weights[i].T @ Bw) + Bw @ np.triu(self.weights[i].T @ Aw))
wAw = self.weights[i].T @ Aw
wBw = self.weights[i].T @ Bw
wAw[np.diag_indices_from(wAw)] = np.where(np.diag(wAw) > 0, np.diag(wAw), 0)
wBw[np.diag_indices_from(wBw)] = np.where(np.diag(wAw) > 0, np.diag(wBw), 0)
grads = 2 * Aw - (Aw @ np.triu(wBw) + Bw @ np.triu(wAw))
self.weights[i] += self.learning_rate * grads

def _Aw(self, view, projections):
Expand Down
11 changes: 9 additions & 2 deletions cca_zoo/models/_stochastic/_ghagep.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class RCCAGHAGEP(_BaseStochastic):
----------
Chapman, James, Ana Lawry Aguila, and Lennie Wells. "A Generalized EigenGame with Extensions to Multiview Representation Learning." arXiv preprint arXiv:2211.11323 (2022).
"""

def __init__(
self,
latent_dims: int = 1,
Expand All @@ -71,6 +72,7 @@ def __init__(
epochs=1,
learning_rate=0.01,
c=0,
**kwargs
):
super().__init__(
latent_dims=latent_dims,
Expand All @@ -90,6 +92,7 @@ def __init__(
worker_init_fn=worker_init_fn,
epochs=epochs,
learning_rate=learning_rate,
**kwargs
)
self.c = c

Expand All @@ -106,14 +109,16 @@ def update(self, views):
Aw = self._Aw(view, projections.sum(axis=0).filled())
projections.mask[i] = False
Bw = self._Bw(view, projections[i].filled(), self.weights[i], self.c[i])
grads = (Aw - Bw @ np.triu(self.weights[i].T @ Aw))
wAw = self.weights[i].T @ Aw
wAw[np.diag_indices_from(wAw)] = np.where(np.diag(wAw) > 0, np.diag(wAw), 0)
grads = (Aw - Bw @ np.triu(wAw))
self.weights[i] += self.learning_rate * grads

def _Aw(self, view, projections):
return view.T @ projections / view.shape[0]

def _Bw(self, view, projection, weight, c):
return (c * weight + (1 - c) * view.T @ projection) / projection.shape[0]
return (c * weight) + (1 - c) * (view.T @ projection) / projection.shape[0]

def objective(self, views, **kwargs):
return self.tcc(views)
Expand Down Expand Up @@ -164,6 +169,7 @@ class CCAGHAGEP(RCCAGHAGEP):
----------
Chapman, James, Ana Lawry Aguila, and Lennie Wells. "A Generalized EigenGame with Extensions to Multiview Representation Learning." arXiv preprint arXiv:2211.11323 (2022).
"""

def __init__(
self,
*args, **kwargs,
Expand Down Expand Up @@ -217,6 +223,7 @@ class PLSGHAGEP(RCCAGHAGEP):
----------
Chapman, James, Ana Lawry Aguila, and Lennie Wells. "A Generalized EigenGame with Extensions to Multiview Representation Learning." arXiv preprint arXiv:2211.11323 (2022).
"""

def __init__(
self,
*args, **kwargs,
Expand Down
11 changes: 8 additions & 3 deletions cca_zoo/models/_stochastic/_stochasticpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,14 @@ def update(self, views):
self.weights[i] += (
self.learning_rate * (view.T @ projections.sum(axis=0).filled()) / view.shape[0]
)
self.weights = [
weight / np.linalg.norm(weight, axis=0) for weight in self.weights
]
#qr decomposition of weights for orthogonality
self.weights[i] = self._orth(self.weights[i])

@staticmethod
def _orth(U):
Qu, Ru = np.linalg.qr(U)
Su = np.sign(np.sign(np.diag(Ru)) + 0.5)
return (Qu @ np.diag(Su))

def objective(self, views, **kwargs):
return self.tv(views)
12 changes: 7 additions & 5 deletions cca_zoo/test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,17 @@ def test_partialcca():
def test_stochastic_pls():
pytest.importorskip("torch")
from cca_zoo.models import PLSGHAGEP, PLSEigenGame, PLSStochasticPower, IncrementalPLS
pls = PLS(latent_dims=1).fit((X, Y))
ipls = IncrementalPLS(latent_dims=1, epochs=100, simple=False, batch_size=10).fit(
from torch import manual_seed
manual_seed(42)
pls = PLS(latent_dims=3).fit((X, Y))
ipls = IncrementalPLS(latent_dims=3, epochs=150, simple=False, batch_size=10).fit(
(X, Y)
)
spls = PLSStochasticPower(latent_dims=1, epochs=100, batch_size=10, learning_rate=1e-2).fit(
spls = PLSStochasticPower(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit(
(X, Y)
)
egpls = PLSEigenGame(latent_dims=1, epochs=100, batch_size=10, learning_rate=1e-2).fit((X, Y))
ghapls = PLSGHAGEP(latent_dims=1, epochs=100, batch_size=10, learning_rate=1e-2).fit((X, Y))
egpls = PLSEigenGame(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit((X, Y))
ghapls = PLSGHAGEP(latent_dims=3, epochs=150, batch_size=10, learning_rate=1e-2).fit((X, Y))
pls_score = pls.score((X, Y))
ipls_score = ipls.score((X, Y))
spls_score = spls.score((X, Y))
Expand Down

0 comments on commit c509ad3

Please sign in to comment.