Skip to content

Commit

Permalink
Merge pull request #54 from FR-DC/FRML-117
Browse files Browse the repository at this point in the history
FRML-117 Improve explainability of adapter
  • Loading branch information
Eve-ning authored Feb 15, 2024
2 parents 6ac2841 + d5c4e67 commit ef8daff
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 24 deletions.
12 changes: 6 additions & 6 deletions src/frdc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
GCS_BUCKET = GCS_CLIENT.bucket(GCS_BUCKET_NAME)
logger.info("Connected to GCS.")
except Exception as e:
except Exception:
logger.warning(
"Could not connect to GCS. Will not be able to download files. "
"Check that you've (1) Installed the GCS CLI and (2) Set up the"
Expand All @@ -79,11 +79,11 @@
LABEL_STUDIO_CLIENT.get_project(1)
except requests.exceptions.HTTPError:
logger.warning(
f"Could not get main annotation project. "
f"Pulling annotations may not work. "
f"It's possible that your API Key is incorrect, "
f"or somehow your .netrc is preventing you from "
f"accessing the project. "
"Could not get main annotation project. "
"Pulling annotations may not work. "
"It's possible that your API Key is incorrect, "
"or somehow your .netrc is preventing you from "
"accessing the project. "
)
except requests.exceptions.ConnectionError:
logger.warning(
Expand Down
1 change: 0 additions & 1 deletion src/frdc/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

41 changes: 29 additions & 12 deletions src/frdc/models/inceptionv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from torch import nn
from torchvision.models import Inception_V3_Weights, inception_v3
from torchvision.models.inception import BasicConv2d, Inception3

from frdc.train.mixmatch_module import MixMatchModule
from frdc.utils.ema import EMA
Expand Down Expand Up @@ -81,29 +82,45 @@ def __init__(
self.ema_lr = ema_lr

@staticmethod
def adapt_inception_multi_channel(inception: nn.Module, in_channels: int):
"""Adapt the 1st layer of the InceptionV3 model to accept n-channels."""
def adapt_inception_multi_channel(
inception: Inception3,
in_channels: int,
) -> Inception3:
"""Adapt the 1st layer of the InceptionV3 model to accept n-channels.
Notes:
This operation is in-place, however will still return the model
Args:
inception: The InceptionV3 model
in_channels: The number of input channels
Returns:
The adapted InceptionV3 model.
"""

original_in_channels = inception.Conv2d_1a_3x3.conv.in_channels

# Replicate the first layer, but with a different number of channels
# We can dynamically pull the architecture from inception if you want
# to make it more general.
conv2d_1a_3x3 = nn.Sequential(
nn.Conv2d(in_channels, 32, bias=False, kernel_size=3, stride=2),
nn.BatchNorm2d(32, eps=0.001),
conv2d_1a_3x3 = BasicConv2d(
in_channels=in_channels,
out_channels=inception.Conv2d_1a_3x3.conv.out_channels,
kernel_size=inception.Conv2d_1a_3x3.conv.kernel_size,
stride=inception.Conv2d_1a_3x3.conv.stride,
)

# Copy the BGR weights from the first layer of the original model
conv2d_1a_3x3[0].weight.data[
:, :3
conv2d_1a_3x3.conv.weight.data[
:, :original_in_channels
] = inception.Conv2d_1a_3x3.conv.weight.data

# We'll repeat the G weights to the other channels as an initial
# approximation
# We use [1:2] instead of [1] so it doesn't lose the dimension
conv2d_1a_3x3[0].weight.data[
:, 3:
conv2d_1a_3x3.conv.weight.data[
:, original_in_channels:
] = inception.Conv2d_1a_3x3.conv.weight.data[:, 1:2].tile(
(in_channels - 3, 1, 1)
(in_channels - original_in_channels, 1, 1)
)

# Finally, set the new layer back
Expand Down
1 change: 0 additions & 1 deletion src/frdc/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

2 changes: 1 addition & 1 deletion src/frdc/train/frdc_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal

from lightning import LightningDataModule
from torch.utils.data import DataLoader, RandomSampler, Sampler
from torch.utils.data import DataLoader, RandomSampler

from frdc.load.dataset import FRDCDataset, FRDCUnlabelledDataset
from frdc.train.stratified_sampling import RandomStratifiedSampler
Expand Down
2 changes: 0 additions & 2 deletions src/frdc/train/mixmatch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.nn.parallel
import wandb
from lightning import LightningModule
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
Expand Down
1 change: 0 additions & 1 deletion src/frdc/train/stratified_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Iterator, Any, Sequence

import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Sampler
Expand Down

0 comments on commit ef8daff

Please sign in to comment.