diff --git a/src/frdc/models/inceptionv3.py b/src/frdc/models/inceptionv3.py index df29187b..4ee47ed6 100644 --- a/src/frdc/models/inceptionv3.py +++ b/src/frdc/models/inceptionv3.py @@ -4,24 +4,12 @@ from sklearn.preprocessing import OrdinalEncoder, StandardScaler from torch import nn from torchvision.models import Inception_V3_Weights, inception_v3 -from torchvision.models.inception import Inception3 +from torchvision.models.inception import BasicConv2d, Inception3 from frdc.train.mixmatch_module import MixMatchModule from frdc.utils.ema import EMA -class BasicConv2d(nn.Module): - def __init__(self, in_channels: int, out_channels: int, **kwargs) -> None: - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return x - - class InceptionV3MixMatchModule(MixMatchModule): INCEPTION_OUT_DIMS = 2048 INCEPTION_AUX_DIMS = 1000