Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array function registration for Lux #1054

Merged
merged 10 commits into from
Mar 3, 2024
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsPreallocationToolsExt = "PreallocationTools"
SymbolicsSymPyExt = "SymPy"

Expand All @@ -65,6 +67,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.8"
MacroTools = "0.5"
NaNMath = "1"
PrecompileTools = "1"
Expand All @@ -83,8 +86,10 @@ julia = "1.10"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -95,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"]
17 changes: 17 additions & 0 deletions ext/SymbolicsLuxCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module SymbolicsLuxCoreExt

using LuxCore, Symbolics

@register_array_symbolic LuxCore.partial_apply(
model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin
size = LuxCore.outputsize(model[end])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model[end] is not well defined, you need to query the output size of the entire model. Consider:

Parallel(vcat, Dense(2 => 2), Dense(2 => 3)), even if you had a sense of "end" the individual layers can't reason about the output size of the parent layer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially if done right you shouldn't need this dispatch at all, just the AbstractExplicitLayer one should be enough

Copy link
Contributor Author

@SebastianM-C SebastianM-C Feb 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should outputsize be defined recursively in Lux?
I think it would make more sense to have it for each container.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will have to rely on users providing the correct outputsize function, it is quite hard to do it in the general case, since you need to trace all the internal operations (in principle if you have outputsize for all primitives Symbolics should be able to compute the output size of the Chain/Parallel or any container layer).

But for now, I would recommend defining the outputsize of the common Lux layers in containers.jl

Copy link
Contributor Author

@SebastianM-C SebastianM-C Feb 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to reason about the size of the containers, can I assume that the connection in SkipConnection and Parallel is stateless? I was wondering if it would be correct to apply the connection function to the layers to figure out the output size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's just rely on the interface and not try to go around it. So then I presume that means size = LuxCore.outputsize(model) and if the user needs to specify this function they will get an error and that's fine. I'd rather stay within the standard API rather than try to do something hacky that only works partially.

eltype = Real
end

@register_array_symbolic LuxCore.partial_apply(
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin
size = LuxCore.outputsize(model)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to change based on the new LuxCore PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to get the rng here 🤔

Copy link
Contributor Author

@SebastianM-C SebastianM-C Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already works for layers & layer containers that can define the 1 arg outputsize, like Dense and Chain, but it would be nice to be able to use the general form.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just pass Xoshiro?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LuxCore.outputsize(model, x, Xoshiro(0))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have Random in deps, so we would need to either add as a passive dep or make the extension available only if Random is loaded (which would be quite confusing). I opted for the passive dep, but I'm not sure if it's okay to load the package only in the extension.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind making the default in LuxCore as outputsize(layer, x, rng = _default_rng())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, I forgot that I can also use stuff from LuxCore in here 😅

eltype = Real
end

end
124 changes: 124 additions & 0 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
using Symbolics, Lux, Random, Test
using ComponentArrays

@testset "Dense" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, st = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_st::typeof(st) = st
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5]

out_ref = LuxCore.partial_apply(model, x, ps, st)
@test out_ref isa Vector{Float32}

out = LuxCore.partial_apply(model, sym_x, ps, st)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 6
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps=>ps),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ca, st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),)))
@test out_sub == out_ref
end

@testset "Chain" begin

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add some tests for Layers where we haven't defined outputsize to check that our fallback works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems a bit tricky, as the outputsize fallback receives a symbolic array that's not an AbstractArray, so we might need to register additional functions, but that would need to make the extension on Lux and not just LuxCore.

I tried changing the extension to Lux and registering apply for Dense, but I hit applyparallel and it tries to iterate SymbolicUtils.BasicSymbolic{AbstractVector{Real}}, which fails. I'm not sure if we can continue on this path.

@ChrisRackauckas Should I leave the extension for LuxCore only or change to Lux? There's no immediate benefit for doing this, but it might make it easier if we wan to support more complex networks in the future, not sure how though, as the fact that we don't have an AbstractArray seems like a deal breaker.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, leave as LuxCore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, in that case this should be good to go. Can you trigger / allow CI?

model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3))
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, st = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_st::typeof(st) = st
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5]

out_ref = LuxCore.partial_apply(model, x, ps, st)
@test out_ref isa Vector{Float32}

out = LuxCore.partial_apply(model, sym_x, ps, st)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 3
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, ca, st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref

out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),)))
@test out_sub == out_ref
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end

Expand Down
Loading