Skip to content

Commit

Permalink
add notebook for beta transition point
Browse files Browse the repository at this point in the history
  • Loading branch information
zimea committed Jan 10, 2024
1 parent ba6e008 commit b0657ba
Show file tree
Hide file tree
Showing 5 changed files with 1,159 additions and 7 deletions.
3 changes: 3 additions & 0 deletions configs/surrogate.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
log_det_estimator:
name: surrogate
hutchinson_samples: 1
11 changes: 4 additions & 7 deletions configs/toy.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model: mlae.model.MaximumLikelihoodAutoencoder

data_set:
name: sine
accelerator: "cpu"
num_workers: 0

noise: 0.1 # This is varied

Expand All @@ -23,8 +23,5 @@ models:
batch_size: 128
optimizer:
name: adam
lr: 0.001
max_epochs: 50

accelerator: "cpu"
num_workers: 0
lr: 0.0001
max_epochs: 5000
26 changes: 26 additions & 0 deletions mlae/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from torch.distributions import Normal
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -697,3 +698,28 @@ def compute_decoder_singular_values_ge_one(model, n_samples, temperature):
singular_values_ge_one_mean_m_std,
singular_values_ge_one_mean_p_std,
)

def scatter_filaments(model, eval_batch, normalize=True, text=None, text_loc=1, xlim=[-3.5,3.5], **kwargs):
z = model.encode(eval_batch, c=None)
x_rec = model.decode(z, c=None).detach().numpy()
z_np = z.detach().numpy()
x_np = eval_batch.detach().numpy()

if normalize:
min, max = np.percentile(z_np, [0.5,99.5])
z_np = np.where(z_np >= min, z_np, min)
z_np = np.where(z_np <= max, z_np, max)

if not 'color' in kwargs:
kwargs['c'] = z_np
fig = plt.scatter(x_np[:,0], x_np[:,1], **kwargs)
plt.xlabel('X',fontsize=18)
plt.ylabel('Y',fontsize=18)
plt.tick_params(labelsize=16)
plt.xlim(xlim[0],xlim[1])
if not text == None:
anchored_text = AnchoredText(text, loc=text_loc, prop=dict(size=16))
ax = plt.gca()
ax.add_artist(anchored_text)
plt.tight_layout()
return fig
4 changes: 4 additions & 0 deletions mlae/model/res_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def __init__(self, hparams: dict | ResNetHParams):
self.model = self.build_model()

def encode(self, x, c):
if c is None:
return self.model.encoder(x)
return self.model.encoder(torch.cat([x, c], -1))

def decode(self, z, c):
if c is None:
return self.model.decoder(z)[..., :self.hparams.data_dim]
return self.model.decoder(torch.cat([z, c], -1))[..., :self.hparams.data_dim]

def build_model(self) -> nn.Module:
Expand Down
Loading

0 comments on commit b0657ba

Please sign in to comment.