diff --git a/src/boltz/model/layers/outer_product_mean.py b/src/boltz/model/layers/outer_product_mean.py index ff79de1..fbd8867 100644 --- a/src/boltz/model/layers/outer_product_mean.py +++ b/src/boltz/model/layers/outer_product_mean.py @@ -85,6 +85,7 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor: z_out = z_out + z.to(m) @ sliced_weight_proj_o.T return z_out else: + mask = mask[:, :, None, :] * mask[:, :, :, None] num_mask = mask.sum(1).clamp(min=1) z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float()) z = z.reshape(*z.shape[:3], -1)