-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/mackelab/labproject
- Loading branch information
Showing
2 changed files
with
261 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The autoreload extension is already loaded. To reload it, use:\n", | ||
" %reload_ext autoreload\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", | ||
"Collecting torchvision\n", | ||
" Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl.metadata (6.6 kB)\n", | ||
"Requirement already satisfied: numpy in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (1.26.3)\n", | ||
"Requirement already satisfied: requests in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (2.31.0)\n", | ||
"Requirement already satisfied: torch==2.2.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (2.2.0)\n", | ||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torchvision) (10.2.0)\n", | ||
"Requirement already satisfied: filelock in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (3.13.1)\n", | ||
"Requirement already satisfied: typing-extensions>=4.8.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (4.9.0)\n", | ||
"Requirement already satisfied: sympy in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (1.12)\n", | ||
"Requirement already satisfied: networkx in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (3.2.1)\n", | ||
"Requirement already satisfied: jinja2 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (3.1.3)\n", | ||
"Requirement already satisfied: fsspec in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (2023.12.2)\n", | ||
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n", | ||
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n", | ||
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n", | ||
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (8.9.2.26)\n", | ||
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.3.1)\n", | ||
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (11.0.2.54)\n", | ||
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (10.3.2.106)\n", | ||
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (11.4.5.107)\n", | ||
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.0.106)\n", | ||
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (2.19.3)\n", | ||
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (12.1.105)\n", | ||
"Requirement already satisfied: triton==2.2.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from torch==2.2.0->torchvision) (2.2.0)\n", | ||
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->torchvision) (12.3.101)\n", | ||
"Requirement already satisfied: charset-normalizer<4,>=2 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (3.3.2)\n", | ||
"Requirement already satisfied: idna<4,>=2.5 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (3.6)\n", | ||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (2.2.0)\n", | ||
"Requirement already satisfied: certifi>=2017.4.17 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from requests->torchvision) (2023.11.17)\n", | ||
"Requirement already satisfied: MarkupSafe>=2.0 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from jinja2->torch==2.2.0->torchvision) (2.1.4)\n", | ||
"Requirement already satisfied: mpmath>=0.19 in /mnt/miniconda3/envs/labproject/lib/python3.9/site-packages (from sympy->torch==2.2.0->torchvision) (1.3.0)\n", | ||
"Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl (6.9 MB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m69.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", | ||
"\u001b[?25hInstalling collected packages: torchvision\n", | ||
"Successfully installed torchvision-0.17.0\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"!pip install torchvision" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /mnt_mount/labproject_data/cifar-10-python.tar.gz\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100.0%\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Extracting /mnt_mount/labproject_data/cifar-10-python.tar.gz to /mnt_mount/labproject_data\n", | ||
"Files already downloaded and verified\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torchvision.transforms as transforms\n", | ||
"from torchvision.datasets import CIFAR10\n", | ||
"\n", | ||
"transform = transforms.Compose([\n", | ||
" transforms.Resize((299, 299)),\n", | ||
" transforms.ToTensor(),\n", | ||
" # normalize specific to inception model\n", | ||
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", | ||
" ])\n", | ||
"\n", | ||
"# load CIFAR10 dataset\n", | ||
"cifar10_train = CIFAR10(root='/mnt_mount/labproject_data', train=True, download=True, transform=transform)\n", | ||
"cifar10_test = CIFAR10(root='/mnt_mount/labproject_data', train=False, download=True, transform=transform)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataloader_1 = torch.utils.data.DataLoader(cifar10_train, batch_size=100, shuffle=False, num_workers=1)\n", | ||
"dataloader_2 = torch.utils.data.DataLoader(cifar10_test, batch_size=100, shuffle=False, num_workers=1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/mnt/miniconda3/envs/labproject/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", | ||
" warnings.warn(\n", | ||
"/mnt/miniconda3/envs/labproject/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=Inception_V3_Weights.IMAGENET1K_V1`. You can also use `weights=Inception_V3_Weights.DEFAULT` to get the most up-to-date weights.\n", | ||
" warnings.warn(msg)\n", | ||
"/mnt/miniconda3/envs/labproject/lib/python3.9/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n", | ||
" warnings.warn(\"Can't initialize NVML\")\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from torchvision.models import inception_v3\n", | ||
"from torchvision.datasets import ImageFolder\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"\n", | ||
"# get embedding net\n", | ||
"def get_embedding_net():\n", | ||
" model = inception_v3(pretrained=True)\n", | ||
" model.fc = torch.nn.Identity() # replace the classifier with identity to get features\n", | ||
" model.eval()\n", | ||
" return model.to('cuda' if torch.cuda.is_available() else 'cpu')\n", | ||
"\n", | ||
"# extract features\n", | ||
"def extract_features(dataloader, model):\n", | ||
" features = []\n", | ||
" with torch.no_grad():\n", | ||
" for data, _ in dataloader:\n", | ||
" data = data.to('cuda' if torch.cuda.is_available() else 'cpu')\n", | ||
" features.append(model(data))\n", | ||
" return torch.cat(features).cpu().numpy()\n", | ||
"\n", | ||
"\n", | ||
"embedding_net = get_embedding_net()\n", | ||
"\n", | ||
"features1 = extract_features(dataloader_1, embedding_net)\n", | ||
"features2 = extract_features(dataloader_2, embedding_net)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"((50000, 2048), (10000, 2048))" | ||
] | ||
}, | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"features1.shape, features2.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance\n", | ||
"\n", | ||
"swd = sliced_wasserstein_distance(torch.from_numpy(features1)[:10000], torch.from_numpy(features2), num_projections=1000)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "labproject", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |