Skip to content

Commit

Permalink
attention code
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmikhael committed Feb 4, 2024
1 parent 95830d5 commit 84f852e
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 14 deletions.
4 changes: 2 additions & 2 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions


__all__ = ["Sybil", "Serie"]
__all__ = ["Sybil", "Serie", "visualize_attentions"]
54 changes: 47 additions & 7 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@

class Prediction(NamedTuple):
scores: List[List[float]]
attentions: List[Dict[str, np.ndarray]] = None


class Evaluation(NamedTuple):
auc: List[float]
c_index: float
scores: List[List[float]]
attentions: List[Dict[str, np.ndarray]] = None


def download_sybil(name, cache):
Expand Down Expand Up @@ -230,6 +232,7 @@ def _predict(
self,
model: SybilNet,
series: Union[Serie, List[Serie]],
return_attentions: bool = False,
) -> np.ndarray:
"""Run predictions over the given serie(s).
Expand All @@ -252,6 +255,7 @@ def _predict(
raise ValueError("Expected either a Serie object or list of Serie objects.")

scores: List[List[float]] = []
attentions: List[Dict[str, np.ndarray]] = []
for serie in series:
if not isinstance(serie, Serie):
raise ValueError("Expected a list of Serie objects.")
Expand All @@ -264,10 +268,25 @@ def _predict(
out = model(volume)
score = out["logit"].sigmoid().squeeze(0).cpu().numpy()
scores.append(score)

return np.stack(scores)

def predict(self, series: Union[Serie, List[Serie]]) -> Prediction:
if return_attentions:
attentions.append(
{
"image_attention_1": out["image_attention_1"]
.detach()
.cpu(),
"volume_attention_1": out["volume_attention_1"]
.detach()
.cpu(),
}
)
if return_attentions:
return Prediction(scores=np.stack(scores), attentions=attentions)

return Prediction(scores=np.stack(scores))

def predict(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
) -> Prediction:
"""Run predictions over the given serie(s) and ensemble
Parameters
Expand All @@ -282,14 +301,29 @@ def predict(self, series: Union[Serie, List[Serie]]) -> Prediction:
"""
scores = []
attentions_ = []
for sybil in self.ensemble:
pred = self._predict(sybil, series)
scores.append(pred)
scores.append(pred.scores)
if return_attentions:
attentions_.append(pred.attentions)
scores = np.mean(np.array(scores), axis=0)
calib_scores = self._calibrate(scores).tolist()
if return_attentions:
attentions = []
for i in range(len(series)):
att = {}
for key in pred.attentions[0][i].keys():
att[key] = [
pred.attentions[j][i][key] for j in range(len(self.ensemble))
]
attentions.append(att)
return Prediction(scores=calib_scores, attentions=attentions)
return Prediction(scores=calib_scores)

def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation:
def evaluate(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
) -> Evaluation:
"""Run evaluation over the given serie(s).
Parameters
Expand All @@ -315,7 +349,8 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation:
raise ValueError("All series must have a label for evaluation")

# Get scores and labels
scores = self.predict(series).scores
predictions = self.predict(series, return_attentions)
scores = predictions.scores
labels = [serie.get_label(self._max_followup) for serie in series]

# Convert to format for survival metrics
Expand All @@ -331,4 +366,9 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation:
auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)]
c_index = float(out["c_index"])

if return_attentions:
attentions = predictions.attentions
return Evaluation(
auc=auc, c_index=c_index, scores=scores, attentions=attentions
)
return Evaluation(auc=auc, c_index=c_index, scores=scores)
29 changes: 24 additions & 5 deletions sybil/serie.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self._censor_time = censor_time
self._label = label
args = self._load_args(file_type)
self._args = args
self._loader = get_sample_loader(split, args)
self._meta = self._load_metadata(dicoms, voxel_spacing, file_type)
self._check_valid(args)
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_label(self, max_followup: int = 6) -> Label:
raise ValueError("No label in this serie.")

# First convert months to years
year_to_cancer = self._censor_time # type: ignore
year_to_cancer = self._censor_time # type: ignore

y_seq = np.zeros(max_followup, dtype=np.float64)
y = int((year_to_cancer < max_followup) and self._label) # type: ignore
Expand All @@ -119,6 +120,22 @@ def get_label(self, max_followup: int = 6) -> Label:
)
return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer)

def get_raw_images(self) -> List[np.ndarray]:
"""
Load raw images from serie
Returns
-------
List[np.ndarray]
List of CT slices of shape (1, C, H, W)
"""

loader = get_sample_loader("test", self._args)
loader.apply_augmentations = False
input_dicts = [loader.get_image(path, {}) for path in self._meta.paths]
images = [i["input"] for i in input_dicts]
return images

def get_volume(self) -> torch.Tensor:
"""
Load loaded 3D CT volume
Expand All @@ -131,10 +148,12 @@ def get_volume(self) -> torch.Tensor:

sample = {"seed": np.random.randint(0, 2**32 - 1)}

input_dicts = [self._loader.get_image(path, sample) for path in self._meta.paths]

x = torch.cat( [i["input"].unsqueeze(0) for i in input_dicts], dim = 0)

input_dicts = [
self._loader.get_image(path, sample) for path in self._meta.paths
]

x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)

# Convert from (T, C, H, W) to (C, T, H, W)
x = x.permute(1, 0, 2, 3)

Expand Down
76 changes: 76 additions & 0 deletions sybil/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import numpy as np
import torch
import torch.nn.functional as F
from sybil.serie import Serie
from typing import Dict, List, Union
import cv2
import os


def visualize_attentions(
series: Serie,
attentions: List[Dict[str, torch.Tensor]],
save_directory: str = None,
gain: int = 3,
) -> List[List[np.ndarray]]:
"""
Args:
series (Serie): series object
attention_dict (Dict[str, torch.Tensor]): attention dictionary output from model
save_directory (str, optional): where to save the images. Defaults to None.
gain (int, optional): how much to scale attention values by for visualization. Defaults to 3.
Returns:
List[List[np.ndarray]]: list of list of overlayed images
"""

if isinstance(series, Serie):
series = [series]

series_overlays = []
for serie_idx, serie in enumerate(series):
a1 = attentions[serie_idx]["image_attention_1"]
v1 = attentions[serie_idx]["volume_attention_1"]

# TODO:
if len(a1) > 1:
a1 = a1.mean(0)
v1 = v1.mean(0)

attention = torch.exp(a1) * torch.exp(v1).unsqueeze(-1)
attention = attention.view(1, 25, 16, 16)

N = len(serie)
attention_up = F.interpolate(
attention.unsqueeze(0), (N, 512, 512), mode="trilinear"
)

# get original image
images = serie.get_raw_images()

overlayed_images = []
for i in range(N):
overlayed = np.zeros((512, 512, 3))
overlayed[..., 0] = images[i]
overlayed[..., 1] = images[i]
overlayed[..., 2] = np.int16(
np.clip(
(attention_up[0, 0, i] * gain * 256) + images[i],
a_min=0,
a_max=256,
)
)
overlayed_images.append(overlayed)
if save_directory is not None:
save_path = os.path.join(save_directory, f"serie_{serie_idx}")
save_images(overlayed_images, save_path, f"serie_{serie_idx}")

series_overlays.append(overlayed_images)
return series_overlays


def save_images(img_list, directory, name):
os.makedirs(directory, exist_ok=True)
N = len(str(len(img_list)))
for i, im in enumerate(img_list):
cv2.imwrite(f"{directory}/{name}_{'0'*(N - len(str(i))) }{i}.png", im)

0 comments on commit 84f852e

Please sign in to comment.