diff --git a/src/cryo_sbi/inference/models/embedding_nets.py b/src/cryo_sbi/inference/models/embedding_nets.py index 221e441..daa99ed 100644 --- a/src/cryo_sbi/inference/models/embedding_nets.py +++ b/src/cryo_sbi/inference/models/embedding_nets.py @@ -76,6 +76,26 @@ def forward(self, x): return x +@add_embedding("RESNET18_FFT") +class ResNet18_Encoder(nn.Module): + def __init__(self, output_dimension: int): + super(ResNet18_Encoder, self).__init__() + print("Using FFT ResNet18") + self.resnet = models.resnet18() + self.resnet.conv1 = nn.Conv2d( + 2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False + ) + self.resnet.fc = nn.Linear( + in_features=512, out_features=output_dimension, bias=True + ) + + def forward(self, x): + x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1))) + x = torch.stack([x.real, x.imag], dim=1) + x = self.resnet(x) + return x + + @add_embedding("RESNET50") class ResNet50_Encoder(nn.Module): def __init__(self, output_dimension: int):