Skip to content
This repository has been archived by the owner on Aug 4, 2022. It is now read-only.

Commit

Permalink
Merge pull request #2 from Alexander-Barth/main
Browse files Browse the repository at this point in the history
import PyTorch weights for VGG11/13/16/19 and ResNet50/101/152
  • Loading branch information
CarloLucibello authored Jun 1, 2022
2 parents 3c801fe + abc51ce commit ab10732
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 0 deletions.
99 changes: 99 additions & 0 deletions pytorch2flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Converts the weigths of a PyTorch model to a Flux model from Metalhead

# PyTorch need to be installed

# Tested on ResNet and VGG models

using Flux
import Metalhead
using DataStructures
using Statistics
using BSON
using PyCall
using Images
using Test

torchvision = pyimport("torchvision")
torch = pyimport("torch")

modellib = [
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11),
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13),
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16),
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19),
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18),
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34),
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50),
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101),
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152),
]


function _list_state(node::Flux.BatchNorm,channel,prefix)
# use the same order of parameters than PyTorch
put!(channel, (prefix * "", node.γ)) # weigth (learnable)
put!(channel, (prefix * "", node.β)) # bias (learnable)
put!(channel, (prefix * "", node.μ)) # running mean
put!(channel, (prefix * ".σ²", node.σ²)) # running variance
end

function _list_state(node::Union{Flux.Conv,Flux.Dense},channel,prefix)
put!(channel, (prefix * ".weight", node.weight))

if node.bias !== Flux.Zeros()
put!(channel, (prefix * ".bias", node.bias))
end
end

_list_state(node,channel,prefix) = nothing

function _list_state(node::Union{Flux.Chain,Flux.Parallel},channel,prefix)
for (i,n) in enumerate(node.layers)
_list_state(n,channel,prefix * ".layers[$i]")
end
end

function list_state(node; prefix = "model")
Channel() do channel
_list_state(node,channel,prefix)
end
end

for (modelname,jlmodel,pymodel) in modellib

model = jlmodel()
pytorchmodel = pymodel(pretrained=true)

state = OrderedDict(list_state(model.layers))

# pytorchmodel.state_dict() looses the order
state_dict = OrderedDict(pycall(pytorchmodel.state_dict,PyObject).items())
pytorch_pp = OrderedDict((k,v.numpy()) for (k,v) in state_dict if !occursin("num_batches_tracked",k))


# loop over all parameters
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp)
if size(flux_param) == size(pytorch_param)
# Dense weight and vectors
flux_param .= pytorch_param
elseif size(flux_param) == reverse(size(pytorch_param))
tmp = pytorch_param
tmp = permutedims(tmp,ndims(tmp):-1:1)

if ndims(flux_param) == 4
# convolutional weights
flux_param .= reverse(tmp,dims=(1,2))
else
flux_param .= tmp
end
else
@debug begin
@show size(flux_param), size(pytorch_param)
end
error("incompatible shape $flux_key $pytorch_key")
end
end

@info "saving model $modelname"
BSON.@save joinpath(@__DIR__,"weights","$(modelname).bson") model
end
96 changes: 96 additions & 0 deletions test/compare_pytorch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Compare Flux model from Metalhead to PyTorch model
# for a sample image

# PyTorch need to be installed

# Tested on ResNet and VGG models

using Flux
import Metalhead
using DataStructures
using Statistics
using BSON
using PyCall
using Images
using Test

using MLUtils
using Random

torchvision = pyimport("torchvision")
torch = pyimport("torch")

modellib = [
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11),
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13),
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16),
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19),
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18),
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34),
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50),
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101),
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152),
]


tr(tmp) = permutedims(tmp,ndims(tmp):-1:1)


function normalize(data)
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
return (data .- cmean) ./ cstd
end

# test image
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")

# image net labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))

weightsdir = joinpath(@__DIR__,"..","weights")

for (modelname,jlmodel,pymodel) in modellib
println(modelname)

model = jlmodel()

saved_model = BSON.load(joinpath(weightsdir,"$(modelname).bson"))
Flux.loadmodel!(model,saved_model[:model])

pytorchmodel = pymodel(pretrained=true)

Flux.testmode!(model)

sz = (224, 224)
img = Images.load(guitar_path);
img = imresize(img, sz);
# CHW -> WHC
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1))
data = normalize(data[:,:,:,1:1])

out = model(data) |> softmax;
out = out[:,1]

println(" Flux:")

for i in sortperm(out,rev=true)[1:5]
println(" $(labels[i]): $(out[i])")
end


pytorchmodel.eval()
output = pytorchmodel(torch.Tensor(tr(data)));
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy();

println(" PyTorch:")

for i in sortperm(probabilities[:,1],rev=true)[1:5]
println(" $(labels[i]): $(probabilities[i])")
end

@test maximum(out) maximum(probabilities)
@test argmax(out) argmax(probabilities)

println()
end

0 comments on commit ab10732

Please sign in to comment.