diff --git a/Project.toml b/Project.toml index 157f7bc9e..979a8f633 100644 --- a/Project.toml +++ b/Project.toml @@ -86,6 +86,7 @@ julia = "1.10" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -99,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux"] +test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"] diff --git a/test/extensions/lux.jl b/test/extensions/lux.jl index 9cb7093f3..41497c313 100644 --- a/test/extensions/lux.jl +++ b/test/extensions/lux.jl @@ -1,13 +1,124 @@ using Symbolics, Lux, Random, Test +using ComponentArrays -@testset "Symbolics extension" begin +@testset "Dense" begin model = Dense(5, 6) rng = Random.default_rng() x = randn(rng, Float32, 5) ps, st = LuxCore.setup(rng, model) - Symbolics.@variables sym_ps[1:5] = Float32[1, 2, 3, 4, 5] + ca = ComponentArray(ps) + Symbolics.@variables sym_ca[1:length(ca)] = ca + Symbolics.@variables sym_ps::typeof(ps) = ps + Symbolics.@variables sym_st::typeof(st) = st + Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5] - out = LuxCore.partial_apply(model, sym_ps, ps, st) + out_ref = LuxCore.partial_apply(model, x, ps, st) + @test out_ref isa Vector{Float32} + + out = LuxCore.partial_apply(model, sym_x, ps, st) + # if the symbolic function regstration wouldn't work, we'd get the + # vector propagated through the neural network and get a Vector{Num} + @test out isa Symbolics.Arr + @test length(out) == 6 + # test that we can recover the same value as when using concrete numbers + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ps, st) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps=>ps),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, ps, sym_st) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, ca, st) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ca, st) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),))) + @test out_sub == out_ref +end + +@testset "Chain" begin + model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3)) + rng = Random.default_rng() + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + ca = ComponentArray(ps) + Symbolics.@variables sym_ca[1:length(ca)] = ca + Symbolics.@variables sym_ps::typeof(ps) = ps + Symbolics.@variables sym_st::typeof(st) = st + Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5] + + out_ref = LuxCore.partial_apply(model, x, ps, st) + @test out_ref isa Vector{Float32} + + out = LuxCore.partial_apply(model, sym_x, ps, st) + # if the symbolic function regstration wouldn't work, we'd get the + # vector propagated through the neural network and get a Vector{Num} + @test out isa Symbolics.Arr + @test length(out) == 3 + # test that we can recover the same value as when using concrete numbers + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ps, st) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, ps, sym_st) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, ca, st) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ca, st) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) + @test out_sub == out_ref + + out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st) @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),))) + @test out_sub == out_ref end