Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array function registration for Lux #1054

Merged
merged 10 commits into from
Mar 3, 2024
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsPreallocationToolsExt = "PreallocationTools"
SymbolicsSymPyExt = "SymPy"

Expand All @@ -65,6 +67,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "0.1.7"
MacroTools = "0.5"
NaNMath = "1"
PrecompileTools = "1"
Expand All @@ -85,6 +88,7 @@ julia = "1.10"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -95,4 +99,4 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "ForwardDiff", "Groebner", "BenchmarkTools", "ReferenceTests", "SymPy", "Random", "Lux"]
17 changes: 17 additions & 0 deletions ext/SymbolicsLuxCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module SymbolicsLuxCoreExt

using LuxCore, Symbolics

@register_array_symbolic LuxCore.partial_apply(
model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin
size = ((model[end].out_dims),)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely won't work for anything other than Dense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated to LuxCore.outputsize.

eltype = Real
end

@register_array_symbolic LuxCore.partial_apply(
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::NamedTuple, st::NamedTuple) begin
size = ((model.out_dims),)
eltype = Real
end

end
13 changes: 13 additions & 0 deletions test/extensions/lux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Symbolics, Lux, Random, Test

@testset "Symbolics extension" begin
model = Dense(5, 6)
rng = Random.default_rng()
x = randn(rng, Float32, 5)
ps, st = LuxCore.setup(rng, model)

Symbolics.@variables sym_ps[1:5] = Float32[1, 2, 3, 4, 5]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the creation of parameter types which match ps and st as well.


out = LuxCore.partial_apply(model, sym_ps, ps, st)
@test out isa Symbolics.Arr
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Integral Test" begin include("integral.jl") end
@safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
@safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end
@safetestset "Registration without using Test" begin include("registration_without_using.jl") end
end

Expand Down
Loading