diff --git a/README.md b/README.md index 06cbc42..fd9e41b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,9 @@ # Sybil -Lung Cancer Risk Prediction +Lung Cancer Risk Prediction. + +Additional documentation can be found on the [GitHub Wiki](https://github.com/reginabarzilaygroup/Sybil/wiki). ## Run a regression test @@ -21,7 +23,7 @@ You can load our pretrained model trained on the NLST dataset, and score a given from sybil import Serie, Sybil # Load a trained model -model = Sybil("sybil_base") +model = Sybil("sybil_ensemble") # Get risk scores serie = Serie([dicom_path_1, dicom_path_2, ...]) @@ -32,9 +34,9 @@ serie = Serie([dicom_path_1, dicom_path_2, ...], label=1) results = model.evaluate([serie]) ``` -Models available include: `sybil_base` and `sybil_ensemble`. +Models available include: `sybil_1`, `sybil_2`, `sybil_3`, `sybil_4`, `sybil_5` and `sybil_ensemble`. -All model files are available [here](https://drive.google.com/drive/folders/1nBp05VV9mf5CfEO6W5RY4ZpcpxmPDEeR?usp=sharing). +All model files are available on [GitHub releases](https://github.com/reginabarzilaygroup/Sybil/releases) as well as [here](https://drive.google.com/drive/folders/1nBp05VV9mf5CfEO6W5RY4ZpcpxmPDEeR?usp=sharing). ## Replicating results diff --git a/setup.cfg b/setup.cfg index 685b6f2..7cfc01a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ author_email = license_file = LICENSE.txt long_description = file: README.md long_description_content_type = text/markdown; charset=UTF-8; variant=GFM -version = 1.0.4 +version = 1.1.0 # url = project_urls = ; Documentation = https://.../docs diff --git a/sybil/model.py b/sybil/model.py index 2e734e7..93af540 100644 --- a/sybil/model.py +++ b/sybil/model.py @@ -1,7 +1,10 @@ -from typing import NamedTuple, Union, Dict, List, Optional -import os from argparse import Namespace -import gdown +from io import BytesIO +import os +from typing import NamedTuple, Union, Dict, List, Optional, Tuple +from urllib.request import urlopen +from zipfile import ZipFile +# import gdown import torch import numpy as np @@ -12,6 +15,7 @@ from sybil.utils.metrics import get_survival_metrics +# Leaving this here for a bit; these are IDs to download the models from Google Drive NAME_TO_FILE = { "sybil_base": { "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"], @@ -62,6 +66,8 @@ }, } +CHECKPOINT_URL = "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.0.3/sybil_checkpoints.zip" + class Prediction(NamedTuple): scores: List[List[float]] @@ -75,7 +81,7 @@ class Evaluation(NamedTuple): attentions: List[Dict[str, np.ndarray]] = None -def download_sybil(name, cache): +def download_sybil_gdrive(name, cache): """Download trained models and calibrator from Google Drive Parameters @@ -118,10 +124,44 @@ def download_sybil(name, cache): return download_model_paths, download_calib_path +def download_sybil(name, cache) -> Tuple[List[str], str]: + """Download trained models and calibrator""" + # Create cache folder if not exists + cache = os.path.expanduser(cache) + os.makedirs(cache, exist_ok=True) + + # Download models + model_files = NAME_TO_FILE[name] + checkpoints = model_files["checkpoint"] + download_calib_path = os.path.join(cache, f"{name}.p") + have_all_files = os.path.exists(download_calib_path) + + download_model_paths = [] + for checkpoint in checkpoints: + cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt") + have_all_files &= os.path.exists(cur_checkpoint_path) + download_model_paths.append(cur_checkpoint_path) + + if not have_all_files: + print(f"Downloading models to {cache}") + download_and_extract(CHECKPOINT_URL, cache) + + return download_model_paths, download_calib_path + + +def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]: + resp = urlopen(remote_model_url) + os.makedirs(local_model_dir, exist_ok=True) + with ZipFile(BytesIO(resp.read())) as zip_file: + all_files_and_dirs = zip_file.namelist() + zip_file.extractall(local_model_dir) + return all_files_and_dirs + + class Sybil: def __init__( self, - name_or_path: Union[List[str], str] = "sybil_base", + name_or_path: Union[List[str], str] = "sybil_ensemble", cache: str = "~/.sybil/", calibrator_path: Optional[str] = None, device: Optional[str] = None, diff --git a/tests/regression_test.py b/tests/regression_test.py index b07169f..a3dfb65 100644 --- a/tests/regression_test.py +++ b/tests/regression_test.py @@ -66,7 +66,8 @@ def main(): num_files = len(dicom_files) # Load a trained model - model = Sybil("sybil_ensemble") + # model = Sybil("sybil_ensemble") + model = Sybil() myprint(f"Beginning prediction using {num_files} files from {image_data_dir}")