Skip to content

Commit

Permalink
fix: pass in exclude to layer_map
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 3, 2024
1 parent 069b237 commit 8cd616d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
11 changes: 3 additions & 8 deletions src/contrib/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,10 @@ Recurses into the `layer` and replaces the inner most non Container Layers with
See [`Lux.Experimental.DebugLayer`](@ref) for details about the Keyword Arguments.
"""
macro debug_mode(layer, kwargs...)
kws = esc.(kwargs)
return esc(:(
$(fmap_with_path)(
(kp, l) -> $(DebugLayer)(l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kws...)),
$(layer); exclude=$(debug_leaf))
(kp, l) -> $(DebugLayer)(
l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kwargs...)),
$(layer); exclude=$(layer_map_leaf))
))
end

debug_leaf(::KeyPath, ::AbstractLuxLayer) = true
debug_leaf(::KeyPath, ::AbstractLuxWrapperLayer) = false
debug_leaf(::KeyPath, ::AbstractLuxContainerLayer) = false
debug_leaf(::KeyPath, x) = Functors.isleaf(x)
8 changes: 7 additions & 1 deletion src/contrib/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ true
```
"""
function layer_map(f, l, ps, st)
return fmap_with_path(l, ps, st; walk=LayerWalkWithPath()) do kp, layer, ps_, st_
return fmap_with_path(
l, ps, st; walk=LayerWalkWithPath(), exclude=layer_map_leaf) do kp, layer, ps_, st_
return f(layer, ps_, st_, kp)
end
end
Expand Down Expand Up @@ -103,3 +104,8 @@ function perform_layer_map(recurse, kp, ps_children, st_children, layer_children

return layer_children_new, ps_children_new, st_children_new
end

layer_map_leaf(::KeyPath, ::AbstractLuxLayer) = true
layer_map_leaf(::KeyPath, ::AbstractLuxWrapperLayer) = false
layer_map_leaf(::KeyPath, ::AbstractLuxContainerLayer) = false
layer_map_leaf(::KeyPath, x) = Functors.isleaf(x)
6 changes: 3 additions & 3 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ for layer_op in (:Max, :Mean, :LP)
layer <: PoolingLayer
end

Experimental.debug_leaf(::KeyPath, ::$(layer_name)) = true
Experimental.layer_map_leaf(::KeyPath, ::$(layer_name)) = true

function $(layer_name)(
window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2)
Expand Down Expand Up @@ -226,7 +226,7 @@ for layer_op in (:Max, :Mean, :LP)
layer <: PoolingLayer
end

Experimental.debug_leaf(::KeyPath, ::$(global_layer_name)) = true
Experimental.layer_map_leaf(::KeyPath, ::$(global_layer_name)) = true

function $(global_layer_name)(; p=2)
return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p))
Expand All @@ -247,7 +247,7 @@ for layer_op in (:Max, :Mean, :LP)
layer <: PoolingLayer
end

Experimental.debug_leaf(::KeyPath, ::$(adaptive_layer_name)) = true
Experimental.layer_map_leaf(::KeyPath, ::$(adaptive_layer_name)) = true

function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2)
return $(adaptive_layer_name)(PoolingLayer(
Expand Down

0 comments on commit 8cd616d

Please sign in to comment.