Skip to content

Commit

Permalink
Merge pull request #22 from reginabarzilaygroup/v1.0.3_dev
Browse files Browse the repository at this point in the history
V1.0.3 dev
  • Loading branch information
pgmikhael authored Dec 1, 2023
2 parents ccd6db0 + 4dd26b6 commit 95830d5
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 33 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

Lung Cancer Risk Prediction

## Run a regression test

```shell
python tests/regression_test.py
```

This will download the`sybil_ensemble` model and sample data, and compare the results to what has previously been calculated.


## Run the model

You can load our pretrained model trained on the NLST dataset, and score a given DICOM serie as follows:
Expand Down
22 changes: 1 addition & 21 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
# Requirements file for ReadTheDocs, check .readthedocs.yml.
# To build the module reference correctly, make sure every external package
# under `install_requires` in `setup.cfg` is also listed here!
# sphinx_rtd_theme
--find-links https://download.pytorch.org/whl/cu113/torch_stable.html
# sphinx_rtd_theme
recommonmark
sphinx>=3.2.1
# deep learning
torch==1.10.1+cu113
torchvision==0.11.2+cu113
pytorch_lightning==1.5.6
# math
scikit-learn==1.0.2
# utils
tqdm
lifelines==0.26.4
# loading
opencv-python==4.5.4.60
opencv-python-headless==4.5.4.60
albumentations==1.1.0
pydicom==2.2.2
# logging
#comet-ml
torchio==0.18.74
# downloading snapshots
gdown==4.6.0
26 changes: 20 additions & 6 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.0.1
version = 1.0.3
# url =
project_urls =
; Documentation = https://.../docs
Source = https://github.com/pgmikhael/Sybil/
Tracker = https://github.com/pgmikhael/Sybil/issues
Source = https://github.com/reginabarzilaygroup/sybil
Tracker = https://github.com/reginabarzilaygroup/sybil/issues


# Change if running only on Windows, Mac or Linux (comma-separated)
Expand All @@ -28,13 +28,28 @@ zip_safe = False
packages = find:
include_package_data = True
python_requires = >=3.8

# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0.
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
# Use --find-links https://download.pytorch.org/whl/cu113/torch_stable.html for torch libraries
install_requires =
importlib-metadata; python_version<"3.8"
importlib-metadata; python_version>="3.8"
numpy==1.24.1
torch==1.11.0+cu113; sys_platform != "darwin"
torch==1.11.0; sys_platform == "darwin"
torchvision==0.12.0+cu113; sys_platform != "darwin"
torchvision==0.12.0; sys_platform == "darwin"
pytorch_lightning==1.5.6
scikit-learn==1.0.2
tqdm==4.62.3
lifelines==0.26.4
opencv-python==4.5.4.60
opencv-python-headless==4.5.4.60
albumentations==1.1.0
pydicom==2.2.2
torchio==0.18.74
gdown==4.6.0


[options.packages.find]
Expand Down Expand Up @@ -76,7 +91,6 @@ norecursedirs =
build
.tox
addopts =
--cov sybil --cov-report term-missing
--verbose
testpaths = tests
# Use pytest markers to select/deselect specific tests
Expand Down
93 changes: 93 additions & 0 deletions tests/regression_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import datetime
import math
import os
import requests
import zipfile

from sybil import Serie, Sybil

script_directory = os.path.dirname(os.path.abspath(__file__))
project_directory = os.path.dirname(script_directory)


def myprint(instr):
print(f"{datetime.datetime.now()} - {instr}")


def download_and_extract_zip(zip_file_name, cache_dir, url, demo_data_dir):
# Check and construct the full path of the zip file
zip_file_path = os.path.join(cache_dir, zip_file_name)

# 1. Check if the zip file exists
if not os.path.exists(zip_file_path):
# myprint(f"Zip file not found at {zip_file_path}. Downloading from {url}...")
# 2. Download the file
response = requests.get(url)
with open(zip_file_path, 'wb') as file:
file.write(response.content)
# myprint(f"Downloaded zip file to {zip_file_path}")

# 3. Check if the output directory exists
if not os.path.exists(demo_data_dir):
# myprint(f"Output directory {demo_data_dir} does not exist. Creating and extracting...")
# 4. Extract the zip file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(demo_data_dir)
# myprint(f"Extracted zip file to {demo_data_dir}")
else:
pass
# myprint(f"Output directory {demo_data_dir} already exists. No extraction needed.")


def main():
# Note that this function is named so that pytest will not automatically discover it
# It takes a long time to run and potentially a lot of disk space

# Download demo data
demo_data_url = "https://www.dropbox.com/sh/addq480zyguxbbg/AACJRVsKDL0gpq-G9o3rfCBQa?dl=1"
expected_scores = [
0.021628819563619374,
0.03857256315036462,
0.07191945816622261,
0.07926975188037134,
0.09584583525781108,
0.13568094038444453
]

zip_file_name = "SYBIL.zip"
cache_dir = os.path.expanduser("~/.sybil")
demo_data_dir = os.path.join(cache_dir, "SYBIL")
image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data")
os.makedirs(cache_dir, exist_ok=True)
download_and_extract_zip(zip_file_name, cache_dir, demo_data_url, demo_data_dir)

dicom_files = os.listdir(image_data_dir)
dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files]
num_files = len(dicom_files)

# Load a trained model
model = Sybil("sybil_ensemble")

# myprint(f"Beginning prediction using {num_files} from {image_data_dir}")

# Get risk scores
serie = Serie(dicom_files)
prediction = model.predict([serie])[0]
actual_scores = prediction[0]
count = len(actual_scores)

# myprint(f"Prediction finished. Results\n{actual_scores}")

assert len(expected_scores) == len(actual_scores), f"Unexpected score length {count}"

all_elements_match = True
for exp_score, act_score in zip(expected_scores, actual_scores):
does_match = math.isclose(exp_score, act_score, rel_tol=1e-6)
assert does_match, f"Mismatched scores. {exp_score} != {act_score}"
all_elements_match &= does_match

print(f"Data URL: {demo_data_url}\nAll {count} elements match: {all_elements_match}")


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions tests/test_create_sybilnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import argparse
import datetime
import os

from sybil import Serie, Sybil

def test_create_sybilnet():
from sybil.models.sybil import SybilNet

fake_args = argparse.Namespace(
dropout=0.1,
max_followup=5,
)

sybil_net = SybilNet(fake_args)

assert sybil_net.hidden_dim == 512
assert sybil_net.prob_of_failure_layer is not None
13 changes: 7 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ deps =
setuptools
pytest
pytest-cov
flake8
mypy
black
# flake8
# mypy
# black
install_command = pip install --pre --find-links https://download.pytorch.org/whl/cu113/torch_stable.html {opts} {packages}
commands =
pytest {posargs}
black {toxinidir}/sybil --check
flake8 {toxinidir}/sybil
mypy {toxinidir}/sybil
# black {toxinidir}/sybil --check
# flake8 {toxinidir}/sybil
# mypy {toxinidir}/sybil


[testenv:{clean,build}]
Expand Down

0 comments on commit 95830d5

Please sign in to comment.