From 84f852e8042fc2768e98f350800fa335cb22ae35 Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Sat, 3 Feb 2024 23:10:42 -0500 Subject: [PATCH 1/8] attention code --- sybil/__init__.py | 4 +- sybil/model.py | 54 +++++++++++++++++++++---- sybil/serie.py | 29 +++++++++++--- sybil/utils/visualization.py | 76 ++++++++++++++++++++++++++++++++++++ 4 files changed, 149 insertions(+), 14 deletions(-) create mode 100644 sybil/utils/visualization.py diff --git a/sybil/__init__.py b/sybil/__init__.py index 9d6f030..bc7b412 100644 --- a/sybil/__init__.py +++ b/sybil/__init__.py @@ -19,6 +19,6 @@ from sybil.model import Sybil from sybil.serie import Serie +from sybil.utils.visualization import visualize_attentions - -__all__ = ["Sybil", "Serie"] +__all__ = ["Sybil", "Serie", "visualize_attentions"] diff --git a/sybil/model.py b/sybil/model.py index 437e1b3..1992a9b 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -65,12 +65,14 @@ class Prediction(NamedTuple): scores: List[List[float]] + attentions: List[Dict[str, np.ndarray]] = None class Evaluation(NamedTuple): auc: List[float] c_index: float scores: List[List[float]] + attentions: List[Dict[str, np.ndarray]] = None def download_sybil(name, cache): @@ -230,6 +232,7 @@ def _predict( self, model: SybilNet, series: Union[Serie, List[Serie]], + return_attentions: bool = False, ) -> np.ndarray: """Run predictions over the given serie(s). @@ -252,6 +255,7 @@ def _predict( raise ValueError("Expected either a Serie object or list of Serie objects.") scores: List[List[float]] = [] + attentions: List[Dict[str, np.ndarray]] = [] for serie in series: if not isinstance(serie, Serie): raise ValueError("Expected a list of Serie objects.") @@ -264,10 +268,25 @@ def _predict( out = model(volume) score = out["logit"].sigmoid().squeeze(0).cpu().numpy() scores.append(score) - - return np.stack(scores) - - def predict(self, series: Union[Serie, List[Serie]]) -> Prediction: + if return_attentions: + attentions.append( + { + "image_attention_1": out["image_attention_1"] + .detach() + .cpu(), + "volume_attention_1": out["volume_attention_1"] + .detach() + .cpu(), + } + ) + if return_attentions: + return Prediction(scores=np.stack(scores), attentions=attentions) + + return Prediction(scores=np.stack(scores)) + + def predict( + self, series: Union[Serie, List[Serie]], return_attentions: bool = False + ) -> Prediction: """Run predictions over the given serie(s) and ensemble Parameters @@ -282,14 +301,29 @@ def predict(self, series: Union[Serie, List[Serie]]) -> Prediction: """ scores = [] + attentions_ = [] for sybil in self.ensemble: pred = self._predict(sybil, series) - scores.append(pred) + scores.append(pred.scores) + if return_attentions: + attentions_.append(pred.attentions) scores = np.mean(np.array(scores), axis=0) calib_scores = self._calibrate(scores).tolist() + if return_attentions: + attentions = [] + for i in range(len(series)): + att = {} + for key in pred.attentions[0][i].keys(): + att[key] = [ + pred.attentions[j][i][key] for j in range(len(self.ensemble)) + ] + attentions.append(att) + return Prediction(scores=calib_scores, attentions=attentions) return Prediction(scores=calib_scores) - def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation: + def evaluate( + self, series: Union[Serie, List[Serie]], return_attentions: bool = False + ) -> Evaluation: """Run evaluation over the given serie(s). Parameters @@ -315,7 +349,8 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation: raise ValueError("All series must have a label for evaluation") # Get scores and labels - scores = self.predict(series).scores + predictions = self.predict(series, return_attentions) + scores = predictions.scores labels = [serie.get_label(self._max_followup) for serie in series] # Convert to format for survival metrics @@ -331,4 +366,9 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation: auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)] c_index = float(out["c_index"]) + if return_attentions: + attentions = predictions.attentions + return Evaluation( + auc=auc, c_index=c_index, scores=scores, attentions=attentions + ) return Evaluation(auc=auc, c_index=c_index, scores=scores) diff --git a/sybil/serie.py b/sybil/serie.py index b4ea826..0baf186 100644 --- a/sybil/serie.py +++ b/sybil/serie.py @@ -63,6 +63,7 @@ def __init__( self._censor_time = censor_time self._label = label args = self._load_args(file_type) + self._args = args self._loader = get_sample_loader(split, args) self._meta = self._load_metadata(dicoms, voxel_spacing, file_type) self._check_valid(args) @@ -104,7 +105,7 @@ def get_label(self, max_followup: int = 6) -> Label: raise ValueError("No label in this serie.") # First convert months to years - year_to_cancer = self._censor_time # type: ignore + year_to_cancer = self._censor_time # type: ignore y_seq = np.zeros(max_followup, dtype=np.float64) y = int((year_to_cancer < max_followup) and self._label) # type: ignore @@ -119,6 +120,22 @@ def get_label(self, max_followup: int = 6) -> Label: ) return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer) + def get_raw_images(self) -> List[np.ndarray]: + """ + Load raw images from serie + + Returns + ------- + List[np.ndarray] + List of CT slices of shape (1, C, H, W) + """ + + loader = get_sample_loader("test", self._args) + loader.apply_augmentations = False + input_dicts = [loader.get_image(path, {}) for path in self._meta.paths] + images = [i["input"] for i in input_dicts] + return images + def get_volume(self) -> torch.Tensor: """ Load loaded 3D CT volume @@ -131,10 +148,12 @@ def get_volume(self) -> torch.Tensor: sample = {"seed": np.random.randint(0, 2**32 - 1)} - input_dicts = [self._loader.get_image(path, sample) for path in self._meta.paths] - - x = torch.cat( [i["input"].unsqueeze(0) for i in input_dicts], dim = 0) - + input_dicts = [ + self._loader.get_image(path, sample) for path in self._meta.paths + ] + + x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0) + # Convert from (T, C, H, W) to (C, T, H, W) x = x.permute(1, 0, 2, 3) diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py new file mode 100644 index 0000000..e6d5096 --- /dev/null +++ b/sybil/utils/visualization.py @@ -0,0 +1,76 @@ +import numpy as np +import torch +import torch.nn.functional as F +from sybil.serie import Serie +from typing import Dict, List, Union +import cv2 +import os + + +def visualize_attentions( + series: Serie, + attentions: List[Dict[str, torch.Tensor]], + save_directory: str = None, + gain: int = 3, +) -> List[List[np.ndarray]]: + """ + Args: + series (Serie): series object + attention_dict (Dict[str, torch.Tensor]): attention dictionary output from model + save_directory (str, optional): where to save the images. Defaults to None. + gain (int, optional): how much to scale attention values by for visualization. Defaults to 3. + + Returns: + List[List[np.ndarray]]: list of list of overlayed images + """ + + if isinstance(series, Serie): + series = [series] + + series_overlays = [] + for serie_idx, serie in enumerate(series): + a1 = attentions[serie_idx]["image_attention_1"] + v1 = attentions[serie_idx]["volume_attention_1"] + + # TODO: + if len(a1) > 1: + a1 = a1.mean(0) + v1 = v1.mean(0) + + attention = torch.exp(a1) * torch.exp(v1).unsqueeze(-1) + attention = attention.view(1, 25, 16, 16) + + N = len(serie) + attention_up = F.interpolate( + attention.unsqueeze(0), (N, 512, 512), mode="trilinear" + ) + + # get original image + images = serie.get_raw_images() + + overlayed_images = [] + for i in range(N): + overlayed = np.zeros((512, 512, 3)) + overlayed[..., 0] = images[i] + overlayed[..., 1] = images[i] + overlayed[..., 2] = np.int16( + np.clip( + (attention_up[0, 0, i] * gain * 256) + images[i], + a_min=0, + a_max=256, + ) + ) + overlayed_images.append(overlayed) + if save_directory is not None: + save_path = os.path.join(save_directory, f"serie_{serie_idx}") + save_images(overlayed_images, save_path, f"serie_{serie_idx}") + + series_overlays.append(overlayed_images) + return series_overlays + + +def save_images(img_list, directory, name): + os.makedirs(directory, exist_ok=True) + N = len(str(len(img_list))) + for i, im in enumerate(img_list): + cv2.imwrite(f"{directory}/{name}_{'0'*(N - len(str(i))) }{i}.png", im) From 6aff7f20f829479792def4285fc956b81afd3887 Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Tue, 6 Feb 2024 16:12:32 -0500 Subject: [PATCH 2/8] debug --- sybil/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sybil/model.py b/sybil/model.py index 1992a9b..5a8bfb2 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -303,7 +303,7 @@ def predict( scores = [] attentions_ = [] for sybil in self.ensemble: - pred = self._predict(sybil, series) + pred = self._predict(sybil, series, return_attentions) scores.append(pred.scores) if return_attentions: attentions_.append(pred.attentions) @@ -313,9 +313,9 @@ def predict( attentions = [] for i in range(len(series)): att = {} - for key in pred.attentions[0][i].keys(): + for key in pred.attentions[0].keys(): att[key] = [ - pred.attentions[j][i][key] for j in range(len(self.ensemble)) + attentions_[j][i][key] for j in range(len(self.ensemble)) ] attentions.append(att) return Prediction(scores=calib_scores, attentions=attentions) From 37ec57286bbae9588d0c957646a7a37164280467 Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Tue, 6 Feb 2024 16:22:26 -0500 Subject: [PATCH 3/8] make calibrator none if no path given --- sybil/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sybil/model.py b/sybil/model.py index 5a8bfb2..72693ed 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -169,6 +169,8 @@ def __init__( if calibrator_path is not None: self.calibrator = pickle.load(open(calibrator_path, "rb")) + else: + self.calibrator = None def load_model(self, path): """Load model from path. From 053da14df6e965fab98bafb1d7939fc8d386a071 Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Wed, 7 Feb 2024 00:57:49 -0500 Subject: [PATCH 4/8] option to turn off augmentations --- sybil/loaders/abstract_loader.py | 10 +++------- sybil/serie.py | 3 +-- sybil/utils/loading.py | 11 ++++++++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sybil/loaders/abstract_loader.py b/sybil/loaders/abstract_loader.py index e81e835..073277b 100644 --- a/sybil/loaders/abstract_loader.py +++ b/sybil/loaders/abstract_loader.py @@ -131,10 +131,11 @@ def rem(self, image_path, attr_key): class abstract_loader: __metaclass__ = ABCMeta - def __init__(self, cache_path, augmentations, args): + def __init__(self, cache_path, augmentations, args, apply_augmentations=True): self.pad_token = IMG_PAD_TOKEN self.augmentations = augmentations self.args = args + self.apply_augmentations = apply_augmentations if cache_path is not None: self.use_cache = True self.cache = cache(cache_path, self.cached_extension) @@ -152,13 +153,8 @@ def load_input(self, path, sample): def cached_extension(self): pass - @property - @abstractmethod - def apply_augmentations(self): - return True - def configure_path(self, path, sample): - return path + return path def get_image(self, path, sample): """ diff --git a/sybil/serie.py b/sybil/serie.py index 0baf186..c7f48fe 100644 --- a/sybil/serie.py +++ b/sybil/serie.py @@ -130,8 +130,7 @@ def get_raw_images(self) -> List[np.ndarray]: List of CT slices of shape (1, C, H, W) """ - loader = get_sample_loader("test", self._args) - loader.apply_augmentations = False + loader = get_sample_loader("test", self._args, apply_augmentations=False) input_dicts = [loader.get_image(path, {}) for path in self._meta.paths] images = [i["input"] for i in input_dicts] return images diff --git a/sybil/utils/loading.py b/sybil/utils/loading.py index 0c7f3f1..e0b67e7 100644 --- a/sybil/utils/loading.py +++ b/sybil/utils/loading.py @@ -158,7 +158,11 @@ def concat_all_gather(tensor): return output -def get_sample_loader(split_group: Literal["train", "dev", "test"], args: Namespace): +def get_sample_loader( + split_group: Literal["train", "dev", "test"], + args: Namespace, + apply_augmentations=True, +): """[summary] Parameters @@ -167,6 +171,7 @@ def get_sample_loader(split_group: Literal["train", "dev", "test"], args: Namesp dataset split according to which the augmentation is selected (choices are ['train', 'dev', 'test']) ``args`` : Namespace global args + ``apply_augmentations`` : bool, optional (default=True) Returns ------- @@ -180,8 +185,8 @@ def get_sample_loader(split_group: Literal["train", "dev", "test"], args: Namesp """ augmentations = get_augmentations(split_group, args) if args.img_file_type == "dicom": - return DicomLoader(args.cache_path, augmentations, args) + return DicomLoader(args.cache_path, augmentations, args, apply_augmentations) elif args.img_file_type == "png": - return OpenCVLoader(args.cache_path, augmentations, args) + return OpenCVLoader(args.cache_path, augmentations, args, apply_augmentations) else: raise NotImplementedError From 3ea1d87dd567156472a054703aa7ad45cd57d034 Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Wed, 7 Feb 2024 11:08:16 -0500 Subject: [PATCH 5/8] run through viz --- sybil/utils/visualization.py | 52 +++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py index e6d5096..b410a59 100644 --- a/sybil/utils/visualization.py +++ b/sybil/utils/visualization.py @@ -2,9 +2,9 @@ import torch import torch.nn.functional as F from sybil.serie import Serie -from typing import Dict, List, Union -import cv2 +from typing import Dict, List import os +import imageio def visualize_attentions( @@ -32,35 +32,35 @@ def visualize_attentions( a1 = attentions[serie_idx]["image_attention_1"] v1 = attentions[serie_idx]["volume_attention_1"] - # TODO: + # take mean attention over ensemble if len(a1) > 1: - a1 = a1.mean(0) - v1 = v1.mean(0) + a1 = torch.exp(torch.stack(a1)).mean(0) + v1 = torch.exp(torch.stack(v1)).mean(0) - attention = torch.exp(a1) * torch.exp(v1).unsqueeze(-1) + attention = a1 * v1.unsqueeze(-1) attention = attention.view(1, 25, 16, 16) - N = len(serie) + # get original image + images = serie.get_raw_images() + + N = len(images) attention_up = F.interpolate( attention.unsqueeze(0), (N, 512, 512), mode="trilinear" ) - # get original image - images = serie.get_raw_images() - overlayed_images = [] for i in range(N): overlayed = np.zeros((512, 512, 3)) overlayed[..., 0] = images[i] overlayed[..., 1] = images[i] - overlayed[..., 2] = np.int16( - np.clip( - (attention_up[0, 0, i] * gain * 256) + images[i], - a_min=0, - a_max=256, - ) + overlayed[..., 2] = np.clip( + (attention_up[0, 0, i] * gain * 256) + images[i], + a_min=0, + a_max=256, ) - overlayed_images.append(overlayed) + + overlayed_images.append(np.uint8(overlayed)) + if save_directory is not None: save_path = os.path.join(save_directory, f"serie_{serie_idx}") save_images(overlayed_images, save_path, f"serie_{serie_idx}") @@ -69,8 +69,18 @@ def visualize_attentions( return series_overlays -def save_images(img_list, directory, name): +def save_images(img_list: List[np.ndarray], directory: str, name: str): + """ + Saves a list of images as a GIF in the specified directory with the given name. + + Args: + ``img_list`` (List[np.ndarray]): A list of numpy arrays representing the images to be saved. + ``directory`` (str): The directory where the GIF should be saved. + ``name`` (str): The name of the GIF file. + + Returns: + None + """ os.makedirs(directory, exist_ok=True) - N = len(str(len(img_list))) - for i, im in enumerate(img_list): - cv2.imwrite(f"{directory}/{name}_{'0'*(N - len(str(i))) }{i}.png", im) + path = os.path.join(directory, f"{name}.gif") + imageio.mimsave(path, img_list) From 0477675b51da5e3b1b52c7e103b117f743e27a3a Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Wed, 7 Feb 2024 11:34:03 -0500 Subject: [PATCH 6/8] add documentation --- README.md | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fff4939..f3b6f01 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ You can replicate the results from our model using our training script: python train.py ``` -See our [documentation](docs/readme.md) for a full description of Sybil's training parameters. +See our [documentation](docs/readme.md) for a full description of Sybil's training parameters. Additional information on the training process can be found on the [train](https://github.com/reginabarzilaygroup/Sybil/tree/train) branch of this repository. ## LDCT Orientation @@ -76,6 +76,40 @@ Annotations are availble to download in JSON format [here](https://drive.google. } ``` +## Attention Scores + +The multi-attention pooling layer aims to learn the importance of each slice in the 3D volume and the importance of each pixel in the 2D slice. During training, these are supervised by bounding boxes of the cancerous nodules. This is a soft attention mechanism, and the model's primary task is to predict the risk of lung cancer. However, the attention scores can be extracted and used to visualize the model's focus on the 3D volume and the 2D slices. + +To extract the attention scores, you can use the `return_attentions` argument as follows: + +```python + +results = model.predict([serie], return_attentions=True) + +attentions = results.attentions + +``` + +The `attentions` will be a list of length equal to the number of series. Each series has a dictionary with the following keys: + +- `image_attention_1`: log-softmax attention scores over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble. +- `volume_attention_1`: log-softmax attention scores over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble. + +To visualize the attention scores, you can use the following code. This will return a list of 2D images, where the attention scores are overlaid on the original images. If you provide a `save_directory`, the images will be saved as a GIF. If multiple series are provided, the function will return a list of lists, one for each series. + +```python + +from sybil import visualize_attentions + +series_with_attention = visualize_attentions( + series, + attentions = attentions, + save_directory = "path_to_save_directory", + gain = 3, +) + +``` + ## Cite ``` From 9d15789f2ca30eeb4234a69c6fc70793b2674e81 Mon Sep 17 00:00:00 2001 From: Peter Mikhael Date: Wed, 7 Feb 2024 11:39:26 -0500 Subject: [PATCH 7/8] set rbg order in overlaying images --- sybil/utils/visualization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py index b410a59..205c328 100644 --- a/sybil/utils/visualization.py +++ b/sybil/utils/visualization.py @@ -51,9 +51,9 @@ def visualize_attentions( overlayed_images = [] for i in range(N): overlayed = np.zeros((512, 512, 3)) - overlayed[..., 0] = images[i] + overlayed[..., 2] = images[i] overlayed[..., 1] = images[i] - overlayed[..., 2] = np.clip( + overlayed[..., 0] = np.clip( (attention_up[0, 0, i] * gain * 256) + images[i], a_min=0, a_max=256, From da43ce85b01d5dd3b5faf4577d322f37de2f3f1f Mon Sep 17 00:00:00 2001 From: Jacob Silterra Date: Fri, 9 Feb 2024 10:41:30 -0500 Subject: [PATCH 8/8] * Reword docstrings and README. * Minor Bugfix visualize_attentions --- README.md | 4 ++-- sybil/loaders/image_loaders.py | 5 ++-- sybil/model.py | 43 +++++++++++++++++++--------------- sybil/models/pooling_layer.py | 14 +++++++---- sybil/utils/visualization.py | 16 +++++++------ tests/regression_test.py | 20 ++++++++++++---- 6 files changed, 63 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index f3b6f01..06cbc42 100644 --- a/README.md +++ b/README.md @@ -92,8 +92,8 @@ attentions = results.attentions The `attentions` will be a list of length equal to the number of series. Each series has a dictionary with the following keys: -- `image_attention_1`: log-softmax attention scores over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble. -- `volume_attention_1`: log-softmax attention scores over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble. +- `image_attention_1`: attention scores (as logits) over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble. +- `volume_attention_1`: attention scores (as logits) over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble. To visualize the attention scores, you can use the following code. This will return a list of 2D images, where the attention scores are overlaid on the original images. If you provide a `save_directory`, the images will be saved as a GIF. If multiple series are provided, the function will return a list of lists, one for each series. diff --git a/sybil/loaders/image_loaders.py b/sybil/loaders/image_loaders.py index 5505b5c..0d2402d 100644 --- a/sybil/loaders/image_loaders.py +++ b/sybil/loaders/image_loaders.py @@ -22,8 +22,8 @@ def cached_extension(self): class DicomLoader(abstract_loader): - def __init__(self, cache_path, augmentations, args): - super(DicomLoader, self).__init__(cache_path, augmentations, args) + def __init__(self, cache_path, augmentations, args, apply_augmentations=True): + super(DicomLoader, self).__init__(cache_path, augmentations, args, apply_augmentations) self.window_center = -600 self.window_width = 1500 @@ -41,6 +41,7 @@ def load_input(self, path, sample): def cached_extension(self): return "" + def apply_windowing(image, center, width, bit_size=16): """Windowing function to transform image pixels for presentation. Must be run after a DICOM modality LUT is applied to the image. diff --git a/sybil/model.py b/sybil/model.py index 72693ed..2e734e7 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -235,7 +235,7 @@ def _predict( model: SybilNet, series: Union[Serie, List[Serie]], return_attentions: bool = False, - ) -> np.ndarray: + ) -> Prediction: """Run predictions over the given serie(s). Parameters @@ -244,6 +244,8 @@ def _predict( Instance of SybilNet series : Union[Serie, Iterable[Serie]] One or multiple series to run predictions for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. Returns ------- @@ -257,7 +259,7 @@ def _predict( raise ValueError("Expected either a Serie object or list of Serie objects.") scores: List[List[float]] = [] - attentions: List[Dict[str, np.ndarray]] = [] + attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None for serie in series: if not isinstance(serie, Serie): raise ValueError("Expected a list of Serie objects.") @@ -269,7 +271,7 @@ def _predict( with torch.no_grad(): out = model(volume) score = out["logit"].sigmoid().squeeze(0).cpu().numpy() - scores.append(score) + scores.append(score.tolist()) if return_attentions: attentions.append( { @@ -281,10 +283,8 @@ def _predict( .cpu(), } ) - if return_attentions: - return Prediction(scores=np.stack(scores), attentions=attentions) - return Prediction(scores=np.stack(scores)) + return Prediction(scores=scores, attentions=attentions) def predict( self, series: Union[Serie, List[Serie]], return_attentions: bool = False @@ -295,6 +295,8 @@ def predict( ---------- series : Union[Serie, Iterable[Serie]] One or multiple series to run predictions for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. Returns ------- @@ -303,25 +305,31 @@ def predict( """ scores = [] - attentions_ = [] + attentions_ = [] if return_attentions else None + attention_keys = None for sybil in self.ensemble: pred = self._predict(sybil, series, return_attentions) scores.append(pred.scores) if return_attentions: attentions_.append(pred.attentions) + if attention_keys is None: + attention_keys = pred.attentions[0].keys() + scores = np.mean(np.array(scores), axis=0) calib_scores = self._calibrate(scores).tolist() + + attentions = None if return_attentions: attentions = [] for i in range(len(series)): att = {} - for key in pred.attentions[0].keys(): - att[key] = [ + for key in attention_keys: + att[key] = np.stack([ attentions_[j][i][key] for j in range(len(self.ensemble)) - ] + ]) attentions.append(att) - return Prediction(scores=calib_scores, attentions=attentions) - return Prediction(scores=calib_scores) + + return Prediction(scores=calib_scores, attentions=attentions) def evaluate( self, series: Union[Serie, List[Serie]], return_attentions: bool = False @@ -332,11 +340,13 @@ def evaluate( ---------- series : Union[Serie, List[Serie]] One or multiple series to run evaluation for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. Returns ------- Evaluation - Output evaluation. See details for :class:`~sybil.model.Evaluation`". + Output evaluation. See details for :class:`~sybil.model.Evaluation`. """ if isinstance(series, Serie): @@ -368,9 +378,4 @@ def evaluate( auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)] c_index = float(out["c_index"]) - if return_attentions: - attentions = predictions.attentions - return Evaluation( - auc=auc, c_index=c_index, scores=scores, attentions=attentions - ) - return Evaluation(auc=auc, c_index=c_index, scores=scores) + return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions) diff --git a/sybil/models/pooling_layer.py b/sybil/models/pooling_layer.py index ee5f9fa..b9d7c27 100644 --- a/sybil/models/pooling_layer.py +++ b/sybil/models/pooling_layer.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class MultiAttentionPool(nn.Module): def __init__(self): super(MultiAttentionPool, self).__init__() @@ -44,7 +45,8 @@ def forward(self, x): output['hidden'] = self.hidden_fc(hidden) return output - + + class GlobalMaxPool(nn.Module): ''' Pool to obtain the maximum value for each channel @@ -64,6 +66,7 @@ def forward(self, x): hidden, _ = torch.max(x, dim=-1) return {'hidden': hidden} + class PerFrameMaxPool(nn.Module): ''' Pool to obtain the maximum value for each slice in 3D input @@ -86,6 +89,7 @@ def forward(self, x): output['multi_image_hidden'], _ = torch.max(x, dim=-1) return output + class Conv1d_AttnPool(nn.Module): ''' Pool to learn an attention over the slices after convolution @@ -108,6 +112,7 @@ def forward(self, x): x = self.conv1d(x) # B, C, N' return self.aggregate(x) + class Simple_AttentionPool(nn.Module): ''' Pool to learn an attention over the slices @@ -125,7 +130,7 @@ def forward(self, x): - x: tensor of shape (B, C, N) returns: - output: dict - + output['attention_scores']: tensor (B, C) + + output['volume_attention']: tensor (B, N) + output['hidden']: tensor (B, C) ''' output = {} @@ -142,6 +147,7 @@ def forward(self, x): output['hidden'] = torch.sum(x, dim=-1) return output + class Simple_AttentionPool_MultiImg(nn.Module): ''' Pool to learn an attention over the slices and the volume @@ -159,8 +165,8 @@ def forward(self, x): - x: tensor of shape (B, C, T, W, H) returns: - output: dict - + output['attention_scores']: tensor (B, T, C) - + output['multi_image_hidden']: tensor (B, T, C) + + output['image_attention']: tensor (B, T, W*H) + + output['multi_image_hidden']: tensor (B, C, T) + output['hidden']: tensor (B, T*C) ''' output = {} diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py index 205c328..82cc7cf 100644 --- a/sybil/utils/visualization.py +++ b/sybil/utils/visualization.py @@ -2,21 +2,21 @@ import torch import torch.nn.functional as F from sybil.serie import Serie -from typing import Dict, List +from typing import Dict, List, Union import os import imageio def visualize_attentions( - series: Serie, - attentions: List[Dict[str, torch.Tensor]], + series: Union[Serie, List[Serie]], + attentions: List[Dict[str, np.ndarray]], save_directory: str = None, gain: int = 3, ) -> List[List[np.ndarray]]: """ Args: series (Serie): series object - attention_dict (Dict[str, torch.Tensor]): attention dictionary output from model + attentions (Dict[str, np.ndarray]): attention dictionary output from model save_directory (str, optional): where to save the images. Defaults to None. gain (int, optional): how much to scale attention values by for visualization. Defaults to 3. @@ -32,10 +32,12 @@ def visualize_attentions( a1 = attentions[serie_idx]["image_attention_1"] v1 = attentions[serie_idx]["volume_attention_1"] + a1 = torch.Tensor(a1) + v1 = torch.Tensor(v1) + # take mean attention over ensemble - if len(a1) > 1: - a1 = torch.exp(torch.stack(a1)).mean(0) - v1 = torch.exp(torch.stack(v1)).mean(0) + a1 = torch.exp(a1).mean(0) + v1 = torch.exp(v1).mean(0) attention = a1 * v1.unsqueeze(-1) attention = attention.view(1, 25, 16, 16) diff --git a/tests/regression_test.py b/tests/regression_test.py index 6a29e67..b07169f 100644 --- a/tests/regression_test.py +++ b/tests/regression_test.py @@ -4,7 +4,7 @@ import requests import zipfile -from sybil import Serie, Sybil +from sybil import Serie, Sybil, visualize_attentions script_directory = os.path.dirname(os.path.abspath(__file__)) project_directory = os.path.dirname(script_directory) @@ -68,15 +68,18 @@ def main(): # Load a trained model model = Sybil("sybil_ensemble") - # myprint(f"Beginning prediction using {num_files} from {image_data_dir}") + myprint(f"Beginning prediction using {num_files} files from {image_data_dir}") # Get risk scores serie = Serie(dicom_files) - prediction = model.predict([serie])[0] - actual_scores = prediction[0] + series = [serie] + prediction = model.predict(series, return_attentions=True) + actual_scores = prediction.scores[0] count = len(actual_scores) - # myprint(f"Prediction finished. Results\n{actual_scores}") + myprint(f"Prediction finished. Results\n{actual_scores}") + + # pprint.pprint(f"Prediction object: {prediction}") assert len(expected_scores) == len(actual_scores), f"Unexpected score length {count}" @@ -88,6 +91,13 @@ def main(): print(f"Data URL: {demo_data_url}\nAll {count} elements match: {all_elements_match}") + series_with_attention = visualize_attentions( + series, + attentions=prediction.attentions, + save_directory="regression_test_output", + gain=3, + ) + if __name__ == "__main__": main()