Skip to content

Commit

Permalink
added fft embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 24, 2024
1 parent 1dbc494 commit d9624f9
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ def forward(self, x):
return x


@add_embedding("RESNET18_FFT")
class ResNet18_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(ResNet18_Encoder, self).__init__()
print("Using FFT ResNet18")
self.resnet = models.resnet18()
self.resnet.conv1 = nn.Conv2d(
2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.resnet.fc = nn.Linear(
in_features=512, out_features=output_dimension, bias=True
)

def forward(self, x):
x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1)))
x = torch.stack([x.real, x.imag], dim=1)
x = self.resnet(x)
return x


@add_embedding("RESNET50")
class ResNet50_Encoder(nn.Module):
def __init__(self, output_dimension: int):
Expand Down

0 comments on commit d9624f9

Please sign in to comment.