diff --git a/siamese_network_prototype.ipynb b/siamese_network_prototype.ipynb new file mode 100644 index 0000000..43a9234 --- /dev/null +++ b/siamese_network_prototype.ipynb @@ -0,0 +1,1216 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# http://www.cs.utoronto.ca/~gkoch/files/msc-thesis.pdf\n", + "# https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "from fastai.vision import *\n", + "from fastai.metrics import accuracy_thresh\n", + "from fastai.basic_data import *\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torch import nn\n", + "from fastai.callbacks.hooks import num_features_model, model_sizes\n", + "from fastai.layers import BCEWithLogitsFlat\n", + "from fastai.basic_train import Learner\n", + "from skimage.util import montage\n", + "import pandas as pd\n", + "from torch import optim\n", + "import re\n", + "\n", + "from utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# import fastai\n", + "# from fastprogress import force_console_behavior\n", + "# import fastprogress\n", + "# fastprogress.fastprogress.NO_BAR = True\n", + "# master_bar, progress_bar = force_console_behavior()\n", + "# fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Posing the problem as a classification task is probably not ideal. We are asking our NN to learn to recognize a whale out of 5004 possible candidates based on what it has learned about the whales. That is a tall order.\n", + "\n", + "Instead, here we will try to pose the problem as a verification task. When presented with two images of whale flukes, we will ask the network - are the images of the same whale or of different whales? In particular, we will try to teach our network to learn features that can be useful in determining the similarity between whale images (hence the name of this approach - feature learning).\n", + "\n", + "This seems like a much easier task, at least in theory. Either way, no need to start with a relatively big CNN like resnet50. Let's see what mileage we can get out of resnet18." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# new architecture calls for a new validation set, this time our validation set will consist of all whales that have exactly two images\n", + "df = pd.read_csv('data/train.csv')\n", + "im_count = df[df.Id != 'new_whale'].Id.value_counts()\n", + "im_count.name = 'sighting_count'\n", + "df = df.join(im_count, on='Id')\n", + "val_fns = set(df[df.sighting_count == 2].Image)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2570" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(val_fns)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "fn2label = {row[1].Image: row[1].Id for row in df.iterrows()}\n", + "path2fn = lambda path: re.search('\\w*\\.jpg$', path).group(0)\n", + "\n", + "name = f'res18-siamese'" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "SZ = 224\n", + "BS = 64\n", + "NUM_WORKERS = 12\n", + "SEED=0" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# data_block api creates categories based on classes it sees in the train set and\n", + "# our val set contains whales whose ids do not appear in the train set\n", + "classes = df.Id.unique()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "data = (\n", + " ImageItemList\n", + " .from_df(df[df.Id != 'new_whale'], f'data/train-{SZ}', cols=['Image'])\n", + " .split_by_valid_func(lambda path: path2fn(path) in val_fns)\n", + " .label_from_func(lambda path: fn2label[path2fn(path)], classes=classes)\n", + " .add_test(ImageItemList.from_folder(f'data/test-{SZ}'))\n", + " .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)\n", + "# .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')\n", + "# .normalize(imagenet_stats)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I am still using the ImageItemList even though I will create my own datasets. Why? Because I want to reuse the functionality that is already there (creating datasets from files, augmentations, resizing, etc).\n", + "\n", + "I realize the code is neither clean nor elegant but for the time being I am happy with this approach." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def is_even(num): return num % 2 == 0\n", + "\n", + "class TwoImDataset(Dataset):\n", + " def __init__(self, ds):\n", + " self.ds = ds\n", + " self.whale_ids = ds.y.items\n", + " def __len__(self):\n", + " return 2 * len(self.ds)\n", + " def __getitem__(self, idx):\n", + " if is_even(idx):\n", + " return self.sample_same(idx // 2)\n", + " else: return self.sample_different((idx-1) // 2)\n", + " def sample_same(self, idx):\n", + " whale_id = self.whale_ids[idx] \n", + " candidates = list(np.where(self.whale_ids == whale_id)[0])\n", + " candidates.remove(idx) # dropping our current whale - we don't want to compare against an identical image!\n", + " \n", + " if len(candidates) == 0: # oops, there is only a single whale with this id in the dataset\n", + " return self.sample_different(idx)\n", + " \n", + " np.random.shuffle(candidates)\n", + " return self.construct_example(self.ds[idx][0], self.ds[candidates[0]][0], 1)\n", + " def sample_different(self, idx):\n", + " whale_id = self.whale_ids[idx]\n", + " candidates = list(np.where(self.whale_ids != whale_id)[0])\n", + " np.random.shuffle(candidates)\n", + " return self.construct_example(self.ds[idx][0], self.ds[candidates[0]][0], 0)\n", + " \n", + " def construct_example(self, im_A, im_B, class_idx):\n", + " return [im_A, im_B], class_idx" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "train_dl = DataLoader(\n", + " TwoImDataset(data.train),\n", + " batch_size=BS,\n", + " shuffle=True,\n", + " num_workers=NUM_WORKERS\n", + ")\n", + "valid_dl = DataLoader(\n", + " TwoImDataset(data.valid),\n", + " batch_size=BS,\n", + " shuffle=False,\n", + " num_workers=NUM_WORKERS\n", + ")\n", + "\n", + "data_bunch = ImageDataBunch(train_dl, valid_dl)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_batch(batch):\n", + " stat_tensors = [torch.tensor(l).cuda() for l in imagenet_stats]\n", + " return [normalize(batch[0][0], *stat_tensors), normalize(batch[0][1], *stat_tensors)], batch[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "data_bunch.add_tfm(normalize_batch)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from functional import seq\n", + "\n", + "class SiameseNetwork(nn.Module):\n", + " def __init__(self, arch=models.resnet18):\n", + " super().__init__() \n", + " self.cnn = create_body(arch)\n", + " self.head = nn.Linear(num_features_model(self.cnn), 1)\n", + " \n", + " def forward(self, im_A, im_B):\n", + " # dl - distance layer\n", + " x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.process_features)\n", + " dl = self.calculate_distance(x1, x2)\n", + " out = self.head(dl)\n", + " return out\n", + " \n", + " def process_features(self, x): return x.reshape(*x.shape[:2], -1).max(-1)[0]\n", + " def calculate_distance(self, x1, x2): return (x1 - x2).abs_()\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below I include two slightly different siamese networks. I leave the code commented out and choose to use the one above." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# from functional import seq\n", + "\n", + "# def cnn_activations_count(model):\n", + "# _, ch, h, w = model_sizes(create_body(models.resnet18), (SZ, SZ))[-1]\n", + "# return ch * h * w\n", + "\n", + "# class SiameseNetwork(nn.Module):\n", + "# def __init__(self, lin_ftrs=2048, arch=models.resnet18):\n", + "# super().__init__() \n", + "# self.cnn = create_body(arch)\n", + "# self.fc1 = nn.Linear(cnn_activations_count(self.cnn), lin_ftrs)\n", + "# self.fc2 = nn.Linear(lin_ftrs, 1)\n", + " \n", + "# def forward(self, im_A, im_B):\n", + "# x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.process_features).map(self.fc1)\n", + "# dl = self.calculate_distance(x1.sigmoid(), x2.sigmoid())\n", + "# out = self.fc2(dl)\n", + "# return out\n", + " \n", + "# def calculate_distance(self, x1, x2): return (x1 - x2).abs_()\n", + "# def process_features(self, x): return x.reshape(x.shape[0], -1)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# from functional import seq\n", + "\n", + "# def cnn_activations_count(model):\n", + "# _, ch, h, w = model_sizes(create_body(models.resnet18), (SZ, SZ))[-1]\n", + "# return ch * h * w\n", + "\n", + "# class SiameseNetwork(nn.Module):\n", + "# def __init__(self, lin_ftrs=2048, pool_to=3, arch=models.resnet18, pooling_layer=nn.AdaptiveMaxPool2d):\n", + "# super().__init__() \n", + "# self.cnn = create_body(arch)\n", + "# self.pool = pooling_layer(pool_to)\n", + "# self.fc1 = nn.Linear(num_features_model(self.cnn) * pool_to**2, lin_ftrs)\n", + "# self.fc2 = nn.Linear(lin_ftrs, 1)\n", + " \n", + "# def forward(self, im_A, im_B):\n", + "# x1, x2 = seq(im_A, im_B).map(self.cnn).map(self.pool).map(self.process_features).map(self.fc1)\n", + "# dl = self.calculate_distance(x1.sigmoid(), x2.sigmoid())\n", + "# out = self.fc2(dl)\n", + "# return out\n", + " \n", + "# def calculate_distance(self, x1, x2): return (x1 - x2).abs_()\n", + "# def process_features(self, x): return x.reshape(x.shape[0], -1)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "learn = Learner(data_bunch, SiameseNetwork(), loss_func=BCEWithLogitsFlat(), metrics=[lambda preds, targs: accuracy_thresh(preds.squeeze(), targs, sigmoid=False)])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "learn.split([learn.model.cnn[:6], learn.model.cnn[6:], learn.model.head])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "learn.freeze_to(-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" + ] + } + ], + "source": [ + "learn.lr_find()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "learn.recorder.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Total time: 05:11

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_loss
10.5251710.7363620.532685
20.4090360.4255130.759533
30.3694580.3146080.868093
40.3285930.2960970.857588
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "learn.fit_one_cycle(4, 1e-2)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "learn.save(f'{name}-stage-1')" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "learn.unfreeze()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "max_lr = 5e-4\n", + "lrs = [max_lr/100, max_lr/10, max_lr]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Total time: 15:46

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_loss
10.2999700.2851360.863424
20.2867530.2601440.887743
30.2776950.2694930.872763
40.2594900.2344930.895720
50.2291940.2249730.912257
60.2170030.2327600.897082
70.2021610.2152720.907977
80.2039440.2284680.894163
90.2014180.2221400.896498
100.1985990.2179330.899416
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "learn.fit_one_cycle(10, lrs)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "learn.save(f'{name}-stage-2')" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "

" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "learn.recorder.plot_losses()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model is not doing that well - out of presented pairs it gets roughly 10% of examples wrong. I also did a cursory error analysis (not shown here for the sake of brevity) and the model is not doing that great at all.\n", + "\n", + "How can this be? Maybe the nearly absolute positional invariance through the use of global max pooling is not working that well. Maybe there is a bug somewhere? Maybe the model has not been trained for long enough or lacks capacity?\n", + "\n", + "If I do continue to work on this I will definitely take a closer look at each of the angles I list above. For the time being, let's try to predict on the validation set and finish off with making a submission.\n", + "\n", + "The predicting part is where the code gets really messy. That is good enough for now though." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "learn.load(f'{name}-stage-2');" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "new_whale_fns = set(df[df.Id == 'new_whale'].sample(frac=1).Image.iloc[:1000])" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "data = (\n", + " ImageItemList\n", + " .from_df(df, f'data/train-{SZ}', cols=['Image'])\n", + " .split_by_valid_func(lambda path: path2fn(path) in val_fns.union(new_whale_fns))\n", + " .label_from_func(lambda path: fn2label[path2fn(path)], classes=classes)\n", + " .add_test(ImageItemList.from_folder(f'data/test-{SZ}'))\n", + " .transform(get_transforms(do_flip=False), size=SZ, resize_method=ResizeMethod.SQUISH)\n", + " .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')\n", + " .normalize(imagenet_stats)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3570" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(data.valid_ds)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.48 s, sys: 924 ms, total: 2.41 s\n", + "Wall time: 3.29 s\n" + ] + } + ], + "source": [ + "%%time\n", + "targs = []\n", + "feats = []\n", + "learn.model.eval()\n", + "for ims, ts in data.valid_dl:\n", + " feats.append(learn.model.process_features(learn.model.cnn(ims)).detach().cpu())\n", + " targs.append(ts)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "feats = torch.cat(feats)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3570, 512])" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "feats.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 21.2 s, sys: 108 ms, total: 21.3 s\n", + "Wall time: 21.3 s\n" + ] + } + ], + "source": [ + "%%time\n", + "sims = []\n", + "for feat in feats:\n", + " dists = learn.model.calculate_distance(feats, feat.unsqueeze(0).repeat(3570, 1))\n", + " predicted_similarity = learn.model.head(dists.cuda()).sigmoid_()\n", + " sims.append(predicted_similarity.squeeze().detach().cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3570" + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(sims)" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": {}, + "outputs": [], + "source": [ + "new_whale_idx = np.where(classes == 'new_whale')[0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.13 s, sys: 4 ms, total: 1.13 s\n", + "Wall time: 1.13 s\n" + ] + } + ], + "source": [ + "%%time\n", + "top_5s = []\n", + "for sim in sims:\n", + " idxs = sim.argsort(descending=True)\n", + " probs = sim[idxs]\n", + " top_5 = []\n", + " for i, p in zip(idxs, probs):\n", + " if len(top_5) == 5: break\n", + " if i == new_whale_idx: continue\n", + " predicted_class = data.valid_ds.y.items[i]\n", + " if predicted_class not in top_5: top_5.append(predicted_class)\n", + " top_5s.append(top_5)" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.24428104575163398" + ] + }, + "execution_count": 147, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# without predicting new_whale\n", + "mapk(data.valid_ds.y.items.reshape(-1,1), np.stack(top_5s), 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.98 0.27504668534080295\n", + "0.9822222222222222 0.2793790849673203\n", + "0.9844444444444445 0.2841456582633053\n", + "0.9866666666666667 0.2927777777777778\n", + "0.9888888888888889 0.3001960784313726\n", + "0.991111111111111 0.31275443510737627\n", + "0.9933333333333333 0.3257049486461251\n", + "0.9955555555555555 0.33599439775910367\n", + "0.9977777777777778 0.3447152194211017\n", + "1.0 0.34714285714285714\n", + "CPU times: user 12.3 s, sys: 4 ms, total: 12.3 s\n", + "Wall time: 12.3 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "for thresh in np.linspace(0.98, 1, 10):\n", + " top_5s = []\n", + " for sim in sims:\n", + " idxs = sim.argsort(descending=True)\n", + " probs = sim[idxs]\n", + " top_5 = []\n", + " for i, p in zip(idxs, probs):\n", + " if new_whale_idx not in top_5 and p < thresh and len(top_5) < 5: top_5.append(new_whale_idx)\n", + " if len(top_5) == 5: break\n", + " if i == new_whale_idx: continue\n", + " predicted_class = data.valid_ds.y.items[i]\n", + " if predicted_class not in top_5: top_5.append(predicted_class)\n", + " top_5s.append(top_5)\n", + " print(thresh, mapk(data.valid_ds.y.items.reshape(-1,1), np.stack(top_5s), 5))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are many reasons why the best threshold here might not carry over to what would make sense on the test set. It is some indication though of how our model is doing and a useful data point." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Predict" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7960" + ] + }, + "execution_count": 163, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(data.test_ds)" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "metadata": {}, + "outputs": [], + "source": [ + "data = (\n", + " ImageItemList\n", + " .from_df(df, f'data/train-{SZ}', cols=['Image'])\n", + " .split_by_valid_func(lambda path: path2fn(path) in {'69823499d.jpg'}) # in newer version of the fastai library there is .no_split that could be used here\n", + " .label_from_func(lambda path: fn2label[path2fn(path)], classes=classes)\n", + " .add_test(ImageItemList.from_folder(f'data/test-{SZ}'))\n", + " .transform(None, size=SZ, resize_method=ResizeMethod.SQUISH)\n", + " .databunch(bs=BS, num_workers=NUM_WORKERS, path='data')\n", + " .normalize(imagenet_stats)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 190, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.9 s, sys: 1.79 s, total: 4.69 s\n", + "Wall time: 5.03 s\n" + ] + } + ], + "source": [ + "%%time\n", + "test_feats = []\n", + "learn.model.eval()\n", + "for ims, _ in data.test_dl:\n", + " test_feats.append(learn.model.process_features(learn.model.cnn(ims)).detach().cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": 195, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 9.02 s, sys: 5.04 s, total: 14.1 s\n", + "Wall time: 14.4 s\n" + ] + } + ], + "source": [ + "%%time\n", + "train_feats = []\n", + "train_class_idxs = []\n", + "learn.model.eval()\n", + "for ims, t in data.train_dl:\n", + " train_feats.append(learn.model.process_features(learn.model.cnn(ims)).detach().cpu())\n", + " train_class_idxs.append(t)" + ] + }, + { + "cell_type": "code", + "execution_count": 196, + "metadata": {}, + "outputs": [], + "source": [ + "train_class_idxs = torch.cat(train_class_idxs)\n", + "train_feats = torch.cat(train_feats)" + ] + }, + { + "cell_type": "code", + "execution_count": 206, + "metadata": {}, + "outputs": [], + "source": [ + "test_feats = torch.cat(test_feats)" + ] + }, + { + "cell_type": "code", + "execution_count": 209, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5min 7s, sys: 2min 58s, total: 8min 6s\n", + "Wall time: 8min 6s\n" + ] + } + ], + "source": [ + "%%time\n", + "sims = []\n", + "for feat in test_feats:\n", + " dists = learn.model.calculate_distance(train_feats, feat.unsqueeze(0).repeat(25344, 1))\n", + " predicted_similarity = learn.model.head(dists.cuda()).sigmoid_()\n", + " sims.append(predicted_similarity.squeeze().detach().cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": 211, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 19.6 s, sys: 128 ms, total: 19.7 s\n", + "Wall time: 19.7 s\n" + ] + } + ], + "source": [ + "%%time\n", + "thresh = 1\n", + "\n", + "top_5s = []\n", + "for sim in sims:\n", + " idxs = sim.argsort(descending=True)\n", + " probs = sim[idxs]\n", + " top_5 = []\n", + " for i, p in zip(idxs, probs):\n", + " if new_whale_idx not in top_5 and p < thresh and len(top_5) < 5: top_5.append(new_whale_idx)\n", + " if len(top_5) == 5: break\n", + " if i == new_whale_idx: continue\n", + " predicted_class = train_class_idxs[i]\n", + " if predicted_class not in top_5: top_5.append(predicted_class)\n", + " top_5s.append(top_5)" + ] + }, + { + "cell_type": "code", + "execution_count": 221, + "metadata": {}, + "outputs": [], + "source": [ + "top_5_classes = []\n", + "for top_5 in top_5s:\n", + " top_5_classes.append(' '.join([classes[t] for t in top_5]))" + ] + }, + { + "cell_type": "code", + "execution_count": 222, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['new_whale w_9bedea6 w_448e190 w_ab629bb w_67e9aa8',\n", + " 'new_whale w_edce644 w_dd79a10 w_99af1a9 w_ae393cd',\n", + " 'new_whale w_4516ff1 w_d1207d9 w_02c7e9d w_8003858',\n", + " 'new_whale w_0369a5c w_f66ec54 w_ae8982d w_d0475b2',\n", + " 'new_whale w_8cd5c91 w_0cc0430 w_06460d7 w_e8b82f6']" + ] + }, + "execution_count": 222, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top_5_classes[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": 223, + "metadata": {}, + "outputs": [], + "source": [ + "sub = pd.DataFrame({'Image': [path.name for path in data.test_ds.x.items]})\n", + "sub['Id'] = top_5_classes\n", + "sub.to_csv(f'subs/{name}.csv.gz', index=False, compression='gzip')" + ] + }, + { + "cell_type": "code", + "execution_count": 224, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ImageId
047380533f.jpgnew_whale w_9bedea6 w_448e190 w_ab629bb w_67e9aa8
11d9de38ba.jpgnew_whale w_edce644 w_dd79a10 w_99af1a9 w_ae393cd
2b3d4ee916.jpgnew_whale w_4516ff1 w_d1207d9 w_02c7e9d w_8003858
3460fd63ae.jpgnew_whale w_0369a5c w_f66ec54 w_ae8982d w_d0475b2
479738ffc1.jpgnew_whale w_8cd5c91 w_0cc0430 w_06460d7 w_e8b82f6
\n", + "
" + ], + "text/plain": [ + " Image Id\n", + "0 47380533f.jpg new_whale w_9bedea6 w_448e190 w_ab629bb w_67e9aa8\n", + "1 1d9de38ba.jpg new_whale w_edce644 w_dd79a10 w_99af1a9 w_ae393cd\n", + "2 b3d4ee916.jpg new_whale w_4516ff1 w_d1207d9 w_02c7e9d w_8003858\n", + "3 460fd63ae.jpg new_whale w_0369a5c w_f66ec54 w_ae8982d w_d0475b2\n", + "4 79738ffc1.jpg new_whale w_8cd5c91 w_0cc0430 w_06460d7 w_e8b82f6" + ] + }, + "execution_count": 224, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.read_csv(f'subs/{name}.csv.gz').head()" + ] + }, + { + "cell_type": "code", + "execution_count": 225, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 225, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.read_csv(f'subs/{name}.csv.gz').Id.str.split().apply(lambda x: x[0] == 'new_whale').mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 226, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████| 164k/164k [00:03<00:00, 46.1kB/s]\n", + "Successfully submitted to Humpback Whale Identification" + ] + } + ], + "source": [ + "!kaggle competitions submit -c humpback-whale-identification -f subs/{name}.csv.gz -m \"{name}\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}