Skip to content

Commit

Permalink
Merge pull request #1054 from SebastianM-C/luxcore_ext
Browse files Browse the repository at this point in the history
Add array function registration for Lux
  • Loading branch information
ChrisRackauckas authored Mar 3, 2024
2 parents 1e4635a + d3ebacd commit d24b4c6
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <[email protected]>"]
version = "5.22.1"
version = "5.23.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -40,11 +40,13 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsPreallocationToolsExt = "PreallocationTools"
SymbolicsSymPyExt = "SymPy"

Expand All @@ -65,6 +67,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.11"
MacroTools = "0.5"
NaNMath = "1"
PrecompileTools = "1"
Expand All @@ -83,8 +86,10 @@ julia = "1.10"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -95,4 +100,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux", "ComponentArrays"]
11 changes: 11 additions & 0 deletions ext/SymbolicsLuxCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module SymbolicsLuxCoreExt

using LuxCore, Symbolics

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model, x, LuxCore._default_rng())
eltype = Real
end

end
86 changes: 86 additions & 0 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Symbolics, Lux, Random, Test
using ComponentArrays

@testset "Dense" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, _ = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5]

out_ref = LuxCore.stateless_apply(model, x, ps)
@test out_ref isa Vector{Float32}

out = LuxCore.stateless_apply(model, sym_x, ps)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 6
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ps)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, ca)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ca)
@test out isa Symbolics.Arr
@test length(out) == 6
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref
end

@testset "Chain" begin
model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3))
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, _ = LuxCore.setup(rng, model)

ca = ComponentArray(ps)
Symbolics.@variables sym_ca[1:length(ca)] = ca
Symbolics.@variables sym_ps::typeof(ps) = ps
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5]

out_ref = LuxCore.stateless_apply(model, x, ps)
@test out_ref isa Vector{Float32}

out = LuxCore.stateless_apply(model, sym_x, ps)
# if the symbolic function regstration wouldn't work, we'd get the
# vector propagated through the neural network and get a Vector{Num}
@test out isa Symbolics.Arr
@test length(out) == 3
# test that we can recover the same value as when using concrete numbers
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ps)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, ca)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),)))
@test out_sub == out_ref

out = LuxCore.stateless_apply(model, sym_x, sym_ca)
@test out isa Symbolics.Arr
@test length(out) == 3
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),)))
@test out_sub == out_ref
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end

Expand Down

0 comments on commit d24b4c6

Please sign in to comment.