Skip to content
This repository has been archived by the owner on Aug 4, 2022. It is now read-only.

Commit

Permalink
split testing of model from convertion
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Apr 14, 2022
1 parent 324cc3e commit 65db9c6
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 58 deletions.
61 changes: 3 additions & 58 deletions pytorch2flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,14 @@ modellib = [
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13),
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16),
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19),
#=
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18),
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34),
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50), # works
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101), # works
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152), # works
=#
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50),
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101),
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152),
]


tr(tmp) = permutedims(tmp,ndims(tmp):-1:1)

function _list_state(node::Flux.BatchNorm,channel,prefix)
# use the same order of parameters than PyTorch
put!(channel, (prefix * "", node.γ)) # weigth (learnable)
Expand Down Expand Up @@ -63,22 +59,6 @@ function list_state(node; prefix = "model")
end
end

function normalize(data)
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
return (data .- cmean) ./ cstd
end

# test image
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")

# image net labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))

#=
modelname,jlmodel,pymodel = modellib[1]
=#

for (modelname,jlmodel,pymodel) in modellib

model = jlmodel()
Expand Down Expand Up @@ -114,41 +94,6 @@ for (modelname,jlmodel,pymodel) in modellib
end
end


Flux.testmode!(model)

sz = (224, 224)
img = Images.load(guitar_path);
img = imresize(img, sz);
# CHW -> WHC
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1))
data = normalize(data[:,:,:,1:1])

out = model(data) |> softmax;
out = out[:,1]

println(modelname)
println("Flux:")

for i in sortperm(out,rev=true)[1:5]
println("$(labels[i]): $(out[i])")
end


pytorchmodel.eval()
output = pytorchmodel(torch.Tensor(tr(data)));
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy();

println("PyTorch:")

for i in sortperm(probabilities[:,1],rev=true)[1:5]
println("$(labels[i]): $(probabilities[i])")
end

@test maximum(out) maximum(probabilities)
@test argmax(out) argmax(probabilities)

@info "saving model $modelname"
BSON.@save joinpath(@__DIR__,"weights","$(modelname).bson") model
println()
end
96 changes: 96 additions & 0 deletions test/compare_pytorch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Compare Flux model from Metalhead to PyTorch model
# for a sample image

# PyTorch need to be installed

# Tested on ResNet and VGG models

using Flux
import Metalhead
using DataStructures
using Statistics
using BSON
using PyCall
using Images
using Test

using MLUtils
using Random

torchvision = pyimport("torchvision")
torch = pyimport("torch")

modellib = [
("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11),
("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13),
("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16),
("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19),
("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18),
("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34),
("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50),
("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101),
("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152),
]


tr(tmp) = permutedims(tmp,ndims(tmp):-1:1)


function normalize(data)
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
return (data .- cmean) ./ cstd
end

# test image
guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")

# image net labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))

weightsdir = joinpath(@__DIR__,"..","weights")

for (modelname,jlmodel,pymodel) in modellib
println(modelname)

model = jlmodel()

saved_model = BSON.load(joinpath(weightsdir,"$(modelname).bson"))
Flux.loadmodel!(model,saved_model[:model])

pytorchmodel = pymodel(pretrained=true)

Flux.testmode!(model)

sz = (224, 224)
img = Images.load(guitar_path);
img = imresize(img, sz);
# CHW -> WHC
data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1))
data = normalize(data[:,:,:,1:1])

out = model(data) |> softmax;
out = out[:,1]

println(" Flux:")

for i in sortperm(out,rev=true)[1:5]
println(" $(labels[i]): $(out[i])")
end


pytorchmodel.eval()
output = pytorchmodel(torch.Tensor(tr(data)));
probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy();

println(" PyTorch:")

for i in sortperm(probabilities[:,1],rev=true)[1:5]
println(" $(labels[i]): $(probabilities[i])")
end

@test maximum(out) maximum(probabilities)
@test argmax(out) argmax(probabilities)

println()
end

0 comments on commit 65db9c6

Please sign in to comment.