Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mackelab/labproject into main
Browse files Browse the repository at this point in the history
  • Loading branch information
jaivardhankapoor committed Jan 31, 2024
2 parents d310bdf + 4e2289e commit 13f4a9d
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 20 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/push_to_overleaf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
- main
jobs:
build:
runs-on: ubuntu-latest
runs-on: self-hosted
steps:
- name: Checkout Overleaf project
run: |
Expand Down Expand Up @@ -45,4 +45,10 @@ jobs:
else
git commit -m "Update figures automatically"
git push
fi
fi
- name: Clean up - delete repositories
run: |
cd ${{ github.workspace }}
cd ..
rm -rf overleaf/
rm -rf github/
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: stable
rev: 24.1.1
hooks:
- id: black
args: ['--line-length=100']
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ conda create -n labproject python=3.9
conda activate labproject

# install labproject package with dependencies
pip install -e .
# if you want to edit the tutorials, install the docs dependencies
pip install -e ".[docs]"
pip install -e ".[dev,docs]"

# install pre-commit hooks for black auto-formatting
pre-commit install
Expand Down
2 changes: 2 additions & 0 deletions labproject/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

torch.manual_seed(0)


def random_dataset(n=1000, d=10):
return torch.randn(n, d)
11 changes: 5 additions & 6 deletions labproject/metrics/sliced_wasserstein.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# STOLEN from Julius: https://github.com/mackelab/wasserstein_source/blob/main/wasser/sliced_wasserstein.py

import numpy as np

np.random.seed(0)
import torch

torch.manual_seed(0)


Expand All @@ -19,8 +21,7 @@ def rand_projections(embedding_dim, num_samples: int):
"""

projection = [
w / np.sqrt((w**2).sum())
for w in np.random.normal(size=(num_samples, embedding_dim))
w / np.sqrt((w**2).sum()) for w in np.random.normal(size=(num_samples, embedding_dim))
]
projection = np.array(projection)
return torch.from_numpy(projection).type(torch.FloatTensor)
Expand Down Expand Up @@ -49,9 +50,7 @@ def sliced_wasserstein_distance(

encoded_projections = encoded_samples.matmul(projections.transpose(-2, -1))

distribution_projections = distribution_samples.matmul(
projections.transpose(-2, -1)
)
distribution_projections = distribution_samples.matmul(projections.transpose(-2, -1))

wasserstein_distance = (
torch.sort(encoded_projections.transpose(-2, -1), dim=-1)[0]
Expand All @@ -64,4 +63,4 @@ def sliced_wasserstein_distance(
# No p-th root is applied

# return torch.pow(torch.mean(wasserstein_distance, dim=(-2, -1)), 1 / p)
return torch.mean(wasserstein_distance, dim=(-2, -1))
return torch.mean(wasserstein_distance, dim=(-2, -1))
8 changes: 6 additions & 2 deletions labproject/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def plot_scaling_metric_dimensionality(dimensionality, distances, metric_name, d
plt.xlabel("Dimensionality")
plt.ylabel(metric_name)
plt.title(f"{metric_name} with increasing dimensionality for {dataset_name}")
plt.savefig(os.path.join(plots_path,
f"{metric_name.lower().replace(' ', '_')}_dimensionality_{dataset_name.lower().replace(' ', '_')}.png"))
plt.savefig(
os.path.join(
plots_path,
f"{metric_name.lower().replace(' ', '_')}_dimensionality_{dataset_name.lower().replace(' ', '_')}.png",
)
)
plt.close()
6 changes: 4 additions & 2 deletions labproject/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@

print("Running experiments...")
dimensionality, distances = scaling_sliced_wasserstein_samples()
plot_scaling_metric_dimensionality(dimensionality, distances, "Sliced Wasserstein", "Random Dataset")
print("Finished running experiments.")
plot_scaling_metric_dimensionality(
dimensionality, distances, "Sliced Wasserstein", "Random Dataset"
)
print("Finished running experiments.")
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ name = "labproject"
version = "0.1"

dependencies = [
"black",
"pre-commit",
"jupyter",
"numpy",
"scipy",
"matplotlib",
Expand All @@ -22,6 +19,11 @@ docs = [
# Newer versions currently produce a "'mermaid.js' is undefined" error, see https://github.com/danielfrg/mkdocs-jupyter/issues/176
"mkdocs-jupyter==0.24.2"
]
dev = [
"black",
"pre-commit",
"jupyter"
]

[tool.setuptools.packages.find]
include=["labproject"]
Expand All @@ -33,7 +35,7 @@ line-length = 100 # Example: set the line length to 100
# Pre-commit Configuration
[[tool.pre-commit.repos]]
repo = "https://github.com/psf/black"
rev = "stable"
rev = "24.1.1"
hooks = [
{ id = "black", language_version = "python3" }
]

0 comments on commit 13f4a9d

Please sign in to comment.