-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add array function registration for Lux #1054
Changes from 4 commits
2f25742
9dbd42d
356e4c9
639a354
26ddc9c
2577243
a6f88bc
78a2732
5dfa5f5
d3ebacd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module SymbolicsLuxCoreExt | ||
|
||
using LuxCore, Symbolics | ||
|
||
@register_array_symbolic LuxCore.partial_apply( | ||
model::LuxCore.AbstractExplicitContainerLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin | ||
size = LuxCore.outputsize(model[end]) | ||
eltype = Real | ||
end | ||
|
||
@register_array_symbolic LuxCore.partial_apply( | ||
model::LuxCore.AbstractExplicitLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}, st::NamedTuple) begin | ||
size = LuxCore.outputsize(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to change based on the new LuxCore PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how to get the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This already works for layers & layer containers that can define the 1 arg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LuxCore.outputsize(model, x, Xoshiro(0)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind making the default in LuxCore as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need, I forgot that I can also use stuff from LuxCore in here 😅 |
||
eltype = Real | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
using Symbolics, Lux, Random, Test | ||
using ComponentArrays | ||
|
||
@testset "Dense" begin | ||
model = Dense(5, 6) | ||
rng = Random.default_rng() | ||
x = randn(rng, Float32, 5) | ||
ps, st = LuxCore.setup(rng, model) | ||
|
||
ca = ComponentArray(ps) | ||
Symbolics.@variables sym_ca[1:length(ca)] = ca | ||
Symbolics.@variables sym_ps::typeof(ps) = ps | ||
Symbolics.@variables sym_st::typeof(st) = st | ||
Symbolics.@variables sym_x[1:5] = Float32[1,2,3,4,5] | ||
|
||
out_ref = LuxCore.partial_apply(model, x, ps, st) | ||
@test out_ref isa Vector{Float32} | ||
|
||
out = LuxCore.partial_apply(model, sym_x, ps, st) | ||
# if the symbolic function regstration wouldn't work, we'd get the | ||
# vector propagated through the neural network and get a Vector{Num} | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
# test that we can recover the same value as when using concrete numbers | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ps, st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps=>ps),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, ps, sym_st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, ca, st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ca, st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 6 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),))) | ||
@test out_sub == out_ref | ||
end | ||
|
||
@testset "Chain" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe also add some tests for Layers where we haven't defined There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one seems a bit tricky, as the I tried changing the extension to Lux and registering @ChrisRackauckas Should I leave the extension for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, leave as LuxCore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, in that case this should be good to go. Can you trigger / allow CI? |
||
model = Chain(Dense(5, 6), Dense(6, 2), Dense(2, 3)) | ||
rng = Random.default_rng() | ||
x = randn(rng, Float32, 5) | ||
ps, st = LuxCore.setup(rng, model) | ||
|
||
ca = ComponentArray(ps) | ||
Symbolics.@variables sym_ca[1:length(ca)] = ca | ||
Symbolics.@variables sym_ps::typeof(ps) = ps | ||
Symbolics.@variables sym_st::typeof(st) = st | ||
Symbolics.@variables sym_x[1:5] = Float32[1, 2, 3, 4, 5] | ||
|
||
out_ref = LuxCore.partial_apply(model, x, ps, st) | ||
@test out_ref isa Vector{Float32} | ||
|
||
out = LuxCore.partial_apply(model, sym_x, ps, st) | ||
# if the symbolic function regstration wouldn't work, we'd get the | ||
# vector propagated through the neural network and get a Vector{Num} | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
# test that we can recover the same value as when using concrete numbers | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ps, st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, ps, sym_st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_st => st),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ps, sym_st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ps => ps, sym_st => st),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, ca, st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x,),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ca, st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca),))) | ||
@test out_sub == out_ref | ||
|
||
out = LuxCore.partial_apply(model, sym_x, sym_ca, sym_st) | ||
@test out isa Symbolics.Arr | ||
@test length(out) == 3 | ||
out_sub = Symbolics.value.(Symbolics.substitute.(Symbolics.scalarize(out), (Dict(sym_x => x, sym_ca => ca, sym_st => st),))) | ||
@test out_sub == out_ref | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model[end]
is not well defined, you need to query the output size of the entiremodel
. Consider:Parallel(vcat, Dense(2 => 2), Dense(2 => 3))
, even if you had a sense of "end" the individual layers can't reason about the output size of the parent layerThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Essentially if done right you shouldn't need this dispatch at all, just the
AbstractExplicitLayer
one should be enoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldoutputsize
be defined recursively in Lux?I think it would make more sense to have it for each container.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you will have to rely on users providing the correct outputsize function, it is quite hard to do it in the general case, since you need to trace all the internal operations (in principle if you have outputsize for all primitives Symbolics should be able to compute the output size of the Chain/Parallel or any container layer).
But for now, I would recommend defining the outputsize of the common Lux layers in containers.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to reason about the size of the containers, can I assume that the
connection
inSkipConnection
andParallel
is stateless? I was wondering if it would be correct to apply the connection function to the layers to figure out the output size.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's just rely on the interface and not try to go around it. So then I presume that means
size = LuxCore.outputsize(model)
and if the user needs to specify this function they will get an error and that's fine. I'd rather stay within the standard API rather than try to do something hacky that only works partially.