diff --git a/src/networks/lenet.py b/src/networks/lenet.py index 2b4f012..b556cb3 100644 --- a/src/networks/lenet.py +++ b/src/networks/lenet.py @@ -2,7 +2,7 @@ import torch.nn.functional as F -class LeNet(nn.Module): +class LeNetArch(nn.Module): """LeNet-like network for tests with MNIST (28x28).""" def __init__(self, in_channels=1, num_classes=10, **kwargs): @@ -28,3 +28,9 @@ def forward(self, x): out = F.relu(self.fc2(out)) out = self.fc(out) return out + + +def LeNet(pretrained=False, **kwargs): + if pretrained: + raise NotImplementedError + return LeNetArch(**kwargs)