Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ICON-console-scripts #55

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
71 changes: 71 additions & 0 deletions bin/icon_register_pair
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python

import argparse

import icon_registration as icon
import icon_registration.itk_wrapper
import itk
import numpy as np
import torch
import importlib

parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--weights_path")
parser.add_argument("--fixed_image", required=True)
parser.add_argument("--moving_image", required=True)
parser.add_argument("--transform_out")
parser.add_argument("--displacement_image_out")
parser.add_argument("--warped_moving_out")
parser.add_argument("--finetune", action='store_true')

args = parser.parse_args()

model_package_name = ".".join(args.model.split(".")[:-1])
model_name = args.model.split(".")[-1]

model_package = importlib.import_module(model_package_name)
net = getattr(model_package, model_name)()

if args.weights_path:
weights = torch.load(args.weights_path)
net.regis_net.load_state_dict(weights)

fixed_image = itk.imread(args.fixed_image)
moving_image = itk.imread(args.moving_image)

# We want images normalized to 0.0, 1.0
for image in fixed_image, moving_image:
assert(np.abs(np.max(np.array(image)) - 1.0) < 0.2)
assert(np.abs(np.min(np.array(image))) < 0.2)

phi, _ = icon.itk_wrapper.register_pair(
net, moving_image, fixed_image, finetune_steps=90 if args.finetune else None
)

if args.transform_out:
itk.transformwrite([phi], args.transform_out)

if args.displacement_image_out:
filter = itk.TransformToDisplacementFieldFilter[
itk.itkImagePython.itkImageVF33, itk.D
].New()
decorator = itk.DataObjectDecorator[itk.Transform[itk.D, 3, 3]].New()
decorator.Set(phi)
filter.SetInput(decorator)
filter.SetReferenceImage(fixed_image)
filter.SetUseReferenceImage(True)
filter.Update()
itk.imwrite(filter.GetOutput(), args.displacement_image_out)

if args.warped_moving_out:
warped = itk.resample_image_filter(
moving_image,
use_reference_image=True,
reference_image=fixed_image,
transform=phi,
interpolator=itk.LinearInterpolateImageFunction.New(moving_image),
)
itk.imwrite(warped, args.warped_moving_out)


3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package_dir =
packages = find:
python_requires = >=3.4

scripts =
bin/icon_register_pair

install_requires =
torch
torchvision
Expand Down
23 changes: 23 additions & 0 deletions test/test_console_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
import subprocess
from icon_registration import test_utils

class TestConsoleScripts(unittest.TestCase):
def test_register_pair_script(self):
import footsteps
test_utils.download_test_data()
subprocess.run(
[
"icon_register_pair",
"--fixed_image",
test_utils.TEST_DATA_DIR / "brain_test_data" / "2_T1w_acpc_dc_restore_brain.nii.gz",
"--moving_image",
test_utils.TEST_DATA_DIR / "brain_test_data" / "8_T1w_acpc_dc_restore_brain.nii.gz",
"--model",
"icon_registration.pretrained_models.brain_registration_model"
"--warped_image_out",
footsteps.output_dir + "warped.nii.gz"
]
)
import itk
itk.imread(footsteps.output_dir + "warped.nii.gz")