Skip to content

Commit

Permalink
Add example python scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jsilter committed Sep 9, 2024
1 parent b061207 commit 844c4a5
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
41 changes: 41 additions & 0 deletions examples/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python

__doc__ = """
Simple example script showing how to use the Sybil library locally to predict risk scores for a set of DICOM files.
"""

import sybil
from sybil import visualize_attentions

from utils import get_demo_data


def main():
# Load a trained model
model = sybil.Sybil("sybil_ensemble")

dicom_files = get_demo_data()

# Get risk scores
serie = sybil.Serie(dicom_files)
print(f"Processing {len(dicom_files)} DICOM files")
prediction = model.predict([serie], return_attentions=True)
scores = prediction.scores

print(f"Risk scores: {scores}")

# Visualize attention maps
output_dir = "sybil_attention_output"

print(f"Writing attention images to {output_dir}")
series_with_attention = visualize_attentions(
serie,
attentions=prediction.attentions,
save_directory=output_dir,
gain=3,
)

print(f"Finished writing attention images to {output_dir}")

if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions examples/remote_ark_sybil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python

__doc__ = """
This example shows how to use a client to access a
remote Sybil server (running Ark) to predict risk scores for a set of DICOM files.
The server must be started separately.
https://github.com/reginabarzilaygroup/Sybil/wiki
https://github.com/reginabarzilaygroup/ark/wiki
"""
import json
import os

import numpy as np
import requests

import sybil.utils.visualization

from utils import get_demo_data

if __name__ == "__main__":

dicom_files = get_demo_data()
serie = sybil.Serie(dicom_files)

# Set the URL of the remote Sybil server
ark_hostname = "localhost"
ark_port = 5000

# Set the URL of the remote Sybil server
ark_host = f"http://{ark_hostname}:{ark_port}"

data_dict = {"return_attentions": True}
payload = {"data": json.dumps(data_dict)}

# Check if the server is running and reachable
resp = requests.get(f"{ark_host}/info")
if resp.status_code != 200:
raise ValueError(f"Failed to connect to ARK server. Status code: {resp.status_code}")

info_data = resp.json()["data"]
assert info_data["modelName"].lower() == "sybil", "The ARK server is not running Sybil"
print(f"ARK server info: {info_data}")

# Submit prediction to ARK server.
files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files]
r = requests.post(f"{ark_host}/dicom/files", files=files, data=payload)
_ = [f[1].close() for f in files]
if r.status_code != 200:
raise ValueError(f"Error occurred processing DICOM files. Status code: {r.status_code}.\n{r.text}")

r_json = r.json()
predictions = r_json["data"]["predictions"]

scores = predictions[0]
print(f"Risk scores: {scores}")

attentions = predictions[1]
attentions = np.array(attentions)
print(f"Attention shape: {attentions.shape}")

# Visualize attention maps
save_directory = "remote_ark_sybil_attention_output"

print(f"Writing attention images to {save_directory}")

images = serie.get_raw_images()
overlayed_images = sybil.utils.visualization.build_overlayed_images(images, attentions, gain=3)

if save_directory is not None:
serie_idx = 0
save_path = os.path.join(save_directory, f"serie_{serie_idx}")
sybil.utils.visualization.save_images(overlayed_images, save_path, f"serie_{serie_idx}")

print(f"Finished writing attention images to {save_directory}")

42 changes: 42 additions & 0 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from urllib.request import urlopen


def download_file(url, filepath):
response = urlopen(url)

target_dir = os.path.dirname(filepath)
if target_dir and not os.path.exists(target_dir):
os.makedirs(target_dir)

# Check if the request was successful
if response.status == 200:
with open(filepath, 'wb') as f:
f.write(response.read())
else:
print(f"Failed to download file. Status code: {response.status_code}")

return filepath

def get_demo_data():
demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1"

zip_file_name = "sybil_example.zip"
cache_dir = os.path.expanduser("~/.sybil")
zip_file_path = os.path.join(cache_dir, zip_file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(zip_file_path):
print(f"Downloading demo data to {zip_file_path}")
download_file(demo_data_url, zip_file_path)

demo_data_dir = os.path.join(cache_dir, "sybil_example")
image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data")
if not os.path.exists(demo_data_dir):
print(f"Extracting demo data to {demo_data_dir}")
import zipfile
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(demo_data_dir)

dicom_files = os.listdir(image_data_dir)
dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files]
return dicom_files

0 comments on commit 844c4a5

Please sign in to comment.