Skip to content

Commit

Permalink
add running_user to cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
jaivardhankapoor committed Feb 5, 2024
1 parent 6d9e90b commit a1bc2fb
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 22 deletions.
10 changes: 9 additions & 1 deletion labproject/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
"""
Best practices for developing metrics:
1. Please do everything in torch, and if that is not possible, cast the output to torch.Tensor.
2. The function should be well-documented, including type hints.
3. The function should be tested with a simple example.
4. Add an assert at the beginning for shape checking (N,D), see examples.
"""

from labproject.metrics.gaussian_kl import gaussian_kl_divergence
from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance

METRICS = {}

2 changes: 1 addition & 1 deletion labproject/metrics/c2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def c2st_scores(

return scores


def c2st_optimal(density1: Any, density2: Any, n_monte_carlo: int = 10_000) -> Tensor:
r"""Return the c2st that can be achieved by an optimal classifier.
Expand Down
17 changes: 11 additions & 6 deletions labproject/metrics/gaussian_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def gaussian_kl_divergence(real_samples: Tensor, fake_samples: Tensor) -> Tensor
Then we calculate the KL divergence between the two Gaussian approximations:
$$
D_{KL}(N(\mu_{\text{real}}, \Sigma_{\text{real}}) || N(\mu_{\text{fake}}, \Sigma_{\text{fake}})) =
D_{KL}(N(\mu_{\text{real}}, \Sigma_{\text{real}}) || N(\mu_{\text{fake}}, \Sigma_{\text{fake}})) =
\frac{1}{2} \left( \text{tr}(\Sigma_{\text{fake}}^{-1} \Sigma_{\text{real}}) + (\mu_{\text{fake}} - \mu_{\text{real}})^T \Sigma_{\text{fake}}^{-1} (\mu_{\text{fake}} - \mu_{\text{real}})
- k + \log \frac{|\Sigma_{\text{fake}}|}{|\Sigma_{\text{real}}|} \right)
$$
Expand All @@ -31,13 +31,18 @@ def gaussian_kl_divergence(real_samples: Tensor, fake_samples: Tensor) -> Tensor
Returns:
torch.Tensor: The KL divergence between the two Gaussian approximations.
Examples:
>>> real_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> fake_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> kl_div = gaussian_kl_divergence(real_samples, fake_samples)
>>> print(kl_div)
>>> real_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> fake_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> kl_div = gaussian_kl_divergence(real_samples, fake_samples)
>>> print(kl_div)
"""

# check input (n,d only)
assert len(real_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
assert len(fake_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

# calculate mean and covariance of real and fake samples
mu_real = real_samples.mean(dim=0)
mu_fake = fake_samples.mean(dim=0)
Expand Down
20 changes: 12 additions & 8 deletions labproject/metrics/gaussian_squared_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,22 @@ def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor) ->
Returns:
torch.Tensor: The KL divergence between the two Gaussian approximations.
References:
[1] https://en.wikipedia.org/wiki/Wasserstein_metric
[2] https://arxiv.org/pdf/1706.08500.pdf
Examples:
>>> real_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> fake_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> w2 = gaussian_squared_w2_distance(real_samples, fake_samples)
>>> print(w2)
>>> real_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> fake_samples = torch.randn(100, 2) # 100 samples, 2-dimensional
>>> w2 = gaussian_squared_w2_distance(real_samples, fake_samples)
>>> print(w2)
"""

# check input (n,d only)
assert len(real_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
assert len(fake_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

# calculate mean and covariance of real and fake samples
mu_real = real_samples.mean(dim=0)
mu_fake = fake_samples.mean(dim=0)
Expand All @@ -53,10 +58,9 @@ def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor) ->
cov_fake += torch.eye(cov_fake.size(0)) * eps

# compute KL divergence
inv_cov_fake = torch.inverse(cov_fake)
mean_dist = torch.norm(mu_real - mu_fake, p=2)
cov_dist = torch.trace(cov_real + cov_fake - 2 * torch.linalg.cholesky(cov_real @ cov_fake))
w2_squared_dist = mean_dist ** 2 + cov_dist
w2_squared_dist = mean_dist**2 + cov_dist

return w2_squared_dist

Expand Down
15 changes: 11 additions & 4 deletions labproject/metrics/sliced_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@


def sliced_wasserstein_distance(
encoded_samples: Tensor, distribution_samples: Tensor, num_projections:int=50, p:int=2, device:str="cpu"
encoded_samples: Tensor,
distribution_samples: Tensor,
num_projections: int = 50,
p: int = 2,
device: str = "cpu",
):
"""
Sliced Wasserstein distance between encoded samples and distribution samples
Expand All @@ -22,6 +26,10 @@ def sliced_wasserstein_distance(
torch.Tensor: Tensor of wasserstein distances of size (num_projections, 1)
"""

# check input (n,d only)
assert len(encoded_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
assert len(distribution_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

embedding_dim = distribution_samples.size(-1)

projections = rand_projections(embedding_dim, num_projections).to(device)
Expand All @@ -44,7 +52,6 @@ def sliced_wasserstein_distance(
return torch.mean(wasserstein_distance, dim=(-2, -1))



def rand_projections(embedding_dim: int, num_samples: int):
"""
This function generates num_samples random samples from the latent space's unti sphere.r
Expand All @@ -67,7 +74,7 @@ def rand_projections(embedding_dim: int, num_samples: int):
# Generate random samples
samples1 = torch.randn(100, 2)
samples2 = torch.randn(100, 2)

# Compute sliced wasserstein distance
sw_distance = sliced_wasserstein_distance(samples1, samples2)
print(sw_distance)
print(sw_distance)
9 changes: 7 additions & 2 deletions labproject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from omegaconf import OmegaConf

CONF_PATH = STYLE_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs"))
CONF_PATH = STYLE_PATH = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs")
)


def set_seed(seed: int) -> None:
"""Set seed for reproducibility
Expand All @@ -21,9 +24,10 @@ def set_seed(seed: int) -> None:
torch.backends.cudnn.benchmark = False
return seed


def get_cfg() -> OmegaConf:
"""This function returns the configuration file for the current experiment run.
The configuration file is expected to be located at ../configs/conf_{name}.yaml, where name will match the name of the run_{name}.py file.
Raises:
Expand All @@ -37,6 +41,7 @@ def get_cfg() -> OmegaConf:
name = filename.split("/")[-1].split(".")[0].split("_")[-1]
try:
config = OmegaConf.load(CONF_PATH + f"/conf_{name}.yaml")
config.running_user = name
except FileNotFoundError:
msg = f"Config file not found for {name}. Please create a config file at ../configs/conf_{name}.yaml"
raise FileNotFoundError(msg)
Expand Down

0 comments on commit a1bc2fb

Please sign in to comment.