From 3ad510a1ba8d9bc97727a3ffb5d4c971a8b4460d Mon Sep 17 00:00:00 2001 From: Saro Passaro Date: Thu, 28 Nov 2024 11:30:21 +0000 Subject: [PATCH] Bug on non-chunking! --- src/boltz/model/layers/outer_product_mean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/boltz/model/layers/outer_product_mean.py b/src/boltz/model/layers/outer_product_mean.py index ff79de1..fbd8867 100644 --- a/src/boltz/model/layers/outer_product_mean.py +++ b/src/boltz/model/layers/outer_product_mean.py @@ -85,6 +85,7 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor: z_out = z_out + z.to(m) @ sliced_weight_proj_o.T return z_out else: + mask = mask[:, :, None, :] * mask[:, :, :, None] num_mask = mask.sum(1).clamp(min=1) z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float()) z = z.reshape(*z.shape[:3], -1)