Skip to content

Commit

Permalink
default line search to False (can be slow)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Feb 17, 2023
1 parent 58e7cd3 commit 98a2c10
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cca_zoo/data/deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class NumpyDataset(Dataset):
Class that turns numpy arrays into a torch dataset
"""

def __init__(self, views, labels=None, scale=False, centre=False, precision="float64"):
def __init__(self, views, labels=None, scale=False, centre=False, precision="float32"):
"""
:param views: list/tuple of numpy arrays or array likes with the same number of rows (samples)
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/models/_stochastic/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs):

def get_dataloader(self, views: Iterable[np.ndarray]):
if self.batch_size is None:
dataset = BatchNumpyDataset(views)
dataset = BatchNumpyDataset(views,precision="float64")
else:
dataset = NumpyDataset(views)
dataset = NumpyDataset(views,precision="float64")
if self.val_split is not None:
train_size = int((1 - self.val_split) * len(dataset))
val_size = len(dataset) - train_size
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/test/test_deepmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
Z = rng.rand(256, 14)
X_conv = rng.rand(256, 1, 16, 16)
Y_conv = rng.rand(256, 1, 16, 16)
dataset = NumpyDataset([X, Y, Z], scale=True, centre=True)
dataset = NumpyDataset([X, Y, Z], scale=True, centre=True, precision="float32")
check_dataset(dataset)
train_dataset, val_dataset = random_split(dataset, [200, 56])
loader = get_dataloaders(dataset)
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)
conv_dataset = NumpyDataset((X_conv, Y_conv))
conv_dataset = NumpyDataset((X_conv, Y_conv), precision="float32")
conv_loader = get_dataloaders(conv_dataset)
train_ids = train_dataset.indices
epochs = 100
Expand Down

0 comments on commit 98a2c10

Please sign in to comment.