-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model implementations #112
Comments
cc @theabhirath please add the models you are planning on implementing to this list |
A starting point for EfficientNet can be found at https://github.com/pxl-th/EfficientNet.jl (see #106). |
I'm planning on at first. I also want to start on object detection and semantic segmentation models but there's a lot of helper functions related to the fields that I will have to write - if everything CV-related is planned to be in Metalhead, then I'll go ahead and start coding them up |
Hello, Can I work on inception v4 and efficient net?(Also I have already coded ESRGANs) |
There's #113 for efficient net. Would be good to port srgan etc into Metalhead as well |
I could provide the pretrained weights for VGG and ResNet converted from PyTorch (once some minor changes to the Metalhead models are merged so that they are equivalent to the PyTorch model). (I would be very interested in an ESRGAN implementation. 😃) |
Also see #109 for pretrained models. It's hard to know before hand which weights would work well between pytorch and Flux, but if we have some pretrained weights we can validate, that would be welcome! |
As mentioned, there is already a PR for EfficientNet, but InceptionNet v4 would be very welcome! ESRGAN would be welcome too. |
Yes, I think your flow will work well for both those models. Please submit PRs to MetalheadWeights when you have them! |
How do I contribute models here? I'm fairly new here |
The Flux contribution guide has some info as well as links for how to get started with making your first contribution. I'll briefly summarize the process for this repo (apologies if you already know this):
|
I hope this is the right place for this discussion, but it looks like using PyTorch weights might not be too complicated. I've had success opening Torch's using Downloads: download
using Pickle
using Flux, Metalhead
# Links from https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py
model_urls = Dict(
"vgg11" => "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13" => "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
)
model_name = "vgg11"
datadir = joinpath(@__DIR__, "data")
path = joinpath(datadir, "$model_name.pth")
!isdir(datadir) && mkdir(datadir)
!isfile(path) && download(model_urls[model_name], path)
# Torchvision uses NCHW ordering, Flux WHCN
function permute_weights(A::AbstractArray{T, N}) where {T, N}
if N == 4
return permutedims(A, (4, 3, 2, 1))
end
return A
end
torchweights = Pickle.Torch.THload(path)
weights = map(permute_weights, torchweights.vals)
model = VGG11()
Flux.loadparams!(model, weights) where julia> torchweights
OrderedCollections.OrderedDict{Any, Any} with 22 entries:
"features.0.weight" => [0.288164 0.401512 0.216151; -0.3528 -0.574001 -0.028024;…
"features.0.bias" => Float32[0.193867, 0.304219, 0.18251, -1.11219, 0.0441538,…
"features.3.weight" => [0.0302545 0.0595999 … 0.00978228 0.0180896; -0.00835048 …
"features.3.bias" => Float32[-0.0372088, -0.115514, 0.148786, -0.106784, 0.153…
"features.6.weight" => [-0.0107434 0.00274947 … 0.0393466 0.0168702; -0.0110679 …
"features.6.bias" => Float32[0.0696629, -0.0745776, 0.0681913, -0.115447, 0.11…
"features.8.weight" => [-0.0158845 -0.0120116 … -0.0287082 0.00195862; -0.026421…
"features.8.bias" => Float32[-0.00635051, 0.031504, 0.0732542, 0.0478025, 0.32…
"features.11.weight" => [0.0145356 -0.0262731 … -0.0032392 0.0459495; -0.00929091…
"features.11.bias" => Float32[-0.0146084, 0.187013, -0.0683434, 0.0223707, -0.0…
"features.13.weight" => [0.0514621 -0.0490013 … -0.00740175 0.00124351; 0.0382785…
"features.13.bias" => Float32[0.324421, -0.00724723, 0.0839103, 0.180003, 0.075…
"features.16.weight" => [0.00122506 -0.0199895 … -0.0369922 -0.0188395; -0.023495…
"features.16.bias" => Float32[-0.0753394, 0.19634, 0.0544855, 0.0230991, 0.2478…
"features.18.weight" => [0.0104981 -0.0085396 … -0.00996796 0.00263586; -0.022834…
"features.18.bias" => Float32[-0.0132562, -0.128536, -0.021685, 0.0401009, 0.14…
"classifier.0.weight" => Float32[-0.00160107 -0.00533261 … 0.00406176 0.0014802; 0…
"classifier.0.bias" => Float32[0.0252175, -0.00486407, 0.0436967, 0.0097529, 0.0…
"classifier.3.weight" => Float32[0.00620286 -0.0210906 … -0.00731366 -0.0212599; -…
"classifier.3.bias" => Float32[0.0449707, 0.0848237, 0.0727477, 0.0816414, 0.074…
"classifier.6.weight" => Float32[-0.0118874 0.0186423 … 0.0170721 0.0105425; 0.033…
"classifier.6.bias" => Float32[0.0146753, -0.00972212, -0.0238722, -0.0290253, -… If this looks promising to you in any way, I'd be glad to open a draft PR with some guidance. :) Edit: transposed weights according to Alexander's feedback. Thanks! |
Indeed!, In this PR I have taken a similar approach (but calling PyTorch via PyCall): Sometimes the weight have to be transposed as you found by the error message. The ability to use PyTorch weights without needing PyTorch installed is indeed very nice! This seems to be also relevant: |
I would like to give LeViT a try, if someone else is not already working on it. |
@darsnack I would like to take a crack at MobileVIT ! Shall I go ahead and create a new issue for it ? |
Go for it! For both ViT-based models, make sure to look at the existing ViT implementation and the Layers submodules. Torchvision models are good reference, but ultimately we want something that is a Flux-idiomatic model not a simple clone. |
Below are models that we still need to implement
Will implement
These are models for which we are actively seeking implementations/weights.
Model implementations
Inceptionv4
,InceptionResNetv2
andXception
#170)Inceptionv4
,InceptionResNetv2
andXception
#170)Inceptionv4
,InceptionResNetv2
andXception
#170)Pre-trained weights
Will consider implementing
These are models that aren't necessarily must haves, but we are happy to accept contributions that add them.
The text was updated successfully, but these errors were encountered: