diff --git a/Project.toml b/Project.toml index 09deda68b..e32377051 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Symbolics" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" authors = ["Shashi Gowda "] -version = "5.22.1" +version = "5.23.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -40,11 +40,13 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [weakdeps] Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" [extensions] SymbolicsGroebnerExt = "Groebner" +SymbolicsLuxCoreExt = "LuxCore" SymbolicsPreallocationToolsExt = "PreallocationTools" SymbolicsSymPyExt = "SymPy" @@ -65,6 +67,7 @@ LaTeXStrings = "1.3" LambertW = "0.4.5" Latexify = "0.16" LogExpFunctions = "0.3" +LuxCore = "0.1.11" MacroTools = "0.5" NaNMath = "1" PrecompileTools = "1" @@ -83,8 +86,10 @@ julia = "1.10" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" @@ -95,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"] +test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"] diff --git a/ext/SymbolicsLuxCoreExt.jl b/ext/SymbolicsLuxCoreExt.jl new file mode 100644 index 000000000..8f5be992b --- /dev/null +++ b/ext/SymbolicsLuxCoreExt.jl @@ -0,0 +1,11 @@ +module SymbolicsLuxCoreExt + +using LuxCore, Symbolics + +@register_array_symbolic LuxCore.stateless_apply( + model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin + size = LuxCore.outputsize(model, x, LuxCore._default_rng()) + eltype = Real +end + +end diff --git a/test/extensions/lux.jl b/test/extensions/lux.jl new file mode 100644 index 000000000..2194e8528 --- /dev/null +++ b/test/extensions/lux.jl @@ -0,0 +1,86 @@ +using Symbolics, Lux, Random, Test +using ComponentArrays + +@testset "Dense" begin + model = Dense(5, 6) + rng = Random.default_rng() + x = randn(rng, Float32, 5) + ps, _ = LuxCore.setup(rng, model) + + ca = ComponentArray(ps) + Symbolics.@variables sym_ca[1:length(ca)] = ca + Symbolics.@variables sym_ps::typeof(ps) = ps + Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5] + + out_ref = LuxCore.stateless_apply(model, x, ps) + @test out_ref isa Vector{Float32} + + out = LuxCore.stateless_apply(model, sym_x, ps) + # if the symbolic function regstration wouldn't work, we'd get the + # vector propagated through the neural network and get a Vector{Num} + @test out isa Symbolics.Arr + @test length(out) == 6 + # test that we can recover the same value as when using concrete numbers + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) + @test out_sub == out_ref + + out = LuxCore.stateless_apply(model, sym_x, sym_ps) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),))) + @test out_sub == out_ref + + out = LuxCore.stateless_apply(model, sym_x, ca) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) + @test out_sub == out_ref + + out = LuxCore.stateless_apply(model, sym_x, sym_ca) + @test out isa Symbolics.Arr + @test length(out) == 6 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) + @test out_sub == out_ref +end + +@testset "Chain" begin + model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3)) + rng = Random.default_rng() + x = randn(rng, Float32, 5) + ps, _ = LuxCore.setup(rng, model) + + ca = ComponentArray(ps) + Symbolics.@variables sym_ca[1:length(ca)] = ca + Symbolics.@variables sym_ps::typeof(ps) = ps + Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5] + + out_ref = LuxCore.stateless_apply(model, x, ps) + @test out_ref isa Vector{Float32} + + out = LuxCore.stateless_apply(model, sym_x, ps) + # if the symbolic function regstration wouldn't work, we'd get the + # vector propagated through the neural network and get a Vector{Num} + @test out isa Symbolics.Arr + @test length(out) == 3 + # test that we can recover the same value as when using concrete numbers + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) + @test out_sub == out_ref + + out = LuxCore.stateless_apply(model, sym_x, sym_ps) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),))) + @test out_sub == out_ref + + out = LuxCore.stateless_apply(model, sym_x, ca) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) + @test out_sub == out_ref + + out = LuxCore.stateless_apply(model, sym_x, sym_ca) + @test out isa Symbolics.Arr + @test length(out) == 3 + out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) + @test out_sub == out_ref +end diff --git a/test/runtests.jl b/test/runtests.jl index 521d86cfc..f151d0384 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Integral Test" begin include("integral.jl") end @safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end + @safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end @safetestset "Registration without using Test" begin include("registration_without_using.jl") end end