Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1.3.0 dev #39

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ if [ ! -d "$demo_scan_dir" ]; then
unzip -q sybil_example.zip
fi

python3 scripts/inference.py \
# Either python3 sybil/predict.py or sybil-predict (if installed via pip)
python3 sybil/predict.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
Expand Down
27 changes: 13 additions & 14 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.2.2
version = 1.3.0
# url =
project_urls =
; Documentation = https://.../docs
Expand All @@ -31,31 +31,27 @@ find_links =
zip_safe = False
packages = find:
include_package_data = True
python_requires = >=3.8
python_requires = >=3.8,<3.11
# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0.
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version>="3.8"
albumentations==1.1.0
numpy==1.24.1
torch==1.13.1+cu117; sys_platform != "darwin"
torch==1.13.1; sys_platform == "darwin"
torchvision==0.14.1+cu117; sys_platform != "darwin"
torchvision==0.14.1; sys_platform == "darwin"
pytorch_lightning==1.6.0
scikit-learn==1.0.2
tqdm==4.62.3
lifelines==0.26.4
opencv-python==4.5.4.60
opencv-python-headless==4.5.4.60
albumentations==1.1.0
pillow>=10.2.0
pydicom==2.3.0
pylibjpeg[all]==2.0.0
scikit-learn==1.0.2
torch==1.13.1+cu117; platform_machine == "x86_64"
torch==1.13.1; platform_machine != "x86_64"
torchio==0.18.74
gdown==4.6.0

torchvision==0.14.1+cu117; platform_machine == "x86_64"
torchvision==0.14.1; platform_machine != "x86_64"
tqdm==4.62.3

[options.packages.find]
exclude =
Expand All @@ -71,10 +67,13 @@ testing =
flake8
mypy
black
train =
lifelines==0.26.4
pytorch_lightning==1.6.0

[options.entry_points]
console_scripts =
sybil = sybil.main:main
sybil-predict = sybil.predict:main


[bdist_wheel]
Expand Down
2 changes: 1 addition & 1 deletion sybil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from sybil.utils.visualization import visualize_attentions
import sybil.utils.logging_utils

__all__ = ["Sybil", "Serie", "visualize_attentions"]
__all__ = ["Sybil", "Serie", "visualize_attentions", "__version__"]
14 changes: 0 additions & 14 deletions sybil/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import math
from lifelines import KaplanMeierFitter

# Error Messages
METAFILE_NOTFOUND_ERR = "Metadata file {} could not be parsed! Exception: {}!"
Expand Down Expand Up @@ -104,16 +103,3 @@ def get_scaled_annotation_area(sample, args):
areas.append(mask.sum() / (mask.shape[0] * mask.shape[1]))
return np.array(areas)


def get_censoring_dist(train_dataset):
_dataset = train_dataset.dataset
times, event_observed = (
[d["time_at_event"] for d in _dataset],
[d["y"] for d in _dataset],
)
all_observed_times = set(times)
kmf = KaplanMeierFitter()
kmf.fit(times, event_observed)

censoring_dist = {str(time): kmf.predict(time) for time in all_observed_times}
return censoring_dist
47 changes: 3 additions & 44 deletions sybil/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import NamedTuple, Union, Dict, List, Optional, Tuple
from urllib.request import urlopen
from zipfile import ZipFile
# import gdown

import torch
import numpy as np
Expand Down Expand Up @@ -83,49 +82,6 @@ class Evaluation(NamedTuple):
attentions: List[Dict[str, np.ndarray]] = None


def download_sybil_gdrive(name, cache):
"""Download trained models and calibrator from Google Drive

Parameters
----------
name (str): name of model to use. A key in NAME_TO_FILE
cache (str): path to directory where files are downloaded

Returns
-------
download_model_paths (list): paths to .ckpt models
download_calib_path (str): path to calibrator
"""
# Create cache folder if not exists
cache = os.path.expanduser(cache)
os.makedirs(cache, exist_ok=True)

# Download if neded
model_files = NAME_TO_FILE[name]

# Download models
download_model_paths = []
for model_name, google_id in zip(
model_files["checkpoint"], model_files["google_checkpoint_id"]
):
model_path = os.path.join(cache, f"{model_name}.ckpt")
if not os.path.exists(model_path):
print(f"Downloading model to {cache}")
gdown.download(id=google_id, output=model_path, quiet=False)
download_model_paths.append(model_path)

# download calibrator
download_calib_path = os.path.join(cache, f"{name}.p")
if not os.path.exists(download_calib_path):
gdown.download(
id=model_files["google_calibrator_id"],
output=download_calib_path,
quiet=False,
)

return download_model_paths, download_calib_path


def download_sybil(name, cache) -> Tuple[List[str], str]:
"""Download trained models and calibrator"""
# Create cache folder if not exists
Expand Down Expand Up @@ -329,6 +285,9 @@ def _predict(
"volume_attention_1": out["volume_attention_1"]
.detach()
.cpu(),
"hidden": out["hidden"]
.detach()
.cpu(),
}
)

Expand Down
61 changes: 42 additions & 19 deletions scripts/inference.py → sybil/predict.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,55 @@
#!/usr/bin/env python

__doc__ = """
Use Sybil to run inference on a single exam.
"""

import argparse
import datetime
import json
import logging
import os
import pickle
import typing
from typing import Literal

import sybil.utils.logging_utils
from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
project_directory = os.path.dirname(script_directory)
import sybil.datasets.utils
from sybil import Serie, Sybil, visualize_attentions, __version__


def _get_parser():
parser = argparse.ArgumentParser(description=__doc__)
description = __doc__ + f"\nVersion: {__version__}\n"
parser = argparse.ArgumentParser(description=description)

parser.add_argument(
"image_dir",
default=None,
help="Path to directory containing DICOM/PNG files (from a single exam) to run inference on."
help="Path to directory containing DICOM/PNG files (from a single exam) to run inference on. "
"Every file in the directory will be included.",
)

parser.add_argument(
"--output-dir",
default="sybil_result",
dest="output_dir",
help="Output directory in which to save prediction results."
help="Output directory in which to save prediction results. "
"Prediction will be printed to stdout as well.",
)

parser.add_argument(
"--return-attentions",
default=False,
action="store_true",
help="Generate an image which overlaps attention scores.",
help="Return hidden vectors and attention scores, write them to a pickle file.",
)

parser.add_argument(
"--write-attention-images",
default=False,
action="store_true",
help="Generate images with attention overlap. Sets --return-attentions (if not already set).",
)


parser.add_argument(
"--file-type",
default="auto",
Expand All @@ -56,33 +69,41 @@ def _get_parser():
parser.add_argument("-l", "--log", "--loglevel", "--log-level",
default="INFO", dest="loglevel")

parser.add_argument("-v", "--version", action="version", version=__version__)

return parser


def inference(
def predict(
image_dir,
output_dir,
model_name="sybil_ensemble",
return_attentions=False,
file_type="auto",
write_attention_images=False,
file_type: Literal["auto", "dicom", "png"] = "auto",
):
logger = sybil.utils.logging_utils.get_logger()

input_files = os.listdir(image_dir)
input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
input_files = [x for x in input_files if os.path.isfile(x)]
extensions = {os.path.splitext(x)[1] for x in input_files}
if len(extensions) > 1:
raise ValueError(
f"Multiple file types found in {image_dir}: {','.join(extensions)}"
)

voxel_spacing = None
if file_type == "auto":
extensions = {os.path.splitext(x)[1] for x in input_files}
extension = extensions.pop()
if len(extensions) > 1:
raise ValueError(
f"Multiple file types found in {image_dir}: {','.join(extensions)}"
)

file_type = "dicom"
if extension.lower() in {".png", "png"}:
file_type = "png"
voxel_spacing = sybil.datasets.utils.VOXEL_SPACING
logger.debug(f"Using default voxel spacing: {voxel_spacing}")
assert file_type in {"dicom", "png"}
file_type = typing.cast(Literal["dicom", "png"], file_type)

num_files = len(input_files)

Expand All @@ -92,7 +113,7 @@ def inference(
model = Sybil(model_name)

# Get risk scores
serie = Serie(input_files, file_type=file_type)
serie = Serie(input_files, voxel_spacing=voxel_spacing, file_type=file_type)
series = [serie]
prediction = model.predict(series, return_attentions=return_attentions)
prediction_scores = prediction.scores[0]
Expand All @@ -110,6 +131,7 @@ def inference(
with open(attention_path, "wb") as f:
pickle.dump(prediction, f)

if write_attention_images:
series_with_attention = visualize_attentions(
series,
attentions=prediction.attentions,
Expand All @@ -126,11 +148,12 @@ def main():

os.makedirs(args.output_dir, exist_ok=True)

pred_dict, series_with_attention = inference(
pred_dict, series_with_attention = predict(
args.image_dir,
args.output_dir,
args.model_name,
args.return_attentions,
args.write_attention_images,
file_type=args.file_type,
)

Expand Down
12 changes: 6 additions & 6 deletions sybil/serie.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
"""
if label is not None and censor_time is None:
raise ValueError("censor_time should also provided with label.")
if file_type == "png" and voxel_spacing is None:
raise ValueError("voxel_spacing should be provided for PNG files.")

self._censor_time = censor_time
self._label = label
Expand Down Expand Up @@ -263,13 +265,11 @@ def _check_valid(self, args):
- serie doesn't have a label, OR
- slice thickness is too big
"""
if (self._meta.thickness is None) or (
self._meta.thickness > args.slice_thickness_filter
):
if self._meta.thickness is None:
raise ValueError("slice thickness not found")
if self._meta.thickness > args.slice_thickness_filter:
raise ValueError(
"slice thickness is greater than {}.".format(
args.slice_thickness_filter
)
f"slice thickness {self._meta.thickness} is greater than {args.slice_thickness_filter}."
)
if self._meta.voxel_spacing is None:
raise ValueError("voxel spacing either not set or not found in DICOM")
2 changes: 1 addition & 1 deletion sybil/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def _get_formatter(loglevel="INFO"):
warn_fmt = "[%(asctime)s] %(levelname)s -%(message)s"
warn_fmt = "[%(asctime)s] %(levelname)s - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if loglevel.upper() in {"DEBUG"} else warn_fmt
return logging.Formatter(
Expand Down
4 changes: 2 additions & 2 deletions sybil/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
average_precision_score,
)
import numpy as np
from lifelines.utils.btree import _BTree
from lifelines import KaplanMeierFitter
import warnings

EPSILON = 1e-6
Expand Down Expand Up @@ -154,6 +152,7 @@ def include_exam_and_determine_label(prob_arr, censor_time, gold):


def get_censoring_dist(train_dataset):
from lifelines import KaplanMeierFitter
_dataset = train_dataset.dataset
times, event_observed = (
[d["time_at_event"] for d in _dataset],
Expand Down Expand Up @@ -309,6 +308,7 @@ def _concordance_summary_statistics(
censored_truth = censored_truth[ix]
censored_pred = predicted_event_times[~died_mask][ix]

from lifelines.utils.btree import _BTree
censored_ix = 0
died_ix = 0
times_to_compare = {}
Expand Down
Loading