Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input shape requirements #78

Open
koegl opened this issue Aug 13, 2024 · 1 comment
Open

Input shape requirements #78

koegl opened this issue Aug 13, 2024 · 1 comment
Assignees

Comments

@koegl
Copy link

koegl commented Aug 13, 2024

What are the input shape requirements for the images passed to the network?

I'm referring to this function, where te shape is fixed. I want to use my own data for training, but I cannot figure out what the shape has to be

def make_network():

    phi = network_wrappers.FunctionFromVectorField(
        networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3)
    )
    psi = network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3))

    hires_net = icon_registration.GradientICON(
        network_wrappers.DoubleNet(
            network_wrappers.DownsampleNet(
                network_wrappers.TwoStepRegistration(phi, psi), dimension=3
            ),
            network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3)),
        ),
        icon_registration.LNCCOnlyInterpolated(sigma=5),
        3,
    )
    SCALE = 2  # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES
    input_shape = [1, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE]
    hires_net.assign_identity_map(input_shape)
    return hires_net

(from lncc_train_knees.py)

I tried forcing my images to [1, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE], but I get an error in the forward() of UNet2ChunkyMiddle at this line:
x = torch.cat([x, skips[depth]], 1)

Sizes of tensors must match except in dimension 1. Expected size 6 but got size 1 for tensor number 1 in the list.

those are the shapes:

x.shape
torch.Size([6, 256, 4, 6, 6])
skips[depth].shape
torch.Size([1, 256, 5, 12, 12])
@HastingsGreer HastingsGreer self-assigned this Aug 29, 2024
@HastingsGreer
Copy link
Collaborator

Hi! This is a great issue. The class UNet2ChunkyMiddle is from the paper ICON, before we were really aspiring to generalize to arbitrary images. Specifically, it only works if the input is a specific size, that specific size is not documented, and is missing any check that its input is that size. The short term fix is to just switch to UNet2, which is used for all stages in the GradICON and uniGradICON papers. The class UNet2 (which you can get an instance of with reasonable defaults by calling icon_registration.networks.tallUNet2(dimension=3) ) is fully parametric over input size and so won't have this error. Long term, the next update of icon needs to add the asserts for the input shape of UNet2ChunkyMiddle (and possibly deprecate it entirely)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants