Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding UNet Model #210

Merged
merged 30 commits into from
Jan 27, 2023
Merged

Adding UNet Model #210

merged 30 commits into from
Jan 27, 2023

Conversation

shivance
Copy link
Contributor

@shivance shivance commented Dec 27, 2022

This PR adds the UNet implementation to Metalhead.jl in favor of #112

I've referred official torchhub implementation here and @DhairyaLGandhi 's UNet.jl package.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@shivance
Copy link
Contributor Author

shivance commented Dec 27, 2022

The PR is ready for code review. I'm still new to flux so apologies for silly mistakes like not following the docstring style or specific design & code principles in Julia ecosystems.

@ToucheSir
Copy link
Member

Metalhead has a JuliaFormatter config, so if your editor supports that I would recommend running it to help with code style adherence.

Copy link
Contributor

@pri1311 pri1311 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt there is supposed to be Batchnorm and ReLu layer before the concat + conv layer

@shivance
Copy link
Contributor Author

@ToucheSir I ran JuliaFormatter.jl using format(".")
Looks like the other files in Metalhead were not formatted as well.
Should I keep it this way or just format unet.jl ?

@pri1311 yup ! corrected that.

src/convnets/unet.jl Outdated Show resolved Hide resolved
@shivance shivance requested review from darsnack and pri1311 and removed request for darsnack and pri1311 December 30, 2022 03:57
src/utilities.jl Outdated Show resolved Hide resolved
@shivance
Copy link
Contributor Author

shivance commented Jan 1, 2023

@ToucheSir @darsnack Finally made it.
Successfully modified the model to use only Parallel and not use custom forward pass.

It was both challenging & confusing simultaneously as I was continuously getting dimension mismatch error. I knew it very well that it's because of Parallel, and the tensors were propagating through the maxpool which it wasn't supposed to.

I tried debugger as well, but there seemed some problem with it.

Following helped :

  1. I drew architecture on paper, and wrote down all layers, and matched with my code.
  2. Then I localized the error.

I was successfully able to resolve the error by moving

layers = Chain(layers, decoder_layer)

to before the decoder block. I realized that I have been chaining the layers with decoder, (whilst decoder layers are Chain of concat and decoder conv layers). This was causing error as chaining at this stage would make all tensors flow through concat again, thus dimensionmismatch.

Moving it before avoided this case.

@shivance shivance requested review from ToucheSir and darsnack and removed request for darsnack and ToucheSir January 1, 2023 15:58
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! I still need to review the architecture details, but here are some initial minor changes.

src/convnets/unet.jl Outdated Show resolved Hide resolved
src/convnets/unet.jl Outdated Show resolved Hide resolved
src/convnets/unet.jl Outdated Show resolved Hide resolved
test/convnets.jl Outdated Show resolved Hide resolved
src/Metalhead.jl Outdated Show resolved Hide resolved
@shivance shivance requested a review from darsnack January 2, 2023 11:57
@shivance
Copy link
Contributor Author

shivance commented Jan 2, 2023

Gtg for next round of review @darsnack @ToucheSir @pri1311 !

src/convnets/alexnet.jl Outdated Show resolved Hide resolved
This reverts commit ca73586.
@pri1311
Copy link
Contributor

pri1311 commented Jan 3, 2023

Following is the output I get from loading the model:

