From 37a5649bf75939e90516591cd66850d069ac1b97 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 15 Feb 2024 12:38:23 +0800 Subject: [PATCH 1/5] Improve explainability of adapter --- src/frdc/models/inceptionv3.py | 35 ++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/frdc/models/inceptionv3.py b/src/frdc/models/inceptionv3.py index 3fe835c5..a48b2824 100644 --- a/src/frdc/models/inceptionv3.py +++ b/src/frdc/models/inceptionv3.py @@ -4,6 +4,7 @@ from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import Inception_V3_Weights, inception_v3 +from torchvision.models.inception import BasicConv2d, Inception3 from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA @@ -81,29 +82,43 @@ def __init__( self.ema_lr = ema_lr @staticmethod - def adapt_inception_multi_channel(inception: nn.Module, in_channels: int): - """Adapt the 1st layer of the InceptionV3 model to accept n-channels.""" + def adapt_inception_multi_channel( + inception: Inception3, + in_channels: int, + ) -> Inception3: + """Adapt the 1st layer of the InceptionV3 model to accept n-channels. + + Notes: + This operation is in-place, however will still return the model + + Args: + inception: The InceptionV3 model + in_channels: The number of input channels + + Returns: + The adapted InceptionV3 model. + """ + + original_in_channels = inception.Conv2d_1a_3x3.conv.in_channels # Replicate the first layer, but with a different number of channels - # We can dynamically pull the architecture from inception if you want - # to make it more general. - conv2d_1a_3x3 = nn.Sequential( - nn.Conv2d(in_channels, 32, bias=False, kernel_size=3, stride=2), - nn.BatchNorm2d(32, eps=0.001), + conv2d_1a_3x3 = BasicConv2d( + in_channels=in_channels, + out_channels=inception.Conv2d_1a_3x3.conv.out_channels, ) # Copy the BGR weights from the first layer of the original model conv2d_1a_3x3[0].weight.data[ - :, :3 + :, :original_in_channels ] = inception.Conv2d_1a_3x3.conv.weight.data # We'll repeat the G weights to the other channels as an initial # approximation # We use [1:2] instead of [1] so it doesn't lose the dimension conv2d_1a_3x3[0].weight.data[ - :, 3: + :, original_in_channels: ] = inception.Conv2d_1a_3x3.conv.weight.data[:, 1:2].tile( - (in_channels - 3, 1, 1) + (in_channels - original_in_channels, 1, 1) ) # Finally, set the new layer back From 0e080eefc3b81a261bbe5a04b4808a0b73338145 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 15 Feb 2024 12:46:55 +0800 Subject: [PATCH 2/5] Fix missing kernel size and stride spec --- src/frdc/models/inceptionv3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/frdc/models/inceptionv3.py b/src/frdc/models/inceptionv3.py index a48b2824..ebfe4376 100644 --- a/src/frdc/models/inceptionv3.py +++ b/src/frdc/models/inceptionv3.py @@ -105,6 +105,8 @@ def adapt_inception_multi_channel( conv2d_1a_3x3 = BasicConv2d( in_channels=in_channels, out_channels=inception.Conv2d_1a_3x3.conv.out_channels, + kernel_size=3, + stride=2, ) # Copy the BGR weights from the first layer of the original model From 91a4e34afed8fe586f9227b3ddf1a50d364456e8 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 15 Feb 2024 12:47:41 +0800 Subject: [PATCH 3/5] Dynamically fetch the kernel and stride --- src/frdc/models/inceptionv3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frdc/models/inceptionv3.py b/src/frdc/models/inceptionv3.py index ebfe4376..8f6b27af 100644 --- a/src/frdc/models/inceptionv3.py +++ b/src/frdc/models/inceptionv3.py @@ -105,8 +105,8 @@ def adapt_inception_multi_channel( conv2d_1a_3x3 = BasicConv2d( in_channels=in_channels, out_channels=inception.Conv2d_1a_3x3.conv.out_channels, - kernel_size=3, - stride=2, + kernel_size=inception.Conv2d_1a_3x3.conv.kernel_size, + stride=inception.Conv2d_1a_3x3.conv.stride, ) # Copy the BGR weights from the first layer of the original model From 48fb4028de3aac783ae7b98578ddf3bc578889b3 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 15 Feb 2024 12:57:20 +0800 Subject: [PATCH 4/5] Fix bad access to conv layer --- src/frdc/models/inceptionv3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frdc/models/inceptionv3.py b/src/frdc/models/inceptionv3.py index 8f6b27af..4ee47ed6 100644 --- a/src/frdc/models/inceptionv3.py +++ b/src/frdc/models/inceptionv3.py @@ -110,14 +110,14 @@ def adapt_inception_multi_channel( ) # Copy the BGR weights from the first layer of the original model - conv2d_1a_3x3[0].weight.data[ + conv2d_1a_3x3.conv.weight.data[ :, :original_in_channels ] = inception.Conv2d_1a_3x3.conv.weight.data # We'll repeat the G weights to the other channels as an initial # approximation # We use [1:2] instead of [1] so it doesn't lose the dimension - conv2d_1a_3x3[0].weight.data[ + conv2d_1a_3x3.conv.weight.data[ :, original_in_channels: ] = inception.Conv2d_1a_3x3.conv.weight.data[:, 1:2].tile( (in_channels - original_in_channels, 1, 1) From d5c4e67bf87bc384023de8a75dcb368d390f0606 Mon Sep 17 00:00:00 2001 From: Evening Date: Thu, 15 Feb 2024 12:59:59 +0800 Subject: [PATCH 5/5] Lint --- src/frdc/conf.py | 12 ++++++------ src/frdc/evaluate/__init__.py | 1 - src/frdc/train/__init__.py | 1 - src/frdc/train/frdc_datamodule.py | 2 +- src/frdc/train/mixmatch_module.py | 2 -- src/frdc/train/stratified_sampling.py | 1 - 6 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/frdc/conf.py b/src/frdc/conf.py index d0cf406f..d9683d23 100644 --- a/src/frdc/conf.py +++ b/src/frdc/conf.py @@ -56,7 +56,7 @@ ) GCS_BUCKET = GCS_CLIENT.bucket(GCS_BUCKET_NAME) logger.info("Connected to GCS.") -except Exception as e: +except Exception: logger.warning( "Could not connect to GCS. Will not be able to download files. " "Check that you've (1) Installed the GCS CLI and (2) Set up the" @@ -79,11 +79,11 @@ LABEL_STUDIO_CLIENT.get_project(1) except requests.exceptions.HTTPError: logger.warning( - f"Could not get main annotation project. " - f"Pulling annotations may not work. " - f"It's possible that your API Key is incorrect, " - f"or somehow your .netrc is preventing you from " - f"accessing the project. " + "Could not get main annotation project. " + "Pulling annotations may not work. " + "It's possible that your API Key is incorrect, " + "or somehow your .netrc is preventing you from " + "accessing the project. " ) except requests.exceptions.ConnectionError: logger.warning( diff --git a/src/frdc/evaluate/__init__.py b/src/frdc/evaluate/__init__.py index 8b137891..e69de29b 100644 --- a/src/frdc/evaluate/__init__.py +++ b/src/frdc/evaluate/__init__.py @@ -1 +0,0 @@ - diff --git a/src/frdc/train/__init__.py b/src/frdc/train/__init__.py index 8b137891..e69de29b 100644 --- a/src/frdc/train/__init__.py +++ b/src/frdc/train/__init__.py @@ -1 +0,0 @@ - diff --git a/src/frdc/train/frdc_datamodule.py b/src/frdc/train/frdc_datamodule.py index 5e4e6dbd..3b8db5f4 100644 --- a/src/frdc/train/frdc_datamodule.py +++ b/src/frdc/train/frdc_datamodule.py @@ -4,7 +4,7 @@ from typing import Literal from lightning import LightningDataModule -from torch.utils.data import DataLoader, RandomSampler, Sampler +from torch.utils.data import DataLoader, RandomSampler from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset from frdc.train.stratified_sampling import RandomStratifiedSampler diff --git a/src/frdc/train/mixmatch_module.py b/src/frdc/train/mixmatch_module.py index 8acb276d..eafa132a 100644 --- a/src/frdc/train/mixmatch_module.py +++ b/src/frdc/train/mixmatch_module.py @@ -6,8 +6,6 @@ import numpy as np import torch import torch.nn.functional as F -import torch.nn.parallel -import torch.nn.parallel import wandb from lightning import LightningModule from sklearn.preprocessing import StandardScaler, OrdinalEncoder diff --git a/src/frdc/train/stratified_sampling.py b/src/frdc/train/stratified_sampling.py index dd17762c..b578271f 100644 --- a/src/frdc/train/stratified_sampling.py +++ b/src/frdc/train/stratified_sampling.py @@ -2,7 +2,6 @@ from typing import Iterator, Any, Sequence -import pandas as pd import torch from sklearn.preprocessing import LabelEncoder from torch.utils.data import Sampler