Skip to content

Commit

Permalink
Merge pull request #112 from jameschapman19/partial_cca
Browse files Browse the repository at this point in the history
Partial cca
  • Loading branch information
jameschapman19 authored Dec 1, 2021
2 parents 686f76a + 79b993a commit 945602b
Show file tree
Hide file tree
Showing 31 changed files with 923 additions and 927 deletions.
62 changes: 29 additions & 33 deletions cca_zoo/data/simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@


def generate_covariance_data(
n: int,
view_features: List[int],
latent_dims: int = 1,
view_sparsity: List[Union[int, float]] = None,
correlation: Union[List[float], float] = 1,
structure: Union[str, List[str]] = None,
sigma: List[float] = None,
decay: float = 0.5,
positive=None,
random_state: Union[int, np.random.RandomState] = None,
n: int,
view_features: List[int],
latent_dims: int = 1,
view_sparsity: List[Union[int, float]] = None,
correlation: Union[List[float], float] = 1,
structure: Union[str, List[str]] = None,
sigma: Union[List[float], float] = None,
decay: float = 0.5,
positive=None,
random_state: Union[int, np.random.RandomState] = None,
):
"""
Function to generate CCA dataset with defined population correlations
Expand All @@ -29,7 +29,7 @@ def generate_covariance_data(
:param view_features: number of features in each view
:param latent_dims: number of latent dimensions
:param correlation: correlation either as list with element for each latent dimension or as float which is scaled by 'decay'
:param structure: within view covariance structure
:param structure: within view covariance structure ('identity','gaussian','toeplitz','random')
:param sigma: gaussian sigma
:param decay: ratio of second signal to first signal
:return: tuple of numpy arrays: view_1, view_2, true weights from view 1, true weights from view 2, overall covariance structure
Expand Down Expand Up @@ -58,7 +58,7 @@ def generate_covariance_data(
covs = []
true_features = []
for view_p, sparsity, view_structure, view_positive, view_sigma in zip(
view_features, view_sparsity, structure, positive, sigma
view_features, view_sparsity, structure, positive, sigma
):
# Covariance Bit
if view_structure == "identity":
Expand Down Expand Up @@ -86,12 +86,9 @@ def generate_covariance_data(
* latent_dims,
axis=0,
).T
mask = mask.flatten()
random_state.shuffle(mask)
while (
np.sum(np.unique(mask, axis=1, return_counts=True)[1] > 1) > 0
or np.sum(np.sum(mask, axis=0) == 0) > 0
):
random_state.shuffle(mask)
mask = mask.reshape(weights.shape)
weights = weights * mask
if view_positive:
weights[weights < 0] = 0
Expand All @@ -113,12 +110,12 @@ def generate_covariance_data(
# Cross Bit
cross += covs[i] @ A @ covs[j]
cov[
splits[i] : splits[i] + view_features[i],
splits[j] : splits[j] + view_features[j],
splits[i]: splits[i] + view_features[i],
splits[j]: splits[j] + view_features[j],
] = cross
cov[
splits[j] : splits[j] + view_features[j],
splits[i] : splits[i] + view_features[i],
splits[j]: splits[j] + view_features[j],
splits[i]: splits[i] + view_features[i],
] = cross.T

X = np.zeros((n, sum(view_features)))
Expand All @@ -133,12 +130,12 @@ def generate_covariance_data(


def generate_simple_data(
n: int,
view_features: List[int],
view_sparsity: List[int] = None,
eps: float = 0,
transform=True,
random_state=None,
n: int,
view_features: List[int],
view_sparsity: List[Union[int, float]] = None,
eps: float = 0,
transform=True,
random_state=None,
):
"""
Simple latent variable model to generate data with one latent factor
Expand All @@ -165,9 +162,8 @@ def generate_simple_data(
)
for p, sparsity in zip(view_features, view_sparsity):
weights = random_state.randn(p, 1)
if sparsity > 0:
if sparsity < 1:
sparsity = np.ceil(sparsity * p).astype("int")
if sparsity <= 1:
sparsity = np.ceil(sparsity * p).astype("int")
weights[random_state.choice(np.arange(p), p - sparsity, replace=False)] = 0
gaussian_x = random_state.randn(n, p) * eps
view = np.outer(z, weights)
Expand Down Expand Up @@ -200,9 +196,9 @@ def _gaussian(x, mu, sig, dn):
:param dn:
"""
return (
np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
* dn
/ (np.sqrt(2 * np.pi) * sig)
np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
* dn
/ (np.sqrt(2 * np.pi) * sig)
)


Expand Down
192 changes: 77 additions & 115 deletions cca_zoo/data/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import numpy as np
import torch
import torch.utils.data
from PIL import Image
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode


class Split_MNIST_Dataset(Dataset):
Expand All @@ -16,7 +15,7 @@ class Split_MNIST_Dataset(Dataset):
"""

def __init__(
self, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
self, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
):
"""
Expand Down Expand Up @@ -72,7 +71,7 @@ class Noisy_MNIST_Dataset(Dataset):
"""

def __init__(
self, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
self, mnist_type: str = "MNIST", train: bool = True, flatten: bool = True
):
"""
Expand All @@ -81,25 +80,45 @@ def __init__(
:param flatten: whether to flatten the data into array or use 2d images
"""
if mnist_type == "MNIST":
self.dataset = datasets.MNIST("../../data", train=train, download=True)
self.dataset = datasets.MNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "FashionMNIST":
self.dataset = datasets.FashionMNIST(
"../../data", train=train, download=True
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "KMNIST":
self.dataset = datasets.KMNIST("../../data", train=train, download=True)
self.dataset = datasets.KMNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)

self.data = self.dataset.data
self.base_transform = transforms.ToTensor()
self.a_transform = transforms.Compose(
[
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.ToPILImage(),
]
[torchvision.transforms.RandomRotation((-45, 45))]
)
self.a_transform = transforms.Compose(
[torchvision.transforms.RandomRotation((-45, 45))]
)
self.b_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(_add_mnist_noise),
transforms.Lambda(self.__threshold_func__),
]
Expand All @@ -108,57 +127,26 @@ def __init__(
self.filtered_classes = []
self.filtered_nums = []
for i in range(10):
self.filtered_classes.append(self.data[self.targets == i])
self.filtered_nums.append(self.filtered_classes[i].shape[0])
self.filtered_nums.append(np.where(self.targets == i)[0])
self.flatten = flatten

def __threshold_func__(self, x):
x[x > 1] = 1
return x

def __len__(self):
return len(self.data)
return len(self.dataset)

def __getitem__(self, idx):
x_a = self.a_transform(self.data[idx].numpy() / 255)
rot_a = torch.rand(1) * 90 - 45
x_a = transforms.functional.rotate(
x_a, rot_a.item(), interpolation=InterpolationMode.BILINEAR
)
x_a = self.base_transform(x_a) # convert from PIL back to pytorch tensor

label = self.targets[idx]
x_a, label = self.dataset[idx]
x_a = self.a_transform(x_a)
# get random index of image with same class
random_index = np.random.randint(self.filtered_nums[label])
x_b = Image.fromarray(
self.filtered_classes[label][random_index, :, :].numpy() / 255, mode="L"
)
x_b = self.b_transform(x_b)

random_index = np.random.choice(self.filtered_nums[label])
x_b = self.b_transform(self.dataset[random_index][0])
if self.flatten:
x_a = torch.flatten(x_a)
x_b = torch.flatten(x_b)
return (x_b, x_a), (rot_a, label)

def to_numpy(self, indices=None):
"""
Converts dataset to numpy array form
:param indices: indices of the samples to extract into numpy arrays
"""
if indices is None:
indices = np.arange(self.__len__())
view_1 = np.zeros((len(indices), 784))
view_2 = np.zeros((len(indices), 784))
labels = np.zeros(len(indices)).astype(int)
rotations = np.zeros(len(indices))
for i, n in enumerate(indices):
sample = self[n]
view_1[i] = sample[0][0].numpy().reshape((-1, 28 * 28))
view_2[i] = sample[0][1].numpy().reshape((-1, 28 * 28))
rotations[i] = sample[1][0].numpy()
labels[i] = sample[1][1].numpy().astype(int)
return (view_1, view_2), (rotations, labels)
return (x_b, x_a), label


class Tangled_MNIST_Dataset(Dataset):
Expand All @@ -174,86 +162,60 @@ def __init__(self, mnist_type="MNIST", train=True, flatten=True):
:param flatten: whether to flatten the data into array or use 2d images
"""
if mnist_type == "MNIST":
self.dataset = datasets.MNIST("../../data", train=train, download=True)
self.dataset = datasets.MNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "FashionMNIST":
self.dataset = datasets.FashionMNIST(
"../../data", train=train, download=True
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
),
)
elif mnist_type == "KMNIST":
self.dataset = datasets.KMNIST("../../data", train=train, download=True)

self.data = self.dataset.data
self.transform = transforms.Compose([transforms.ToTensor()])
self.dataset = datasets.KMNIST(
"../../data",
train=train,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
self.transform = transforms.Compose(
[torchvision.transforms.RandomRotation((-45, 45))]
)
self.targets = self.dataset.targets
self.OHs = _OH_digits(self.targets.numpy().astype(int))
self.filtered_classes = []
self.filtered_nums = []
for i in range(10):
self.filtered_classes.append(self.data[self.targets == i])
self.filtered_nums.append(self.filtered_classes[i].shape[0])
self.filtered_nums.append(np.where(self.targets == i)[0])
self.flatten = flatten

def __len__(self):
return len(self.data)
return len(self.dataset)

def __getitem__(self, idx):
# get first image from idx and second of same class
label = self.targets[idx]
x_a = Image.fromarray(self.data[idx].numpy() / 255, mode="L")
x_a, label = self.dataset[idx]
x_a = self.transform(x_a)
# get random index of image with same class
random_index = np.random.randint(self.filtered_nums[label])
x_b = Image.fromarray(
self.filtered_classes[label][random_index, :, :].numpy() / 255, mode="L"
)
# get random angles of rotation
rot_a, rot_b = torch.rand(2) * 90 - 45
x_a_rotate = transforms.functional.rotate(
x_a, rot_a.item(), interpolation=InterpolationMode.BILINEAR
)
x_b_rotate = transforms.functional.rotate(
x_b, rot_b.item(), interpolation=InterpolationMode.BILINEAR
)
# convert images to tensors
x_a_rotate = self.transform(x_a_rotate)
x_b_rotate = self.transform(x_b_rotate)

random_index = np.random.choice(self.filtered_nums[label])
x_b = self.transform(self.dataset[random_index][0])
if self.flatten:
x_a_rotate = torch.flatten(x_a_rotate)
x_b_rotate = torch.flatten(x_b_rotate)
return (x_a_rotate, x_b_rotate), (rot_a, rot_b, label)

def to_numpy(self, indices):
"""
Converts dataset to numpy array form
:param indices: indices of the samples to extract into numpy arrays
"""
view_1 = np.zeros((len(indices), 784))
view_2 = np.zeros((len(indices), 784))
labels = np.zeros(len(indices)).astype(int)
rotation_1 = np.zeros(len(indices))
rotation_2 = np.zeros(len(indices))
for i, n in enumerate(indices):
sample = self[n]
view_1[i] = sample[0][0].numpy().reshape((-1, 28 * 28))
view_2[i] = sample[0][1].numpy().reshape((-1, 28 * 28))
rotation_1[i] = sample[1][0].numpy()
rotation_2[i] = sample[1][1].numpy()
labels[i] = sample[1][2].numpy().astype(int)
return (view_1, view_2), (rotation_1, rotation_2, labels)


def _OH_digits(digits):
"""
One hot encode numpy array
:param digits:
"""
b = np.zeros((digits.size, digits.max() + 1))
b[np.arange(digits.size), digits] = 1
return b
x_a = torch.flatten(x_a)
x_b = torch.flatten(x_b)
return (x_b, x_a), label


def _add_mnist_noise(x):
x = x + torch.rand(28, 28)
x = x + torch.rand(28, 28) / 10
return x
Loading

0 comments on commit 945602b

Please sign in to comment.