Skip to content

Commit

Permalink
msa
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 29, 2024
1 parent 2d04fc9 commit 6923048
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/beignet/nn/_msa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from torch.nn import Conv2d, Module
from torch.nn import Conv1d, Module

import beignet.operators

Expand All @@ -21,12 +21,17 @@ def __init__(

self.temperature = temperature

self.embed = Conv2d(in_channels, out_channels, kernel_size)
self.embedding = Conv1d(
in_channels,
out_channels,
kernel_size,
padding="same",
)

def forward(self, inputs: (Tensor, Tensor)) -> Tensor:
matrices, shapes = inputs

embedding = self.embed(matrices)
embedding = self.embedding(matrices)

embedding = embedding @ embedding[0].T

Expand Down

0 comments on commit 6923048

Please sign in to comment.