UNet(
  Chain(
    Chain(
      Chain([
        Chain(
          conv1 = Conv((3, 3), 3 => 32, pad=1),  # 896 parameters
          norm1 = BatchNorm(32, relu),  # 64 parameters, plus 64
          conv2 = Conv((3, 3), 32 => 32, pad=1),  # 9_248 parameters
          norm2 = BatchNorm(32, relu),  # 64 parameters, plus 64
        ),
        Chain(
          conv1 = Conv((3, 3), 32 => 64, pad=1),  # 18_496 parameters
          norm1 = BatchNorm(64, relu),  # 128 parameters, plus 128
          conv2 = Conv((3, 3), 64 => 64, pad=1),  # 36_928 parameters
          norm2 = BatchNorm(64, relu),  # 128 parameters, plus 128
        ),
        Chain(
          conv1 = Conv((3, 3), 64 => 128, pad=1),  # 73_856 parameters
          norm1 = BatchNorm(128, relu),  # 256 parameters, plus 256
          conv2 = Conv((3, 3), 128 => 128, pad=1),  # 147_584 parameters
          norm2 = BatchNorm(128, relu),  # 256 parameters, plus 256
        ),
        Chain(
          conv1 = Conv((3, 3), 128 => 256, pad=1),  # 295_168 parameters
          norm1 = BatchNorm(256, relu),  # 512 parameters, plus 512
          conv2 = Conv((3, 3), 256 => 256, pad=1),  # 590_080 parameters
          norm2 = BatchNorm(256, relu),  # 512 parameters, plus 512
        ),
      ]),
      Chain(
        conv1 = Conv((3, 3), 256 => 512, pad=1),  # 1_180_160 parameters
        norm1 = BatchNorm(512, relu),   # 1_024 parameters, plus 1_024
        conv2 = Conv((3, 3), 512 => 512, pad=1),  # 2_359_808 parameters
        norm2 = BatchNorm(512, relu),   # 1_024 parameters, plus 1_024
      ),
    ),
    Chain(
      Chain(
        conv1 = Conv((3, 3), 512 => 256, pad=1),  # 1_179_904 parameters
        norm1 = BatchNorm(256, relu),   # 512 parameters, plus 512
        conv2 = Conv((3, 3), 256 => 256, pad=1),  # 590_080 parameters
        norm2 = BatchNorm(256, relu),   # 512 parameters, plus 512
      ),
      Chain(
        conv1 = Conv((3, 3), 256 => 128, pad=1),  # 295_040 parameters
        norm1 = BatchNorm(128, relu),   # 256 parameters, plus 256
        conv2 = Conv((3, 3), 128 => 128, pad=1),  # 147_584 parameters
        norm2 = BatchNorm(128, relu),   # 256 parameters, plus 256
      ),
      Chain(
        conv1 = Conv((3, 3), 128 => 64, pad=1),  # 73_792 parameters
        norm1 = BatchNorm(64, relu),    # 128 parameters, plus 128
        conv2 = Conv((3, 3), 64 => 64, pad=1),  # 36_928 parameters
        norm2 = BatchNorm(64, relu),    # 128 parameters, plus 128
      ),
      Chain(
        conv1 = Conv((3, 3), 64 => 32, pad=1),  # 18_464 parameters
        norm1 = BatchNorm(32, relu),    # 64 parameters, plus 64
        conv2 = Conv((3, 3), 32 => 32, pad=1),  # 9_248 parameters
        norm2 = BatchNorm(32, relu),    # 64 parameters, plus 64
      ),
    ),
  ),
)         # Total: 72 trainable arrays, 7_069_152 parameters,
          # plus 36 non-trainable, 5_888 parameters, summarysize 27.002 MiB.

I believe layers are missing. I am not completely well versed with Flux, but in my knowledge it should display all the layers, even the custom cat_channels function/layer

Comment on lines 2 to 5
return Chain(conv1 = Conv(kernel, in_chs => out_chs; pad = (1, 1)),
norm1 = BatchNorm(out_chs, relu),
conv2 = Conv(kernel, out_chs => out_chs; pad = (1, 1)),
norm2 = BatchNorm(out_chs, relu))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason for using named layers? I don't see it being used anywhere. I haven't seen Metalhead use such a code convention, so a seems a little inconsistent with the code base.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this, it looks weird – I think the only place we might need named layers is if we specifically need to index into a Chain later for use and the name of the layer isn't apparent. Here I don't see that happening.

end
@functor UNet

