Skip to content

Commit

Permalink
GH28 Individual inference
Browse files Browse the repository at this point in the history
Add inference.py, a simple CLI script which performs inference on a single exam.

Also remove the `sample` argument from a bunch of loader functions since it is not being used.
  • Loading branch information
jsilter committed Mar 4, 2024
1 parent 6453d08 commit d0da448
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 32 deletions.
99 changes: 99 additions & 0 deletions scripts/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import argparse
import datetime
import json
import logging
import os
import pickle

from sybil import Serie, Sybil, visualize_attentions

script_directory = os.path.dirname(os.path.abspath(__file__))
project_directory = os.path.dirname(script_directory)


def _get_parser():
parser = argparse.ArgumentParser(description=__doc__)

parser.add_argument('--output-dir', default="sybil_result", dest="output_dir",
help="Output directory in which to save prediction results."
"Prediction will be printed to stdout as well.")

parser.add_argument('--return-attentions', default=False, action="store_true",
help="Generate an image which overlaps attention scores.")

parser.add_argument('dicom_dir', default=None,
help="Path to directory containing DICOM files (from a single exam) to run inference on."
"Every dcm file in the directory will be included.")

parser.add_argument('--model-name', default="sybil_ensemble", dest="model_name",
help="Name of the model to use for prediction. Default: sybil_ensemble")

parser.add_argument('-l', '--log', '--loglevel', default="INFO", dest="loglevel")

return parser


def logging_basic_config(args):
info_fmt = "[%(asctime)s] - %(message)s"
debug_fmt = "[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s"
fmt = debug_fmt if args.loglevel.upper() == "DEBUG" else info_fmt

logging.basicConfig(format=fmt,
datefmt="%Y-%m-%d %H:%M:%S",
level=args.loglevel.upper())


def inference(dicom_dir, output_dir, model_name="sybil_ensemble", return_attentions=False):
logger = logging.getLogger('inference')

dicom_files = os.listdir(dicom_dir)
dicom_files = [os.path.join(dicom_dir, x) for x in dicom_files]
dicom_files = [x for x in dicom_files if x.endswith(".dcm") and os.path.isfile(x)]
num_files = len(dicom_files)

# Load a trained model
model = Sybil(model_name)

logger.debug(f"Beginning prediction using {num_files} files from {dicom_dir}")

# Get risk scores
serie = Serie(dicom_files)
series = [serie]
prediction = model.predict(series, return_attentions=return_attentions)
prediction_scores = prediction.scores[0]

logger.debug(f"Prediction finished. Results:\n{prediction_scores}")

prediction_path = os.path.join(output_dir, "prediction_scores.json")
pred_dict = {"predictions": prediction.scores}
with open(prediction_path, "w") as f:
json.dump(pred_dict, f, indent=2)

if return_attentions:
attention_path = os.path.join(output_dir, "attention_scores.pkl")
with open(attention_path, "wb") as f:
pickle.dump(prediction, f)

series_with_attention = visualize_attentions(
series,
attentions=prediction.attentions,
save_directory=output_dir,
gain=3,
)

return pred_dict


def main():
args = _get_parser().parse_args()
logging_basic_config(args)

os.makedirs(args.output_dir, exist_ok=True)

pred_dict = inference(args.dicom_dir, args.output_dir, args.model_name, args.return_attentions)

print(json.dumps(pred_dict, indent=2))


if __name__ == "__main__":
main()
20 changes: 20 additions & 0 deletions scripts/run_inference_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

# Run inference on the demo data
# The output will be printed to the console


demo_scan_dir=sybil_demo_data

# Download the demo data if it doesn't exist
if [ ! -d "$demo_scan_dir" ]; then
# Download example data
curl -L -o sybil_example.zip "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1"
tar -xf sybil_example.zip
fi

python3 scripts/inference.py \
--loglevel DEBUG \
--output-dir demo_prediction \
--return-attentions \
$demo_scan_dir
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ author_email =
license_file = LICENSE.txt
long_description = file: README.md
long_description_content_type = text/markdown; charset=UTF-8; variant=GFM
version = 1.0.3
version = 1.0.4
# url =
project_urls =
; Documentation = https://.../docs
Expand Down
18 changes: 8 additions & 10 deletions sybil/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self):
)

@abstractmethod
def __call__(self, img, mask=None, additional=None):
def __call__(self, input_dict):
pass

def set_seed(self, seed):
Expand Down Expand Up @@ -80,9 +80,9 @@ def __init__(self, augmentations):
super(ComposeAug, self).__init__()
self.augmentations = augmentations

def __call__(self, input_dict, sample=None):
def __call__(self, input_dict):
for transformer in self.augmentations:
input_dict = transformer(input_dict, sample)
input_dict = transformer(input_dict)

return input_dict

Expand All @@ -97,7 +97,7 @@ def __init__(self):
self.transform = ToTensorV2()
self.name = "totensor"

def __call__(self, input_dict, sample=None):
def __call__(self, input_dict):
input_dict["input"] = torch.from_numpy(input_dict["input"]).float()
if input_dict.get("mask", None) is not None:
input_dict["mask"] = torch.from_numpy(input_dict["mask"]).float()
Expand All @@ -117,7 +117,7 @@ def __init__(self, args, kwargs):
self.set_cachable(width, height)
self.transform = A.Resize(height, width)

def __call__(self, input_dict, sample=None):
def __call__(self, input_dict):
out = self.transform(
image=input_dict["input"], mask=input_dict.get("mask", None)
)
Expand All @@ -140,9 +140,7 @@ def __init__(self, args, kwargs):
self.max_angle = int(kwargs["deg"])
self.transform = A.Rotate(limit=self.max_angle, p=0.5)

def __call__(self, input_dict, sample=None):
if "seed" in sample:
self.set_seed(sample["seed"])
def __call__(self, input_dict):
out = self.transform(
image=input_dict["input"], mask=input_dict.get("mask", None)
)
Expand Down Expand Up @@ -171,7 +169,7 @@ def __init__(self, args, kwargs):
"png",
]

def __call__(self, input_dict, sample=None):
def __call__(self, input_dict):
img = input_dict["input"]
if len(img.size()) == 2:
img = img.unsqueeze(0)
Expand All @@ -195,7 +193,7 @@ def __init__(self, args, kwargs):
assert len(kwargs) == 0
self.args = args

def __call__(self, input_dict, sample=None):
def __call__(self, input_dict):
img = input_dict["input"]
mask = input_dict.get("mask", None)
if mask is not None:
Expand Down
24 changes: 11 additions & 13 deletions sybil/loaders/abstract_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def split_augmentations_by_cache(augmentations):


def apply_augmentations_and_cache(
loaded_input, sample, img_path, augmentations, cache, base_key=""
loaded_input, img_path, augmentations, cache, base_key=""
):
"""
Loads the loaded input by its absolute path and apply the augmentations one
Expand All @@ -69,7 +69,7 @@ def apply_augmentations_and_cache(
all_prev_cachable = True
key = base_key
for ind, trans in enumerate(augmentations):
loaded_input = trans(loaded_input, sample)
loaded_input = trans(loaded_input)
if not all_prev_cachable or not trans.cachable():
all_prev_cachable = False
else:
Expand Down Expand Up @@ -145,39 +145,39 @@ def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
self.composed_all_augmentations = ComposeAug(augmentations)

@abstractmethod
def load_input(self, path, sample):
def load_input(self, path):
pass

@property
@abstractmethod
def cached_extension(self):
pass

def configure_path(self, path, sample):
def configure_path(self, path):
return path

def get_image(self, path, sample):
def get_image(self, path):
"""
Returns a transformed image by its absolute path.
If cache is used - transformed image will be loaded if available,
and saved to cache if not.
"""
input_dict = {}
input_path = self.configure_path(path, sample)
input_path = self.configure_path(path)

if input_path == self.pad_token:
return self.load_input(input_path, sample)
return self.load_input(input_path)

if not self.use_cache:
input_dict = self.load_input(input_path, sample)
input_dict = self.load_input(input_path)
# hidden loaders typically do not use augmentation
if self.apply_augmentations:
input_dict = self.composed_all_augmentations(input_dict, sample)
input_dict = self.composed_all_augmentations(input_dict)
return input_dict

if self.args.use_annotations:
input_dict["mask"] = get_scaled_annotation_mask(
sample["annotations"], self.args
input_dict["annotations"], self.args
)

for key, post_augmentations in self.split_augmentations:
Expand All @@ -192,7 +192,6 @@ def get_image(self, path, sample):
if self.apply_augmentations:
input_dict = apply_augmentations_and_cache(
input_dict,
sample,
input_path,
post_augmentations,
self.cache,
Expand All @@ -207,11 +206,10 @@ def get_image(self, path, sample):
warnings.warn(CORUPTED_FILE_ERR.format(sys.exc_info()[0]))
self.cache.rem(input_path, key)
all_augmentations = self.split_augmentations[-1][1]
input_dict = self.load_input(input_path, sample)
input_dict = self.load_input(input_path)
if self.apply_augmentations:
input_dict = apply_augmentations_and_cache(
input_dict,
sample,
input_path,
all_augmentations,
self.cache,
Expand Down
8 changes: 4 additions & 4 deletions sybil/loaders/image_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

class OpenCVLoader(abstract_loader):

def load_input(self, path, sample):
def load_input(self, path):
"""
loads as grayscale image
"""
return {"input": cv2.imread(path, 0) }
return {"input": cv2.imread(path, 0)}

@property
def cached_extension(self):
Expand All @@ -27,12 +27,12 @@ def __init__(self, cache_path, augmentations, args, apply_augmentations=True):
self.window_center = -600
self.window_width = 1500

def load_input(self, path, sample):
def load_input(self, path):
try:
dcm = pydicom.dcmread(path)
dcm = apply_modality_lut(dcm.pixel_array, dcm)
arr = apply_windowing(dcm, self.window_center, self.window_width)
arr = arr//256 # parity with images loaded as 8 bit
arr = arr//256 # parity with images loaded as 8 bit
except Exception:
raise Exception(LOADING_ERROR.format("COULD NOT LOAD DICOM."))
return {"input": arr}
Expand Down
6 changes: 2 additions & 4 deletions sybil/serie.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_raw_images(self) -> List[np.ndarray]:
"""

loader = get_sample_loader("test", self._args, apply_augmentations=False)
input_dicts = [loader.get_image(path, {}) for path in self._meta.paths]
input_dicts = [loader.get_image(path) for path in self._meta.paths]
images = [i["input"] for i in input_dicts]
return images

Expand All @@ -145,10 +145,8 @@ def get_volume(self) -> torch.Tensor:
CT volume of shape (1, C, N, H, W)
"""

sample = {"seed": np.random.randint(0, 2**32 - 1)}

input_dicts = [
self._loader.get_image(path, sample) for path in self._meta.paths
self._loader.get_image(path) for path in self._meta.paths
]

x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)
Expand Down

0 comments on commit d0da448

Please sign in to comment.