diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..3cfacfb --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1 @@ +sphinx-tabs diff --git a/docs/source/conf.py b/docs/source/conf.py index 0065ed5..0d916d8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -38,6 +38,7 @@ "sphinx.ext.linkcode", "sphinx.ext.mathjax", "matplotlib.sphinxext.plot_directive", + "sphinx_tabs.tabs", ] plot_html_show_source_link = False diff --git a/docs/source/medical_training.rst b/docs/source/medical_training.rst index 55e8df9..4d0d3de 100644 --- a/docs/source/medical_training.rst +++ b/docs/source/medical_training.rst @@ -39,69 +39,79 @@ For this tutorial we will use the LUMIR dataset and evaluation provided by Learn LUMIR_L2R24_TrainVal.zip imagesTr imagesVal Selecting a Model -================ +================= This tutorial can be used to train the architectures GradICON or Inverse Consistency by Construction, or to finetune uniGradICON. +Create model.py as follows: + .. tabs:: .. code-tab:: python GradICON + # model.py + import icon_registration as icon input_shape = [1, 1, 96, 112, 80] - inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2)) - - for _ in range(3): - inner_net = icon.TwoStepRegistration( - icon.DownsampleRegistration(inner_net, dimension=2), - icon.FunctionFromVectorField(networks.tallUNet2(dimension=2)) - ) - - net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=.5) + def make_network(): + inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2)) + + for _ in range(3): + inner_net = icon.TwoStepRegistration( + icon.DownsampleRegistration(inner_net, dimension=2), + icon.FunctionFromVectorField(networks.tallUNet2(dimension=2)) + ) + + net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=.5) + net.assign_identity_map(input_shape) + return net .. code-tab:: python ConstrICON - input_shape = [1, 1, 96, 112, 80] + # model.py - def make_network(): + import icon_registration.constricon as constricon - import icon_registration.constricon as constricon + input_shape = [1, 1, 96, 112, 80] - net = multiscale_constr_model.FirstTransform( - multiscale_constr_model.TwoStepInverseConsistent( - multiscale_constr_model.ConsistentFromMatrix( + def make_network(): + net = constricon.FirstTransform( + constricon.TwoStepInverseConsistent( + constricon.ConsistentFromMatrix( networks.ConvolutionalMatrixNet(dimension=3) ), - multiscale_constr_model.TwoStepInverseConsistent( - multiscale_constr_model.ConsistentFromMatrix( + constricon.TwoStepInverseConsistent( + constricon.ConsistentFromMatrix( networks.ConvolutionalMatrixNet(dimension=3) ), - multiscale_constr_model.TwoStepInverseConsistent( - multiscale_constr_model.ICONSquaringVelocityField( + constricon.TwoStepInverseConsistent( + constricon.ICONSquaringVelocityField( networks.tallUNet2(dimension=3) ), - multiscale_constr_model.ICONSquaringVelocityField( + constricon.ICONSquaringVelocityField( networks.tallUNet2(dimension=3) ), ), ), + ) ) - ) - + net = constricon.VelocityFieldDiffusion(net, icon.LNCC(5), lmbda) + net.assign_identity_map(input_shape) + return net - loss = multiscale_constr_model.VelocityFieldDiffusion(net, icon.LNCC(5), lmbda) - return loss .. code-tab:: python uniGradICON + # model.py + import unigradicon input_shape = [1, 1, 175, 175, 175] def make_network(): - return unigradicon.get_unigradicon() # Initialize unified GradICON model with pretrained wieghts + return unigradicon.get_unigradicon() Preprocessing the Dataset @@ -120,6 +130,9 @@ the same resolution if they were heterogeneous resolutions or downsampling if th import tqdm import numpy as np import glob + + from model import input_shape + footsteps.initialize() image_paths = glob.glob("imagesTr/LUMIRMRI_*_*.nii.gz") # @@ -129,7 +142,8 @@ the same resolution if they were heterogeneous resolutions or downsampling if th def process(image): image = image[None, None] # add batch and channel dimensions - image = torch.nn.functional.avg_pool3d(image, 2) # comment this line to train at full resolution + #image = torch.nn.functional.avg_pool3d(image, 2) + image = F.interpolate(image, input_shape, mode="trilinear") return image @@ -167,21 +181,9 @@ Once the data is preprocessed, we train a network to register it. In this exampl import icon_registration.networks as networks import torch + from model import input_shape, make_network - input_shape = [1, 1, 96, 112, 80] - - def make_network(): - inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=3)) - - for _ in range(2): - inner_net = icon.TwoStepRegistration( - icon.DownsampleRegistration(inner_net, dimension=3), - icon.FunctionFromVectorField(networks.tallUNet2(dimension=3)) - ) - net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=1.5) - net.assign_identity_map(input_shape) - return net We define a custom function for creating and preparing batches of images. Feel free to do this with a torch :class:`torch.Dataset`, but I am more confident about predicting the performance of procedural code for this task. @@ -248,12 +250,12 @@ What we have now is a trained model that operates at resolution [96, 112, 80] wh import argparse import itk - import train + import model import icon_registration.register_pair import icon_registration.config def get_model(): - net = train.make_network() + net = model.make_network() # modify weights_location based on the training run you want to use weights_location = "results/train_halfres/network_weights_49800" trained_weights = torch.load(weights_location, map_location=torch.device("cpu")) @@ -315,3 +317,10 @@ What we have now is a trained model that operates at resolution [96, 112, 80] wh ) itk.imwrite(warped_moving_image, args.warped_moving_out) +Now, we are able to register images. + +.. code-block:: bash + + python register_pair.py --fixed fixed.nrrd --moving moving.nrrd --transform_out transform.hdf5 --warped_moving_out warped.nrrd + +The warped image warped.nrrd and transform transform.hdf5 can be viewed and further used (e.g. to warp a segmentation) using medical imaging software such as 3-D Slicer. (https://www.slicer.org/)