Skip to content

Commit

Permalink
Incorporate other changes:
Browse files Browse the repository at this point in the history
* Update version to 1.1.0
* Download model from GitHub
* Update README to point to wiki
* Update default model to be "sybil_ensemble"
  • Loading branch information
jsilter committed Mar 18, 2024
1 parent 4d72a7b commit 1685364
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# Sybil

Lung Cancer Risk Prediction
Lung Cancer Risk Prediction.

Additional documentation can be found on the [GitHub Wiki](https://github.com/reginabarzilaygroup/Sybil/wiki).

## Run a regression test

Expand All @@ -21,7 +23,7 @@ You can load our pretrained model trained on the NLST dataset, and score a given
from sybil import Serie, Sybil

# Load a trained model
model = Sybil("sybil_base")
model = Sybil("sybil_ensemble")

# Get risk scores
serie = Serie([dicom_path_1, dicom_path_2, ...])
Expand All @@ -32,9 +34,9 @@ serie = Serie([dicom_path_1, dicom_path_2, ...], label=1)
results = model.evaluate([serie])
```

Models available include: `sybil_base` and `sybil_ensemble`.
Models available include: `sybil_1`, `sybil_2`, `sybil_3`, `sybil_4`, `sybil_5` and `sybil_ensemble`.

All model files are available [here](https://drive.google.com/drive/folders/1nBp05VV9mf5CfEO6W5RY4ZpcpxmPDEeR?usp=sharing).
All model files are available on [GitHub releases](https://github.com/reginabarzilaygroup/Sybil/releases) as well as [here](https://drive.google.com/drive/folders/1nBp05VV9mf5CfEO6W5RY4ZpcpxmPDEeR?usp=sharing).

## Replicating results

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.0.4
version = 1.1.0
# url =
project_urls =
; Documentation = https://.../docs
Expand Down
50 changes: 45 additions & 5 deletions sybil/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import NamedTuple, Union, Dict, List, Optional
import os
from argparse import Namespace
import gdown
from io import BytesIO
import os
from typing import NamedTuple, Union, Dict, List, Optional, Tuple
from urllib.request import urlopen
from zipfile import ZipFile
# import gdown

import torch
import numpy as np
Expand All @@ -12,6 +15,7 @@
from sybil.utils.metrics import get_survival_metrics


# Leaving this here for a bit; these are IDs to download the models from Google Drive
NAME_TO_FILE = {
"sybil_base": {
"checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
Expand Down Expand Up @@ -62,6 +66,8 @@
},
}

CHECKPOINT_URL = "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.0.3/sybil_checkpoints.zip"


class Prediction(NamedTuple):
scores: List[List[float]]
Expand All @@ -75,7 +81,7 @@ class Evaluation(NamedTuple):
attentions: List[Dict[str, np.ndarray]] = None


def download_sybil(name, cache):
def download_sybil_gdrive(name, cache):
"""Download trained models and calibrator from Google Drive
Parameters
Expand Down Expand Up @@ -118,10 +124,44 @@ def download_sybil(name, cache):
return download_model_paths, download_calib_path


def download_sybil(name, cache) -> Tuple[List[str], str]:
"""Download trained models and calibrator"""
# Create cache folder if not exists
cache = os.path.expanduser(cache)
os.makedirs(cache, exist_ok=True)

# Download models
model_files = NAME_TO_FILE[name]
checkpoints = model_files["checkpoint"]
download_calib_path = os.path.join(cache, f"{name}.p")
have_all_files = os.path.exists(download_calib_path)

download_model_paths = []
for checkpoint in checkpoints:
cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt")
have_all_files &= os.path.exists(cur_checkpoint_path)
download_model_paths.append(cur_checkpoint_path)

if not have_all_files:
print(f"Downloading models to {cache}")
download_and_extract(CHECKPOINT_URL, cache)

return download_model_paths, download_calib_path


def download_and_extract(remote_model_url: str, local_model_dir) -> List[str]:
resp = urlopen(remote_model_url)
os.makedirs(local_model_dir, exist_ok=True)
with ZipFile(BytesIO(resp.read())) as zip_file:
all_files_and_dirs = zip_file.namelist()
zip_file.extractall(local_model_dir)
return all_files_and_dirs


class Sybil:
def __init__(
self,
name_or_path: Union[List[str], str] = "sybil_base",
name_or_path: Union[List[str], str] = "sybil_ensemble",
cache: str = "~/.sybil/",
calibrator_path: Optional[str] = None,
device: Optional[str] = None,
Expand Down
3 changes: 2 additions & 1 deletion tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def main():
num_files = len(dicom_files)

# Load a trained model
model = Sybil("sybil_ensemble")
# model = Sybil("sybil_ensemble")
model = Sybil()

myprint(f"Beginning prediction using {num_files} files from {image_data_dir}")

Expand Down

0 comments on commit 1685364

Please sign in to comment.