Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate Flux.params #2413

Closed
CarloLucibello opened this issue Mar 26, 2024 · 7 comments · Fixed by #2495
Closed

deprecate Flux.params #2413

CarloLucibello opened this issue Mar 26, 2024 · 7 comments · Fixed by #2495
Milestone

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Mar 26, 2024

Is there any reason why we keep it around?

For the need of having a vector or iterable over trainable leaves we can build something ( trainables(model) ?) on top of Functors.fleaves so that we have a function decoupled from Zygote.

@CarloLucibello CarloLucibello added this to the v0.15 milestone Mar 26, 2024
@ToucheSir
Copy link
Member

I think we were waiting for a couple more features to land so we could have parity with some of the remaining use cases people might use implicit params for. FluxML/Optimisers.jl#57 is the main one I can think of.

@darsnack
Copy link
Member

I think that's the only one left

@kishore-nori
Copy link

Would there be an alternative way to perform copy! between a flat vector and a Params like object, or even probably directly into nn (a Flux.Chain), something like copy!(x, nn) and copy!(nn, x)?

Along these lines, I also wanted to ask if Flux.jl would have ComponentArrays used similar to Lux.jl? And would it be optional like Lux.jl with NamedTuple being default for parameters?

@mcabbott
Copy link
Member

mcabbott commented Mar 30, 2024

That already exists, roughly:

julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1));

julia> st = Flux.state(model)
(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())),)

julia> Flux.loadmodel!(model, st);  # this is a nested copyto!

julia> using ComponentArrays

julia> ca = ComponentArray(; Flux.state(model)...)
ComponentVector{Tuple{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}}}(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())))

julia> ca.layers[1].weight .= NaN
1×2 Matrix{Float32}:
 NaN  NaN

julia> Flux.loadmodel!(model, ca)
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters  (some NaN)
  Dense(1 => 1),                        # 2 parameters
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

The caveats are (1) what Flux.state returns includes non-trainable parameters, (2) I've no idea what'll happen to shared parameters, ComponentArrays ignores them, and (3) this is designed for loading from disk not for use within gradients, so Zygote may hate it, but that's fixable. (Edit, (4) my use of ComponentArray does not seem to produce something backed by one big vector, e.g. getfield(ca, :data), maybe I need to read their docs.)

Flux.loadmodel! is for nested structures, we also have Flux.destructure which is about flat vectors of parameters (and should respect points 1,2,3).

Possibly OT here. But perhaps worth opening an issue... perhaps with an example of what you wish would work?

@kishore-nori
Copy link

kishore-nori commented Apr 3, 2024

Hi Michael, thanks a lot for the detailed reply (and sorry for the delay in my reply), I wasn't aware of Flux.State. My use case has been to use Flux.jl with Optim.jl which requires a flat vector, so with Flux.Params I could use the existing copy! provided by Zygote.jl (earlier from FluxOptTools.jl) between Flux.Params and flat vector, and this was useful also to convert the gradient into a flat vector for Optim.jl, of course all the usage of copy! was outside Zygote's over-watch.

Now, if I understand correctly, I have to write my own copy! for conversion between Flux.State and flat vector object, and this would be useful also with the object (seems similar to st ) returned by Zygote gradient with the new Flux usage Zygote.gradient(loss, model), which is not very hard, but the problem like you mentioned - "(1) what Flux.state returns includes non-trainable parameters" needs to be tackled (does trainables(model) is intend to solve this issue?).

And with regards to destructure, it makes the whole process more expensive due to a new model created every single epoch, and I have observed this hurts performance, so I have kept it aside.

And with regards to ComponentArrays, I think it works for situations where we have nested NamedTuples, in case of a neural network a layer wise NamedTuple of NamedTuple but Flux.State doesn't return that but a Tuple of NamedTuples, hence the discrepancy observed above, but doesn't seem to be conceptually far away from intended usage.

So for now I can write a copy! between Flux.State and flat vector ignoring the non-trainable parameters, but would be happy to know if trainables(model) and ComponentArrays solutions work! Thanks a lot!

@CarloLucibello
Copy link
Member Author

Hi @kishore-nori, could you open a new issue and provide a specific example that we can reason on? Your case seems to be well served by destructure, if it's slow we should try to understand why.

@kishore-nori
Copy link

Sure will come up with a MWE and open an issue, thank you. By the way, I have realized that that idea of destructure! (FluxML/Optimisers.jl#165) would be really beneficial and fit well for my purpose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants