Skip to content

Commit

Permalink
Improve tests for the Lux extenstion
Browse files Browse the repository at this point in the history
Add tests for parameter types for `ps` and `st`, `ComponentArrays`
and use both 1 layer and a `Chain`
  • Loading branch information
SebastianM-C committed Feb 16, 2024
1 parent 356e4c9 commit 639a354
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 4 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ julia = "1.10"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand All @@ -99,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"]
117 changes: 114 additions & 3 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,124 @@
using Symbolics, Lux, Random, Test
using ComponentArrays

@testset "Symbolics extension" begin
@testset "Dense" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, st = LuxCore.setup(rng, model)

Symbolics.@variables sym_ps[1:5] = Float32[1, 2, 3, 4, 5]
ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_st::typeof(st) = st
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5]

out = LuxCore.partial_apply(model, sym_ps, ps, st)
out_ref = LuxCore.partial_apply(model, x, ps, st)
@test out_ref isa Vector{Float32}

out = LuxCore.partial_apply(model, sym_x, ps, st)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 6
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps=>ps),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ca, st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),)))
@test out_sub == out_ref
end

@testset "Chain" begin
model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3))
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, st = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_st::typeof(st) = st
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5]

out_ref = LuxCore.partial_apply(model, x, ps, st)
@test out_ref isa Vector{Float32}

out = LuxCore.partial_apply(model, sym_x, ps, st)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 3
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ca, st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),)))
@test out_sub == out_ref
end

0 comments on commit 639a354

Please sign in to comment.