diff --git a/AutoscoperM/AutoscoperM.py b/AutoscoperM/AutoscoperM.py index 3848ad9..b23e8c7 100644 --- a/AutoscoperM/AutoscoperM.py +++ b/AutoscoperM/AutoscoperM.py @@ -9,6 +9,7 @@ import qt import slicer +import vtkAddon import vtk from slicer.ScriptedLoadableModule import ( ScriptedLoadableModule, @@ -1061,25 +1062,28 @@ def progressCallback(x): IO.writeTFMFile(filename, spacing, origin) self.showVolumeIn3D(segmentVolume) - bounds = [0] * 6 - segmentVolume.GetRASBounds(bounds) - segmentVolumeSize = [abs(bounds[i + 1] - bounds[i]) for i in range(0, len(bounds), 2)] - # Write TRA tfm = vtk.vtkMatrix4x4() tfm.SetElement(0, 3, origin[0]) tfm.SetElement(1, 3, origin[1]) tfm.SetElement(2, 3, origin[2]) - IO.createTRAFile( - volName=segmentName, - trialName=None, - outputDir=outputDir, - trackingsubDir=trackingSubDir, - volSize=segmentVolumeSize, - Origin2DicomTransformFile=origin2DicomTransformFile, - transform=tfm, + if origin2DicomTransformFile is not None: + origin2DicomNode = self.loadTransformFromFile(origin2DicomTransformFile) + tfm = self.applyOrigin2DicomTransform(tfm, origin2DicomNode, ApplyDicom2Origin=True) + slicer2autoscoperNode = self.createAndAddSlicer2AutoscoperTransformNode(segmentVolume) + slicer2autoscoperFilename = os.path.join( + outputDir, transformSubDir, f"{segmentVolume.GetName()}-Slicer2AUT.tfm" ) + slicer.util.saveNode(slicer2autoscoperNode, slicer2autoscoperFilename) + tfm = self.applySlicer2AutoscoperTransform(tfm, slicer2autoscoperNode, ApplyAutoscoper2Slicer=False) + traDir = os.path.join(outputDir, trackingSubDir) + if not os.path.exists(traDir): + os.mkdir(traDir) + filename = os.path.join(traDir, f"{segmentName}.tra") + IO.writeTRA(filename, [tfm]) + + slicer.mrmlScene.RemoveNode(slicer2autoscoperNode) # update progress bar progressCallback((idx + 1) / numSegments * 100) @@ -1427,3 +1431,75 @@ def IsVolumeCentered(node: Union[slicer.vtkMRMLVolumeNode, slicer.vtkMRMLSequenc if AutoscoperMLogic.IsSequenceVolume(node): return AutoscoperMLogic.getItemInSequence(node, 0)[0].IsCentered() return node.IsCentered() + + @staticmethod + def loadTransformFromFile(transformFileName: str) -> slicer.vtkMRMLLinearTransformNode: + return slicer.util.loadNodeFromFile(transformFileName) + + @staticmethod + def applyOrigin2DicomTransform( + transform: vtk.vtkMatrix4x4, + origin2DicomTransformNode: slicer.vtkMRMLLinearTransformNode, + ApplyDicom2Origin: bool = False, + ) -> vtk.vtkMatrix4x4: + if ApplyDicom2Origin: + origin2DicomTransformNode.Inverse() + + origin2DicomTransformMatrix = vtk.vtkMatrix4x4() + origin2DicomTransformNode.GetMatrixTransformToParent(origin2DicomTransformMatrix) + + vtk.vtkMatrix4x4.Multiply4x4(origin2DicomTransformMatrix, transform, transform) + + slicer.mrmlScene.RemoveNode(origin2DicomTransformNode) + return transform + + @staticmethod + def applySlicer2AutoscoperTransform( + transform: vtk.vtkMatrix4x4, + slicer2AutoscoperNode: slicer.vtkMRMLLinearTransformNode, + ApplyAutoscoper2Slicer: bool = False, + ) -> vtk.vtkMatrix4x4: + """Utility function for converting a transform between the Slicer and Autoscoper coordinate systems.""" + from itertools import product + + if ApplyAutoscoper2Slicer: + slicer2AutoscoperNode.Inverse() + + slicer2Autoscoper = vtk.vtkMatrix4x4() + slicer2AutoscoperNode.GetMatrixTransformToParent(slicer2Autoscoper) + + # Extract the rotation matrices so we are not affecting the translation vector + transformR = vtk.vtkMatrix3x3() + slicer2AutoscoperR = vtk.vtkMatrix3x3() + vtkAddon.vtkAddonMathUtilities.GetOrientationMatrix(transform, transformR) + vtkAddon.vtkAddonMathUtilities.GetOrientationMatrix(slicer2Autoscoper, slicer2AutoscoperR) + + vtk.vtkMatrix3x3.Multiply3x3(slicer2AutoscoperR, transformR, transformR) + + vtkAddon.vtkAddonMathUtilities.SetOrientationMatrix(transformR, transform) + + # Apply the translation vector + for i in range(3): + transform.SetElement(i, 3, transform.GetElement(i, 3) + slicer2Autoscoper.GetElement(i, 3)) + + return transform + + @staticmethod + def createAndAddSlicer2AutoscoperTransformNode(volumeNode: slicer.vtkMRMLVolumeNode) -> slicer.vtkMRMLLinearTransformNode: + # Slicer 2 Autoscoper Transform + # https://github.com/BrownBiomechanics/Autoscoper/issues/280 + bounds = [0] * 6 + volumeNode.GetRASBounds(bounds) + volSize = [abs(bounds[i + 1] - bounds[i]) for i in range(0, len(bounds), 2)] + + slicer2autoscoper = vtk.vtkMatrix4x4() + slicer2autoscoper.Identity() + # Rotation matrix for a 180 x-axis rotation + slicer2autoscoper.SetElement(1, 1, -slicer2autoscoper.GetElement(1, 1)) + slicer2autoscoper.SetElement(1, 2, -slicer2autoscoper.GetElement(1, 2)) + slicer2autoscoper.SetElement(2, 2, -slicer2autoscoper.GetElement(2, 2)) + slicer2autoscoper.SetElement(0, 3, -volSize[0]) # Offset -X + + slicer2autoscoperNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLinearTransformNode") + slicer2autoscoperNode.SetMatrixTransformToParent(slicer2autoscoper) + return slicer2autoscoperNode diff --git a/AutoscoperM/AutoscoperMLib/IO.py b/AutoscoperM/AutoscoperMLib/IO.py index 0d94c8c..b8ccd52 100644 --- a/AutoscoperM/AutoscoperMLib/IO.py +++ b/AutoscoperM/AutoscoperMLib/IO.py @@ -1,6 +1,7 @@ import glob import logging import os +from itertools import product import numpy as np import slicer @@ -255,56 +256,13 @@ def writeTFMFile(filename: str, spacing: list[float], origin: list[float]): slicer.mrmlScene.RemoveNode(transformNode) -def createTRAFile( - volName: str, - trialName: str, - outputDir: str, - trackingsubDir: str, - volSize: list[float], - Origin2DicomTransformFile: str, - transform: vtk.vtkMatrix4x4, -): - transformNode = slicer.vtkMRMLLinearTransformNode() - transformNode.SetMatrixTransformToParent(transform) - slicer.mrmlScene.AddNode(transformNode) - - if Origin2DicomTransformFile is not None: - origin2DicomTransformNode = slicer.util.loadNodeFromFile(Origin2DicomTransformFile) - origin2DicomTransformNode.Inverse() - transformNode.SetAndObserveTransformNodeID(origin2DicomTransformNode.GetID()) - transformNode.HardenTransform() - slicer.mrmlScene.RemoveNode(origin2DicomTransformNode) - - filename = f"{trialName}_{volName}.tra" if trialName is not None else f"{volName}.tra" - filename = os.path.join(outputDir, trackingsubDir, filename) - - if not os.path.exists(os.path.join(outputDir, trackingsubDir)): - os.mkdir(os.path.join(outputDir, trackingsubDir)) - - tfmMat = vtk.vtkMatrix4x4() - transformNode.GetMatrixTransformToParent(tfmMat) - - writeTRA(filename, volSize, tfmMat) - - slicer.mrmlScene.RemoveNode(transformNode) - - -def writeTRA(filename: str, volSize: list[float], transform: vtk.vtkMatrix4x4): - # Slicer 2 Autoscoper Transform - # https://github.com/BrownBiomechanics/Autoscoper/issues/280 - transform.SetElement(1, 1, -transform.GetElement(1, 1)) # Flip Y - transform.SetElement(2, 2, -transform.GetElement(2, 2)) # Flip Z - - transform.SetElement(0, 3, transform.GetElement(0, 3) - volSize[0]) # Offset X - - # Write TRA - rowwise = [] - for i in range(4): # Row - for j in range(4): # Col - rowwise.append(str(transform.GetElement(i, j))) - - with open(filename, "w+") as f: - f.write(",".join(rowwise)) +def writeTRA(fileName: str, transforms: list[vtk.vtkMatrix4x4]) -> None: + rowWiseStrings = [] + for transform in transforms: + rowWiseStrings.append([str(transform.GetElement(i, j)) for i, j in product(range(4), range(4))]) + with open(fileName, "w+") as traFile: + for row in rowWiseStrings: + traFile.write(",".join(row) + "\n") def writeTemporyFile(filename: str, data: vtk.vtkImageData) -> str: