Skip to content
This repository has been archived by the owner on Aug 4, 2022. It is now read-only.

import PyTorch weights for VGG11/13/16/19 and ResNet50/101/152 #2

Merged
merged 1 commit into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions pytorch2flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Converts the weigths of a PyTorch model to a Flux model from Metalhead

# PyTorch need to be installed

# Tested on ResNet and VGG models

using Flux
import Metalhead
using DataStructures
using Statistics
using BSON
using PyCall
using Images
using Test

torchvision = pyimport("torchvision")
torch = pyimport("torch")

modellib = [
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11),
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13),
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16),
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19),
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18),
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34),
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50),
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101),
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152),
]


function _list_state(node::Flux.BatchNorm,channel,prefix)
# use the same order of parameters than PyTorch
put!(channel, (prefix * ".γ", node.γ)) # weigth (learnable)
put!(channel, (prefix * ".β", node.β)) # bias (learnable)
put!(channel, (prefix * ".μ", node.μ)) # running mean
put!(channel, (prefix * ".σ²", node.σ²)) # running variance
end

function _list_state(node::Union{Flux.Conv,Flux.Dense},channel,prefix)
put!(channel, (prefix * ".weight", node.weight))

if node.bias !== Flux.Zeros()
put!(channel, (prefix * ".bias", node.bias))
end
end

_list_state(node,channel,prefix) = nothing

function _list_state(node::Union{Flux.Chain,Flux.Parallel},channel,prefix)
for (i,n) in enumerate(node.layers)
_list_state(n,channel,prefix * ".layers[$i]")
end
end

function list_state(node; prefix = "model")
Channel() do channel
_list_state(node,channel,prefix)
end
end

for (modelname,jlmodel,pymodel) in modellib

model = jlmodel()
pytorchmodel = pymodel(pretrained=true)

state = OrderedDict(list_state(model.layers))

# pytorchmodel.state_dict() looses the order
state_dict = OrderedDict(pycall(pytorchmodel.state_dict,PyObject).items())
pytorch_pp = OrderedDict((k,v.numpy()) for (k,v) in state_dict if !occursin("num_batches_tracked",k))


# loop over all parameters
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp)
if size(flux_param) == size(pytorch_param)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit risky. What if the input size is equal to the output size but a permutation is needed?
Maybe we should have specific rules for each pytorch layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, but the scope of this script (so far at least) is just the tested VGG and ResNet models (not for an arbitrary model).

# Dense weight and vectors
flux_param .= pytorch_param
elseif size(flux_param) == reverse(size(pytorch_param))
tmp = pytorch_param
tmp = permutedims(tmp,ndims(tmp):-1:1)

if ndims(flux_param) == 4
# convolutional weights
flux_param .= reverse(tmp,dims=(1,2))
else
flux_param .= tmp
end
else
@debug begin
@show size(flux_param), size(pytorch_param)
end
error("incompatible shape $flux_key $pytorch_key")
end
end

@info "saving model $modelname"
BSON.@save joinpath(@__DIR__,"weights","$(modelname).bson") model
end
96 changes: 96 additions & 0 deletions test/compare_pytorch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Compare Flux model from Metalhead to PyTorch model
# for a sample image

# PyTorch need to be installed

# Tested on ResNet and VGG models

using Flux
import Metalhead
using DataStructures
using Statistics
using BSON
using PyCall
using Images
using Test

using MLUtils
using Random

torchvision = pyimport("torchvision")
torch = pyimport("torch")

modellib = [
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11),
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13),
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16),
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19),
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18),
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34),
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50),
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101),
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152),
]


tr(tmp) = permutedims(tmp,ndims(tmp):-1:1)


function normalize(data)
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
return (data .- cmean) ./ cstd
end

# test image
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")

# image net labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))

weightsdir = joinpath(@__DIR__,"..","weights")

for (modelname,jlmodel,pymodel) in modellib
println(modelname)

model = jlmodel()

saved_model = BSON.load(joinpath(weightsdir,"$(modelname).bson"))
Flux.loadmodel!(model,saved_model[:model])

pytorchmodel = pymodel(pretrained=true)

Flux.testmode!(model)

sz = (224, 224)
img = Images.load(guitar_path);
img = imresize(img, sz);
# CHW -> WHC
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1))
data = normalize(data[:,:,:,1:1])

out = model(data) |> softmax;
out = out[:,1]

println(" Flux:")

for i in sortperm(out,rev=true)[1:5]
println(" $(labels[i]): $(out[i])")
end


pytorchmodel.eval()
output = pytorchmodel(torch.Tensor(tr(data)));
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy();

println(" PyTorch:")

for i in sortperm(probabilities[:,1],rev=true)[1:5]
println(" $(labels[i]): $(probabilities[i])")
end

@test maximum(out) ≈ maximum(probabilities)
@test argmax(out) ≈ argmax(probabilities)

println()
end