diff --git a/bin/icon_register_pair b/bin/icon_register_pair new file mode 100644 index 0000000..7ffc5fb --- /dev/null +++ b/bin/icon_register_pair @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +import argparse + +import icon_registration as icon +import icon_registration.itk_wrapper +import itk +import numpy as np +import torch +import importlib + +parser = argparse.ArgumentParser() +parser.add_argument("--model", required=True) +parser.add_argument("--weights_path") +parser.add_argument("--fixed_image", required=True) +parser.add_argument("--moving_image", required=True) +parser.add_argument("--transform_out") +parser.add_argument("--displacement_image_out") +parser.add_argument("--warped_moving_out") +parser.add_argument("--finetune", action='store_true') + +args = parser.parse_args() + +model_package_name = ".".join(args.model.split(".")[:-1]) +model_name = args.model.split(".")[-1] + +model_package = importlib.import_module(model_package_name) +net = getattr(model_package, model_name)() + +if args.weights_path: + weights = torch.load(args.weights_path) + net.regis_net.load_state_dict(weights) + +fixed_image = itk.imread(args.fixed_image) +moving_image = itk.imread(args.moving_image) + +# We want images normalized to 0.0, 1.0 +for image in fixed_image, moving_image: + assert(np.abs(np.max(np.array(image)) - 1.0) < 0.2) + assert(np.abs(np.min(np.array(image))) < 0.2) + +phi, _ = icon.itk_wrapper.register_pair( + net, moving_image, fixed_image, finetune_steps=90 if args.finetune else None +) + +if args.transform_out: + itk.transformwrite([phi], args.transform_out) + +if args.displacement_image_out: + filter = itk.TransformToDisplacementFieldFilter[ + itk.itkImagePython.itkImageVF33, itk.D + ].New() + decorator = itk.DataObjectDecorator[itk.Transform[itk.D, 3, 3]].New() + decorator.Set(phi) + filter.SetInput(decorator) + filter.SetReferenceImage(fixed_image) + filter.SetUseReferenceImage(True) + filter.Update() + itk.imwrite(filter.GetOutput(), args.displacement_image_out) + +if args.warped_moving_out: + warped = itk.resample_image_filter( + moving_image, + use_reference_image=True, + reference_image=fixed_image, + transform=phi, + interpolator=itk.LinearInterpolateImageFunction.New(moving_image), + ) + itk.imwrite(warped, args.warped_moving_out) + + diff --git a/setup.cfg b/setup.cfg index e3c360d..4e2c564 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,9 @@ package_dir = packages = find: python_requires = >=3.4 +scripts = + bin/icon_register_pair + install_requires = torch torchvision diff --git a/test/test_console_scripts.py b/test/test_console_scripts.py new file mode 100644 index 0000000..63cfa4e --- /dev/null +++ b/test/test_console_scripts.py @@ -0,0 +1,23 @@ +import unittest +import subprocess +from icon_registration import test_utils + +class TestConsoleScripts(unittest.TestCase): + def test_register_pair_script(self): + import footsteps + test_utils.download_test_data() + subprocess.run( + [ + "icon_register_pair", + "--fixed_image", + test_utils.TEST_DATA_DIR / "brain_test_data" / "2_T1w_acpc_dc_restore_brain.nii.gz", + "--moving_image", + test_utils.TEST_DATA_DIR / "brain_test_data" / "8_T1w_acpc_dc_restore_brain.nii.gz", + "--model", + "icon_registration.pretrained_models.brain_registration_model" + "--warped_image_out", + footsteps.output_dir + "warped.nii.gz" + ] + ) + import itk + itk.imread(footsteps.output_dir + "warped.nii.gz")