Skip to content

Commit

Permalink
WIP: 4DCT
Browse files Browse the repository at this point in the history
  • Loading branch information
NicerNewerCar committed Sep 27, 2023
1 parent 713e827 commit 86717ef
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 49 deletions.
254 changes: 211 additions & 43 deletions AutoscoperM/AutoscoperM.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, parent=None):
self.logic = None
self._parameterNode = None
self._updatingGUIFromParameterNode = False
self.is_4d = False

def setup(self):
"""
Expand Down Expand Up @@ -201,6 +202,8 @@ def setup(self):

self.ui.loadPVButton.connect("clicked(bool)", self.onLoadPV)

self.ui.volumeSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.check4D)

# Default output directory
self.ui.mainOutputSelector.setCurrentPath(
os.path.join(slicer.mrmlScene.GetCacheManager().GetRemoteCacheDirectory(), "AutoscoperM-Pre-Processing")
Expand Down Expand Up @@ -259,7 +262,10 @@ def initializeParameterNode(self):
self.setParameterNode(self.logic.getParameterNode())

# Select default input nodes if nothing is selected yet to save a few clicks for the user
# NA
if self.ui.volumeSelector.currentNode() is not None:
self.check4D(self.ui.volumeSelector.currentNode())
if self.ui.mVRG_markupSelector.currentNode() is not None:
self.onMarkupNodeChanged(self.ui.mVRG_markupSelector.currentNode())

def setParameterNode(self, inputParameterNode):
"""
Expand Down Expand Up @@ -400,6 +406,8 @@ def onGeneratePartialVolumes(self):
self.logic.createPathsIfNotExists(
mainOutputDir, os.path.join(mainOutputDir, tiffSubDir), os.path.join(mainOutputDir, tfmSubDir)
)
self.logic.centerVolume(volumeNode, os.path.join(mainOutputDir, tfmSubDir), self.is_4d)

self.ui.progressBar.setValue(0)
self.ui.progressBar.setMaximum(100)
self.logic.saveSubVolumesFromSegmentation(
Expand Down Expand Up @@ -430,6 +438,7 @@ def onGenerateVRG(self):
tmpDir = self.ui.vrgTempDir.text
cameraSubDir = self.ui.cameraSubDir.text
vrgSubDir = self.ui.vrgSubDir.text
tfmSubDir = self.ui.tfmSubDir.text
self.logic.validateInputs(
volumeNode=volumeNode,
segmentationNode=segmentationNode,
Expand All @@ -443,6 +452,9 @@ def onGenerateVRG(self):
vrgSubDir=vrgSubDir,
)
self.logic.validatePaths(mainOutputDir=mainOutputDir)
self.logic.createPathsIfNotExists(os.path.join(mainOutputDir, tfmSubDir))
self.logic.centerVolume(volumeNode, os.path.join(mainOutputDir, tfmSubDir), self.is_4d)

if nPossibleCameras < nOptimizedCameras:
logging.error("Failed to generate VRG: more optimized cameras than possible cameras")
return
Expand All @@ -465,8 +477,8 @@ def onGenerateVRG(self):
cameras,
volumeImageData,
os.path.join(mainOutputDir, tmpDir),
width,
height,
[width, height],
frameNum=1,
progressCallback=self.updateProgressBar,
)

Expand Down Expand Up @@ -557,16 +569,41 @@ def onSegmentation(self):
self.ui.progressBar.setMaximum(100)

volumeNode = self.ui.volumeSelector.currentNode()
mainOutputDir = self.ui.mainOutputSelector.currentPath
tfmSubDir = self.ui.tfmSubDir.text

self.logic.validateInputs(voluemNode=volumeNode, mainOutputDir=mainOutputDir, tfmSubDir=tfmSubDir)
self.logic.validatePaths(mainOutputDir=mainOutputDir)
self.logic.createPathsIfNotExists(os.path.join(mainOutputDir, tfmSubDir))

self.logic.validateInputs(voluemNode=volumeNode)
self.logic.centerVolume(volumeNode, os.path.join(mainOutputDir, tfmSubDir), self.is_4d)

if self.ui.segGen_autoRadioButton.isChecked():
segmentationNode = SubVolumeExtraction.automaticSegmentation(
volumeNode,
self.ui.segGen_ThresholdSpinBox.value,
self.ui.segGen_marginSizeSpin.value,
progressCallback=self.updateProgressBar,
)
currentVolumeNode = volumeNode
numFrames = 1
if self.is_4d:
numFrames = volumeNode.GetNumberOfDataNodes()
browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode(volumeNode)
browserNode.SetSelectedItemNumber(0)
currentVolumeNode = browserNode.GetProxyNode(volumeNode)
segmentationSequenceNode = slicer.mrmlScene.AddNewNodeByClass(
"vtkMRMLSequenceNode", f"{volumeNode.GetName()}_Segmentation"
)
browserNode.AddSynchronizedSequenceNode(segmentationSequenceNode)
browserNode.SetOverwriteProxyName(segmentationSequenceNode, True)
browserNode.SetSaveChanges(segmentationSequenceNode, True)
for i in range(numFrames):
segmentationNode = SubVolumeExtraction.automaticSegmentation(
currentVolumeNode,
self.ui.segGen_ThresholdSpinBox.value,
self.ui.segGen_marginSizeSpin.value,
)
progress = (i + 1) / numFrames * 100
self.ui.progressBar.setValue(progress)
if self.is_4d:
segmentationSequenceNode.SetDataNodeAtValue(segmentationNode, str(i))
slicer.mrmlScene.RemoveNode(segmentationNode)
currentVolumeNode = self.logic.getItemInSequence(volumeNode, i + 1)
elif self.ui.segGen_fileRadioButton.isChecked():
segmentationFileDir = self.ui.segGen_lineEdit.currentPath
self.logic.validatePaths(segmentationFileDir=segmentationFileDir)
Expand Down Expand Up @@ -648,18 +685,42 @@ def onManualVRGGen(self):
if self.logic.vrgManualCameras is None:
self.onMarkupNodeChanged(markupsNode) # create the cameras

volumeImageData, _ = self.logic.extractSubVolumeForVRG(
volumeNode, segmentationNode, cameraDebugMode=self.ui.camDebugCheckbox.isChecked()
)
# Check if the volume is centered at the origin
bounds = [0] * 6
if self.is_4d:
# get the bounds of the first frame
volumeNode.GetNthDataNode(0).GetRASBounds(bounds)
else:
volumeNode.GetRASBounds(bounds)

center = [(bounds[0] + bounds[1]) / 2, (bounds[2] + bounds[3]) / 2, (bounds[4] + bounds[5]) / 2]
center = [round(x) for x in center]
if center != [0, 0, 0]:
logging.warning("Volume is not centered at the origin. This may cause issues with Autoscoper.")

numFrames = 1
currentNode = volumeNode
currentSegmentationNode = segmentationNode
if self.is_4d:
numFrames = volumeNode.GetNumberOfDataNodes()

for i in range(numFrames):
if self.is_4d:
currentNode = self.logic.getItemInSequence(volumeNode, i)
currentSegmentationNode = self.logic.getItemInSequence(segmentationNode, i)

volumeImageData, _ = self.logic.extractSubVolumeForVRG(
currentNode, currentSegmentationNode, cameraDebugMode=self.ui.camDebugCheckbox.isChecked()
)

self.logic.generateVRGForCameras(
self.logic.vrgManualCameras,
volumeImageData,
os.path.join(mainOutputDir, vrgDir),
width,
height,
progressCallback=self.updateProgressBar,
)
self.logic.generateVRGForCameras(
self.logic.vrgManualCameras,
volumeImageData,
os.path.join(mainOutputDir, vrgDir),
[width, height],
frameNum=i,
progressCallback=self.updateProgressBar,
)

self.updateProgressBar(100)

