Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array function registration for Lux #1054

Merged
merged 10 commits into from
Mar 3, 2024
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <[email protected]>"]
version = "5.22.1"
version = "5.23.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -40,11 +40,13 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsPreallocationToolsExt = "PreallocationTools"
SymbolicsSymPyExt = "SymPy"

Expand All @@ -65,6 +67,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.11"
MacroTools = "0.5"
NaNMath = "1"
PrecompileTools = "1"
Expand All @@ -83,8 +86,10 @@ julia = "1.10"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -95,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"]
11 changes: 11 additions & 0 deletions ext/SymbolicsLuxCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module SymbolicsLuxCoreExt

using LuxCore, Symbolics

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model, x, LuxCore._default_rng())
eltype = Real
end

end
86 changes: 86 additions & 0 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Symbolics, Lux, Random, Test
using ComponentArrays

@testset "Dense" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, _ = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5]

out_ref = LuxCore.stateless_apply(model, x, ps)
@test out_ref isa Vector{Float32}

out = LuxCore.stateless_apply(model, sym_x, ps)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 6
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ps)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, ca)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ca)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref
end

@testset "Chain" begin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add some tests for Layers where we haven't defined outputsize to check that our fallback works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems a bit tricky, as the outputsize fallback receives a symbolic array that's not an AbstractArray, so we might need to register additional functions, but that would need to make the extension on Lux and not just LuxCore.

I tried changing the extension to Lux and registering apply for Dense, but I hit applyparallel and it tries to iterate SymbolicUtils.BasicSymbolic{AbstractVector{Real}}, which fails. I'm not sure if we can continue on this path.

@ChrisRackauckas Should I leave the extension for LuxCore only or change to Lux? There's no immediate benefit for doing this, but it might make it easier if we wan to support more complex networks in the future, not sure how though, as the fact that we don't have an AbstractArray seems like a deal breaker.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, leave as LuxCore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, in that case this should be good to go. Can you trigger / allow CI?

model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3))
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, _ = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5]

out_ref = LuxCore.stateless_apply(model, x, ps)
@test out_ref isa Vector{Float32}

out = LuxCore.stateless_apply(model, sym_x, ps)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 3
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ps)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, ca)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ca)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end

Expand Down
Loading