Skip to content

Commit

Permalink
Refactor attention calculation
Browse files Browse the repository at this point in the history
Make it easier for other tools (eg ark) to calculate attention. Also clip small attention values by default.
  • Loading branch information
jsilter committed Sep 5, 2024
1 parent 0699737 commit c1f8c2a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
4 changes: 2 additions & 2 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions
from sybil.utils.visualization import visualize_attentions, collate_attention_scores
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"]
__all__ = ["Sybil", "Serie", "visualize_attentions", "collate_attention_scores", "__version__"]
42 changes: 24 additions & 18 deletions sybil/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@
import os
import imageio

def collate_attention_scores(attention_dict: Dict[str, np.ndarray], N: int, eps=1e-6) -> torch.Tensor:
a1 = attention_dict["image_attention_1"]
v1 = attention_dict["volume_attention_1"]

a1 = torch.Tensor(a1)
v1 = torch.Tensor(v1)

# take mean attention over ensemble
a1 = torch.exp(a1).mean(0)
v1 = torch.exp(v1).mean(0)

attention = a1 * v1.unsqueeze(-1)
attention = attention.view(1, 25, 16, 16)

attention_up = F.interpolate(
attention.unsqueeze(0), (N, 512, 512), mode="trilinear"
)
attention_up = attention_up.cpu().numpy()
if eps:
attention_up[attention_up <= eps] = 0.0

return attention_up


def visualize_attentions(
series: Union[Serie, List[Serie]],
Expand All @@ -29,26 +52,9 @@ def visualize_attentions(

series_overlays = []
for serie_idx, serie in enumerate(series):
a1 = attentions[serie_idx]["image_attention_1"]
v1 = attentions[serie_idx]["volume_attention_1"]

a1 = torch.Tensor(a1)
v1 = torch.Tensor(v1)

# take mean attention over ensemble
a1 = torch.exp(a1).mean(0)
v1 = torch.exp(v1).mean(0)

attention = a1 * v1.unsqueeze(-1)
attention = attention.view(1, 25, 16, 16)

# get original image
images = serie.get_raw_images()

N = len(images)
attention_up = F.interpolate(
attention.unsqueeze(0), (N, 512, 512), mode="trilinear"
)
attention_up = collate_attention_scores(attentions[serie_idx], N)

overlayed_images = []
for i in range(N):
Expand Down

0 comments on commit c1f8c2a

Please sign in to comment.