Skip to content

Commit

Permalink
Support for distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler authored and michaeldeistler committed Feb 5, 2024
1 parent 7591f40 commit 1f49a6a
Showing 1 changed file with 45 additions and 37 deletions.
82 changes: 45 additions & 37 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DATASETS = {}
DISTRIBUTIONS = {}


def upload_file(local_path: str, remote_path: str):
r"""
Uploads a file to the Hetzner Storage Box.
Expand Down Expand Up @@ -157,6 +158,7 @@ def wrapper(*args, **kwargs):

return decorator


def get_distribution(name: str) -> torch.Tensor:
r"""Get a distribution by name
Expand Down Expand Up @@ -230,45 +232,51 @@ def normal_distribution():
return torch.distributions.Normal(0,1)


@register_dataset("toy_2d")
def toy_mog_2D(n=1000, d=2):
"""Generate samples from a 2D mixture of 4 Gaussians that look funky.
@register_distribution("normal")
def normal_distribution():
return torch.distributions.Normal(0, 1)

Args:
n (int): number of samples to generate
d (int): dimensionality of the samples, always 2. Changing it does nothing.

Returns:
tensor: samples of shape (num_samples, 2)
"""
means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
gaussians = [
MultivariateNormal(mean, covariance) for mean, covariance in zip(means, covariances)
]

# Sample from the mixture
categorical = Categorical(weights)
sample_indices = categorical.sample([n])
samples = torch.stack([gaussians[i].sample() for i in sample_indices])
return samples
@register_distribution("toy_2d")
def normal_distribution():
class Toy2D:
def __init__(self):
self.means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
self.covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
self.weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
self.gaussians = [
MultivariateNormal(mean, covariance)
for mean, covariance in zip(self.means, self.covariances)
]

def sample(self, sample_shape):
# Sample from the mixture
categorical = Categorical(self.weights)
sample_indices = categorical.sample(sample_shape)
return torch.stack([self.gaussians[i].sample() for i in sample_indices])

def log_prob(self, input):
probs = torch.stack([g.log_prob(input).exp() for g in self.gaussians])
probs = probs.T * self.weights
return torch.sum(probs, dim=1).log()

return Toy2D()


@register_dataset("cifar10_train")
Expand Down

0 comments on commit 1f49a6a

Please sign in to comment.