Skip to content

Commit

Permalink
Move multivariate normal example script into notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Nov 3, 2023
1 parent cd89c64 commit 0f7d842
Show file tree
Hide file tree
Showing 2 changed files with 384 additions and 135 deletions.
384 changes: 384 additions & 0 deletions notebooks/multivariate_normal.ipynb

Large diffs are not rendered by default.

135 changes: 0 additions & 135 deletions simple_einet/layers/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,138 +191,3 @@ def mpe(self, num_samples) -> torch.Tensor:
return samples


if __name__ == "__main__":
# The following code is a test snippet that generates multiple 2D gaussian distributions, fits a multivariate normal distribution and visualizes the data against the fitted distribution.

# Import necessary modules
# Torch for the model and optimization, numpy for data manipulation, matplotlib for plotting
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.optim as optim
import numpy as np
from typing import List
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

import seaborn as sns

# Apply seaborn's default style to make plots more aesthetically pleasing
sns.set_style("whitegrid")

# Function to generate synthetic 2D data from two multivariate Gaussian distributions
# This serves as the dataset for which we want to fit a multivariate normal distribution
def generate_data(num_samples=100):
# Parameters for first Gaussian blob
mean1 = [2.0, 3.0]
cov1 = [[1.0, 0.9], [0.9, 0.5]]

# Parameters for second Gaussian blob
mean2 = [-1.0, -2.0]
cov2 = [[0.4, -0.1], [-0.1, 0.3]]

# Parameters for third Gaussian blob
mean3 = [4.0, -1.0]
cov3 = [[0.3, 0.2], [0.2, 0.5]]

# Parameters for fourth Gaussian blob
mean4 = [-3.0, 2.0]
cov4 = [[0.5, -0.2], [-0.2, 0.3]]

# Generate data points
data1 = np.random.multivariate_normal(mean1, cov1, num_samples // 4)
data2 = np.random.multivariate_normal(mean2, cov2, num_samples // 4)
data3 = np.random.multivariate_normal(mean3, cov3, num_samples // 4)
data4 = np.random.multivariate_normal(mean4, cov4, num_samples // 4)
data = np.vstack([data1, data2, data3, data4])

return torch.tensor(data, dtype=torch.float32)

# Function to plot both the generated data and the learned probability density
def plot_data_and_distribution_seaborn(data, samples, model):
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)

# Generate a grid over which we evaluate the model's density function
x, y = np.linspace(data[:, 0].min(), data[:, 0].max(), 100), np.linspace(
data[:, 1].min(), data[:, 1].max(), 100
)
X, Y = np.meshgrid(x, y)
grid = np.vstack([X.ravel(), Y.ravel()]).T

# Evaluate the learned density function over the grid
with torch.no_grad():
grid_tensor = torch.tensor(grid, dtype=torch.float32)
log_prob = model(grid_tensor)
prob_density = log_prob.exp().numpy().ravel() # Ensure this is 1-dimensional

# Plot for original data points using Seaborn
sns.scatterplot(x=data[:, 0], y=data[:, 1], ax=axes[0], color="green", alpha=0.6, label="Original Data")
sns.kdeplot(x=grid[:, 0], y=grid[:, 1], weights=prob_density, fill=True, ax=axes[0], cmap="viridis", alpha=0.5)
axes[0].set_title("Original Data and Fitted Density")
axes[0].legend()

# Plot for sampled data points using Seaborn
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], ax=axes[1], color="blue", alpha=0.6, label="Sampled Data")
sns.kdeplot(x=grid[:, 0], y=grid[:, 1], weights=prob_density, fill=True, ax=axes[1], cmap="plasma", alpha=0.5)
axes[1].set_title("Samples and Fitted Density")
axes[1].legend()

plt.tight_layout()
plt.show(dpi=120)

# Generate synthetic 2D data
from sklearn.datasets import make_moons

n_samples = 400
data = generate_data(n_samples)
# data = torch.tensor(make_moons(n_samples=n_samples, noise=0.1, random_state=0)[0])

# Initialize the Multivariate Normal model
# The model will be trained to fit the synthetic data
num_features = 2
num_channels = 1
num_leaves = 4
num_repetitions = 1
cardinality = 2

from simple_einet.einet import Einet, EinetConfig

cfg = EinetConfig(
num_features=num_features,
num_channels=num_channels,
num_leaves=num_leaves,
depth=0,
num_repetitions=num_repetitions,
num_classes=1,
leaf_type=MultivariateNormal,
leaf_kwargs={"cardinality": cardinality},
)
model = Einet(cfg)

# Setup optimization
optimizer = optim.Adam(model.parameters(), lr=0.01)
epochs = 1000

# Training loop to fit the Multivariate Normal model
for epoch in range(epochs):
optimizer.zero_grad()
log_prob = model(data)

# Negative log-likelihood as loss function
loss = -torch.mean(log_prob)
loss.backward()
optimizer.step()

# Logging to monitor progress
if epoch % 50 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}")

# Sample
samples = model.sample(num_samples=n_samples)
samples.squeeze_(1)
ic(samples.shape)

plot_data_and_distribution_seaborn(data, samples, model)

0 comments on commit 0f7d842

Please sign in to comment.