Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mackelab/labproject into main
Browse files Browse the repository at this point in the history
  • Loading branch information
a-darcher committed Feb 7, 2024
2 parents 58ce6c8 + e3f6644 commit 42c3966
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 43 deletions.
52 changes: 33 additions & 19 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
HETZNER_STORAGEBOX_USERNAME = os.getenv("HETZNER_STORAGEBOX_USERNAME")
HETZNER_STORAGEBOX_PASSWORD = os.getenv("HETZNER_STORAGEBOX_PASSWORD")

IMAGENET_UNCONDITIONAL_MODEL_EMBEDDING = "https://drive.google.com/uc?id=1xsGlNig7pCQuMpsvN86hgTGLEDGi6fVD"
IMAGENET_UNCONDITIONAL_MODEL_EMBEDDING = (
"https://drive.google.com/uc?id=1xsGlNig7pCQuMpsvN86hgTGLEDGi6fVD"
)
IMAGENET_CONDITIONAL_MODEL = "https://drive.google.com/uc?id=1FBVFiFcWnVs4i_LK4lAUemx83D7Hb_tU"
IMAGENET_TEST_EMBEDDING = "https://drive.google.com/uc?id=12B5Nkjr611WhXUafv08BciW7nsZ20Dfc"
IMAGENET_VALIDATION_EMBEDDING = "https://drive.google.com/uc?id=1Chc2ygs-Akw0Hlq-Nx7ykF2fp3SqV_aM"



## Hetzner Storage Box API functions ----

DATASETS = {}
Expand Down Expand Up @@ -280,16 +281,14 @@ def multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None):
vars = torch.diag(vars)

samples = torch.distributions.MultivariateNormal(means, vars).sample((n,))
print(f"Shape of samples: {samples.shape}")
if distort == "shift_all":
shift = 0.1
shift = 1
samples = samples + shift
elif distort == "shift_one":
# randomly choose one index among dims
# idx = torch.randint(dims, size=(1,))[0]
idx = 0
shift = torch.zeros(n) + 1
samples[:, idx] = samples[:, idx] + shift
print(f"First 5 rows of dataset distorted: {samples[:5, :5]}")
return samples


