diff --git a/AutoscoperM/AutoscoperM.py b/AutoscoperM/AutoscoperM.py index 1a3623f..3e7ad53 100644 --- a/AutoscoperM/AutoscoperM.py +++ b/AutoscoperM/AutoscoperM.py @@ -2,11 +2,10 @@ import glob import logging import os -import re import shutil import time import zipfile -from typing import Optional +from typing import Optional, Union import qt import slicer @@ -222,7 +221,6 @@ def setup(self): self.ui.segmentationButton.connect("clicked(bool)", self.onSegmentation) self.ui.loadPVButton.connect("clicked(bool)", self.onLoadPV) - self.ui.volumeSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.logic.check4D) # Default output directory self.ui.mainOutputSelector.setCurrentPath( @@ -280,8 +278,6 @@ def initializeParameterNode(self): # so that when the scene is saved and reloaded, these settings are restored. self.setParameterNode(self.logic.getParameterNode()) - if self.ui.volumeSelector.currentNode() is not None: - self.logic.check4D(self.ui.volumeSelector.currentNode()) if self.ui.mVRG_markupSelector.currentNode() is not None: self.onMarkupNodeChanged(self.ui.mVRG_markupSelector.currentNode()) @@ -501,7 +497,7 @@ def onGenerateVRG(self): # center the volume self.logic.createPathsIfNotExists(os.path.join(mainOutputDir, tfmPath)) - if not self.logic.is_4d: + if not self.logic.IsSequenceVolume(volumeNode): volumeNode.AddCenteringTransform() tfmNode = slicer.util.getNode(f"{volumeNode.GetName()} centering transform") volumeNode.HardenTransform() @@ -534,7 +530,7 @@ def onGenerateVRG(self): numFrames = 1 currentNode = volumeNode curName = volumeNode.GetName() - if self.logic.is_4d: + if self.logic.IsSequenceVolume(currentNode): numFrames = volumeNode.GetNumberOfDataNodes() currentNode, curName = self.logic.getItemInSequence(volumeNode, 0) bounds = [0] * 6 @@ -570,7 +566,9 @@ def onGenerateVRG(self): progress = (i + 1) / numFrames * 40 + 10 self.updateProgressBar(progress) - currentNode, curName = self.logic.getNextItemInSequence(volumeNode) if self.logic.is_4d else currentNode + currentNode, curName = ( + self.logic.getNextItemInSequence(volumeNode) if self.logic.IsSequenceVolume(volumeNode) else currentNode + ) # Optimize the camera positions bestCameras = RadiographGeneration.optimizeCameras( @@ -643,10 +641,7 @@ def onGenerateConfig(self): int(self.ui.flipZ.isChecked()), ] - if self.logic.is_4d: - voxel_spacing = self.logic.getItemInSequence(volumeNode, 0).GetSpacing() - else: - voxel_spacing = volumeNode.GetSpacing() + voxel_spacing = self.logic.GetVolumeSpacing(volumeNode) # generate the config file configFilePath = IO.generateConfigFile( mainOutputDir, @@ -677,7 +672,7 @@ def onSegmentation(self): if self.ui.segGen_autoRadioButton.isChecked(): currentVolumeNode = volumeNode numFrames = 1 - if self.logic.is_4d: + if self.logic.IsSequenceVolume(volumeNode): numFrames = volumeNode.GetNumberOfDataNodes() currentVolumeNode = self.logic.getItemInSequence(volumeNode, 0) segmentationSequenceNode = self.logic.createSequenceNodeInBrowser( @@ -692,7 +687,7 @@ def onSegmentation(self): ) progress = (i + 1) / numFrames * 100 self.ui.progressBar.setValue(progress) - if self.logic.is_4d: + if self.logic.IsSequenceVolume(volumeNode): segmentationSequenceNode.SetDataNodeAtValue(segmentationNode, str(i)) slicer.mrmlScene.RemoveNode(segmentationNode) currentVolumeNode = self.logic.getNextItemInSequence(volumeNode) @@ -793,22 +788,13 @@ def onManualVRGGen(self): if self.logic.vrgManualCameras is None: self.onMarkupNodeChanged(markupsNode) # create the cameras - # Check if the volume is centered at the origin - bounds = [0.0] * 6 - if self.logic.is_4d: - # get the bounds of the first frame - volumeNode.GetNthDataNode(0).GetRASBounds(bounds) - else: - volumeNode.GetRASBounds(bounds) - - isCentered = volumeNode.GetNthDataNode(0).IsCentered() if self.logic.is_4d else volumeNode.IsCentered() - if not isCentered: + if not self.logic.IsVolumeCentered(volumeNode): logging.warning("Volume is not centered at the origin. This may cause issues with Autoscoper.") numFrames = 1 currentNode = volumeNode curName = currentNode.GetName() - if self.logic.is_4d: + if self.logic.IsSequenceVolume(currentNode): numFrames = volumeNode.GetNumberOfDataNodes() currentNode, curName = self.logic.getItemInSequence(volumeNode, 0) @@ -821,7 +807,9 @@ def onManualVRGGen(self): [width, height], filename=filename, ) - currentNode, curName = self.logic.getNextItemInSequence(volumeNode) if self.logic.is_4d else currentNode + currentNode, curName = ( + self.logic.getNextItemInSequence(volumeNode) if self.logic.IsSequenceVolume(volumeNode) else currentNode + ) self.updateProgressBar(100) @@ -844,9 +832,7 @@ def onMarkupNodeChanged(self, node): # get the volume nodes volumeNode = self.ui.volumeSelector.currentNode() self.logic.validateInputs(volumeNode=volumeNode) - volumeNode, _ = self.logic.getItemInSequence(volumeNode, 0) if self.logic.is_4d else volumeNode - bounds = [0] * 6 - volumeNode.GetBounds(bounds) + bounds = self.logic.GetRASBounds(volumeNode) self.logic.vrgManualCameras = RadiographGeneration.generateCamerasFromMarkups( node, bounds, @@ -892,7 +878,10 @@ def __init__(self): self.AutoscoperProcess.setProcessChannelMode(qt.QProcess.ForwardedChannels) self.AutoscoperSocket = None self.vrgManualCameras = None - self.is_4d = False + + @staticmethod + def IsSequenceVolume(node: Union[slicer.vtkMRMLNode, None]) -> bool: + return isinstance(node, slicer.vtkMRMLSequenceNode) def setDefaultParameters(self, parameterNode): """ @@ -1042,7 +1031,8 @@ def progressCallback(x): slicer.app.layoutManager().resetSliceViews() return True - def showVolumeIn3D(self, volumeNode: slicer.vtkMRMLVolumeNode): + @staticmethod + def showVolumeIn3D(volumeNode: slicer.vtkMRMLVolumeNode): logic = slicer.modules.volumerendering.logic() displayNode = logic.CreateVolumeRenderingDisplayNode() displayNode.UnRegister(logic) @@ -1051,7 +1041,8 @@ def showVolumeIn3D(self, volumeNode: slicer.vtkMRMLVolumeNode): logic.UpdateDisplayNodeFromVolumeNode(displayNode, volumeNode) slicer.mrmlScene.RemoveNode(slicer.util.getNode("Volume rendering ROI")) - def validateInputs(self, *args: tuple, **kwargs: dict) -> bool: + @staticmethod + def validateInputs(*args: tuple, **kwargs: dict) -> bool: """ Validates that the provided inputs are not None. @@ -1083,7 +1074,8 @@ def validateInputs(self, *args: tuple, **kwargs: dict) -> bool: return all(statuses) - def validatePaths(self, *args: tuple, **kwargs: dict) -> bool: + @staticmethod + def validatePaths(*args: tuple, **kwargs: dict) -> bool: """ Checks that the provided paths exist. @@ -1109,7 +1101,8 @@ def validatePaths(self, *args: tuple, **kwargs: dict) -> bool: return all(statuses) - def createPathsIfNotExists(self, *args: tuple) -> None: + @staticmethod + def createPathsIfNotExists(*args: tuple) -> None: """ Creates a path if it does not exist. @@ -1119,8 +1112,8 @@ def createPathsIfNotExists(self, *args: tuple) -> None: if not os.path.exists(arg): os.makedirs(arg) + @staticmethod def extractSubVolumeForVRG( - self, volumeNode: slicer.vtkMRMLVolumeNode, segmentationNode: slicer.vtkMRMLSegmentationNode, cameraDebugMode: bool = False, @@ -1268,7 +1261,8 @@ def progressCallback(x): progress = ((idx + 1) / len(bestCameras)) * 10 + 90 progressCallback(progress) - def convertNodeToData(self, volumeNode: slicer.vtkMRMLVolumeNode) -> vtk.vtkImageData: + @staticmethod + def convertNodeToData(volumeNode: slicer.vtkMRMLVolumeNode) -> vtk.vtkImageData: """ Converts a volume node to a vtkImageData object """ @@ -1299,13 +1293,8 @@ def convertNodeToData(self, volumeNode: slicer.vtkMRMLVolumeNode) -> vtk.vtkImag return imageData - def check4D(self, node: slicer.vtkMRMLNode) -> bool: - """ - Checks if the volume is 4D - """ - self.is_4d = type(node) == slicer.vtkMRMLSequenceNode - - def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) -> slicer.vtkMRMLNode: + @staticmethod + def getItemInSequence(sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) -> slicer.vtkMRMLNode: """ Returns the item at the specified index in the sequence node @@ -1314,7 +1303,7 @@ def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) :return: item at the specified index """ - if type(sequenceNode) != slicer.vtkMRMLSequenceNode: + if not AutoscoperMLogic.IsSequenceVolume(sequenceNode): logging.error("[AutoscoperM.logic.getItemInSequence] sequenceNode must be a sequence node") return None @@ -1326,7 +1315,8 @@ def getItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode, idx: int) browserNode.SetSelectedItemNumber(idx) return browserNode.GetProxyNode(sequenceNode), sequenceNode.GetNthDataNode(idx).GetName() - def getNextItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode) -> slicer.vtkMRMLNode: + @staticmethod + def getNextItemInSequence(sequenceNode: slicer.vtkMRMLSequenceNode) -> slicer.vtkMRMLNode: """ Returns the next item in the sequence @@ -1334,7 +1324,7 @@ def getNextItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode) -> sli :return: next item in the sequence """ - if type(sequenceNode) != slicer.vtkMRMLSequenceNode: + if not AutoscoperMLogic.IsSequenceVolume(sequenceNode): logging.error("[AutoscoperM.logic.getNextItemInSequence] sequenceNode must be a sequence node") return None @@ -1343,17 +1333,39 @@ def getNextItemInSequence(self, sequenceNode: slicer.vtkMRMLSequenceNode) -> sli idx = browserNode.GetSelectedItemNumber() return browserNode.GetProxyNode(sequenceNode), sequenceNode.GetNthDataNode(idx).GetName() - def cleanFilename(self, volumeName: str, index: Optional[int] = None) -> str: - filename = ( - re.sub(r"\s+", "_", f"{index}_{volumeName}") if index is not None else re.sub(r"\s+", "_", f"{volumeName}") - ) # Remove spaces - filename = re.sub(r"[^\w]", "", filename) # Remove non alphanumeric characters - return re.sub(r"__+", "_", filename) # Remove double or more underscores + @staticmethod + def cleanFilename(volumeName: str, index: Optional[int] = None) -> str: + filename = slicer.qSlicerCoreIOManager().forceFileNameValidCharacters(volumeName) + return f"{index}_{filename}" if index is not None else filename + + @staticmethod + def createSequenceNodeInBrowser(nodename, sequenceNode): + if not AutoscoperMLogic.IsSequenceVolume(sequenceNode): + logging.error("[AutoscoperMLogic.createSequenceNodeInBrowser] sequenceNode must be a sequence node") + return None - def createSequenceNodeInBrowser(self, nodename, sequenceNode): browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode(sequenceNode) newSeqenceNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSequenceNode", nodename) browserNode.AddSynchronizedSequenceNode(newSeqenceNode) browserNode.SetOverwriteProxyName(newSeqenceNode, True) browserNode.SetSaveChanges(newSeqenceNode, True) return newSeqenceNode + + @staticmethod + def GetVolumeSpacing(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenceNode]) -> list[float]: + if AutoscoperMLogic.IsSequenceVolume(node): + return AutoscoperMLogic.getItemInSequence(node, 0)[0].GetSpacing() + return node.GetSpacing() + + @staticmethod + def GetRASBounds(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenceNode]) -> list[float]: + bounds = [0] * 6 + if AutoscoperMLogic.IsSequenceVolume(node): + return AutoscoperMLogic.getItemInSequence(node, 0)[0].GetRASBounds(bounds) + return node.GetRASBounds(bounds) + + @staticmethod + def IsVolumeCentered(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenceNode]) -> bool: + if AutoscoperMLogic.IsSequenceVolume(node): + return AutoscoperMLogic.getItemInSequence(node, 0)[0].IsCentered() + return node.IsCentered()