From 69230486b40e327ad2ff835fc756066988b56248 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Mon, 29 Apr 2024 16:16:12 -0400 Subject: [PATCH] msa --- src/beignet/nn/_msa.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/beignet/nn/_msa.py b/src/beignet/nn/_msa.py index f8746a3a7d..92d6b7eeac 100644 --- a/src/beignet/nn/_msa.py +++ b/src/beignet/nn/_msa.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch.nn import Conv2d, Module +from torch.nn import Conv1d, Module import beignet.operators @@ -21,12 +21,17 @@ def __init__( self.temperature = temperature - self.embed = Conv2d(in_channels, out_channels, kernel_size) + self.embedding = Conv1d( + in_channels, + out_channels, + kernel_size, + padding="same", + ) def forward(self, inputs: (Tensor, Tensor)) -> Tensor: matrices, shapes = inputs - embedding = self.embed(matrices) + embedding = self.embedding(matrices) embedding = embedding @ embedding[0].T