Skip to content

Commit

Permalink
ENH: Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
NicerNewerCar committed Mar 25, 2024
1 parent 71c20a0 commit 379e71b
Showing 1 changed file with 65 additions and 53 deletions.
118 changes: 65 additions & 53 deletions AutoscoperM/AutoscoperM.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import glob
import logging
import os
import re
import shutil
import time
import zipfile
from typing import Optional
from typing import Optional, Union

import qt
import slicer
Expand Down Expand Up @@ -222,7 +221,6 @@ def setup(self):
self.ui.segmentationButton.connect("clicked(bool)", self.onSegmentation)

self.ui.loadPVButton.connect("clicked(bool)", self.onLoadPV)
self.ui.volumeSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.logic.check4D)

# Default output directory
self.ui.mainOutputSelector.setCurrentPath(
Expand Down Expand Up @@ -280,8 +278,6 @@ def initializeParameterNode(self):
# so that when the scene is saved and reloaded, these settings are restored.

self.setParameterNode(self.logic.getParameterNode())
if self.ui.volumeSelector.currentNode() is not None:
self.logic.check4D(self.ui.volumeSelector.currentNode())
if self.ui.mVRG_markupSelector.currentNode() is not None:
self.onMarkupNodeChanged(self.ui.mVRG_markupSelector.currentNode())

Expand Down Expand Up @@ -501,7 +497,7 @@ def onGenerateVRG(self):

# center the volume
self.logic.createPathsIfNotExists(os.path.join(mainOutputDir, tfmPath))
if not self.logic.is_4d:
if not self.logic.IsSequenceVolume(volumeNode):
volumeNode.AddCenteringTransform()
tfmNode = slicer.util.getNode(f"{volumeNode.GetName()} centering transform")
volumeNode.HardenTransform()
Expand Down Expand Up @@ -534,7 +530,7 @@ def onGenerateVRG(self):
numFrames = 1
currentNode = volumeNode
curName = volumeNode.GetName()
if self.logic.is_4d:
if self.logic.IsSequenceVolume(currentNode):
numFrames = volumeNode.GetNumberOfDataNodes()
currentNode, curName = self.logic.getItemInSequence(volumeNode, 0)
bounds = [0] * 6
Expand Down Expand Up @@ -570,7 +566,9 @@ def onGenerateVRG(self):
progress = (i + 1) / numFrames * 40 + 10
self.updateProgressBar(progress)

currentNode, curName = self.logic.getNextItemInSequence(volumeNode) if self.logic.is_4d else currentNode
currentNode, curName = (
self.logic.getNextItemInSequence(volumeNode) if self.logic.IsSequenceVolume(volumeNode) else currentNode
)

# Optimize the camera positions
bestCameras = RadiographGeneration.optimizeCameras(
Expand Down Expand Up @@ -643,10 +641,7 @@ def onGenerateConfig(self):
int(self.ui.flipZ.isChecked()),
]

if self.logic.is_4d:
voxel_spacing = self.logic.getItemInSequence(volumeNode, 0).GetSpacing()
else:
voxel_spacing = volumeNode.GetSpacing()
voxel_spacing = self.logic.GetVolumeSpacing(volumeNode)
# generate the config file
configFilePath = IO.generateConfigFile(
mainOutputDir,
Expand Down Expand Up @@ -677,7 +672,7 @@ def onSegmentation(self):
if self.ui.segGen_autoRadioButton.isChecked():
currentVolumeNode = volumeNode
numFrames = 1
if self.logic.is_4d:
if self.logic.IsSequenceVolume(volumeNode):
numFrames = volumeNode.GetNumberOfDataNodes()
currentVolumeNode = self.logic.getItemInSequence(volumeNode, 0)
segmentationSequenceNode = self.logic.createSequenceNodeInBrowser(
Expand All @@ -692,7 +687,7 @@ def onSegmentation(self):
)
progress = (i + 1) / numFrames * 100
self.ui.progressBar.setValue(progress)
if self.logic.is_4d:
if self.logic.IsSequenceVolume(volumeNode):
segmentationSequenceNode.SetDataNodeAtValue(segmentationNode, str(i))
slicer.mrmlScene.RemoveNode(segmentationNode)
currentVolumeNode = self.logic.getNextItemInSequence(volumeNode)
Expand Down Expand Up @@ -784,22 +779,13 @@ def onManualVRGGen(self):
if self.logic.vrgManualCameras is None:
self.onMarkupNodeChanged(markupsNode) # create the cameras

# Check if the volume is centered at the origin
bounds = [0.0] * 6
if self.logic.is_4d:
# get the bounds of the first frame
volumeNode.GetNthDataNode(0).GetRASBounds(bounds)
else:
volumeNode.GetRASBounds(bounds)

isCentered = volumeNode.GetNthDataNode(0).IsCentered() if self.logic.is_4d else volumeNode.IsCentered()
if not isCentered:
if not self.logic.IsVolumeCentered(volumeNode):
logging.warning("Volume is not centered at the origin. This may cause issues with Autoscoper.")

numFrames = 1
currentNode = volumeNode
curName = currentNode.GetName()
if self.logic.is_4d:
if self.logic.IsSequenceVolume(currentNode):
numFrames = volumeNode.GetNumberOfDataNodes()
currentNode, curName = self.logic.getItemInSequence(volumeNode, 0)

Expand All @@ -812,7 +798,9 @@ def onManualVRGGen(self):
[width, height],
filename=filename,
)
currentNode, curName = self.logic.getNextItemInSequence(volumeNode) if self.logic.is_4d else currentNode
currentNode, curName = (
self.logic.getNextItemInSequence(volumeNode) if self.logic.IsSequenceVolume(volumeNode) else currentNode
)

self.updateProgressBar(100)

Expand All @@ -835,9 +823,7 @@ def onMarkupNodeChanged(self, node):
# get the volume nodes
volumeNode = self.ui.volumeSelector.currentNode()
self.logic.validateInputs(volumeNode=volumeNode)
volumeNode, _ = self.logic.getItemInSequence(volumeNode, 0) if self.logic.is_4d else volumeNode
bounds = [0] * 6
volumeNode.GetBounds(bounds)
bounds = self.logic.GetRASBounds(volumeNode)
self.logic.vrgManualCameras = RadiographGeneration.generateCamerasFromMarkups(
node,
bounds,
Expand Down Expand Up @@ -883,7 +869,10 @@ def __init__(self):
self.AutoscoperProcess.setProcessChannelMode(qt.QProcess.ForwardedChannels)
self.AutoscoperSocket = None
self.vrgManualCameras = None
self.is_4d = False

@staticmethod
def IsSequenceVolume(node: Union[slicer.vtkMRMLNode, None]) -> bool:
return isinstance(node, slicer.vtkMRMLSequenceNode)

def setDefaultParameters(self, parameterNode):
"""
Expand Down Expand Up @@ -1024,7 +1013,8 @@ def progressCallback(x):
slicer.app.layoutManager().resetSliceViews()
return True

def showVolumeIn3D(self, volumeNode: slicer.vtkMRMLVolumeNode):
@staticmethod
def showVolumeIn3D(volumeNode: slicer.vtkMRMLVolumeNode):
logic = slicer.modules.volumerendering.logic()
displayNode = logic.CreateVolumeRenderingDisplayNode()
displayNode.UnRegister(logic)
Expand All @@ -1033,7 +1023,8 @@ def showVolumeIn3D(self, volumeNode: slicer.vtkMRMLVolumeNode):
logic.UpdateDisplayNodeFromVolumeNode(displayNode, volumeNode)
slicer.mrmlScene.RemoveNode(slicer.util.getNode("Volume rendering ROI"))

def validateInputs(self, *args: tuple, **kwargs: dict) -> bool:
@staticmethod
def validateInputs(*args: tuple, **kwargs: dict) -> bool:
"""
Validates that the provided inputs are not None.
Expand Down Expand Up @@ -1065,7 +1056,8 @@ def validateInputs(self, *args: tuple, **kwargs: dict) -> bool:

return all(statuses)

def validatePaths(self, *args: tuple, **kwargs: dict) -> bool:
@staticmethod
def validatePaths(*args: tuple, **kwargs: dict) -> bool:
"""
Checks that the provided paths exist.
Expand All @@ -1091,7 +1083,8 @@ def validatePaths(self, *args: tuple, **kwargs: dict) -> bool:

return all(statuses)

def createPathsIfNotExists(self, *args: tuple) -> None:
@staticmethod
def createPathsIfNotExists(*args: tuple) -> None:
"""
Creates a path if it does not exist.
Expand All @@ -1101,8 +1094,8 @@ def createPathsIfNotExists(self, *args: tuple) -> None:
if not os.path.exists(arg):
os.makedirs(arg)

@staticmethod
def extractSubVolumeForVRG(
self,
volumeNode: slicer.vtkMRMLVolumeNode,
segmentationNode: slicer.vtkMRMLSegmentationNode,
cameraDebugMode: bool = False,
Expand Down Expand Up @@ -1254,7 +1247,8 @@ def progressCallback(x):
progress = ((idx + 1) / len(bestCameras)) * 10 + 90
progressCallback(progress)

def convertNodeToData(self, volumeNode: slicer.vtkMRMLVolumeNode) -> vtk.vtkImageData:
@staticmethod
def convertNodeToData(volumeNode: slicer.vtkMRMLVolumeNode) -> vtk.vtkImageData:
"""
Converts a volume node to a vtkImageData object
"""
Expand Down Expand Up @@ -1285,13 +1279,8 @@ def convertNodeToData(self, volumeNode: slicer.vtkMRMLVolumeNode) -> vtk.vtkImag

return imageData

def check4D(self, node: slicer.vtkMRMLNode) -> bool:
"""
Checks if the volume is 4D
"""
self.is_4d = type(node) == slicer.vtkMRMLSequenceNode

def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) -> slicer.vtkMRMLNode:
@staticmethod
def getItemInSequence(sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) -> slicer.vtkMRMLNode:
"""
Returns the item at the specified index in the sequence node
Expand All @@ -1302,7 +1291,7 @@ def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int)
:return: item at the specified index
:rtype: slicer.vtkMRMLNode
"""
if type(sequenceNode) != slicer.vtkMRMLSequenceNode:
if not AutoscoperMLogic.IsSequenceVolume(sequenceNode):
logging.error("[AutoscoperM.logic.getItemInSequence] sequenceNode must be a sequence node")
return None

Expand All @@ -1314,7 +1303,8 @@ def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int)
browserNode.SetSelectedItemNumber(idx)
return browserNode.GetProxyNode(sequenceNode), sequenceNode.GetNthDataNode(idx).GetName()

def getNextItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode) -> slicer.vtkMRMLNode:
@staticmethod
def getNextItemInSequence(sequenceNode: slicer.vtkMRMLSequenceNode) -> slicer.vtkMRMLNode:
"""
Returns the next item in the sequence
Expand All @@ -1323,7 +1313,7 @@ def getNextItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode) -> sli
:return: next item in the sequence
:rtype: slicer.vtkMRMLNode
"""
if type(sequenceNode) != slicer.vtkMRMLSequenceNode:
if not AutoscoperMLogic.IsSequenceVolume(sequenceNode):
logging.error("[AutoscoperM.logic.getNextItemInSequence] sequenceNode must be a sequence node")
return None

Expand All @@ -1332,17 +1322,39 @@ def getNextItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode) -> sli
idx = browserNode.GetSelectedItemNumber()
return browserNode.GetProxyNode(sequenceNode), sequenceNode.GetNthDataNode(idx).GetName()

def cleanFilename(self, volumeName: str, index: Optional[int] = None) -> str:
filename = (
re.sub(r"\s+", "_", f"{index}_{volumeName}") if index is not None else re.sub(r"\s+", "_", f"{volumeName}")
) # Remove spaces
filename = re.sub(r"[^\w]", "", filename) # Remove non alphanumeric characters
return re.sub(r"__+", "_", filename) # Remove double or more underscores
@staticmethod
def cleanFilename(volumeName: str, index: Optional[int] = None) -> str:
filename = slicer.qSlicerCoreIOManager().forceFileNameValidCharacters(volumeName)
return f"{index}_{filename}" if index is not None else filename

@staticmethod
def createSequenceNodeInBrowser(nodename, sequenceNode):
if not AutoscoperMLogic.IsSequenceVolume(sequenceNode):
logging.error("[AutoscoperMLogic.createSequenceNodeInBrowser] sequenceNode must be a sequence node")
return None

def createSequenceNodeInBrowser(self, nodename, sequenceNode):
browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode(sequenceNode)
newSeqenceNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSequenceNode", nodename)
browserNode.AddSynchronizedSequenceNode(newSeqenceNode)
browserNode.SetOverwriteProxyName(newSeqenceNode, True)
browserNode.SetSaveChanges(newSeqenceNode, True)
return newSeqenceNode

@staticmethod
def GetVolumeSpacing(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenceNode]) -> list[float]:
if AutoscoperMLogic.IsSequenceVolume(node):
return AutoscoperMLogic.getItemInSequence(node, 0)[0].GetSpacing()
return node.GetSpacing()

@staticmethod
def GetRASBounds(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenceNode]) -> list[float]:
bounds = [0] * 6
if AutoscoperMLogic.IsSequenceVolume(node):
return AutoscoperMLogic.getItemInSequence(node, 0)[0].GetRASBounds(bounds)
return node.GetRASBounds(bounds)

@staticmethod
def IsVolumeCentered(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenceNode]) -> bool:
if AutoscoperMLogic.IsSequenceVolume(node):
return AutoscoperMLogic.getItemInSequence(node, 0)[0].IsCentered()
return node.IsCentered()

0 comments on commit 379e71b

Please sign in to comment.