Expand All @@ -684,7 +745,17 @@ def onMarkupNodeChanged(self, node):
if not self.logic.validateInputs(segmentationNode=segmentationNode):
return
bounds = [0] * 6
segmentationNode.GetBounds(bounds)
if self.is_4d:
# calculate the average bounds
tmp = [0] * 6
for i in range(segmentationNode.GetNumberOfDataNodes()):
segmentationNode.GetNthDataNode(i).GetBounds(tmp)
for j in range(6):
bounds[j] += tmp[j]
for j in range(6):
bounds[j] /= segmentationNode.GetNumberOfDataNodes()
else:
segmentationNode.GetRASBounds(bounds)
self.logic.vrgManualCameras = RadiographGeneration.generateCamerasFromMarkups(
node,
bounds,
Expand All @@ -704,6 +775,9 @@ def updateViewAngle(self, value):
cam.vtkCamera.SetViewAngle(value)
RadiographGeneration._updateFrustumModel(cam)

def check4D(self, node):
self.is_4d = type(node) == slicer.vtkMRMLSequenceNode


#
# AutoscoperMLogic
Expand Down Expand Up @@ -973,28 +1047,33 @@ def extractSubVolumeForVRG(
newVolumeNode.SetName(volumeNode.GetName() + " - Bone Subvolume")

bounds = [0, 0, 0, 0, 0, 0]
newVolumeNode.GetBounds(bounds)
newVolumeNode.GetRASBounds(bounds)

# Copy the metadata from the original volume into the ImageData
newVolumeImageData = vtk.vtkImageData()
newVolumeImageData.DeepCopy(newVolumeNode.GetImageData()) # So we don't modify the original volume
newVolumeImageData.SetSpacing(newVolumeNode.GetSpacing())
origin = list(newVolumeNode.GetOrigin())
origin[0:2] = [x * -1 for x in origin[0:2]]
newVolumeImageData.SetOrigin(origin)

# Ensure we are in the correct orientation (RAS vs LPS)
imageReslice = vtk.vtkImageReslice()
imageReslice.SetInputData(newVolumeImageData)
mat = vtk.vtkMatrix4x4()
volumeNode.GetIJKToRASMatrix(mat)

if mat.GetElement(0, 0) < 0 and mat.GetElement(1, 1) < 0:
origin[0:2] = [x * -1 for x in origin[0:2]]
newVolumeImageData.SetOrigin(origin)
# Ensure we are in the correct orientation (RAS vs LPS)
imageReslice = vtk.vtkImageReslice()
imageReslice.SetInputData(newVolumeImageData)

axes = vtk.vtkMatrix4x4()
axes.Identity()
axes.SetElement(0, 0, -1)
axes.SetElement(1, 1, -1)
axes = vtk.vtkMatrix4x4()
axes.Identity()
axes.SetElement(0, 0, -1)
axes.SetElement(1, 1, -1)

imageReslice.SetResliceAxes(axes)
imageReslice.Update()
newVolumeImageData = imageReslice.GetOutput()
imageReslice.SetResliceAxes(axes)
imageReslice.Update()
newVolumeImageData = imageReslice.GetOutput()

if not cameraDebugMode:
slicer.mrmlScene.RemoveNode(newVolumeNode)
Expand All @@ -1007,8 +1086,8 @@ def generateVRGForCameras(
cameras: list[RadiographGeneration.Camera],
volumeImageData: vtk.vtkImageData,
outputDir: str,
width: int,
height: int,
imageSize: list[int],
frameNum: int = 1,
progressCallback=None,
) -> None:
"""
Expand All @@ -1020,10 +1099,10 @@ def generateVRGForCameras(
:type volumeImageData: vtk.vtkImageData
:param outputDir: output directory
:type outputDir: str
:param width: width of the radiographs
:type width: int
:param height: height of the radiographs
:type height: int
:param imageSize: image size
:type imageSize: list[int]
:param frameNum: frame number, defaults to 1
:type frameNum: int, optional
:param progressCallback: progress callback, defaults to None
:type progressCallback: callable, optional
"""
Expand Down Expand Up @@ -1056,9 +1135,9 @@ def progressCallback(x):
"cameraViewUp": [camera.GetViewUp()[0], camera.GetViewUp()[1], camera.GetViewUp()[2]],
"cameraViewAngle": camera.GetViewAngle(),
"clippingRange": [camera.GetClippingRange()[0], camera.GetClippingRange()[1]],
"width": width,
"height": height,
"outputFName": os.path.join(cameraDir, "1.tif"),
"width": imageSize[0],
"height": imageSize[1],
"outputFName": os.path.join(cameraDir, f"{frameNum}.tif"),
}
cliNode = slicer.cli.run(cliModule, None, parameters) # run asynchronously
cliNodes.append(cliNode)
Expand All @@ -1072,6 +1151,7 @@ def progressCallback(x):
errorText = cliNode.GetErrorText()
slicer.mrmlScene.RemoveNode(cliNode)
raise ValueError("CLI execution failed: " + errorText)
# get the output
slicer.mrmlScene.RemoveNode(cliNode)
progress = ((i + 1) / len(cameras)) * 30 + 10
progressCallback(progress)
Expand Down Expand Up @@ -1122,3 +1202,91 @@ def progressCallback(x):

progress = ((i + 1) / len(bestCameras)) * 10 + 90
progressCallback(progress)

def centerVolume(self, volumeNode: slicer.vtkMRMLVolumeNode, transformPath: str, is_4d: bool) -> None:
"""
A requirement for Autoscoper is that the center of the volume is at the origin.
This method will center the volume and save the transform to the transformPath
:param volumeNode: volume node
:type volumeNode: slicer.vtkMRMLVolumeNode
:param transformPath: path to save the transform to
:type transformPath: str
:param is_4d: whether or not the volume is a 4D volume
:type is_4d: bool
:return: None
"""

# Get the bounds of the volume
bounds = [0] * 6
if is_4d:
volumeNode.GetNthDataNode(0).GetRASBounds(bounds)
else:
volumeNode.GetRASBounds(bounds)

# Get the center of the volume
center = [0] * 3
for i in range(3):
center[i] = (bounds[i * 2] + bounds[i * 2 + 1]) / 2

center_rounded = [round(x) for x in center] # don't want to move the volume if its off by a small amount
if center_rounded == [0, 0, 0]:
return # Already centered

# Create a transform node
transformNode = slicer.vtkMRMLTransformNode()
transformNode.SetName("CenteringTransform")
slicer.mrmlScene.AddNode(transformNode)

# Get the transform matrix
matrix = vtk.vtkMatrix4x4()

# Move the center of the volume to the origin
matrix.SetElement(0, 3, -center[0])
matrix.SetElement(1, 3, -center[1])
matrix.SetElement(2, 3, -center[2])

# Set the transform matrix
transformNode.SetMatrixTransformToParent(matrix)

# Apply the transform to the volume
num_frames = 1
curVol = volumeNode
if is_4d:
num_frames = volumeNode.GetNumberOfDataNodes()
for i in range(num_frames):
if is_4d:
curVol = self.getItemInSequence(volumeNode, i)
curVol.SetAndObserveTransformNodeID(transformNode.GetID())

# Harden the transform
slicer.modules.transforms.logic().hardenTransform(curVol)
curVol.SetAndObserveTransformNodeID(None)

# # Invert and save the transform
matrix.Invert()
transformNode.SetMatrixTransformToParent(matrix)
slicer.util.exportNode(transformNode, os.path.join(transformPath, "Origin2DICOMCenter.tfm"))

def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) -> slicer.vtkMRMLNode:
"""
Returns the item at the specified index in the sequence node
:param sequenceNode: sequence node
:type sequenceNode: slicer.vtkMRMLSequenceNode
:param idx: index
:type idx: int
:return: item at the specified index
:rtype: slicer.vtkMRMLNode
"""
if type(sequenceNode) != slicer.vtkMRMLSequenceNode:
logging.error("[AutoscoperM.logic.getItemInSequence] sequenceNode must be a sequence node")
return None

if idx >= sequenceNode.GetNumberOfDataNodes():
logging.error(f"[AutoscoperM.logic.getItemInSequence] index {idx} is out of range")
return None

browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode(sequenceNode)
browserNode.SetSelectedItemNumber(idx)
return browserNode.GetProxyNode(sequenceNode)
Loading

0 comments on commit 86717ef

Please sign in to comment.