-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add array function registration for Lux #1054
Changes from 7 commits
2f25742
9dbd42d
356e4c9
639a354
26ddc9c
2577243
a6f88bc
78a2732
5dfa5f5
d3ebacd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
module SymbolicsLuxCoreExt | ||
|
||
using LuxCore, Symbolics | ||
|
||
@register_array_symbolic LuxCore.stateless_apply( | ||
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin | ||
size = LuxCore.outputsize(model) | ||
eltype = Real | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
using Symbolics, Lux, Random, Test | ||
using ComponentArrays | ||
|
||
@testset "Dense" begin | ||
model = Dense(5, 6) | ||
rng = Random.default_rng() | ||
x = randn(rng, Float32, 5) | ||
ps, _ = LuxCore.setup(rng, model) | ||
|
||
ca = ComponentArray(ps) | ||
Symbolics.@variables sym_ca[1:length(ca)] = ca | ||
Symbolics.@variables sym_ps::typeof(ps) = ps | ||
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5] | ||
|
||
out_ref = LuxCore.stateless_apply(model, x, ps) | ||
@test out_ref isa Vector{Float32} | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, ps) | ||
# if the symbolic function regstration wouldn't work, we'd get the | ||
# vector propagated through the neural network and get a Vector{Num} | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
# test that we can recover the same value as when using concrete numbers | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, sym_ps) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, ca) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, sym_ca) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) | ||
@test out_sub == out_ref | ||
end | ||
|
||
@testset "Chain" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also add some tests for Layers where we haven't defined There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one seems a bit tricky, as the I tried changing the extension to Lux and registering @ChrisRackauckas Should I leave the extension for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, leave as LuxCore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, in that case this should be good to go. Can you trigger / allow CI? |
||
model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3)) | ||
rng = Random.default_rng() | ||
x = randn(rng, Float32, 5) | ||
ps, _ = LuxCore.setup(rng, model) | ||
|
||
ca = ComponentArray(ps) | ||
Symbolics.@variables sym_ca[1:length(ca)] = ca | ||
Symbolics.@variables sym_ps::typeof(ps) = ps | ||
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5] | ||
|
||
out_ref = LuxCore.stateless_apply(model, x, ps) | ||
@test out_ref isa Vector{Float32} | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, ps) | ||
# if the symbolic function regstration wouldn't work, we'd get the | ||
# vector propagated through the neural network and get a Vector{Num} | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
# test that we can recover the same value as when using concrete numbers | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, sym_ps) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, ca) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.stateless_apply(model, sym_x, sym_ca) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) | ||
@test out_sub == out_ref | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to change based on the new LuxCore PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how to get the
rng
here 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This already works for layers & layer containers that can define the 1 arg
outputsize
, likeDense
andChain
, but it would be nice to be able to use the general form.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just pass
Xoshiro
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LuxCore.outputsize(model, x, Xoshiro(0))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have
Random
in deps, so we would need to either add as a passive dep or make the extension available only ifRandom
is loaded (which would be quite confusing). I opted for the passive dep, but I'm not sure if it's okay to load the package only in the extension.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind making the default in LuxCore as
outputsize(layer, x, rng = _default_rng())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need, I forgot that I can also use stuff from LuxCore in here 😅