Skip to content

Commit

Permalink
center_crop
Browse files Browse the repository at this point in the history
  • Loading branch information
HastingsGreer committed May 5, 2023
1 parent 763bd3b commit 2d25dc6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
27 changes: 24 additions & 3 deletions training_scripts/unigradICON/ITK_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,33 @@ def itk_crop_foreground(image: itk.Image, monai_crop_pad:monai.transforms.CropFo

return image

def itk_resize_with_pad_or_crop(image: itk.image, pad_or_crop:monai.transforms.ResizeWithPadOrCrop):
def itk_resize_with_pad_or_crop(image: itk.image, pad_or_crop:monai.transforms.SpatialCrop):
torch_image = torch.tensor(itk.GetArrayFromImage(image))

boundary_coords = pad_or_crop.com


def itk_crop(image:itk.image, crop:monai.transforms.SpatialPad):
pass
def itk_crop(image:itk.image, cropper:monai.transforms.CenterSpatialCrop):

torch_image = torch.tensor(itk.GetArrayFromImage(image))
crop = cropper.compute_slices(spatial_size=torch_image.shape[:])

lower = [s.start for s in reversed(crop)]
upper = [dim_len - s.stop for dim_len, s in reversed(list(zip(torch_image.shape, crop)))]

print(lower, upper)


filter = itk.CropImageFilter[type(image), type(image)].New()
filter.SetInput(image)
filter.SetLowerBoundaryCropSize(lower)
filter.SetUpperBoundaryCropSize(upper)

filter.Update()

image = filter.GetOutput()

return image



Expand Down
45 changes: 43 additions & 2 deletions training_scripts/unigradICON/test_ITK_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import tqdm
import torch

from monai.transforms import CropForeground, SpatialPad, ResizeWithPadOrCrop
from monai.transforms import CropForeground, SpatialPad, ResizeWithPadOrCrop, SpatialCrop
with open(f"/playpen-raid2/lin.tian/projects/icon_lung/ICON/training_scripts/brain_t1_pipeline/splits/train.txt") as f:
image_paths = f.readlines()
f_path = image_paths[0].split(".nii.gz")[0] + "_restore_brain.nii.gz"
import ITK_transforms
import importlib
def doit():
def test_itk_crop_foreground():
importlib.reload(ITK_transforms)

# crop foreground test
Expand All @@ -29,6 +29,47 @@ def doit():

print(torch.sum((cropped_monai - cropped_itk)**2))

def test_itk_crop_filter():

importlib.reload(ITK_transforms)

image = itk.imread(f_path)

transform = ResizeWithPadOrCrop([175, 175, 175]).cropper

print(transform)

image_t = torch.tensor(np.asarray(image))[None]

import pdb
# pdb.set_trace()
cropped_monai = transform(image_t)

print("Cropped monai shape", cropped_monai.shape)

crop = transform.compute_slices(spatial_size=image_t.shape[1:])

print(dir(crop[0]))




cropped_itk =torch.tensor( np.asarray(ITK_transforms.itk_crop(image, transform)))[None, None]

print("itk, monai:", cropped_itk.shape, cropped_monai.shape)


print(torch.sum((cropped_monai - cropped_itk)**2))

def test_itk_pad_image_filter():
importlib.reload(ITK_transforms)

image = itk.imread(f_path)

image = ITK_transforms.itk_crop_foreground(image, CropForeground())

tranform = SpatialPad([175, 175, 175])




Expand Down

0 comments on commit 2d25dc6

Please sign in to comment.