Skip to content

Commit

Permalink
run through viz
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmikhael committed Feb 7, 2024
1 parent 053da14 commit 3ea1d87
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions sybil/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch
import torch.nn.functional as F
from sybil.serie import Serie
from typing import Dict, List, Union
import cv2
from typing import Dict, List
import os
import imageio


def visualize_attentions(
Expand Down Expand Up @@ -32,35 +32,35 @@ def visualize_attentions(
a1 = attentions[serie_idx]["image_attention_1"]
v1 = attentions[serie_idx]["volume_attention_1"]

# TODO:
# take mean attention over ensemble
if len(a1) > 1:
a1 = a1.mean(0)
v1 = v1.mean(0)
a1 = torch.exp(torch.stack(a1)).mean(0)
v1 = torch.exp(torch.stack(v1)).mean(0)

attention = torch.exp(a1) * torch.exp(v1).unsqueeze(-1)
attention = a1 * v1.unsqueeze(-1)
attention = attention.view(1, 25, 16, 16)

N = len(serie)
# get original image
images = serie.get_raw_images()

N = len(images)
attention_up = F.interpolate(
attention.unsqueeze(0), (N, 512, 512), mode="trilinear"
)

# get original image
images = serie.get_raw_images()

overlayed_images = []
for i in range(N):
overlayed = np.zeros((512, 512, 3))
overlayed[..., 0] = images[i]
overlayed[..., 1] = images[i]
overlayed[..., 2] = np.int16(
np.clip(
(attention_up[0, 0, i] * gain * 256) + images[i],
a_min=0,
a_max=256,
)
overlayed[..., 2] = np.clip(
(attention_up[0, 0, i] * gain * 256) + images[i],
a_min=0,
a_max=256,
)
overlayed_images.append(overlayed)

overlayed_images.append(np.uint8(overlayed))

if save_directory is not None:
save_path = os.path.join(save_directory, f"serie_{serie_idx}")
save_images(overlayed_images, save_path, f"serie_{serie_idx}")
Expand All @@ -69,8 +69,18 @@ def visualize_attentions(
return series_overlays


def save_images(img_list, directory, name):
def save_images(img_list: List[np.ndarray], directory: str, name: str):
"""
Saves a list of images as a GIF in the specified directory with the given name.
Args:
``img_list`` (List[np.ndarray]): A list of numpy arrays representing the images to be saved.
``directory`` (str): The directory where the GIF should be saved.
``name`` (str): The name of the GIF file.
Returns:
None
"""
os.makedirs(directory, exist_ok=True)
N = len(str(len(img_list)))
for i, im in enumerate(img_list):
cv2.imwrite(f"{directory}/{name}_{'0'*(N - len(str(i))) }{i}.png", im)
path = os.path.join(directory, f"{name}.gif")
imageio.mimsave(path, img_list)

0 comments on commit 3ea1d87

Please sign in to comment.