From bfd9ad66bf2e3089714dda1a0ed46484e0613222 Mon Sep 17 00:00:00 2001 From: franzigrkn Date: Wed, 7 Feb 2024 10:19:27 +0100 Subject: [PATCH] Adapted experiments for scaling dimensionality --- labproject/data.py | 3 +- labproject/experiments.py | 73 +++++++++++++++++++++++++++++++++++---- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/labproject/data.py b/labproject/data.py index 74ce497..4878bae 100644 --- a/labproject/data.py +++ b/labproject/data.py @@ -277,7 +277,8 @@ def multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None): samples = samples + shift elif distort == "shift_one": # randomly choose one index among dims - idx = torch.randint(dims, size=(1,))[0] + # idx = torch.randint(dims, size=(1,))[0] + idx = 0 shift = torch.zeros(n) + 1 samples[:, idx] = samples[:, idx] + shift return samples diff --git a/labproject/experiments.py b/labproject/experiments.py index 3501af7..85559d6 100644 --- a/labproject/experiments.py +++ b/labproject/experiments.py @@ -25,21 +25,65 @@ def log_results(self, results, log_path): class ScaleDim(Experiment): - def __init__(self, metric_name, metric_fn, min_dim=1, max_dim=1000, step=100): + def __init__(self, metric_name, metric_fn, dim_sizes=None, min_dim=1, max_dim=1000, step=100): self.metric_name = metric_name self.metric_fn = metric_fn - self.dimensionality = list(range(min_dim, max_dim, step)) + if dim_sizes is not None: + self.dim_sizes = dim_sizes + else: + self.dim_sizes = list(range(min_dim, max_dim, step)) super().__init__() - def run_experiment(self, dataset1, dataset2): - distances = [] + def run_experiment(self, dataset1, dataset2, nb_runs=5, dim_sizes=None): + """distances = [] for d in self.dimensionality: distances.append(self.metric_fn(dataset1[:, :d], dataset2[:, :d])) - return self.dimensionality, distances + return self.dimensionality, distances""" + final_distances = [] + final_errors = [] + if dim_sizes is None: + dim_sizes = self.dim_sizes + for idx in range(nb_runs): + distances = [] + for n in dim_sizes: + data1 = dataset1[:, :n] + data2 = dataset2[:, :n] + distances.append(self.metric_fn(data1, data2)) + final_distances.append(distances) + final_distances = torch.transpose(torch.tensor(final_distances), 0, 1) + final_errors = ( + torch.tensor([torch.std(d) for d in final_distances]) + if nb_runs > 1 + else torch.zeros_like(torch.tensor(dim_sizes)) + ) + final_distances = torch.tensor([torch.mean(d) for d in final_distances]) + + return dim_sizes, final_distances, final_errors + + def plot_experiment( + self, + dim_sizes, + distances, + errors, + dataset_name, + ax=None, + color=None, + label=None, + linestyle="-", + **kwargs + ): - def plot_experiment(self, dimensionality, distances, dataset_name, ax=None): plot_scaling_metric_dimensionality( - dimensionality, distances, self.metric_name, dataset_name, ax=ax + dim_sizes, + distances, + errors, + self.metric_name, + dataset_name, + ax=ax, + color=color, + label=label, + linestyle=linestyle, + **kwargs ) def log_results(self, results, log_path): @@ -60,6 +104,21 @@ def __init__(self, min_dim=2, **kwargs): super().__init__("Sliced Wasserstein", sliced_wasserstein_distance, **kwargs) +class ScaleDimC2ST(ScaleDim): + def __init__(self, min_dim=2, **kwargs): + super().__init__("C2ST", c2st_nn, **kwargs) + + +class ScaleDimMMD(ScaleDim): + def __init__(self, min_dim=2, **kwargs): + super().__init__("MMD", compute_rbf_mmd, **kwargs) + + +"""class ScaleDimMMD(ScaleDim): + def __init__(self, min_dim=2, **kwargs): + super().__init__("FID", compute_rbf_mmd, **kwargs)""" + + class ScaleSampleSize(Experiment): def __init__(