Skip to content

Commit

Permalink
Merge pull request #27 from reginabarzilaygroup/attention-visualization
Browse files Browse the repository at this point in the history
Attention visualization
  • Loading branch information
jsilter authored Feb 22, 2024
2 parents 95830d5 + da43ce8 commit 6453d08
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 42 deletions.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ You can replicate the results from our model using our training script:
python train.py
```

See our [documentation](docs/readme.md) for a full description of Sybil's training parameters.
See our [documentation](docs/readme.md) for a full description of Sybil's training parameters. Additional information on the training process can be found on the [train](https://github.com/reginabarzilaygroup/Sybil/tree/train) branch of this repository.


## LDCT Orientation
Expand Down Expand Up @@ -76,6 +76,40 @@ Annotations are availble to download in JSON format [here](https://drive.google.
}
```

## Attention Scores

The multi-attention pooling layer aims to learn the importance of each slice in the 3D volume and the importance of each pixel in the 2D slice. During training, these are supervised by bounding boxes of the cancerous nodules. This is a soft attention mechanism, and the model's primary task is to predict the risk of lung cancer. However, the attention scores can be extracted and used to visualize the model's focus on the 3D volume and the 2D slices.

To extract the attention scores, you can use the `return_attentions` argument as follows:

```python

results = model.predict([serie], return_attentions=True)

attentions = results.attentions

```

The `attentions` will be a list of length equal to the number of series. Each series has a dictionary with the following keys:

- `image_attention_1`: attention scores (as logits) over the pixels in the 2D slice. This will be a list of length equal to the size of the model ensemble.
- `volume_attention_1`: attention scores (as logits) over each slice in the 3D volume. This will be a list of length equal to the size of the model ensemble.

To visualize the attention scores, you can use the following code. This will return a list of 2D images, where the attention scores are overlaid on the original images. If you provide a `save_directory`, the images will be saved as a GIF. If multiple series are provided, the function will return a list of lists, one for each series.

```python

from sybil import visualize_attentions

series_with_attention = visualize_attentions(
series,
attentions = attentions,
save_directory = "path_to_save_directory",
gain = 3,
)

```

## Cite

```
Expand Down
4 changes: 2 additions & 2 deletions sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

from sybil.model import Sybil
from sybil.serie import Serie
from sybil.utils.visualization import visualize_attentions


__all__ = ["Sybil", "Serie"]
__all__ = ["Sybil", "Serie", "visualize_attentions"]
10 changes: 3 additions & 7 deletions sybil/loaders/abstract_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def rem(self, image_path, attr_key):
class abstract_loader:
__metaclass__ = ABCMeta

def __init__(self, cache_path, augmentations, args):
def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
self.pad_token = IMG_PAD_TOKEN
self.augmentations = augmentations
self.args = args
self.apply_augmentations = apply_augmentations
if cache_path is not None:
self.use_cache = True
self.cache = cache(cache_path, self.cached_extension)
Expand All @@ -152,13 +153,8 @@ def load_input(self, path, sample):
def cached_extension(self):
pass

@property
@abstractmethod
def apply_augmentations(self):
return True

def configure_path(self, path, sample):
return path
return path

def get_image(self, path, sample):
"""
Expand Down
5 changes: 3 additions & 2 deletions sybil/loaders/image_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def cached_extension(self):


class DicomLoader(abstract_loader):
def __init__(self, cache_path, augmentations, args):
super(DicomLoader, self).__init__(cache_path, augmentations, args)
def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
super(DicomLoader, self).__init__(cache_path, augmentations, args, apply_augmentations)
self.window_center = -600
self.window_width = 1500

Expand All @@ -41,6 +41,7 @@ def load_input(self, path, sample):
def cached_extension(self):
return ""


def apply_windowing(image, center, width, bit_size=16):
"""Windowing function to transform image pixels for presentation.
Must be run after a DICOM modality LUT is applied to the image.
Expand Down
73 changes: 60 additions & 13 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@

class Prediction(NamedTuple):
scores: List[List[float]]
attentions: List[Dict[str, np.ndarray]] = None


class Evaluation(NamedTuple):
auc: List[float]
c_index: float
scores: List[List[float]]
attentions: List[Dict[str, np.ndarray]] = None


def download_sybil(name, cache):
Expand Down Expand Up @@ -167,6 +169,8 @@ def __init__(

if calibrator_path is not None:
self.calibrator = pickle.load(open(calibrator_path, "rb"))
else:
self.calibrator = None

def load_model(self, path):
"""Load model from path.
Expand Down Expand Up @@ -230,7 +234,8 @@ def _predict(
self,
model: SybilNet,
series: Union[Serie, List[Serie]],
) -> np.ndarray:
return_attentions: bool = False,
) -> Prediction:
"""Run predictions over the given serie(s).
Parameters
Expand All @@ -239,6 +244,8 @@ def _predict(
Instance of SybilNet
series : Union[Serie, Iterable[Serie]]
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Expand All @@ -252,6 +259,7 @@ def _predict(
raise ValueError("Expected either a Serie object or list of Serie objects.")

scores: List[List[float]] = []
attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None
for serie in series:
if not isinstance(serie, Serie):
raise ValueError("Expected a list of Serie objects.")
Expand All @@ -263,17 +271,32 @@ def _predict(
with torch.no_grad():
out = model(volume)
score = out["logit"].sigmoid().squeeze(0).cpu().numpy()
scores.append(score)

return np.stack(scores)

def predict(self, series: Union[Serie, List[Serie]]) -> Prediction:
scores.append(score.tolist())
if return_attentions:
attentions.append(
{
"image_attention_1": out["image_attention_1"]
.detach()
.cpu(),
"volume_attention_1": out["volume_attention_1"]
.detach()
.cpu(),
}
)

return Prediction(scores=scores, attentions=attentions)

def predict(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
) -> Prediction:
"""Run predictions over the given serie(s) and ensemble
Parameters
----------
series : Union[Serie, Iterable[Serie]]
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Expand All @@ -282,25 +305,48 @@ def predict(self, series: Union[Serie, List[Serie]]) -> Prediction:
"""
scores = []
attentions_ = [] if return_attentions else None
attention_keys = None
for sybil in self.ensemble:
pred = self._predict(sybil, series)
scores.append(pred)
pred = self._predict(sybil, series, return_attentions)
scores.append(pred.scores)
if return_attentions:
attentions_.append(pred.attentions)
if attention_keys is None:
attention_keys = pred.attentions[0].keys()

scores = np.mean(np.array(scores), axis=0)
calib_scores = self._calibrate(scores).tolist()
return Prediction(scores=calib_scores)

def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation:
attentions = None
if return_attentions:
attentions = []
for i in range(len(series)):
att = {}
for key in attention_keys:
att[key] = np.stack([
attentions_[j][i][key] for j in range(len(self.ensemble))
])
attentions.append(att)

return Prediction(scores=calib_scores, attentions=attentions)

def evaluate(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
) -> Evaluation:
"""Run evaluation over the given serie(s).
Parameters
----------
series : Union[Serie, List[Serie]]
One or multiple series to run evaluation for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Evaluation
Output evaluation. See details for :class:`~sybil.model.Evaluation`".
Output evaluation. See details for :class:`~sybil.model.Evaluation`.
"""
if isinstance(series, Serie):
Expand All @@ -315,7 +361,8 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation:
raise ValueError("All series must have a label for evaluation")

# Get scores and labels
scores = self.predict(series).scores
predictions = self.predict(series, return_attentions)
scores = predictions.scores
labels = [serie.get_label(self._max_followup) for serie in series]

# Convert to format for survival metrics
Expand All @@ -331,4 +378,4 @@ def evaluate(self, series: Union[Serie, List[Serie]]) -> Evaluation:
auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)]
c_index = float(out["c_index"])

return Evaluation(auc=auc, c_index=c_index, scores=scores)
return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)
14 changes: 10 additions & 4 deletions sybil/models/pooling_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn


class MultiAttentionPool(nn.Module):
def __init__(self):
super(MultiAttentionPool, self).__init__()
Expand Down Expand Up @@ -44,7 +45,8 @@ def forward(self, x):
output['hidden'] = self.hidden_fc(hidden)

return output



class GlobalMaxPool(nn.Module):
'''
Pool to obtain the maximum value for each channel
Expand All @@ -64,6 +66,7 @@ def forward(self, x):
hidden, _ = torch.max(x, dim=-1)
return {'hidden': hidden}


class PerFrameMaxPool(nn.Module):
'''
Pool to obtain the maximum value for each slice in 3D input
Expand All @@ -86,6 +89,7 @@ def forward(self, x):
output['multi_image_hidden'], _ = torch.max(x, dim=-1)
return output


class Conv1d_AttnPool(nn.Module):
'''
Pool to learn an attention over the slices after convolution
Expand All @@ -108,6 +112,7 @@ def forward(self, x):
x = self.conv1d(x) # B, C, N'
return self.aggregate(x)


class Simple_AttentionPool(nn.Module):
'''
Pool to learn an attention over the slices
Expand All @@ -125,7 +130,7 @@ def forward(self, x):
- x: tensor of shape (B, C, N)
returns:
- output: dict
+ output['attention_scores']: tensor (B, C)
+ output['volume_attention']: tensor (B, N)
+ output['hidden']: tensor (B, C)
'''
output = {}
Expand All @@ -142,6 +147,7 @@ def forward(self, x):
output['hidden'] = torch.sum(x, dim=-1)
return output


class Simple_AttentionPool_MultiImg(nn.Module):
'''
Pool to learn an attention over the slices and the volume
Expand All @@ -159,8 +165,8 @@ def forward(self, x):
- x: tensor of shape (B, C, T, W, H)
returns:
- output: dict
+ output['attention_scores']: tensor (B, T, C)
+ output['multi_image_hidden']: tensor (B, T, C)
+ output['image_attention']: tensor (B, T, W*H)
+ output['multi_image_hidden']: tensor (B, C, T)
+ output['hidden']: tensor (B, T*C)
'''
output = {}
Expand Down
28 changes: 23 additions & 5 deletions sybil/serie.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self._censor_time = censor_time
self._label = label
args = self._load_args(file_type)
self._args = args
self._loader = get_sample_loader(split, args)
self._meta = self._load_metadata(dicoms, voxel_spacing, file_type)
self._check_valid(args)
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_label(self, max_followup: int = 6) -> Label:
raise ValueError("No label in this serie.")

# First convert months to years
year_to_cancer = self._censor_time # type: ignore
year_to_cancer = self._censor_time # type: ignore

y_seq = np.zeros(max_followup, dtype=np.float64)
y = int((year_to_cancer < max_followup) and self._label) # type: ignore
Expand All @@ -119,6 +120,21 @@ def get_label(self, max_followup: int = 6) -> Label:
)
return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer)

def get_raw_images(self) -> List[np.ndarray]:
"""
Load raw images from serie
Returns
-------
List[np.ndarray]
List of CT slices of shape (1, C, H, W)
"""

loader = get_sample_loader("test", self._args, apply_augmentations=False)
input_dicts = [loader.get_image(path, {}) for path in self._meta.paths]
images = [i["input"] for i in input_dicts]
return images

def get_volume(self) -> torch.Tensor:
"""
Load loaded 3D CT volume
Expand All @@ -131,10 +147,12 @@ def get_volume(self) -> torch.Tensor:

sample = {"seed": np.random.randint(0, 2**32 - 1)}

input_dicts = [self._loader.get_image(path, sample) for path in self._meta.paths]

x = torch.cat( [i["input"].unsqueeze(0) for i in input_dicts], dim = 0)

input_dicts = [
self._loader.get_image(path, sample) for path in self._meta.paths
]

x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)

# Convert from (T, C, H, W) to (C, T, H, W)
x = x.permute(1, 0, 2, 3)

Expand Down
Loading

0 comments on commit 6453d08

Please sign in to comment.