Skip to content

Commit

Permalink
Eigengame updates and a suprising (to me) mistake in the API that for…
Browse files Browse the repository at this point in the history
…ces users to use in the way I expected
  • Loading branch information
jameschapman19 committed Dec 23, 2022
1 parent a55ed2c commit 98c820c
Show file tree
Hide file tree
Showing 56 changed files with 1,305 additions and 976 deletions.
16 changes: 10 additions & 6 deletions cca_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from .data import *
from .model_selection import *
from .models import *
from .plotting import *
__all__ = [
"data",
"model_selection",
"models",
"plotting",]

#if can import deepmodels add to all
try:
from cca_zoo.deepmodels import *
import cca_zoo.deepmodels
__all__.append("deepmodels")
except ModuleNotFoundError:
pass
try:
from cca_zoo.probabilisticmodels import *
import cca_zoo.probabilisticmodels
__all__.append("probabilisticmodels")
except ModuleNotFoundError:
pass
13 changes: 3 additions & 10 deletions cca_zoo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
from . import simulated

try:
from . import deep
import cca_zoo.data.deep

__all__ = [
"simulated",
"deep"
]
__all__ = ["simulated", "deep"]
except ModuleNotFoundError:
__all__ = [
"simulated"
]
__all__ = ["simulated"]
59 changes: 46 additions & 13 deletions cca_zoo/data/deep.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Iterable

import numpy as np
from torch.utils.data import Dataset, DataLoader


class NumpyDataset(Dataset):
"""
Class that turns numpy arrays into a torch dataset
"""

def __init__(self, views, labels=None):
def __init__(self, views, labels=None, scale=False, centre=False):
"""
:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
"""
self.views = [view for view in views]
self.labels = labels
self.centre = centre
self.scale = scale
self.views = self._centre_scale(views)

def __len__(self):
return len(self.views[0])
Expand All @@ -27,6 +30,36 @@ def __getitem__(self, index):
else:
return {"views": views}

def _centre_scale(self, views: Iterable[np.ndarray]):
"""
Centers and scales the data
Parameters
----------
views : list/tuple of numpy arrays or array likes with the same number of rows (samples)
Returns
-------
views : list of numpy arrays
"""
self.view_means = []
self.view_stds = []
transformed_views = []
for view in views:
if self.centre:
view_mean = view.mean(axis=0)
self.view_means.append(view_mean)
view = view - self.view_means[-1]
if self.scale:
view_std = view.std(axis=0, ddof=1)
view_std[view_std == 0.0] = 1.0
self.view_stds.append(view_std)
view = view / self.view_stds[-1]
transformed_views.append(view)
return transformed_views


def check_dataset(dataset):
"""
Expand All @@ -53,16 +86,16 @@ def check_dataset(dataset):


def get_dataloaders(
dataset,
val_dataset=None,
batch_size=None,
val_batch_size=None,
drop_last=True,
val_drop_last=False,
shuffle_train=False,
pin_memory=True,
num_workers=0,
persistent_workers=True,
dataset,
val_dataset=None,
batch_size=None,
val_batch_size=None,
drop_last=True,
val_drop_last=False,
shuffle_train=False,
pin_memory=True,
num_workers=0,
persistent_workers=True,
):
"""
A utility function to allow users to quickly get hold of the dataloaders required by pytorch lightning
Expand Down
62 changes: 35 additions & 27 deletions cca_zoo/data/simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
from scipy.linalg import block_diag
from sklearn.utils.validation import check_random_state

from ..utils import _process_parameter
from cca_zoo.utils import _process_parameter


class LinearSimulatedData:
def __init__(self,
view_features: List[int],
latent_dims: int = 1,
view_sparsity: List[Union[int, float]] = None,
correlation: Union[List[float], float] = 0.99,
structure: Union[str, List[str]] = None,
sigma: Union[List[float], float] = None,
decay: float = 0.5,
positive=None,
random_state: Union[int, np.random.RandomState] = None):
def __init__(
self,
view_features: List[int],
latent_dims: int = 1,
view_sparsity: List[Union[int, float]] = None,
correlation: Union[List[float], float] = 0.99,
structure: Union[str, List[str]] = None,
sigma: Union[List[float], float] = None,
decay: float = 0.5,
positive=None,
random_state: Union[int, np.random.RandomState] = None,
):
"""
Parameters
Expand Down Expand Up @@ -57,7 +59,9 @@ def __init__(self,
self.view_sparsity = _process_parameter(
"view_sparsity", view_sparsity, 1, len(view_features)
)
self.positive = _process_parameter("positive", positive, False, len(view_features))
self.positive = _process_parameter(
"positive", positive, False, len(view_features)
)
self.sigma = _process_parameter("sigma", sigma, 0.5, len(view_features))

self.mean, covs, self.true_features = self._generate_covariance_matrices()
Expand Down Expand Up @@ -89,12 +93,12 @@ def _generate_joint_covariance(self, covs):
# Cross Bit
cross += covs[i] @ A @ covs[j]
cov[
splits[i]: splits[i] + self.view_features[i],
splits[j]: splits[j] + self.view_features[j],
splits[i] : splits[i] + self.view_features[i],
splits[j] : splits[j] + self.view_features[j],
] = cross
cov[
splits[j]: splits[j] + self.view_features[j],
splits[i]: splits[i] + self.view_features[i],
splits[j] : splits[j] + self.view_features[j],
splits[i] : splits[i] + self.view_features[i],
] = cross.T
return cov

Expand All @@ -103,7 +107,11 @@ def _generate_covariance_matrices(self):
covs = []
true_features = []
for view_p, sparsity, view_structure, view_positive, view_sigma in zip(
self.view_features, self.view_sparsity, self.structure, self.positive, self.sigma
self.view_features,
self.view_sparsity,
self.structure,
self.positive,
self.sigma,
):
cov = self._generate_covariance_matrix(view_p, view_structure, view_sigma)
weights = self.random_state.randn(view_p, self.latent_dims)
Expand Down Expand Up @@ -146,12 +154,12 @@ def _chol_sample(mean, chol, random_state):


def simple_simulated_data(
n: int,
view_features: List[int],
view_sparsity: List[Union[int, float]] = None,
eps: float = 0,
transform=False,
random_state=None,
n: int,
view_features: List[int],
view_sparsity: List[Union[int, float]] = None,
eps: float = 0,
transform=False,
random_state=None,
):
"""
Generate a simple simulated dataset with a single latent dimension
Expand Down Expand Up @@ -215,9 +223,9 @@ def _gaussian(x, mu, sig, dn):
:param dn:
"""
return (
np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
* dn
/ (np.sqrt(2 * np.pi) * sig)
np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
* dn
/ (np.sqrt(2 * np.pi) * sig)
)


Expand All @@ -233,7 +241,7 @@ def _generate_gaussian_cov(p, sigma):

def _generate_toeplitz_cov(p, sigma):
c = np.arange(0, p)
c = sigma ** c
c = sigma**c
cov = linalg.toeplitz(c, c)
return cov

Expand Down
11 changes: 9 additions & 2 deletions cca_zoo/deepmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from . import architectures
from . import callbacks
from . import objectives
from ._discriminative import DCCA, DCCA_NOI, BarlowTwins, DCCA_SDL, DTCCA, DCCA_EigenGame
from ._discriminative import (
DCCA,
DCCA_NOI,
BarlowTwins,
DCCA_SDL,
DTCCA,
DCCA_EigenGame,
)
from ._generative import DVCCA, SplitAE, DCCAE

__all__ = [
Expand All @@ -28,5 +35,5 @@
"BarlowTwins",
"DTCCA",
"SplitAE",
"DCCA_EigenGame"
"DCCA_EigenGame",
]
30 changes: 15 additions & 15 deletions cca_zoo/deepmodels/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@

class _BaseDeep(pl.LightningModule):
def __init__(
self,
latent_dims: int,
optimizer="adam",
scheduler=None,
lr=1e-3,
weight_decay=0,
extra_optimizer_kwargs=None,
max_epochs=1000,
min_lr=1e-9,
lr_decay_steps=None,
correlation=True,
*args,
**kwargs,
self,
latent_dims: int,
optimizer="adam",
scheduler=None,
lr=1e-3,
weight_decay=0,
extra_optimizer_kwargs=None,
max_epochs=1000,
min_lr=1e-9,
lr_decay_steps=None,
correlation=True,
*args,
**kwargs,
):
super().__init__()
if extra_optimizer_kwargs is None:
Expand Down Expand Up @@ -73,8 +73,8 @@ def test_step(self, batch, batch_idx):
return loss["objective"]

def transform(
self,
loader: torch.utils.data.DataLoader,
self,
loader: torch.utils.data.DataLoader,
):
"""
:param loader: a dataloader that matches the structure of that used for training
Expand Down
20 changes: 10 additions & 10 deletions cca_zoo/deepmodels/_discriminative/_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class DCCA(_BaseDeep, _BaseCCA):
"""

def __init__(
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
**kwargs,
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
**kwargs,
):
super().__init__(latent_dims=latent_dims, **kwargs)
self.encoders = torch.nn.ModuleList(encoders)
Expand All @@ -41,9 +41,9 @@ def loss(self, views, **kwargs):
return {"objective": self.objective.loss(z)}

def pairwise_correlations(
self,
loader: torch.utils.data.DataLoader,
train=False,
self,
loader: torch.utils.data.DataLoader,
train=False,
):
"""
Calculates correlation for entire batch from dataloader
Expand Down
10 changes: 5 additions & 5 deletions cca_zoo/deepmodels/_discriminative/_dcca_barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class BarlowTwins(DCCA):
"""

def __init__(
self,
latent_dims: int,
encoders=None,
lam=1,
**kwargs,
self,
latent_dims: int,
encoders=None,
lam=1,
**kwargs,
):
super().__init__(latent_dims=latent_dims, encoders=encoders, **kwargs)
self.lam = lam
Expand Down
Loading

0 comments on commit 98c820c

Please sign in to comment.