Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getkeypath and layer_map not fully working with model with Parallel layers #1068

Closed
dmetivie opened this issue Nov 13, 2024 · 6 comments · Fixed by #1115
Closed

getkeypath and layer_map not fully working with model with Parallel layers #1068

dmetivie opened this issue Nov 13, 2024 · 6 comments · Fixed by #1115

Comments

@dmetivie
Copy link
Contributor

Trying the example of layer_map here, I wonder how to get back a specific layer given a KeyPath.
In the example doing on the parameters ps

getkeypath(ps, KeyPath(:chain, :dense_1))

works, however, with the model c it does not work

getkeypath(c, KeyPath(:chain, :dense_1))
ERROR: type Parallel has no field chain

I wondered if you intended this to work (as it would be very convenient to target specific layer and get parameters (with ps) or types (with c).
Note that doing

getkeypath(c.layers, KeyPath(:chain, :dense_1))
ERROR: type Parallel has no field chain

works and on regular layers too.
It looks like a dispatch like getkeypath(c::Lux.Parallel, kp) = getkeypath(c.layers, kp) could do the job (however it does not work directly).

@dmetivie
Copy link
Contributor Author

dmetivie commented Nov 13, 2024

Another issue I think is related as keys of layer and ps are treated differently is that layer_map fails with MaxPool layers.
Using the chain from MNIST tuto

c = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
_, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
ERROR: ArgumentError: keys(layer_children) == keys(ps_children) must hold. Got
keys(layer_children) => Base.OneTo(0)
keys(ps_children) => ()

The same example works when removing the two MaxPool.

@avik-pal
Copy link
Member

This is mostly by design of AbstractLuxWrapperLayer. See

function (::LayerWalkWithPath)(
recurse::R, kp::KeyPath, layer::AbstractLuxWrapperLayer{field},
ps, st) where {R, field}
layer_children, layer_re = functor(getfield(layer, field))
ps_children, ps_re = functor(ps)
st_children, st_re = functor(st)
layer_children_new, ps_children_new, st_children_new = perform_layer_map(
recurse, kp, ps_children, st_children, layer_children)
inner_layer = layer_re(layer_children_new)
return (Setfield.set(layer, Setfield.PropertyLens{field}(), inner_layer),
ps_re(ps_children_new), st_re(st_children_new))
end
for how we do the traversal.

That said can you tell about your specific usecase? If you are trying to debug something then https://lux.csail.mit.edu/stable/api/Lux/contrib#Lux.Experimental.@debug_mode should print out the exact path. As for layer_map the layer, ps and st are already available directly to the input function

@dmetivie
Copy link
Contributor Author

Thanks, I did look up quite a bit into map.jl, but could not understand everything.
To me, it looks like a bug that the layer_map errors with MaxPool layers.

My use case: I am trying to implement a Concrete Dropout (CD) layer. It is basically a Dropout with trainable rate.
See here for PyTorch and TensorFlow implementation. I tried a Julia implementation with Flux and lately Lux, but I struggle with a few things.
First, unrelated, I wanted to use Package extension to load conditionally the FLux or Lux version. I did not succeed, so last version is just Lux.

I implemented the CD layer, however in the original implementation they add a regularization term in the loss that depends on

  • the L2 norm of the weights (from a dense or conv where the CD is applied)
  • A term depending on the rate value and shape of dense or conv

To get automatically the relevant layers path i.e. all CD layers and the layer just before (where CD is applied), I wanted to design a layer_map like function to call before training.
This path function will enable easy access to these layers coefficients during training i.e. weights and CD rates.
At time pre v1.0, I used @layer_map, and it was working with a lot of hack (the path was a string if I remembered correctly). Here is the code pre v1.0 Lux.

To update post v1.0 I tried

function get_key_type!(kp_cd, kp_layer, t_layer, l, ps, st, name, name_prev, t_prev)
  if l isa Dropout 
    # here example just with Dropout so anyone can test without `ConcreteDropout.jl`
    push!(kp_cd, name)  
    push!(kp_layer, name_prev)
    push!(t_layer, t_prev)
  end
  return l, ps, st
end

function layer_map_with_previous(l, ps, st)
  kp_cd = KeyPath[]
  kp_layer = KeyPath[]
  t_layer = AbstractLuxLayer[]
  kp_prev = KeyPath(1)
  t_prev = Dense(1=>1)
  Lux.Functors.fmap_with_path(l, ps, st; walk=Lux.Experimental.LayerWalkWithPath()) do kp, layer, ps_, st_
    l__, ps__, st__ = get_key_type!(kp_cd, kp_layer, t_layer, layer, ps_, st_, kp, kp_prev, t_prev)
    kp_prev = kp
    t_prev = layer
    return l__, ps__, st__ # needed for the code not to error but useless here
  end
  return kp_cd, kp_layer, t_layer
end
m = Chain(
Dense(10=>100), 
Dropout(0.5), 
Dense(100=>2)
)
ps, st = Lux.setup(rng, m)
key_CD, key_layer_before, type_of_layer_before = layer_map_with_previous(m, ps, st)

Now, we can get all the weights to put in the loss with getkeypath

getkeypath(ps, key_layer_before[1]).weight

This work as intended, but with a MaxPool layer (maybe others?) it does not.

There is probably a simpler way to code all that.
Do you have any idea how to do that?
BTW I don't know where to put this layer (probably to specific to directly put in Lux.jl)

@avik-pal
Copy link
Member

To me, it looks like a bug that the layer_map errors with MaxPool layers.

I would agree, pooling layers are implemented slightly differently, so this is possible. I will dig into this.

BTW I don't know where to put this layer (probably to specific to directly put in Lux.jl)

https://github.com/LuxDL/Boltz.jl is exactly the repo for these forms of layers

I implemented the CD layer, however in the original implementation they add a regularization term in the loss that depends on

Would it be possible to implement this as a AbstractLuxWrapperLayer over a Conv/Dense similar to how WeightNorm is implemented?

@dmetivie
Copy link
Contributor Author

Thanks. Note that @debug_mode also fails with MaxPool layers.

Ok I'll try PR to Boltz when I succeed.
I see how AbstractLuxWrapperLayer can do the Concrete Dropout layers, however there is still the issue of automatically extracting the weights and dropout rate for the penalization.

@avik-pal
Copy link
Member

avik-pal commented Dec 3, 2024

layer_map and @debug_mode should work with pooling layers after #1115

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants