Skip to content

Commit

Permalink
FID added
Browse files Browse the repository at this point in the history
  • Loading branch information
aesagtekin committed Feb 6, 2024
1 parent 89a0f6d commit 2810cfc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ results/
docs/notebooks/sample_size_auguste-2.ipynb
configs/conf_default.yaml
configs/conf_default.yaml
configs/conf_FIDinc.yaml
28 changes: 24 additions & 4 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
assert min_samples > 2, "min_samples must be greater than 2 to compute covariance for KL"
self.metric_name = metric_name
self.metric_fn = metric_fn
# TODO: add logarithmic scale or only keep pass in run experiment
if sample_sizes is not None:
self.sample_sizes = sample_sizes
else:
Expand All @@ -78,9 +79,6 @@ def run_experiment(self, dataset1, dataset2, nb_runs=5, sample_sizes=None):
"""
Computes for each subset 5 different random subsets and averages performance across the subsets.
"""
assert sample_sizes[-1] < dataset1.size(
0
), "Sample size must be smaller than the dataset size."
final_distances = []
final_errors = []
if sample_sizes is None:
Expand All @@ -103,7 +101,16 @@ def run_experiment(self, dataset1, dataset2, nb_runs=5, sample_sizes=None):
return sample_sizes, final_distances, final_errors

def plot_experiment(
self, sample_sizes, distances, errors, dataset_name, ax=None, label=None, **kwargs
self,
sample_sizes,
distances,
errors,
dataset_name,
ax=None,
color=None,
label=None,
linestyle="-",
**kwargs
):
plot_scaling_metric_sample_size(
sample_sizes,
Expand All @@ -112,7 +119,9 @@ def plot_experiment(
self.metric_name,
dataset_name,
ax=ax,
color=color,
label=label,
linestyle=linestyle,
**kwargs
)

Expand Down Expand Up @@ -160,6 +169,17 @@ def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
)


class ScaleSampleSizeFID(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
"FID",
gaussian_squared_w2_distance,
min_samples=min_samples,
sample_sizes=sample_sizes,
**kwargs
)


class CIFAR10_FID_Train_Test(Experiment):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 2810cfc

Please sign in to comment.