Expand Down Expand Up @@ -458,15 +457,21 @@ def imagenet_unconditional_model_embedding(n, d=2048, device="cpu", save_path="d
assert d == 2048, "The dimensionality of the embeddings must be 2048"
if not os.path.exists("imagenet_unconditional_model_embedding.pt"):
import gdown
gdown.download(IMAGENET_UNCONDITIONAL_MODEL_EMBEDDING, "imagenet_unconditional_model_embedding.pt", quiet=False)

gdown.download(
IMAGENET_UNCONDITIONAL_MODEL_EMBEDDING,
"imagenet_unconditional_model_embedding.pt",
quiet=False,
)
unconditional_embeddigns = torch.load("imagenet_unconditional_model_embedding.pt")

max_n = unconditional_embeddigns.shape[0]

assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

return unconditional_embeddigns[:n]


@register_dataset("imagenet_test_embedding")
def imagenet_test_embedding(n, d=2048, device="cpu", save_path="data"):
r"""Get the test embeddings for ImageNet
Expand All @@ -482,15 +487,15 @@ def imagenet_test_embedding(n, d=2048, device="cpu", save_path="data"):
assert d == 2048, "The dimensionality of the embeddings must be 2048"
if not os.path.exists("imagenet_test_embedding.pt"):
import gdown

gdown.download(IMAGENET_TEST_EMBEDDING, "imagenet_test_embedding.pt", quiet=False)
test_embeddigns = torch.load("imagenet_test_embedding.pt")

max_n = test_embeddigns.shape[0]

assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

return test_embeddigns[:n]



@register_dataset("imagenet_validation_embedding")
Expand All @@ -508,17 +513,23 @@ def imagenet_validation_embedding(n, d=2048, device="cpu", save_path="data"):
assert d == 2048, "The dimensionality of the embeddings must be 2048"
if not os.path.exists("imagenet_validation_embedding.pt"):
import gdown
gdown.download(IMAGENET_VALIDATION_EMBEDDING, "imagenet_validation_embedding.pt", quiet=False)

gdown.download(
IMAGENET_VALIDATION_EMBEDDING, "imagenet_validation_embedding.pt", quiet=False
)
validation_embeddigns = torch.load("imagenet_validation_embedding.pt")

max_n = validation_embeddigns.shape[0]

assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

return validation_embeddigns[:n]


@register_dataset("imagenet_conditional_model")
def imagenet_conditional_model(n, d=2048, label:Optional[int]=None, device="cpu", save_path="data"):
def imagenet_conditional_model(
n, d=2048, label: Optional[int] = None, device="cpu", save_path="data"
):
r"""Get the conditional model embeddings for ImageNet
Args:
Expand All @@ -533,17 +544,20 @@ def imagenet_conditional_model(n, d=2048, label:Optional[int]=None, device="cpu"
assert d == 2048, "The dimensionality of the embeddings must be 2048"
if not os.path.exists("imagenet_conditional_model.pt"):
import gdown

gdown.download(IMAGENET_CONDITIONAL_MODEL, "imagenet_conditional_model.pt", quiet=False)
conditional_embeddings = torch.load("imagenet_conditional_model.pt")

if label is not None:
conditional_embeddings = conditional_embeddings[label]
else:
conditional_embeddings = conditional_embeddings.flatten(0, 1)
conditional_embeddings = conditional_embeddings[torch.randperm(conditional_embeddings.shape[0])]
conditional_embeddings = conditional_embeddings[
torch.randperm(conditional_embeddings.shape[0])
]

max_n = conditional_embeddings.shape[0]

assert n <= max_n, f"Requested {n} samples, but only {max_n} are available"

return conditional_embeddings[:n]
30 changes: 14 additions & 16 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,18 @@ def __init__(self, metric_name, metric_fn, dim_sizes=None, min_dim=1, max_dim=10
self.dim_sizes = list(range(min_dim, max_dim, step))
super().__init__()

def run_experiment(self, dataset1, dataset2, nb_runs=5, dim_sizes=None):
"""distances = []
for d in self.dimensionality:
distances.append(self.metric_fn(dataset1[:, :d], dataset2[:, :d]))
return self.dimensionality, distances"""
def run_experiment(self, dataset1, dataset2, dataset_size, nb_runs=5, dim_sizes=None):
final_distances = []
final_errors = []
n = dataset_size
if dim_sizes is None:
dim_sizes = self.dim_sizes
for idx in range(nb_runs):
distances = []
for n in dim_sizes:
data1 = dataset1[:, :n]
data2 = dataset2[:, :n]
for d in dim_sizes:
# 3000 x 100
data1 = dataset1[torch.randperm(dataset1.size(0))[:n], :d]
data2 = dataset2[torch.randperm(dataset1.size(0))[:n], :d]
distances.append(self.metric_fn(data1, data2))
final_distances.append(distances)
final_distances = torch.transpose(torch.tensor(final_distances), 0, 1)
Expand All @@ -57,7 +55,7 @@ def run_experiment(self, dataset1, dataset2, nb_runs=5, dim_sizes=None):
else torch.zeros_like(torch.tensor(dim_sizes))
)
final_distances = torch.tensor([torch.mean(d) for d in final_distances])

print(f"Final errors: {final_errors}")
return dim_sizes, final_distances, final_errors

def plot_experiment(
Expand All @@ -70,7 +68,7 @@ def plot_experiment(
color=None,
label=None,
linestyle="-",
**kwargs
**kwargs,
):

plot_scaling_metric_dimensionality(
Expand All @@ -83,7 +81,7 @@ def plot_experiment(
color=color,
label=label,
linestyle=linestyle,
**kwargs
**kwargs,
)

def log_results(self, results, log_path):
Expand Down Expand Up @@ -169,7 +167,7 @@ def plot_experiment(
color=None,
label=None,
linestyle="-",
**kwargs
**kwargs,
):
plot_scaling_metric_sample_size(
sample_sizes,
Expand All @@ -181,7 +179,7 @@ def plot_experiment(
color=color,
label=label,
linestyle=linestyle,
**kwargs
**kwargs,
)

def log_results(self, results, log_path):
Expand All @@ -199,7 +197,7 @@ def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
gaussian_kl_divergence,
min_samples=min_samples,
sample_sizes=sample_sizes,
**kwargs
**kwargs,
)


Expand All @@ -210,7 +208,7 @@ def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
sliced_wasserstein_distance,
min_samples=min_samples,
sample_sizes=sample_sizes,
**kwargs
**kwargs,
)


Expand All @@ -235,7 +233,7 @@ def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
gaussian_squared_w2_distance,
min_samples=min_samples,
sample_sizes=sample_sizes,
**kwargs
**kwargs,
)


Expand Down
49 changes: 41 additions & 8 deletions labproject/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,58 @@ def cm2inch(cm, INCH=2.54):


def plot_scaling_metric_dimensionality(
dimensionality, distances, metric_name, dataset_name, ax=None
dim_sizes,
distances,
errors,
metric_name,
dataset_name,
ax=None,
label=None,
**kwargs,
):
"""Plot the scaling of a metric with increasing dimensionality."""
if ax is None:
plt.plot(dimensionality, distances, label=metric_name)
plt.xlabel("Dimensionality")
plt.plot(
dim_sizes,
distances,
label=metric_name if label is None else label,
**kwargs,
)
plt.fill_between(
dim_sizes,
distances - errors,
distances + errors,
alpha=0.2,
color="black" if kwargs.get("color") is None else kwargs.get("color"),
)
plt.xlabel("Dimension")
plt.ylabel(metric_name)
plt.title(f"{metric_name} with increasing dimensionality for {dataset_name}")
plt.title(f"{metric_name} with increasing dimensionality size for {dataset_name}")
plt.savefig(
os.path.join(
PLOT_PATH,
f"{metric_name.lower().replace(' ', '_')}_dimensionality_{dataset_name.lower().replace(' ', '_')}.png",
f"{metric_name.lower().replace(' ', '_')}_dimensionality_size_{dataset_name.lower().replace(' ', '_')}.png",
)
)
plt.close()
else:
ax.plot(dimensionality, distances, label=metric_name)
ax.set_xlabel("Dimensionality")

ax.plot(
dim_sizes,
distances,
label=metric_name if label is None else label,
**kwargs,
)
ax.fill_between(
dim_sizes,
distances - errors,
distances + errors,
alpha=0.2,
color="black" if kwargs.get("color") is None else kwargs.get("color"),
)
ax.set_xlabel("samples")
ax.set_ylabel(
metric_name, color="black" if kwargs.get("color") is None else kwargs.get("color")
)
return ax


Expand Down

0 comments on commit 42c3966

Please sign in to comment.