diff --git a/README.md b/README.md index 751e615..fff4939 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,15 @@ Lung Cancer Risk Prediction +## Run a regression test + +```shell +python tests/regression_test.py +``` + +This will download the`sybil_ensemble` model and sample data, and compare the results to what has previously been calculated. + + ## Run the model You can load our pretrained model trained on the NLST dataset, and score a given DICOM serie as follows: diff --git a/docs/requirements.txt b/docs/requirements.txt index d57f8e4..b2adea0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,26 +1,6 @@ # Requirements file for ReadTheDocs, check .readthedocs.yml. # To build the module reference correctly, make sure every external package # under `install_requires` in `setup.cfg` is also listed here! -# sphinx_rtd_theme ---find-links https://download.pytorch.org/whl/cu113/torch_stable.html +# sphinx_rtd_theme recommonmark sphinx>=3.2.1 -# deep learning -torch==1.10.1+cu113 -torchvision==0.11.2+cu113 -pytorch_lightning==1.5.6 -# math -scikit-learn==1.0.2 -# utils -tqdm -lifelines==0.26.4 -# loading -opencv-python==4.5.4.60 -opencv-python-headless==4.5.4.60 -albumentations==1.1.0 -pydicom==2.2.2 -# logging -#comet-ml -torchio==0.18.74 -# downloading snapshots -gdown==4.6.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index c87c2eb..cb647a5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,12 +7,12 @@ author_email = license_file = LICENSE.txt long_description = file: README.md long_description_content_type = text/markdown; charset=UTF-8; variant=GFM -version = 1.0.1 +version = 1.0.3 # url = project_urls = ; Documentation = https://.../docs - Source = https://github.com/pgmikhael/Sybil/ - Tracker = https://github.com/pgmikhael/Sybil/issues + Source = https://github.com/reginabarzilaygroup/sybil + Tracker = https://github.com/reginabarzilaygroup/sybil/issues # Change if running only on Windows, Mac or Linux (comma-separated) @@ -28,13 +28,28 @@ zip_safe = False packages = find: include_package_data = True python_requires = >=3.8 - # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in # new major versions. This works if the required packages follow Semantic Versioning. # For more information, check out https://semver.org/. +# Use --find-links https://download.pytorch.org/whl/cu113/torch_stable.html for torch libraries install_requires = - importlib-metadata; python_version<"3.8" + importlib-metadata; python_version>="3.8" + numpy==1.24.1 + torch==1.11.0+cu113; sys_platform != "darwin" + torch==1.11.0; sys_platform == "darwin" + torchvision==0.12.0+cu113; sys_platform != "darwin" + torchvision==0.12.0; sys_platform == "darwin" + pytorch_lightning==1.5.6 + scikit-learn==1.0.2 + tqdm==4.62.3 + lifelines==0.26.4 + opencv-python==4.5.4.60 + opencv-python-headless==4.5.4.60 + albumentations==1.1.0 + pydicom==2.2.2 + torchio==0.18.74 + gdown==4.6.0 [options.packages.find] @@ -76,7 +91,6 @@ norecursedirs = build .tox addopts = - --cov sybil --cov-report term-missing --verbose testpaths = tests # Use pytest markers to select/deselect specific tests diff --git a/tests/regression_test.py b/tests/regression_test.py new file mode 100644 index 0000000..6a29e67 --- /dev/null +++ b/tests/regression_test.py @@ -0,0 +1,93 @@ +import datetime +import math +import os +import requests +import zipfile + +from sybil import Serie, Sybil + +script_directory = os.path.dirname(os.path.abspath(__file__)) +project_directory = os.path.dirname(script_directory) + + +def myprint(instr): + print(f"{datetime.datetime.now()} - {instr}") + + +def download_and_extract_zip(zip_file_name, cache_dir, url, demo_data_dir): + # Check and construct the full path of the zip file + zip_file_path = os.path.join(cache_dir, zip_file_name) + + # 1. Check if the zip file exists + if not os.path.exists(zip_file_path): + # myprint(f"Zip file not found at {zip_file_path}. Downloading from {url}...") + # 2. Download the file + response = requests.get(url) + with open(zip_file_path, 'wb') as file: + file.write(response.content) + # myprint(f"Downloaded zip file to {zip_file_path}") + + # 3. Check if the output directory exists + if not os.path.exists(demo_data_dir): + # myprint(f"Output directory {demo_data_dir} does not exist. Creating and extracting...") + # 4. Extract the zip file + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall(demo_data_dir) + # myprint(f"Extracted zip file to {demo_data_dir}") + else: + pass + # myprint(f"Output directory {demo_data_dir} already exists. No extraction needed.") + + +def main(): + # Note that this function is named so that pytest will not automatically discover it + # It takes a long time to run and potentially a lot of disk space + + # Download demo data + demo_data_url = "https://www.dropbox.com/sh/addq480zyguxbbg/AACJRVsKDL0gpq-G9o3rfCBQa?dl=1" + expected_scores = [ + 0.021628819563619374, + 0.03857256315036462, + 0.07191945816622261, + 0.07926975188037134, + 0.09584583525781108, + 0.13568094038444453 + ] + + zip_file_name = "SYBIL.zip" + cache_dir = os.path.expanduser("~/.sybil") + demo_data_dir = os.path.join(cache_dir, "SYBIL") + image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data") + os.makedirs(cache_dir, exist_ok=True) + download_and_extract_zip(zip_file_name, cache_dir, demo_data_url, demo_data_dir) + + dicom_files = os.listdir(image_data_dir) + dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files] + num_files = len(dicom_files) + + # Load a trained model + model = Sybil("sybil_ensemble") + + # myprint(f"Beginning prediction using {num_files} from {image_data_dir}") + + # Get risk scores + serie = Serie(dicom_files) + prediction = model.predict([serie])[0] + actual_scores = prediction[0] + count = len(actual_scores) + + # myprint(f"Prediction finished. Results\n{actual_scores}") + + assert len(expected_scores) == len(actual_scores), f"Unexpected score length {count}" + + all_elements_match = True + for exp_score, act_score in zip(expected_scores, actual_scores): + does_match = math.isclose(exp_score, act_score, rel_tol=1e-6) + assert does_match, f"Mismatched scores. {exp_score} != {act_score}" + all_elements_match &= does_match + + print(f"Data URL: {demo_data_url}\nAll {count} elements match: {all_elements_match}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_create_sybilnet.py b/tests/test_create_sybilnet.py new file mode 100644 index 0000000..9574c5a --- /dev/null +++ b/tests/test_create_sybilnet.py @@ -0,0 +1,18 @@ +import argparse +import datetime +import os + +from sybil import Serie, Sybil + +def test_create_sybilnet(): + from sybil.models.sybil import SybilNet + + fake_args = argparse.Namespace( + dropout=0.1, + max_followup=5, + ) + + sybil_net = SybilNet(fake_args) + + assert sybil_net.hidden_dim == 512 + assert sybil_net.prob_of_failure_layer is not None diff --git a/tox.ini b/tox.ini index 45b1308..77a3595 100644 --- a/tox.ini +++ b/tox.ini @@ -14,14 +14,15 @@ deps = setuptools pytest pytest-cov - flake8 - mypy - black + # flake8 + # mypy + # black +install_command = pip install --pre --find-links https://download.pytorch.org/whl/cu113/torch_stable.html {opts} {packages} commands = pytest {posargs} - black {toxinidir}/sybil --check - flake8 {toxinidir}/sybil - mypy {toxinidir}/sybil + # black {toxinidir}/sybil --check + # flake8 {toxinidir}/sybil + # mypy {toxinidir}/sybil [testenv:{clean,build}]