Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array function registration for Lux #1054

Merged
merged 10 commits into from
Mar 3, 2024
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsPreallocationToolsExt = "PreallocationTools"
SymbolicsSymPyExt = "SymPy"

Expand All @@ -65,6 +67,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.10"
MacroTools = "0.5"
NaNMath = "1"
PrecompileTools = "1"
Expand All @@ -83,8 +86,10 @@ julia = "1.10"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -95,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"]
11 changes: 11 additions & 0 deletions ext/SymbolicsLuxCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module SymbolicsLuxCoreExt

using LuxCore, Symbolics

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to change based on the new LuxCore PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to get the rng here 🤔

Copy link
Contributor Author

@SebastianM-C SebastianM-C Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already works for layers & layer containers that can define the 1 arg outputsize, like Dense and Chain, but it would be nice to be able to use the general form.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just pass Xoshiro?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LuxCore.outputsize(model, x, Xoshiro(0))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have Random in deps, so we would need to either add as a passive dep or make the extension available only if Random is loaded (which would be quite confusing). I opted for the passive dep, but I'm not sure if it's okay to load the package only in the extension.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind making the default in LuxCore as outputsize(layer, x, rng = _default_rng())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, I forgot that I can also use stuff from LuxCore in here 😅

eltype = Real
end

end
86 changes: 86 additions & 0 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Symbolics, Lux, Random, Test
using ComponentArrays

@testset "Dense" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, _ = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5]

out_ref = LuxCore.stateless_apply(model, x, ps)
@test out_ref isa Vector{Float32}

out = LuxCore.stateless_apply(model, sym_x, ps)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 6
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ps)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, ca)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ca)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref
end

@testset "Chain" begin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add some tests for Layers where we haven't defined outputsize to check that our fallback works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems a bit tricky, as the outputsize fallback receives a symbolic array that's not an AbstractArray, so we might need to register additional functions, but that would need to make the extension on Lux and not just LuxCore.

I tried changing the extension to Lux and registering apply for Dense, but I hit applyparallel and it tries to iterate SymbolicUtils.BasicSymbolic{AbstractVector{Real}}, which fails. I'm not sure if we can continue on this path.

@ChrisRackauckas Should I leave the extension for LuxCore only or change to Lux? There's no immediate benefit for doing this, but it might make it easier if we wan to support more complex networks in the future, not sure how though, as the fact that we don't have an AbstractArray seems like a deal breaker.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, leave as LuxCore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, in that case this should be good to go. Can you trigger / allow CI?

model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3))
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, _ = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5]

out_ref = LuxCore.stateless_apply(model, x, ps)
@test out_ref isa Vector{Float32}

out = LuxCore.stateless_apply(model, sym_x, ps)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 3
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ps)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, ca)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ca)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end

Expand Down
Loading