From da43ce85b01d5dd3b5faf4577d322f37de2f3f1f Mon Sep 17 00:00:00 2001 From: Jacob Silterra Date: Fri, 9 Feb 2024 10:41:30 -0500 Subject: [PATCH] * Reword docstrings and README. * Minor Bugfix visualize_attentions --- README.md | 4 ++-- sybil/loaders/image_loaders.py | 5 ++-- sybil/model.py | 43 +++++++++++++++++++--------------- sybil/models/pooling_layer.py | 14 +++++++---- sybil/utils/visualization.py | 16 +++++++------ tests/regression_test.py | 20 ++++++++++++---- 6 files changed, 63 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index f3b6f01..06cbc42 100644 --- a/README.md +++ b/README.md @@ -92,8 +92,8 @@ attentions = results.attentions The `attentions` will be a list of length equal to the number of series. Each series has a dictionary with the following keys: -- `image_attention_1`: log-softmax attention scores over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble. -- `volume_attention_1`: log-softmax attention scores over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble. +- `image_attention_1`: attention scores (as logits) over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble. +- `volume_attention_1`: attention scores (as logits) over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble. To visualize the attention scores, you can use the following code. This will return a list of 2D images, where the attention scores are overlaid on the original images. If you provide a `save_directory`, the images will be saved as a GIF. If multiple series are provided, the function will return a list of lists, one for each series. diff --git a/sybil/loaders/image_loaders.py b/sybil/loaders/image_loaders.py index 5505b5c..0d2402d 100644 --- a/sybil/loaders/image_loaders.py +++ b/sybil/loaders/image_loaders.py @@ -22,8 +22,8 @@ def cached_extension(self): class DicomLoader(abstract_loader): - def __init__(self, cache_path, augmentations, args): - super(DicomLoader, self).__init__(cache_path, augmentations, args) + def __init__(self, cache_path, augmentations, args, apply_augmentations=True): + super(DicomLoader, self).__init__(cache_path, augmentations, args, apply_augmentations) self.window_center = -600 self.window_width = 1500 @@ -41,6 +41,7 @@ def load_input(self, path, sample): def cached_extension(self): return "" + def apply_windowing(image, center, width, bit_size=16): """Windowing function to transform image pixels for presentation. Must be run after a DICOM modality LUT is applied to the image. diff --git a/sybil/model.py b/sybil/model.py index 72693ed..2e734e7 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -235,7 +235,7 @@ def _predict( model: SybilNet, series: Union[Serie, List[Serie]], return_attentions: bool = False, - ) -> np.ndarray: + ) -> Prediction: """Run predictions over the given serie(s). Parameters @@ -244,6 +244,8 @@ def _predict( Instance of SybilNet series : Union[Serie, Iterable[Serie]] One or multiple series to run predictions for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. Returns ------- @@ -257,7 +259,7 @@ def _predict( raise ValueError("Expected either a Serie object or list of Serie objects.") scores: List[List[float]] = [] - attentions: List[Dict[str, np.ndarray]] = [] + attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None for serie in series: if not isinstance(serie, Serie): raise ValueError("Expected a list of Serie objects.") @@ -269,7 +271,7 @@ def _predict( with torch.no_grad(): out = model(volume) score = out["logit"].sigmoid().squeeze(0).cpu().numpy() - scores.append(score) + scores.append(score.tolist()) if return_attentions: attentions.append( { @@ -281,10 +283,8 @@ def _predict( .cpu(), } ) - if return_attentions: - return Prediction(scores=np.stack(scores), attentions=attentions) - return Prediction(scores=np.stack(scores)) + return Prediction(scores=scores, attentions=attentions) def predict( self, series: Union[Serie, List[Serie]], return_attentions: bool = False @@ -295,6 +295,8 @@ def predict( ---------- series : Union[Serie, Iterable[Serie]] One or multiple series to run predictions for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. Returns ------- @@ -303,25 +305,31 @@ def predict( """ scores = [] - attentions_ = [] + attentions_ = [] if return_attentions else None + attention_keys = None for sybil in self.ensemble: pred = self._predict(sybil, series, return_attentions) scores.append(pred.scores) if return_attentions: attentions_.append(pred.attentions) + if attention_keys is None: + attention_keys = pred.attentions[0].keys() + scores = np.mean(np.array(scores), axis=0) calib_scores = self._calibrate(scores).tolist() + + attentions = None if return_attentions: attentions = [] for i in range(len(series)): att = {} - for key in pred.attentions[0].keys(): - att[key] = [ + for key in attention_keys: + att[key] = np.stack([ attentions_[j][i][key] for j in range(len(self.ensemble)) - ] + ]) attentions.append(att) - return Prediction(scores=calib_scores, attentions=attentions) - return Prediction(scores=calib_scores) + + return Prediction(scores=calib_scores, attentions=attentions) def evaluate( self, series: Union[Serie, List[Serie]], return_attentions: bool = False @@ -332,11 +340,13 @@ def evaluate( ---------- series : Union[Serie, List[Serie]] One or multiple series to run evaluation for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. Returns ------- Evaluation - Output evaluation. See details for :class:`~sybil.model.Evaluation`". + Output evaluation. See details for :class:`~sybil.model.Evaluation`. """ if isinstance(series, Serie): @@ -368,9 +378,4 @@ def evaluate( auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)] c_index = float(out["c_index"]) - if return_attentions: - attentions = predictions.attentions - return Evaluation( - auc=auc, c_index=c_index, scores=scores, attentions=attentions - ) - return Evaluation(auc=auc, c_index=c_index, scores=scores) + return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions) diff --git a/sybil/models/pooling_layer.py b/sybil/models/pooling_layer.py index ee5f9fa..b9d7c27 100644 --- a/sybil/models/pooling_layer.py +++ b/sybil/models/pooling_layer.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class MultiAttentionPool(nn.Module): def __init__(self): super(MultiAttentionPool, self).__init__() @@ -44,7 +45,8 @@ def forward(self, x): output['hidden'] = self.hidden_fc(hidden) return output - + + class GlobalMaxPool(nn.Module): ''' Pool to obtain the maximum value for each channel @@ -64,6 +66,7 @@ def forward(self, x): hidden, _ = torch.max(x, dim=-1) return {'hidden': hidden} + class PerFrameMaxPool(nn.Module): ''' Pool to obtain the maximum value for each slice in 3D input @@ -86,6 +89,7 @@ def forward(self, x): output['multi_image_hidden'], _ = torch.max(x, dim=-1) return output + class Conv1d_AttnPool(nn.Module): ''' Pool to learn an attention over the slices after convolution @@ -108,6 +112,7 @@ def forward(self, x): x = self.conv1d(x) # B, C, N' return self.aggregate(x) + class Simple_AttentionPool(nn.Module): ''' Pool to learn an attention over the slices @@ -125,7 +130,7 @@ def forward(self, x): - x: tensor of shape (B, C, N) returns: - output: dict - + output['attention_scores']: tensor (B, C) + + output['volume_attention']: tensor (B, N) + output['hidden']: tensor (B, C) ''' output = {} @@ -142,6 +147,7 @@ def forward(self, x): output['hidden'] = torch.sum(x, dim=-1) return output + class Simple_AttentionPool_MultiImg(nn.Module): ''' Pool to learn an attention over the slices and the volume @@ -159,8 +165,8 @@ def forward(self, x): - x: tensor of shape (B, C, T, W, H) returns: - output: dict - + output['attention_scores']: tensor (B, T, C) - + output['multi_image_hidden']: tensor (B, T, C) + + output['image_attention']: tensor (B, T, W*H) + + output['multi_image_hidden']: tensor (B, C, T) + output['hidden']: tensor (B, T*C) ''' output = {} diff --git a/sybil/utils/visualization.py b/sybil/utils/visualization.py index 205c328..82cc7cf 100644 --- a/sybil/utils/visualization.py +++ b/sybil/utils/visualization.py @@ -2,21 +2,21 @@ import torch import torch.nn.functional as F from sybil.serie import Serie -from typing import Dict, List +from typing import Dict, List, Union import os import imageio def visualize_attentions( - series: Serie, - attentions: List[Dict[str, torch.Tensor]], + series: Union[Serie, List[Serie]], + attentions: List[Dict[str, np.ndarray]], save_directory: str = None, gain: int = 3, ) -> List[List[np.ndarray]]: """ Args: series (Serie): series object - attention_dict (Dict[str, torch.Tensor]): attention dictionary output from model + attentions (Dict[str, np.ndarray]): attention dictionary output from model save_directory (str, optional): where to save the images. Defaults to None. gain (int, optional): how much to scale attention values by for visualization. Defaults to 3. @@ -32,10 +32,12 @@ def visualize_attentions( a1 = attentions[serie_idx]["image_attention_1"] v1 = attentions[serie_idx]["volume_attention_1"] + a1 = torch.Tensor(a1) + v1 = torch.Tensor(v1) + # take mean attention over ensemble - if len(a1) > 1: - a1 = torch.exp(torch.stack(a1)).mean(0) - v1 = torch.exp(torch.stack(v1)).mean(0) + a1 = torch.exp(a1).mean(0) + v1 = torch.exp(v1).mean(0) attention = a1 * v1.unsqueeze(-1) attention = attention.view(1, 25, 16, 16) diff --git a/tests/regression_test.py b/tests/regression_test.py index 6a29e67..b07169f 100644 --- a/tests/regression_test.py +++ b/tests/regression_test.py @@ -4,7 +4,7 @@ import requests import zipfile -from sybil import Serie, Sybil +from sybil import Serie, Sybil, visualize_attentions script_directory = os.path.dirname(os.path.abspath(__file__)) project_directory = os.path.dirname(script_directory) @@ -68,15 +68,18 @@ def main(): # Load a trained model model = Sybil("sybil_ensemble") - # myprint(f"Beginning prediction using {num_files} from {image_data_dir}") + myprint(f"Beginning prediction using {num_files} files from {image_data_dir}") # Get risk scores serie = Serie(dicom_files) - prediction = model.predict([serie])[0] - actual_scores = prediction[0] + series = [serie] + prediction = model.predict(series, return_attentions=True) + actual_scores = prediction.scores[0] count = len(actual_scores) - # myprint(f"Prediction finished. Results\n{actual_scores}") + myprint(f"Prediction finished. Results\n{actual_scores}") + + # pprint.pprint(f"Prediction object: {prediction}") assert len(expected_scores) == len(actual_scores), f"Unexpected score length {count}" @@ -88,6 +91,13 @@ def main(): print(f"Data URL: {demo_data_url}\nAll {count} elements match: {all_elements_match}") + series_with_attention = visualize_attentions( + series, + attentions=prediction.attentions, + save_directory="regression_test_output", + gain=3, + ) + if __name__ == "__main__": main()