Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
HastingsGreer committed Oct 28, 2024
1 parent 0aa8943 commit 2a5de26
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sphinx-tabs
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"sphinx.ext.linkcode",
"sphinx.ext.mathjax",
"matplotlib.sphinxext.plot_directive",
"sphinx_tabs.tabs",
]

plot_html_show_source_link = False
Expand Down
93 changes: 51 additions & 42 deletions docs/source/medical_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,69 +39,79 @@ For this tutorial we will use the LUMIR dataset and evaluation provided by Learn
LUMIR_L2R24_TrainVal.zip imagesTr imagesVal
Selecting a Model
================
=================

This tutorial can be used to train the architectures GradICON or Inverse Consistency by Construction, or to finetune uniGradICON.

Create model.py as follows:

.. tabs::

.. code-tab:: python GradICON

# model.py

import icon_registration as icon

input_shape = [1, 1, 96, 112, 80]

inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))

for _ in range(3):
inner_net = icon.TwoStepRegistration(
icon.DownsampleRegistration(inner_net, dimension=2),
icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
)

net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=.5)
def make_network():
inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))

for _ in range(3):
inner_net = icon.TwoStepRegistration(
icon.DownsampleRegistration(inner_net, dimension=2),
icon.FunctionFromVectorField(networks.tallUNet2(dimension=2))
)

net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=.5)
net.assign_identity_map(input_shape)
return net

.. code-tab:: python ConstrICON

input_shape = [1, 1, 96, 112, 80]
# model.py

def make_network():
import icon_registration.constricon as constricon

import icon_registration.constricon as constricon
input_shape = [1, 1, 96, 112, 80]

net = multiscale_constr_model.FirstTransform(
multiscale_constr_model.TwoStepInverseConsistent(
multiscale_constr_model.ConsistentFromMatrix(
def make_network():
net = constricon.FirstTransform(
constricon.TwoStepInverseConsistent(
constricon.ConsistentFromMatrix(
networks.ConvolutionalMatrixNet(dimension=3)
),
multiscale_constr_model.TwoStepInverseConsistent(
multiscale_constr_model.ConsistentFromMatrix(
constricon.TwoStepInverseConsistent(
constricon.ConsistentFromMatrix(
networks.ConvolutionalMatrixNet(dimension=3)
),
multiscale_constr_model.TwoStepInverseConsistent(
multiscale_constr_model.ICONSquaringVelocityField(
constricon.TwoStepInverseConsistent(
constricon.ICONSquaringVelocityField(
networks.tallUNet2(dimension=3)
),
multiscale_constr_model.ICONSquaringVelocityField(
constricon.ICONSquaringVelocityField(
networks.tallUNet2(dimension=3)
),
),
),
)
)
)

net = constricon.VelocityFieldDiffusion(net, icon.LNCC(5), lmbda)
net.assign_identity_map(input_shape)
return net

loss = multiscale_constr_model.VelocityFieldDiffusion(net, icon.LNCC(5), lmbda)
return loss
.. code-tab:: python uniGradICON

# model.py

import unigradicon

input_shape = [1, 1, 175, 175, 175]

def make_network():

return unigradicon.get_unigradicon() # Initialize unified GradICON model with pretrained wieghts
return unigradicon.get_unigradicon()


Preprocessing the Dataset
Expand All @@ -120,6 +130,9 @@ the same resolution if they were heterogeneous resolutions or downsampling if th
import tqdm
import numpy as np
import glob
from model import input_shape
footsteps.initialize()
image_paths = glob.glob("imagesTr/LUMIRMRI_*_*.nii.gz") #
Expand All @@ -129,7 +142,8 @@ the same resolution if they were heterogeneous resolutions or downsampling if th
def process(image):
image = image[None, None] # add batch and channel dimensions
image = torch.nn.functional.avg_pool3d(image, 2) # comment this line to train at full resolution
#image = torch.nn.functional.avg_pool3d(image, 2)
image = F.interpolate(image, input_shape, mode="trilinear")
return image
Expand Down Expand Up @@ -167,21 +181,9 @@ Once the data is preprocessed, we train a network to register it. In this exampl
import icon_registration.networks as networks
import torch
from model import input_shape, make_network
input_shape = [1, 1, 96, 112, 80]
def make_network():
inner_net = icon.FunctionFromVectorField(networks.tallUNet2(dimension=3))
for _ in range(2):
inner_net = icon.TwoStepRegistration(
icon.DownsampleRegistration(inner_net, dimension=3),
icon.FunctionFromVectorField(networks.tallUNet2(dimension=3))
)
net = icon.GradientICON(inner_net, icon.LNCC(sigma=4), lmbda=1.5)
net.assign_identity_map(input_shape)
return net
We define a custom function for creating and preparing batches of images. Feel free to do this with a torch :class:`torch.Dataset`, but I am more confident about predicting the performance of procedural code for this task.

Expand Down Expand Up @@ -248,12 +250,12 @@ What we have now is a trained model that operates at resolution [96, 112, 80] wh
import argparse
import itk
import train
import model
import icon_registration.register_pair
import icon_registration.config
def get_model():
net = train.make_network()
net = model.make_network()
# modify weights_location based on the training run you want to use
weights_location = "results/train_halfres/network_weights_49800"
trained_weights = torch.load(weights_location, map_location=torch.device("cpu"))
Expand Down Expand Up @@ -315,3 +317,10 @@ What we have now is a trained model that operates at resolution [96, 112, 80] wh
)
itk.imwrite(warped_moving_image, args.warped_moving_out)
Now, we are able to register images.

.. code-block:: bash
python register_pair.py --fixed fixed.nrrd --moving moving.nrrd --transform_out transform.hdf5 --warped_moving_out warped.nrrd
The warped image warped.nrrd and transform transform.hdf5 can be viewed and further used (e.g. to warp a segmentation) using medical imaging software such as 3-D Slicer. (https://www.slicer.org/)

0 comments on commit 2a5de26

Please sign in to comment.