Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Add stateless_apply #18

Merged
merged 9 commits into from
Feb 25, 2024
Merged

Add stateless_apply #18

merged 9 commits into from
Feb 25, 2024

Conversation

SebastianM-C
Copy link
Contributor

@SebastianM-C SebastianM-C commented Feb 14, 2024

This PR adds a small function, partial_apply, which is needed to implement an extension in Symbolics for LuxCore (see JuliaSymbolics/Symbolics.jl#1054), allowing the registration of the application of the layer as a vectror function. The important aspect here is that we need the size of the return type of the function, which is why we only return the first argument from apply.

Let me know what you think about this approach.

cc @ChrisRackauckas

Edit: renamed to stateless_apply.

@avik-pal
Copy link
Member

If Symbolics is the only place this will be used, doesn't it make more sense to define it there exclusively?

function partial_apply end

in the main package, and then place the code you wrote here in the extension

@SebastianM-C
Copy link
Contributor Author

I was thinking that this could also be used when one does not need st outside of the use with Symbolics.

For example https://docs.sciml.ai/Overview/stable/showcase/missing_physics/#Definition-of-the-Universal-Differential-Equation could change the UDE prediction from

= U(u, p, _st)[1] # Network prediction

to

= LuxCore.partial_apply(U, u, p, _st) # Network prediction

This could also be reexported by Lux for discoverability / ease of use.

@ChrisRackauckas
Copy link

To be a nice tool for Lux, it would be nice if it varified the network was stateless, since this is useful only for stateless networks but for such a case it's really common

@SebastianM-C
Copy link
Contributor Author

it would be nice if it varified the network was stateless

Would something like this be a good way to check?

out, st = apply(model, x, ps, st)
@assert isempty(st) "The passed model is not stateless, please use `apply` instead."
out

@avik-pal
Copy link
Member

avik-pal commented Feb 14, 2024

There is https://lux.csail.mit.edu/dev/api/Lux/contrib#stateful-layer for doing this exact thing. (I have been meaning to move it out of contrib soon)

Also maybe if we define it here, partial_apply seems to imply partial application of the arguments which is clearly not the intended wording

@SebastianM-C
Copy link
Contributor Author

Yeah, I was thinking that's not the best name 😅
What should I use instead? stateless_apply?

@SebastianM-C
Copy link
Contributor Author

There is https://lux.csail.mit.edu/dev/api/Lux/contrib#stateful-layer for doing this exact thing. (I have been meaning to move it out of contrib soon)

Ah, so if the dispatch is on Lux.AbstractExplicitLayer, that should be enough and we don't need the isempty(st) check?

Copy link

codecov bot commented Feb 17, 2024

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 88.00%. Comparing base (62f52f3) to head (412c055).

❗ Current head 412c055 differs from pull request most recent head 8d9b57e. Consider uploading reports for the commit 8d9b57e to get more accurate results

Files Patch % Lines
src/LuxCore.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #18      +/-   ##
==========================================
- Coverage   90.41%   88.00%   -2.42%     
==========================================
  Files           1        1              
  Lines          73       75       +2     
==========================================
  Hits           66       66              
- Misses          7        9       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SebastianM-C SebastianM-C changed the title Add partial_apply Add stateless_apply Feb 23, 2024
This calls `apply` and only returns the first argument.
@SebastianM-C
Copy link
Contributor Author

I renamed the function to stateless_apply and I've added error messages for stateful cases. Does this look good now?

src/LuxCore.jl Outdated
Comment on lines 134 to 142
function stateless_apply(model::AbstractExplicitLayer, x, ps, st)
return first(apply(model, x, ps, st))
end

function stateless_apply(model, x, ps, st)
u, st = apply(model, x, ps, st)
@assert isempty(st) "Model is not stateless. Use `apply` instead."
return u
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If st must be empty by definition, should we not define it as

function stateless_apply(model::AbstractExplicitLayer, x, ps)
	y, st = apply(model, x, ps, NamedTuple())
	@assert isempty(st) "Model is not stateless. use `apply` instead."
    return y
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the fallback method with model::Any would it make sense to use st=NamedTuple in the method signature?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there needs to be a model::Any dispatch at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly because we are very clear on https://lux.csail.mit.edu/dev/manual/interface that if it is not subtyped then most things won't be defined.

@avik-pal
Copy link
Member

Don't worry about the GPU CI for now, it seems to be down due to storage issues

@SebastianM-C SebastianM-C force-pushed the symbolics_ext branch 3 times, most recently from 6ff9130 to 30a0f45 Compare February 24, 2024 21:28
@SebastianM-C
Copy link
Contributor Author

I think I got the AbstractExplicitContainerLayer implementation now 😅

src/LuxCore.jl Outdated Show resolved Hide resolved
@SebastianM-C
Copy link
Contributor Author

Would it make sense to have a get_state (or a better name) function that constructs the st and implement stateless_apply with the last argument as st=get_state(...)? I was thinking that in performance sensitive scenarios it might make sense to have the st cached instead of re-creating it all the time.

@avik-pal
Copy link
Member

So the st if created correctly is a compile-time constant, so it shouldn't have a performance implication.

But I am still unsure of what we get by having this as part of the API (and just being part of Symbolics), since I have seen very rare instances of anyone directly using apply except when creating custom layers.

@SebastianM-C
Copy link
Contributor Author

This stateless_apply is needed for defining the extension, but it can't be defined in the extension and it doesn't make much sense to define in Symbolics (since it can't be used in Symbolics outside of the extension), which is why I started this PR in the first place 😅

As for usages of apply, I think the UDE tutorial should use this, as mentioned in #18 (comment)

I'm not sure how relevant is that for Lux though.

src/LuxCore.jl Outdated Show resolved Hide resolved
src/LuxCore.jl Outdated Show resolved Hide resolved
src/LuxCore.jl Outdated Show resolved Hide resolved
src/LuxCore.jl Outdated Show resolved Hide resolved
src/LuxCore.jl Outdated Show resolved Hide resolved
@avik-pal
Copy link
Member

getstate is internal so naming it _getstate. Maybe we should name it _getemptystate instead of _getstate.

Can you take a relatively complex Lux model, and check that _getstate is a compile time constant?

@SebastianM-C
Copy link
Contributor Author

Maybe we should name it _getemptystate instead of _getstate.

I was wondering if _getstate would be useful for stateful cases, where one could add a method for that, but again, I don't know it there are any applications where this would be useful.

Can you take a relatively complex Lux model, and check that _getstate is a compile time constant?

Do you have any suggestions? I tried

model = Lux.Chain(Lux.Dense(3,3), Parallel(+, Lux.Dense(3,3), Lux.Dense(3,3)))

locally, but I'm not sure if that's good enough.

Apply suggestions from code review

Co-authored-by: Avik Pal <[email protected]>
@avik-pal
Copy link
Member

For stateful cases that is just initialstates

@avik-pal
Copy link
Member

use a couple more nesting for chain and turn off optimizations (see the kwarg in https://lux.csail.mit.edu/dev/api/Lux/layers#Lux.Chain)

Project.toml Outdated Show resolved Hide resolved
Co-authored-by: Avik Pal <[email protected]>
@SebastianM-C
Copy link
Contributor Author

julia> model = Lux.Chain(Lux.Dense(3,3), Parallel(+, Lux.Dense(3,4), Parallel(+, Lux.Dense(3,4), Parallel(+, Lux.Dense(3,4), Lux.Dense(3,4)), Lux.Dense(3,4)), Lux.Dense(3,4)), disable_optimizations=true)
Chain(
    layer_1 = Dense(3 => 3),            # 12 parameters
    layer_2 = Parallel(
        +
        Dense(3 => 4),                  # 16 parameters
        Parallel(
            +
            Dense(3 => 4),              # 16 parameters
            Parallel(
                +
                Dense(3 => 4),          # 16 parameters
                Dense(3 => 4),          # 16 parameters
            ),
            Dense(3 => 4),              # 16 parameters
        ),
        Dense(3 => 4),                  # 16 parameters
    ),
)         # Total: 108 parameters,
          #        plus 0 states.

julia> x = rand(3)
3-element Vector{Float64}:
 0.5396312192884193
 0.2939048617340425
 0.24219473008713432

julia> ps, st = LuxCore.setup(rng, model)
((layer_1 = (weight = Float32[-0.5638304 0.7064836 -0.68545914; 0.9561393 -0.063563704 -0.115722895; 0.24014795 0.41213036 -0.54157174], bias = Float32[0.0; 0.0; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[0.1435352 0.12118944 0.71831423; 0.038405616 0.57530284 0.88325536; -0.32796544 -0.6643518 0.2772332; 0.818112 -0.06508801 0.00029843062], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[0.0021485018 -0.27785414 0.61988515; 0.39966893 -0.22878593 -0.6069017; 0.78745824 -0.33321932 0.36335143; -0.5742856 0.1219567 -0.31968322], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[0.7534941 -0.603465 0.042327706; 0.24461655 0.42317894 -0.6764314; -0.46775734 -0.21744843 0.7626165; -0.2330453 -0.92067915 0.059065647], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.16759352 -0.121825255 0.85951114; -0.7903339 0.75350535 -0.24590242; -0.0746273 -0.20504689 -0.61860776; 0.0032513929 -0.70960593 0.7057774], bias = Float32[0.0; 0.0; 0.0; 0.0;;])), layer_3 = (weight = Float32[-0.42817247 -0.8049099 -0.3437964; 0.31827924 0.83317935 -0.28295153; -0.7879606 -0.43561655 0.04584453; -0.037287604 -0.8772459 0.5911436], bias = Float32[0.0; 0.0; 0.0; 0.0;;])), layer_3 = (weight = Float32[0.097560994 -0.7195461 0.3436186; 0.5422182 -0.052859966 0.81801933; 0.25912023 -0.013861904 0.33478278; 0.86514133 -0.56700957 0.31364506], bias = Float32[0.0; 0.0; 0.0; 0.0;;]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = NamedTuple()), layer_3 = NamedTuple())))

julia> LuxCore.apply(model, x, ps, st)
([-1.054777648714576, 0.869948369838022, -0.5773222915067829, -1.475737151147033], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = NamedTuple()), layer_3 = NamedTuple())))

julia> LuxCore.stateless_apply(model, x, ps)
4-element Vector{Float64}:
 -1.054777648714576
  0.869948369838022
 -0.5773222915067829
 -1.475737151147033

works after switching the getfield to use names instead of Ints, as that fails on Parallel.

@avik-pal
Copy link
Member

Tested that it infers locally:

julia> @code_typed LuxCore._getemptystate(model)
CodeInfo(
1return $(QuoteNode((layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = NamedTuple()), layer_3 = NamedTuple()))))
) => @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, layer_3::@NamedTuple{}}, layer_3::@NamedTuple{}}}

@avik-pal avik-pal merged commit 1bea85a into LuxDL:main Feb 25, 2024
7 of 8 checks passed
@SebastianM-C SebastianM-C deleted the symbolics_ext branch February 25, 2024 00:40
@avik-pal
Copy link
Member

Do you want me to release this or wait till Symbolics Ext is ready?

@SebastianM-C
Copy link
Contributor Author

It would be useful to have this released so that I can check the extension in CI too.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants