diff --git a/sybil/__init__.py b/sybil/__init__.py index 9d6f030..bc7b412 100644 --- a/sybil/__init__.py +++ b/sybil/__init__.py @@ -19,6 +19,6 @@ from sybil.model import Sybil from sybil.serie import Serie +from sybil.utils.visualization import visualize_attentions - -__all__ = ["Sybil", "Serie"] +__all__ = ["Sybil", "Serie", "visualize_attentions"] diff --git a/sybil/model.py b/sybil/model.py index 437e1b3..1992a9b 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -65,12 +65,14 @@ class Prediction(NamedTuple): scores: List[List[float]] + attentions: List[Dict[str, np.ndarray]] = None class Evaluation(NamedTuple): auc: List[float] c_index: float scores: List[List[float]] + attentions: List[Dict[str, np.ndarray]] = None def download_sybil(name, cache): @@ -230,6 +232,7 @@ def _predict( self, model: SybilNet, series: Union[Serie, List[Serie]], + return_attentions: bool = False, ) -> np.ndarray: """Run predictions over the given serie(s). @@ -252,6 +255,7 @@ def _predict( raise ValueError("Expected either a Serie object or list of Serie objects.") scores: List[List[float]] = [] + attentions: List[Dict[str, np.ndarray]] = [] for serie in series: if not isinstance(serie, Serie): raise ValueError("Expected a list of Serie objects.") @@ -264,10 +268,25 @@ def _predict( out = model(volume) score = out["logit"].sigmoid().squeeze(0).cpu().numpy() scores.append(score) - - return np.stack(scores) - - def predict(self, series: Union[Serie, List[Serie]]) -> Prediction: + if return_attentions: + attentions.append( + { + "image_attention_1": out["image_attention_1"] + .detach() + .cpu(), + "volume_attention_1": out["volume_attention_1"] + .detach() + .cpu(), + } + ) + if return_attentions: + return Prediction(scores=np.stack(scores), attentions=attentions) + + return Prediction(scores=np.stack(scores)) + + def predict( + self, series: Union[Serie, List[Serie]], return_attentions: bool = False + ) -> Prediction: """Run predictions over the given serie(s) and ensemble Parameters @@ -282,14 +301,29 @@ def predict(self, series: Union[Serie, List[Serie]]) -> Prediction: """ scores = [] + attentions_ = [] for sybil in self.ensemble: pred = self._predict(sybil, series) - scores.append(pred) + scores.append(pred.scores) + if return_attentions: + attentions_.append(pred.attentions) scores = np.mean(np.array(scores), axis=0) calib_scores = self._calibrate(scores).tolist() + if return_attentions: + attentions = [] + for i in range(len(series)): + att = {} + for key in pred.attentions[0][i].keys(): + att[key] = [ + pred.attentions[j][i][key] for j in range(len(self.ensemble)) + ] + attentions.append(att) + return Prediction(scores=calib_scores, attentions=attentions) return Prediction(scores=calib_scores) - def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation: + def evaluate( + self, series: Union[Serie, List[Serie]], return_attentions: bool = False + ) -> Evaluation: """Run evaluation over the given serie(s). Parameters @@ -315,7 +349,8 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation: raise ValueError("All series must have a label for evaluation") # Get scores and labels - scores = self.predict(series).scores + predictions = self.predict(series, return_attentions) + scores = predictions.scores labels = [serie.get_label(self._max_followup) for serie in series] # Convert to format for survival metrics @@ -331,4 +366,9 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation: auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)] c_index = float(out["c_index"]) + if return_attentions: + attentions = predictions.attentions + return Evaluation( + auc=auc, c_index=c_index, scores=scores, attentions=attentions + ) return Evaluation(auc=auc, c_index=c_index, scores=scores) diff --git a/sybil/serie.py b/sybil/serie.py index b4ea826..0baf186 100644 --- a/sybil/serie.py +++ b/sybil/serie.py @@ -63,6 +63,7 @@ def __init__( self._censor_time = censor_time self._label = label args = self._load_args(file_type) + self._args = args self._loader = get_sample_loader(split, args) self._meta = self._load_metadata(dicoms, voxel_spacing, file_type) self._check_valid(args) @@ -104,7 +105,7 @@ def get_label(self, max_followup: int = 6) -> Label: raise ValueError("No label in this serie.") # First convert months to years - year_to_cancer = self._censor_time # type: ignore + year_to_cancer = self._censor_time # type: ignore y_seq = np.zeros(max_followup, dtype=np.float64) y = int((year_to_cancer < max_followup) and self._label) # type: ignore @@ -119,6 +120,22 @@ def get_label(self, max_followup: int = 6) -> Label: ) return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer) + def get_raw_images(self) -> List[np.ndarray]: + """ + Load raw images from serie + + Returns + ------- + List[np.ndarray] + List of CT slices of shape (1, C, H, W) + """ + + loader = get_sample_loader("test", self._args) + loader.apply_augmentations = False + input_dicts = [loader.get_image(path, {}) for path in self._meta.paths] + images = [i["input"] for i in input_dicts] + return images + def get_volume(self) -> torch.Tensor: """ Load loaded 3D CT volume @@ -131,10 +148,12 @@ def get_volume(self) -> torch.Tensor: sample = {"seed": np.random.randint(0, 2**32 - 1)} - input_dicts = [self._loader.get_image(path, sample) for path in self._meta.paths] - - x = torch.cat( [i["input"].unsqueeze(0) for i in input_dicts], dim = 0) - + input_dicts = [ + self._loader.get_image(path, sample) for path in self._meta.paths + ] + + x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0) + # Convert from (T, C, H, W) to (C, T, H, W) x = x.permute(1, 0, 2, 3) diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py new file mode 100644 index 0000000..e6d5096 --- /dev/null +++ b/sybil/utils/visualization.py @@ -0,0 +1,76 @@ +import numpy as np +import torch +import torch.nn.functional as F +from sybil.serie import Serie +from typing import Dict, List, Union +import cv2 +import os + + +def visualize_attentions( + series: Serie, + attentions: List[Dict[str, torch.Tensor]], + save_directory: str = None, + gain: int = 3, +) -> List[List[np.ndarray]]: + """ + Args: + series (Serie): series object + attention_dict (Dict[str, torch.Tensor]): attention dictionary output from model + save_directory (str, optional): where to save the images. Defaults to None. + gain (int, optional): how much to scale attention values by for visualization. Defaults to 3. + + Returns: + List[List[np.ndarray]]: list of list of overlayed images + """ + + if isinstance(series, Serie): + series = [series] + + series_overlays = [] + for serie_idx, serie in enumerate(series): + a1 = attentions[serie_idx]["image_attention_1"] + v1 = attentions[serie_idx]["volume_attention_1"] + + # TODO: + if len(a1) > 1: + a1 = a1.mean(0) + v1 = v1.mean(0) + + attention = torch.exp(a1) * torch.exp(v1).unsqueeze(-1) + attention = attention.view(1, 25, 16, 16) + + N = len(serie) + attention_up = F.interpolate( + attention.unsqueeze(0), (N, 512, 512), mode="trilinear" + ) + + # get original image + images = serie.get_raw_images() + + overlayed_images = [] + for i in range(N): + overlayed = np.zeros((512, 512, 3)) + overlayed[..., 0] = images[i] + overlayed[..., 1] = images[i] + overlayed[..., 2] = np.int16( + np.clip( + (attention_up[0, 0, i] * gain * 256) + images[i], + a_min=0, + a_max=256, + ) + ) + overlayed_images.append(overlayed) + if save_directory is not None: + save_path = os.path.join(save_directory, f"serie_{serie_idx}") + save_images(overlayed_images, save_path, f"serie_{serie_idx}") + + series_overlays.append(overlayed_images) + return series_overlays + + +def save_images(img_list, directory, name): + os.makedirs(directory, exist_ok=True) + N = len(str(len(img_list))) + for i, im in enumerate(img_list): + cv2.imwrite(f"{directory}/{name}_{'0'*(N - len(str(i))) }{i}.png", im)