-
Notifications
You must be signed in to change notification settings - Fork 3
Conversation
If Symbolics is the only place this will be used, doesn't it make more sense to define it there exclusively? function partial_apply end in the main package, and then place the code you wrote here in the extension |
I was thinking that this could also be used when one does not need For example https://docs.sciml.ai/Overview/stable/showcase/missing_physics/#Definition-of-the-Universal-Differential-Equation could change the UDE prediction from û = U(u, p, _st)[1] # Network prediction to û = LuxCore.partial_apply(U, u, p, _st) # Network prediction This could also be reexported by Lux for discoverability / ease of use. |
To be a nice tool for Lux, it would be nice if it varified the network was stateless, since this is useful only for stateless networks but for such a case it's really common |
Would something like this be a good way to check? out, st = apply(model, x, ps, st)
@assert isempty(st) "The passed model is not stateless, please use `apply` instead."
out |
There is https://lux.csail.mit.edu/dev/api/Lux/contrib#stateful-layer for doing this exact thing. (I have been meaning to move it out of contrib soon) Also maybe if we define it here, |
Yeah, I was thinking that's not the best name 😅 |
Ah, so if the dispatch is on |
fd9a91f
to
412c055
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #18 +/- ##
==========================================
- Coverage 90.41% 88.00% -2.42%
==========================================
Files 1 1
Lines 73 75 +2
==========================================
Hits 66 66
- Misses 7 9 +2 ☔ View full report in Codecov by Sentry. |
412c055
to
29f2c2a
Compare
This calls `apply` and only returns the first argument.
29f2c2a
to
8d9b57e
Compare
I renamed the function to |
c5bfa7e
to
5611f0e
Compare
src/LuxCore.jl
Outdated
function stateless_apply(model::AbstractExplicitLayer, x, ps, st) | ||
return first(apply(model, x, ps, st)) | ||
end | ||
|
||
function stateless_apply(model, x, ps, st) | ||
u, st = apply(model, x, ps, st) | ||
@assert isempty(st) "Model is not stateless. Use `apply` instead." | ||
return u | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If st
must be empty by definition, should we not define it as
function stateless_apply(model::AbstractExplicitLayer, x, ps)
y, st = apply(model, x, ps, NamedTuple())
@assert isempty(st) "Model is not stateless. use `apply` instead."
return y
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the fallback method with model::Any
would it make sense to use st=NamedTuple
in the method signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there needs to be a model::Any
dispatch at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly because we are very clear on https://lux.csail.mit.edu/dev/manual/interface that if it is not subtyped then most things won't be defined.
Don't worry about the GPU CI for now, it seems to be down due to storage issues |
6ff9130
to
30a0f45
Compare
30a0f45
to
596daa6
Compare
I think I got the |
Would it make sense to have a |
So the But I am still unsure of what we get by having this as part of the API (and just being part of Symbolics), since I have seen very rare instances of anyone directly using |
This As for usages of I'm not sure how relevant is that for Lux though. |
Can you take a relatively complex Lux model, and check that |
802dfe2
to
7665825
Compare
I was wondering if
Do you have any suggestions? I tried
locally, but I'm not sure if that's good enough. |
Apply suggestions from code review Co-authored-by: Avik Pal <[email protected]>
7665825
to
a777715
Compare
For stateful cases that is just |
use a couple more nesting for chain and turn off optimizations (see the kwarg in https://lux.csail.mit.edu/dev/api/Lux/layers#Lux.Chain) |
Co-authored-by: Avik Pal <[email protected]>
0842d61
to
e1344ac
Compare
julia> model = Lux.Chain(Lux.Dense(3,3), Parallel(+, Lux.Dense(3,4), Parallel(+, Lux.Dense(3,4), Parallel(+, Lux.Dense(3,4), Lux.Dense(3,4)), Lux.Dense(3,4)), Lux.Dense(3,4)), disable_optimizations=true)
Chain(
layer_1 = Dense(3 => 3), # 12 parameters
layer_2 = Parallel(
+
Dense(3 => 4), # 16 parameters
Parallel(
+
Dense(3 => 4), # 16 parameters
Parallel(
+
Dense(3 => 4), # 16 parameters
Dense(3 => 4), # 16 parameters
),
Dense(3 => 4), # 16 parameters
),
Dense(3 => 4), # 16 parameters
),
) # Total: 108 parameters,
# plus 0 states.
julia> x = rand(3)
3-element Vector{Float64}:
0.5396312192884193
0.2939048617340425
0.24219473008713432
julia> ps, st = LuxCore.setup(rng, model)
((layer_1 = (weight = Float32[-0.5638304 0.7064836 -0.68545914; 0.9561393 -0.063563704 -0.115722895; 0.24014795 0.41213036 -0.54157174], bias = Float32[0.0; 0.0; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[0.1435352 0.12118944 0.71831423; 0.038405616 0.57530284 0.88325536; -0.32796544 -0.6643518 0.2772332; 0.818112 -0.06508801 0.00029843062], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[0.0021485018 -0.27785414 0.61988515; 0.39966893 -0.22878593 -0.6069017; 0.78745824 -0.33321932 0.36335143; -0.5742856 0.1219567 -0.31968322], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[0.7534941 -0.603465 0.042327706; 0.24461655 0.42317894 -0.6764314; -0.46775734 -0.21744843 0.7626165; -0.2330453 -0.92067915 0.059065647], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.16759352 -0.121825255 0.85951114; -0.7903339 0.75350535 -0.24590242; -0.0746273 -0.20504689 -0.61860776; 0.0032513929 -0.70960593 0.7057774], bias = Float32[0.0; 0.0; 0.0; 0.0;;])), layer_3 = (weight = Float32[-0.42817247 -0.8049099 -0.3437964; 0.31827924 0.83317935 -0.28295153; -0.7879606 -0.43561655 0.04584453; -0.037287604 -0.8772459 0.5911436], bias = Float32[0.0; 0.0; 0.0; 0.0;;])), layer_3 = (weight = Float32[0.097560994 -0.7195461 0.3436186; 0.5422182 -0.052859966 0.81801933; 0.25912023 -0.013861904 0.33478278; 0.86514133 -0.56700957 0.31364506], bias = Float32[0.0; 0.0; 0.0; 0.0;;]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = NamedTuple()), layer_3 = NamedTuple())))
julia> LuxCore.apply(model, x, ps, st)
([-1.054777648714576, 0.869948369838022, -0.5773222915067829, -1.475737151147033], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = NamedTuple()), layer_3 = NamedTuple())))
julia> LuxCore.stateless_apply(model, x, ps)
4-element Vector{Float64}:
-1.054777648714576
0.869948369838022
-0.5773222915067829
-1.475737151147033 works after switching the |
Tested that it infers locally: julia> @code_typed LuxCore._getemptystate(model)
CodeInfo(
1 ─ return $(QuoteNode((layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = NamedTuple()), layer_3 = NamedTuple()))))
) => @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, layer_3::@NamedTuple{}}, layer_3::@NamedTuple{}}} |
Do you want me to release this or wait till Symbolics Ext is ready? |
It would be useful to have this released so that I can check the extension in CI too. |
This PR adds a small function,
partial_apply
, which is needed to implement an extension in Symbolics for LuxCore (see JuliaSymbolics/Symbolics.jl#1054), allowing the registration of the application of the layer as a vectror function. The important aspect here is that we need the size of the return type of the function, which is why we only return the first argument fromapply
.Let me know what you think about this approach.
cc @ChrisRackauckas
Edit: renamed to
stateless_apply
.