Skip to content

Commit

Permalink
No need to use image resolution scaling (implicitly done)
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Jun 19, 2022
1 parent 0ecaab1 commit 69e1ac4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/convnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function EfficientNet(name::Symbol; pretrain = false)
@assert name in keys(efficientnet_global_configs)
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"

model = EfficientNet(efficientnet_global_configs[name]..., efficientnet_block_configs)
model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
pretrain && loadpretrain!(model, string("efficientnet-", name))

return model
Expand Down
6 changes: 4 additions & 2 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,16 @@ GC.gc()

@testset "EfficientNet" begin
@testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]
xsz = Metalhead.efficientnet_global_configs[name][1]
x = rand(Float32, xsz...)
m = EfficientNet(name)
@test size(m(x_256)) == (1000, 1)
@test size(m(x)) == (1000, 1)
if (EfficientNet, name) in PRETRAINED_MODELS
@test (EfficientNet(name, pretrain = true); true)
else
@test_throws ArgumentError EfficientNet(name, pretrain = true)
end
@test gradtest(m, x_256)
@test gradtest(m, x)
GC.safepoint()
GC.gc()
end
Expand Down

0 comments on commit 69e1ac4

Please sign in to comment.