Skip to content

Commit

Permalink
Adapted experiments for scaling dimensionality
Browse files Browse the repository at this point in the history
  • Loading branch information
franzigrkn committed Feb 7, 2024
1 parent b1275ff commit bfd9ad6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
3 changes: 2 additions & 1 deletion labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None):
samples = samples + shift
elif distort == "shift_one":
# randomly choose one index among dims
idx = torch.randint(dims, size=(1,))[0]
# idx = torch.randint(dims, size=(1,))[0]
idx = 0
shift = torch.zeros(n) + 1
samples[:, idx] = samples[:, idx] + shift
return samples
Expand Down
73 changes: 66 additions & 7 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,65 @@ def log_results(self, results, log_path):


class ScaleDim(Experiment):
def __init__(self, metric_name, metric_fn, min_dim=1, max_dim=1000, step=100):
def __init__(self, metric_name, metric_fn, dim_sizes=None, min_dim=1, max_dim=1000, step=100):
self.metric_name = metric_name
self.metric_fn = metric_fn
self.dimensionality = list(range(min_dim, max_dim, step))
if dim_sizes is not None:
self.dim_sizes = dim_sizes
else:
self.dim_sizes = list(range(min_dim, max_dim, step))
super().__init__()

def run_experiment(self, dataset1, dataset2):
distances = []
def run_experiment(self, dataset1, dataset2, nb_runs=5, dim_sizes=None):
"""distances = []
for d in self.dimensionality:
distances.append(self.metric_fn(dataset1[:, :d], dataset2[:, :d]))
return self.dimensionality, distances
return self.dimensionality, distances"""
final_distances = []
final_errors = []
if dim_sizes is None:
dim_sizes = self.dim_sizes
for idx in range(nb_runs):
distances = []
for n in dim_sizes:
data1 = dataset1[:, :n]
data2 = dataset2[:, :n]
distances.append(self.metric_fn(data1, data2))
final_distances.append(distances)
final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
final_errors = (
torch.tensor([torch.std(d) for d in final_distances])
if nb_runs > 1
else torch.zeros_like(torch.tensor(dim_sizes))
)
final_distances = torch.tensor([torch.mean(d) for d in final_distances])

return dim_sizes, final_distances, final_errors

def plot_experiment(
self,
dim_sizes,
distances,
errors,
dataset_name,
ax=None,
color=None,
label=None,
linestyle="-",
**kwargs
):

def plot_experiment(self, dimensionality, distances, dataset_name, ax=None):
plot_scaling_metric_dimensionality(
dimensionality, distances, self.metric_name, dataset_name, ax=ax
dim_sizes,
distances,
errors,
self.metric_name,
dataset_name,
ax=ax,
color=color,
label=label,
linestyle=linestyle,
**kwargs
)

def log_results(self, results, log_path):
Expand All @@ -60,6 +104,21 @@ def __init__(self, min_dim=2, **kwargs):
super().__init__("Sliced Wasserstein", sliced_wasserstein_distance, **kwargs)


class ScaleDimC2ST(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("C2ST", c2st_nn, **kwargs)


class ScaleDimMMD(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_rbf_mmd, **kwargs)


"""class ScaleDimMMD(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("FID", compute_rbf_mmd, **kwargs)"""


class ScaleSampleSize(Experiment):

def __init__(
Expand Down

0 comments on commit bfd9ad6

Please sign in to comment.