diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 95933912d..265a386ea 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -147,7 +147,7 @@ function EfficientNet(name::Symbol; pretrain = false) @assert name in keys(efficientnet_global_configs) "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" - model = EfficientNet(efficientnet_global_configs[name]..., efficientnet_block_configs) + model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) pretrain && loadpretrain!(model, string("efficientnet-", name)) return model diff --git a/test/convnets.jl b/test/convnets.jl index c019a7fac..86875aae8 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -78,14 +78,16 @@ GC.gc() @testset "EfficientNet" begin @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8] + xsz = Metalhead.efficientnet_global_configs[name][1] + x = rand(Float32, xsz...) m = EfficientNet(name) - @test size(m(x_256)) == (1000, 1) + @test size(m(x)) == (1000, 1) if (EfficientNet, name) in PRETRAINED_MODELS @test (EfficientNet(name, pretrain = true); true) else @test_throws ArgumentError EfficientNet(name, pretrain = true) end - @test gradtest(m, x_256) + @test gradtest(m, x) GC.safepoint() GC.gc() end