Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Baschdl committed Jan 31, 2024
2 parents a0ce9d7 + c946d46 commit c036a22
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ figures/
*.sublime-workspace
*.sublime-project

# Jupyter notebooks
# ignore jupyter notebooks
*.ipynb

# except those in docs/notebooks
!docs/notebooks/*.ipynb

.idea/
257 changes: 257 additions & 0 deletions docs/notebooks/fid.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
"Collecting torchvision\n",
" Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl.metadata (6.6 kB)\n",
"Requirement already satisfied: numpy in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (1.26.3)\n",
"Requirement already satisfied: requests in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (2.31.0)\n",
"Requirement already satisfied: torch==2.2.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (2.2.0)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (10.2.0)\n",
"Requirement already satisfied: filelock in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (4.9.0)\n",
"Requirement already satisfied: sympy in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (1.12)\n",
"Requirement already satisfied: networkx in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (3.1.3)\n",
"Requirement already satisfied: fsspec in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (2023.12.2)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (8.9.2.26)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.3.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (11.0.2.54)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (10.3.2.106)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (11.4.5.107)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.0.106)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (2.19.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n",
"Requirement already satisfied: triton==2.2.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (2.2.0)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->torchvision) (12.3.101)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (2.2.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (2023.11.17)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from jinja2->torch==2.2.0->torchvision) (2.1.4)\n",
"Requirement already satisfied: mpmath>=0.19 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from sympy->torch==2.2.0->torchvision) (1.3.0)\n",
"Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl (6.9 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m69.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
"\u001b[?25hInstalling collected packages: torchvision\n",
"Successfully installed torchvision-0.17.0\n"
]
}
],
"source": [
"!pip install torchvision"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /mnt_mount/labproject_data/cifar-10-python.tar.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting /mnt_mount/labproject_data/cifar-10-python.tar.gz to /mnt_mount/labproject_data\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"import torchvision.transforms as transforms\n",
"from torchvision.datasets import CIFAR10\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize((299, 299)),\n",
" transforms.ToTensor(),\n",
" # normalize specific to inception model\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"\n",
"# load CIFAR10 dataset\n",
"cifar10_train = CIFAR10(root='/mnt_mount/labproject_data', train=True, download=True, transform=transform)\n",
"cifar10_test = CIFAR10(root='/mnt_mount/labproject_data', train=False, download=True, transform=transform)\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"dataloader_1 = torch.utils.data.DataLoader(cifar10_train, batch_size=100, shuffle=False, num_workers=1)\n",
"dataloader_2 = torch.utils.data.DataLoader(cifar10_test, batch_size=100, shuffle=False, num_workers=1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/miniconda3/envs/labproject/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/mnt/miniconda3/envs/labproject/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=Inception_V3_Weights.IMAGENET1K_V1`. You can also use `weights=Inception_V3_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n",
"/mnt/miniconda3/envs/labproject/lib/python3.9/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n",
" warnings.warn(\"Can't initialize NVML\")\n"
]
}
],
"source": [
"from torchvision.models import inception_v3\n",
"from torchvision.datasets import ImageFolder\n",
"from torch.utils.data import DataLoader\n",
"import numpy as np\n",
"\n",
"\n",
"# get embedding net\n",
"def get_embedding_net():\n",
" model = inception_v3(pretrained=True)\n",
" model.fc = torch.nn.Identity() # replace the classifier with identity to get features\n",
" model.eval()\n",
" return model.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"# extract features\n",
"def extract_features(dataloader, model):\n",
" features = []\n",
" with torch.no_grad():\n",
" for data, _ in dataloader:\n",
" data = data.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
" features.append(model(data))\n",
" return torch.cat(features).cpu().numpy()\n",
"\n",
"\n",
"embedding_net = get_embedding_net()\n",
"\n",
"features1 = extract_features(dataloader_1, embedding_net)\n",
"features2 = extract_features(dataloader_2, embedding_net)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((50000, 2048), (10000, 2048))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features1.shape, features2.shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance\n",
"\n",
"swd = sliced_wasserstein_distance(torch.from_numpy(features1)[:10000], torch.from_numpy(features2), num_projections=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "labproject",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c036a22

Please sign in to comment.