-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python | ||
|
||
__doc__ = """ | ||
Simple example script showing how to use the Sybil library locally to predict risk scores for a set of DICOM files. | ||
""" | ||
|
||
import sybil | ||
from sybil import visualize_attentions | ||
|
||
from utils import get_demo_data | ||
|
||
|
||
def main(): | ||
# Load a trained model | ||
model = sybil.Sybil("sybil_ensemble") | ||
|
||
dicom_files = get_demo_data() | ||
|
||
# Get risk scores | ||
serie = sybil.Serie(dicom_files) | ||
print(f"Processing {len(dicom_files)} DICOM files") | ||
prediction = model.predict([serie], return_attentions=True) | ||
scores = prediction.scores | ||
|
||
print(f"Risk scores: {scores}") | ||
|
||
# Visualize attention maps | ||
output_dir = "sybil_attention_output" | ||
|
||
print(f"Writing attention images to {output_dir}") | ||
series_with_attention = visualize_attentions( | ||
serie, | ||
attentions=prediction.attentions, | ||
save_directory=output_dir, | ||
gain=3, | ||
) | ||
|
||
print(f"Finished writing attention images to {output_dir}") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
#!/usr/bin/env python | ||
|
||
__doc__ = """ | ||
This example shows how to use a client to access a | ||
remote Sybil server (running Ark) to predict risk scores for a set of DICOM files. | ||
The server must be started separately. | ||
https://github.com/reginabarzilaygroup/Sybil/wiki | ||
https://github.com/reginabarzilaygroup/ark/wiki | ||
""" | ||
import json | ||
import os | ||
|
||
import numpy as np | ||
import requests | ||
|
||
import sybil.utils.visualization | ||
|
||
from utils import get_demo_data | ||
|
||
if __name__ == "__main__": | ||
|
||
dicom_files = get_demo_data() | ||
serie = sybil.Serie(dicom_files) | ||
|
||
# Set the URL of the remote Sybil server | ||
ark_hostname = "localhost" | ||
ark_port = 5000 | ||
|
||
# Set the URL of the remote Sybil server | ||
ark_host = f"http://{ark_hostname}:{ark_port}" | ||
|
||
data_dict = {"return_attentions": True} | ||
payload = {"data": json.dumps(data_dict)} | ||
|
||
# Check if the server is running and reachable | ||
resp = requests.get(f"{ark_host}/info") | ||
if resp.status_code != 200: | ||
raise ValueError(f"Failed to connect to ARK server. Status code: {resp.status_code}") | ||
|
||
info_data = resp.json()["data"] | ||
assert info_data["modelName"].lower() == "sybil", "The ARK server is not running Sybil" | ||
print(f"ARK server info: {info_data}") | ||
|
||
# Submit prediction to ARK server. | ||
files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files] | ||
r = requests.post(f"{ark_host}/dicom/files", files=files, data=payload) | ||
_ = [f[1].close() for f in files] | ||
if r.status_code != 200: | ||
raise ValueError(f"Error occurred processing DICOM files. Status code: {r.status_code}.\n{r.text}") | ||
|
||
r_json = r.json() | ||
predictions = r_json["data"]["predictions"] | ||
|
||
scores = predictions[0] | ||
print(f"Risk scores: {scores}") | ||
|
||
attentions = predictions[1] | ||
attentions = np.array(attentions) | ||
print(f"Attention shape: {attentions.shape}") | ||
|
||
# Visualize attention maps | ||
save_directory = "remote_ark_sybil_attention_output" | ||
|
||
print(f"Writing attention images to {save_directory}") | ||
|
||
images = serie.get_raw_images() | ||
overlayed_images = sybil.utils.visualization.build_overlayed_images(images, attentions, gain=3) | ||
|
||
if save_directory is not None: | ||
serie_idx = 0 | ||
save_path = os.path.join(save_directory, f"serie_{serie_idx}") | ||
sybil.utils.visualization.save_images(overlayed_images, save_path, f"serie_{serie_idx}") | ||
|
||
print(f"Finished writing attention images to {save_directory}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
from urllib.request import urlopen | ||
|
||
|
||
def download_file(url, filepath): | ||
response = urlopen(url) | ||
|
||
target_dir = os.path.dirname(filepath) | ||
if target_dir and not os.path.exists(target_dir): | ||
os.makedirs(target_dir) | ||
|
||
# Check if the request was successful | ||
if response.status == 200: | ||
with open(filepath, 'wb') as f: | ||
f.write(response.read()) | ||
else: | ||
print(f"Failed to download file. Status code: {response.status_code}") | ||
|
||
return filepath | ||
|
||
def get_demo_data(): | ||
demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1" | ||
|
||
zip_file_name = "sybil_example.zip" | ||
cache_dir = os.path.expanduser("~/.sybil") | ||
zip_file_path = os.path.join(cache_dir, zip_file_name) | ||
os.makedirs(cache_dir, exist_ok=True) | ||
if not os.path.exists(zip_file_path): | ||
print(f"Downloading demo data to {zip_file_path}") | ||
download_file(demo_data_url, zip_file_path) | ||
|
||
demo_data_dir = os.path.join(cache_dir, "sybil_example") | ||
image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data") | ||
if not os.path.exists(demo_data_dir): | ||
print(f"Extracting demo data to {demo_data_dir}") | ||
import zipfile | ||
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | ||
zip_ref.extractall(demo_data_dir) | ||
|
||
dicom_files = os.listdir(image_data_dir) | ||
dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files] | ||
return dicom_files |