function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
function UNet(imsize::Dims = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,

Is there anything in the UNet implementation that would prevent us from generalizing it to 1, 3 or more dimensions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to my own ignorance, which dimensions are spatial in the 1 and N>2 cases? Meaning which ones should be downscaled?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as with 2D. Spatial dimensions x channels/features x batch size, so all but the last two assuming the usual memory layout.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivance I think the point is that you don't need any changes other than dropping the type restriction to generalize to more dimensions.

But we'd want to have that in the test, so we can save it for another PR if you'd like.

@shivance
Copy link
Contributor Author

shivance commented Jan 22, 2023

It's funny how many rounds of reviews, architecture changes, this PR has had. Over a month since it's being reviewed 😆
Contributing to open source requires a lots of perseverance I must say 🔢

@shivance shivance requested review from darsnack and theabhirath and removed request for lorenzoh, darsnack and theabhirath January 24, 2023 12:50
@ToucheSir
Copy link
Member

Thanks for your patience, I think we're very close!

What you're experiencing is a triple learning curve of sorts. Julia is a new language for most contributors and so it takes longer to learn idiomatic code patterns than e.g. already knowing idiomatic Python. Flux is a new library for most contributors and thus folks are less familiar with what's available to use + limitations than they would be with PyTorch/TF. Metalhead is even more domain-specific and more opinionated because it sits at a higher level in the stack. I think a good analogy would be opening a PR contributing a new model to timm after just a month or two of Python experience ;)

Some things we can do on our side to flatten the learning curve:

  • Add pre-commit hooks and CI checks for formatting so zero review time is consumed on it
  • Add difficulty markers to feature issues so that potential contributors know what they're getting into. GH labels could work here.
  • Write proper devdocs for Flux and Metalhead. This is a much bigger project and almost certainly will be an ongoing one.

@shivance
Copy link
Contributor Author

Thanks @ToucheSir !

Are we still going with n dimensional unet?

@shivance shivance requested review from ToucheSir and removed request for darsnack January 25, 2023 19:17
@shivance
Copy link
Contributor Author

shivance commented Jan 25, 2023

  • Write proper devdocs for Flux and Metalhead. This is a much bigger project and almost certainly will be an ongoing one.

@ToucheSir Come to think of it, this could be a potential GSoD project!

@darsnack
Copy link
Member

Unfortunately, I think GSoCs are not allowed to be solely for documentation (I'll have to double check this). But you can propose it for GSoD!

@shivance
Copy link
Contributor Author

@ToucheSir @darsnack I'm willing to open a follow up PR to add N spatial dimensional support.
Let's get the PR for 2 dimension merged in first !
(It's kind of demotivating for me to drag this PR further after so many rounds, feels like no result of all this work) 😅

Open for review in current state...

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it is ready to merge modulo one small docstring issue.

PRs can take time (sometimes extenuated by our ability to review frequently). In our case, this is even more true since FluxML is very community-driven. This makes our development extremely distributed, and it is important that PRs are "release ready" before merging. Otherwise, simple changes can get bottlenecked from release due to larger changes that require refactoring/polishing.

Your patience and hard work is very much appreciated. The long review time is not a reflection of your work, just a consequence of the fact that we're all contributing on a volunteer basis. Please don't feel discouraged! Some of the most prolific Julia contributors have high impact PRs that take months to get right. So you're in good company!

src/convnets/unet.jl Outdated Show resolved Hide resolved
end
@functor UNet

function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivance I think the point is that you don't need any changes other than dropping the type restriction to generalize to more dimensions.

But we'd want to have that in the test, so we can save it for another PR if you'd like.

Co-authored-by: Kyle Daruwalla <[email protected]>
@shivance
Copy link
Contributor Author

@darsnack So I'll leave the signature of imsize as

imsize::Dims{2} = (256, 256)

for now?

Or make it

UNet(imsize::Dims= (256, 256)

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave it as is for now. Good job!

@shivance
Copy link
Contributor Author

Thanks @darsnack !

@shivance
Copy link
Contributor Author

Thank you @ToucheSir @darsnack !

@darsnack darsnack merged commit 80ab995 into FluxML:master Jan 27, 2023
@shivance shivance deleted the unet branch January 30, 2023 18:37
@CarloLucibello CarloLucibello mentioned this pull request May 7, 2023
46 tasks
@shivance shivance changed the title Adding UNet implementation Adding UNet Model Aug 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants