Skip to content

Commit

Permalink
Bump to version 1.7 given new pytorch 1.9 requirement. Stability for …
Browse files Browse the repository at this point in the history
…DCCA improved!!! Exciting
  • Loading branch information
jameschapman19 committed Jun 17, 2021
1 parent 27ae4e8 commit 01f4695
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class DCCA(_DCCA_base, torch.nn.Module):

def __init__(self, latent_dims: int, objective=objectives.CCA,
encoders: List[BaseEncoder] = [Encoder, Encoder],
learning_rate=1e-3, r: float = 0, eps: float = 1e-9,
learning_rate=1e-3, r: float = 1e-7, eps: float = 1e-7,
schedulers: List = None,
optimizers: List[torch.optim.Optimizer] = None):
"""
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class DCCAE(_DCCA_base):

def __init__(self, latent_dims: int, objective=objectives.MCCA,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
decoders: Iterable[BaseDecoder] = [Decoder, Decoder], r: float = 0, eps: float = 1e-9,
decoders: Iterable[BaseDecoder] = [Decoder, Decoder], r: float = 1e-7, eps: float = 1e-7,
learning_rate=1e-3, lam=0.5,
schedulers: Iterable = None, optimizers: Iterable = None):
"""
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/dtcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DTCCA(DCCA, torch.nn.Module):
"""

def __init__(self, latent_dims: int, encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
learning_rate=1e-3, r: float = 0, eps: float = 1e-9,
learning_rate=1e-3, r: float = 1e-7, eps: float = 1e-7,
schedulers: Iterable = None, optimizers: Iterable = None):
"""
Expand Down
8 changes: 4 additions & 4 deletions cca_zoo/deepmodels/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class MCCA:
"""

def __init__(self, latent_dims: int, r: float = 0, eps: float = 1e-9):
def __init__(self, latent_dims: int, r: float = 1e-7, eps: float = 1e-7):
"""
:param latent_dims: the number of latent dimensions
Expand Down Expand Up @@ -73,7 +73,7 @@ class GCCA:
"""

def __init__(self, latent_dims: int, r: float = 0, eps: float = 1e-9):
def __init__(self, latent_dims: int, r: float = 1e-7, eps: float = 1e-7):
"""
:param latent_dims: the number of latent dimensions
Expand Down Expand Up @@ -123,7 +123,7 @@ class CCA:
"""

def __init__(self, latent_dims: int, r: float = 0, eps: float = 1e-9):
def __init__(self, latent_dims: int, r: float = 1e-7, eps: float = 1e-7):
"""
:param latent_dims: the number of latent dimensions
:param r: regularisation as in regularized CCA. Makes the problem well posed when batch size is similar to the number of latent dimensions
Expand Down Expand Up @@ -173,7 +173,7 @@ class TCCA:
"""

def __init__(self, latent_dims: int, r: float = 0, eps: float = 1e-9):
def __init__(self, latent_dims: int, r: float = 1e-7, eps: float = 1e-7):
"""
:param latent_dims: the number of latent dimensions
Expand Down
12 changes: 12 additions & 0 deletions tests/testdeepmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def test_input_types(self):
dcca_model.fit(self.train_dataset, val_dataset=self.train_dataset, epochs=3)
dcca_model.fit((self.X, self.Y), val_dataset=(self.X, self.Y), epochs=3)

def test_large_p(self):
X = np.random.rand(2000, 2048)
Y = np.random.rand(2000, 2048)
latent_dims = 150
device = 'cpu'
encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=2048)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=2048)
dcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2],
objective=objectives.CCA)
dcca_model = DeepWrapper(dcca_model, device=device)
dcca_model.fit((X, Y), epochs=10)

def test_DCCA_methods_cpu(self):
latent_dims = 2
device = 'cpu'
Expand Down

0 comments on commit 01f4695

Please sign in to comment.