Skip to content

Commit

Permalink
fix: update tests to the new API
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 9, 2024
1 parent 6cbed33 commit 217d2f5
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ext/BoltzDynamicExpressionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function Layers.DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions
i -> Layers.InternalDynamicExpressionWrapper(
operator_enum, expressions[i], eval_options),
length(expressions))...),
WrappedFunction{:direct_call}(Lux.Utils.stack1))
WrappedFunction(Lux.Utils.stack1))
end
return Layers.DynamicExpressionsLayer(internal_layer)
end
Expand Down
2 changes: 1 addition & 1 deletion src/initialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Random: Random

using LuxCore: LuxCore

using ..Utils: is_extension_loaded, unwrap_val
using ..Utils: is_extension_loaded

get_pretrained_weights_path(name::Symbol) = get_pretrained_weights_path(string(name))
function get_pretrained_weights_path(name::String)
Expand Down
54 changes: 27 additions & 27 deletions test/vision_tests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@testitem "AlexNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
@testset "pretrained: $(pretrained)" for pretrained in [true, false]
model, ps, st = Vision.AlexNet(; pretrained)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.AlexNet(; pretrained)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -16,9 +16,9 @@ end

@testitem "ConvMixer" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, name in [:small, :base, :large]
model, ps, st = Vision.ConvMixer(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.ConvMixer(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 256, 256, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -30,9 +30,9 @@ end

@testitem "GoogLeNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
model, ps, st = Vision.GoogLeNet(; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.GoogLeNet(; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -44,9 +44,9 @@ end

@testitem "MobileNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, name in [:v1, :v2, :v3_small, :v3_large]
model, ps, st = Vision.MobileNet(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.MobileNet(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -58,9 +58,9 @@ end

@testitem "ResNet" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [18, 34, 50, 101, 152]
model, ps, st = Vision.ResNet(depth; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.ResNet(depth; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -72,9 +72,9 @@ end

@testitem "ResNeXt" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [50, 101, 152]
model, ps, st = Vision.ResNeXt(depth; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.ResNeXt(depth; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -90,9 +90,9 @@ end
false, true],
batchnorm in [false, true]

model, ps, st = Vision.VGG(depth; batchnorm, pretrained)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.VGG(depth; batchnorm, pretrained)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 224, 224, 3, 2) |> aType

@jet model(img, ps, st)
Expand All @@ -106,17 +106,17 @@ end
@testitem "VisionTransformer" setup=[SharedTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, name in [:tiny, :small, :base]
# :large, :huge, :giant, :gigantic --> too large for CI
model, ps, st = Vision.VisionTransformer(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.VisionTransformer(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 256, 256, 3, 2) |> aType

@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

model, ps, st = Vision.VisionTransformer(name; pretrained=false)
ps = ps |> dev
st = Lux.testmode(st) |> dev
model = Vision.VisionTransformer(name; pretrained=false)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
img = randn(Float32, 256, 256, 3, 2) |> aType

@jet model(img, ps, st)
Expand Down

0 comments on commit 217d2f5

Please sign in to comment.