From 844c4a5d2b1e3a480eb3ca3eafbca5323eac9566 Mon Sep 17 00:00:00 2001 From: Jacob Silterra Date: Mon, 9 Sep 2024 13:31:09 -0400 Subject: [PATCH] Add example python scripts --- examples/local.py | 41 +++++++++++++++++++ examples/remote_ark_sybil.py | 77 ++++++++++++++++++++++++++++++++++++ examples/utils.py | 42 ++++++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 examples/local.py create mode 100644 examples/remote_ark_sybil.py create mode 100644 examples/utils.py diff --git a/examples/local.py b/examples/local.py new file mode 100644 index 0000000..e4a752c --- /dev/null +++ b/examples/local.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +__doc__ = """ +Simple example script showing how to use the Sybil library locally to predict risk scores for a set of DICOM files. +""" + +import sybil +from sybil import visualize_attentions + +from utils import get_demo_data + + +def main(): + # Load a trained model + model = sybil.Sybil("sybil_ensemble") + + dicom_files = get_demo_data() + + # Get risk scores + serie = sybil.Serie(dicom_files) + print(f"Processing {len(dicom_files)} DICOM files") + prediction = model.predict([serie], return_attentions=True) + scores = prediction.scores + + print(f"Risk scores: {scores}") + + # Visualize attention maps + output_dir = "sybil_attention_output" + + print(f"Writing attention images to {output_dir}") + series_with_attention = visualize_attentions( + serie, + attentions=prediction.attentions, + save_directory=output_dir, + gain=3, + ) + + print(f"Finished writing attention images to {output_dir}") + +if __name__ == "__main__": + main() diff --git a/examples/remote_ark_sybil.py b/examples/remote_ark_sybil.py new file mode 100644 index 0000000..a11aced --- /dev/null +++ b/examples/remote_ark_sybil.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +__doc__ = """ +This example shows how to use a client to access a +remote Sybil server (running Ark) to predict risk scores for a set of DICOM files. + +The server must be started separately. + +https://github.com/reginabarzilaygroup/Sybil/wiki +https://github.com/reginabarzilaygroup/ark/wiki +""" +import json +import os + +import numpy as np +import requests + +import sybil.utils.visualization + +from utils import get_demo_data + +if __name__ == "__main__": + + dicom_files = get_demo_data() + serie = sybil.Serie(dicom_files) + + # Set the URL of the remote Sybil server + ark_hostname = "localhost" + ark_port = 5000 + + # Set the URL of the remote Sybil server + ark_host = f"http://{ark_hostname}:{ark_port}" + + data_dict = {"return_attentions": True} + payload = {"data": json.dumps(data_dict)} + + # Check if the server is running and reachable + resp = requests.get(f"{ark_host}/info") + if resp.status_code != 200: + raise ValueError(f"Failed to connect to ARK server. Status code: {resp.status_code}") + + info_data = resp.json()["data"] + assert info_data["modelName"].lower() == "sybil", "The ARK server is not running Sybil" + print(f"ARK server info: {info_data}") + + # Submit prediction to ARK server. + files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files] + r = requests.post(f"{ark_host}/dicom/files", files=files, data=payload) + _ = [f[1].close() for f in files] + if r.status_code != 200: + raise ValueError(f"Error occurred processing DICOM files. Status code: {r.status_code}.\n{r.text}") + + r_json = r.json() + predictions = r_json["data"]["predictions"] + + scores = predictions[0] + print(f"Risk scores: {scores}") + + attentions = predictions[1] + attentions = np.array(attentions) + print(f"Attention shape: {attentions.shape}") + + # Visualize attention maps + save_directory = "remote_ark_sybil_attention_output" + + print(f"Writing attention images to {save_directory}") + + images = serie.get_raw_images() + overlayed_images = sybil.utils.visualization.build_overlayed_images(images, attentions, gain=3) + + if save_directory is not None: + serie_idx = 0 + save_path = os.path.join(save_directory, f"serie_{serie_idx}") + sybil.utils.visualization.save_images(overlayed_images, save_path, f"serie_{serie_idx}") + + print(f"Finished writing attention images to {save_directory}") + diff --git a/examples/utils.py b/examples/utils.py new file mode 100644 index 0000000..c30417c --- /dev/null +++ b/examples/utils.py @@ -0,0 +1,42 @@ +import os +from urllib.request import urlopen + + +def download_file(url, filepath): + response = urlopen(url) + + target_dir = os.path.dirname(filepath) + if target_dir and not os.path.exists(target_dir): + os.makedirs(target_dir) + + # Check if the request was successful + if response.status == 200: + with open(filepath, 'wb') as f: + f.write(response.read()) + else: + print(f"Failed to download file. Status code: {response.status_code}") + + return filepath + +def get_demo_data(): + demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1" + + zip_file_name = "sybil_example.zip" + cache_dir = os.path.expanduser("~/.sybil") + zip_file_path = os.path.join(cache_dir, zip_file_name) + os.makedirs(cache_dir, exist_ok=True) + if not os.path.exists(zip_file_path): + print(f"Downloading demo data to {zip_file_path}") + download_file(demo_data_url, zip_file_path) + + demo_data_dir = os.path.join(cache_dir, "sybil_example") + image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data") + if not os.path.exists(demo_data_dir): + print(f"Extracting demo data to {demo_data_dir}") + import zipfile + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall(demo_data_dir) + + dicom_files = os.listdir(image_data_dir) + dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files] + return dicom_files