Skip to content

Commit

Permalink
* Reword docstrings and README.
Browse files Browse the repository at this point in the history
* Minor Bugfix visualize_attentions
  • Loading branch information
jsilter committed Feb 21, 2024
1 parent 9d15789 commit da43ce8
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 39 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ attentions = results.attentions

The `attentions` will be a list of length equal to the number of series. Each series has a dictionary with the following keys:

- `image_attention_1`: log-softmax attention scores over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble.
- `volume_attention_1`: log-softmax attention scores over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble.
- `image_attention_1`: attention scores (as logits) over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble.
- `volume_attention_1`: attention scores (as logits) over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble.

To visualize the attention scores, you can use the following code. This will return a list of 2D images, where the attention scores are overlaid on the original images. If you provide a `save_directory`, the images will be saved as a GIF. If multiple series are provided, the function will return a list of lists, one for each series.

Expand Down
5 changes: 3 additions & 2 deletions sybil/loaders/image_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def cached_extension(self):


class DicomLoader(abstract_loader):
def __init__(self, cache_path, augmentations, args):
super(DicomLoader, self).__init__(cache_path, augmentations, args)
def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
super(DicomLoader, self).__init__(cache_path, augmentations, args, apply_augmentations)
self.window_center = -600
self.window_width = 1500

Expand All @@ -41,6 +41,7 @@ def load_input(self, path, sample):
def cached_extension(self):
return ""


def apply_windowing(image, center, width, bit_size=16):
"""Windowing function to transform image pixels for presentation.
Must be run after a DICOM modality LUT is applied to the image.
Expand Down
43 changes: 24 additions & 19 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _predict(
model: SybilNet,
series: Union[Serie, List[Serie]],
return_attentions: bool = False,
) -> np.ndarray:
) -> Prediction:
"""Run predictions over the given serie(s).
Parameters
Expand All @@ -244,6 +244,8 @@ def _predict(
Instance of SybilNet
series : Union[Serie, Iterable[Serie]]
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Expand All @@ -257,7 +259,7 @@ def _predict(
raise ValueError("Expected either a Serie object or list of Serie objects.")

scores: List[List[float]] = []
attentions: List[Dict[str, np.ndarray]] = []
attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None
for serie in series:
if not isinstance(serie, Serie):
raise ValueError("Expected a list of Serie objects.")
Expand All @@ -269,7 +271,7 @@ def _predict(
with torch.no_grad():
out = model(volume)
score = out["logit"].sigmoid().squeeze(0).cpu().numpy()
scores.append(score)
scores.append(score.tolist())
if return_attentions:
attentions.append(
{
Expand All @@ -281,10 +283,8 @@ def _predict(
.cpu(),
}
)
if return_attentions:
return Prediction(scores=np.stack(scores), attentions=attentions)

return Prediction(scores=np.stack(scores))
return Prediction(scores=scores, attentions=attentions)

def predict(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
Expand All @@ -295,6 +295,8 @@ def predict(
----------
series : Union[Serie, Iterable[Serie]]
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Expand All @@ -303,25 +305,31 @@ def predict(
"""
scores = []
attentions_ = []
attentions_ = [] if return_attentions else None
attention_keys = None
for sybil in self.ensemble:
pred = self._predict(sybil, series, return_attentions)
scores.append(pred.scores)
if return_attentions:
attentions_.append(pred.attentions)
if attention_keys is None:
attention_keys = pred.attentions[0].keys()

scores = np.mean(np.array(scores), axis=0)
calib_scores = self._calibrate(scores).tolist()

attentions = None
if return_attentions:
attentions = []
for i in range(len(series)):
att = {}
for key in pred.attentions[0].keys():
att[key] = [
for key in attention_keys:
att[key] = np.stack([
attentions_[j][i][key] for j in range(len(self.ensemble))
]
])
attentions.append(att)
return Prediction(scores=calib_scores, attentions=attentions)
return Prediction(scores=calib_scores)

return Prediction(scores=calib_scores, attentions=attentions)

def evaluate(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
Expand All @@ -332,11 +340,13 @@ def evaluate(
----------
series : Union[Serie, List[Serie]]
One or multiple series to run evaluation for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Evaluation
Output evaluation. See details for :class:`~sybil.model.Evaluation`".
Output evaluation. See details for :class:`~sybil.model.Evaluation`.
"""
if isinstance(series, Serie):
Expand Down Expand Up @@ -368,9 +378,4 @@ def evaluate(
auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)]
c_index = float(out["c_index"])

if return_attentions:
attentions = predictions.attentions
return Evaluation(
auc=auc, c_index=c_index, scores=scores, attentions=attentions
)
return Evaluation(auc=auc, c_index=c_index, scores=scores)
return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)
14 changes: 10 additions & 4 deletions sybil/models/pooling_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn


class MultiAttentionPool(nn.Module):
def __init__(self):
super(MultiAttentionPool, self).__init__()
Expand Down Expand Up @@ -44,7 +45,8 @@ def forward(self, x):
output['hidden'] = self.hidden_fc(hidden)

return output



class GlobalMaxPool(nn.Module):
'''
Pool to obtain the maximum value for each channel
Expand All @@ -64,6 +66,7 @@ def forward(self, x):
hidden, _ = torch.max(x, dim=-1)
return {'hidden': hidden}


class PerFrameMaxPool(nn.Module):
'''
Pool to obtain the maximum value for each slice in 3D input
Expand All @@ -86,6 +89,7 @@ def forward(self, x):
output['multi_image_hidden'], _ = torch.max(x, dim=-1)
return output


class Conv1d_AttnPool(nn.Module):
'''
Pool to learn an attention over the slices after convolution
Expand All @@ -108,6 +112,7 @@ def forward(self, x):
x = self.conv1d(x) # B, C, N'
return self.aggregate(x)


class Simple_AttentionPool(nn.Module):
'''
Pool to learn an attention over the slices
Expand All @@ -125,7 +130,7 @@ def forward(self, x):
- x: tensor of shape (B, C, N)
returns:
- output: dict
+ output['attention_scores']: tensor (B, C)
+ output['volume_attention']: tensor (B, N)
+ output['hidden']: tensor (B, C)
'''
output = {}
Expand All @@ -142,6 +147,7 @@ def forward(self, x):
output['hidden'] = torch.sum(x, dim=-1)
return output


class Simple_AttentionPool_MultiImg(nn.Module):
'''
Pool to learn an attention over the slices and the volume
Expand All @@ -159,8 +165,8 @@ def forward(self, x):
- x: tensor of shape (B, C, T, W, H)
returns:
- output: dict
+ output['attention_scores']: tensor (B, T, C)
+ output['multi_image_hidden']: tensor (B, T, C)
+ output['image_attention']: tensor (B, T, W*H)
+ output['multi_image_hidden']: tensor (B, C, T)
+ output['hidden']: tensor (B, T*C)
'''
output = {}
Expand Down
16 changes: 9 additions & 7 deletions sybil/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
import torch
import torch.nn.functional as F
from sybil.serie import Serie
from typing import Dict, List
from typing import Dict, List, Union
import os
import imageio


def visualize_attentions(
series: Serie,
attentions: List[Dict[str, torch.Tensor]],
series: Union[Serie, List[Serie]],
attentions: List[Dict[str, np.ndarray]],
save_directory: str = None,
gain: int = 3,
) -> List[List[np.ndarray]]:
"""
Args:
series (Serie): series object
attention_dict (Dict[str, torch.Tensor]): attention dictionary output from model
attentions (Dict[str, np.ndarray]): attention dictionary output from model
save_directory (str, optional): where to save the images. Defaults to None.
gain (int, optional): how much to scale attention values by for visualization. Defaults to 3.
Expand All @@ -32,10 +32,12 @@ def visualize_attentions(
a1 = attentions[serie_idx]["image_attention_1"]
v1 = attentions[serie_idx]["volume_attention_1"]

a1 = torch.Tensor(a1)
v1 = torch.Tensor(v1)

# take mean attention over ensemble
if len(a1) > 1:
a1 = torch.exp(torch.stack(a1)).mean(0)
v1 = torch.exp(torch.stack(v1)).mean(0)
a1 = torch.exp(a1).mean(0)
v1 = torch.exp(v1).mean(0)

attention = a1 * v1.unsqueeze(-1)
attention = attention.view(1, 25, 16, 16)
Expand Down
20 changes: 15 additions & 5 deletions tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests
import zipfile

from sybil import Serie, Sybil
from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
project_directory = os.path.dirname(script_directory)
Expand Down Expand Up @@ -68,15 +68,18 @@ def main():
# Load a trained model
model = Sybil("sybil_ensemble")

# myprint(f"Beginning prediction using {num_files} from {image_data_dir}")
myprint(f"Beginning prediction using {num_files} files from {image_data_dir}")

# Get risk scores
serie = Serie(dicom_files)
prediction = model.predict([serie])[0]
actual_scores = prediction[0]
series = [serie]
prediction = model.predict(series, return_attentions=True)
actual_scores = prediction.scores[0]
count = len(actual_scores)

# myprint(f"Prediction finished. Results\n{actual_scores}")
myprint(f"Prediction finished. Results\n{actual_scores}")

# pprint.pprint(f"Prediction object: {prediction}")

assert len(expected_scores) == len(actual_scores), f"Unexpected score length {count}"

Expand All @@ -88,6 +91,13 @@ def main():

print(f"Data URL: {demo_data_url}\nAll {count} elements match: {all_elements_match}")

series_with_attention = visualize_attentions(
series,
attentions=prediction.attentions,
save_directory="regression_test_output",
gain=3,
)


if __name__ == "__main__":
main()

0 comments on commit da43ce8

Please sign in to comment.