From 8cd616db4d61d4513cffc03c6bbe07f935469ead Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Dec 2024 11:38:26 +0530 Subject: [PATCH] fix: pass in exclude to layer_map --- src/contrib/debug.jl | 11 +++-------- src/contrib/map.jl | 8 +++++++- src/layers/pooling.jl | 6 +++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index cbd9e7352..e4c475d65 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -141,15 +141,10 @@ Recurses into the `layer` and replaces the inner most non Container Layers with See [`Lux.Experimental.DebugLayer`](@ref) for details about the Keyword Arguments. """ macro debug_mode(layer, kwargs...) - kws = esc.(kwargs) return esc(:( $(fmap_with_path)( - (kp, l) -> $(DebugLayer)(l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kws...)), - $(layer); exclude=$(debug_leaf)) + (kp, l) -> $(DebugLayer)( + l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kwargs...)), + $(layer); exclude=$(layer_map_leaf)) )) end - -debug_leaf(::KeyPath, ::AbstractLuxLayer) = true -debug_leaf(::KeyPath, ::AbstractLuxWrapperLayer) = false -debug_leaf(::KeyPath, ::AbstractLuxContainerLayer) = false -debug_leaf(::KeyPath, x) = Functors.isleaf(x) diff --git a/src/contrib/map.jl b/src/contrib/map.jl index f5142f0db..efd6acf29 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -57,7 +57,8 @@ true ``` """ function layer_map(f, l, ps, st) - return fmap_with_path(l, ps, st; walk=LayerWalkWithPath()) do kp, layer, ps_, st_ + return fmap_with_path( + l, ps, st; walk=LayerWalkWithPath(), exclude=layer_map_leaf) do kp, layer, ps_, st_ return f(layer, ps_, st_, kp) end end @@ -103,3 +104,8 @@ function perform_layer_map(recurse, kp, ps_children, st_children, layer_children return layer_children_new, ps_children_new, st_children_new end + +layer_map_leaf(::KeyPath, ::AbstractLuxLayer) = true +layer_map_leaf(::KeyPath, ::AbstractLuxWrapperLayer) = false +layer_map_leaf(::KeyPath, ::AbstractLuxContainerLayer) = false +layer_map_leaf(::KeyPath, x) = Functors.isleaf(x) diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 9c630ed01..943eb947c 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -197,7 +197,7 @@ for layer_op in (:Max, :Mean, :LP) layer <: PoolingLayer end - Experimental.debug_leaf(::KeyPath, ::$(layer_name)) = true + Experimental.layer_map_leaf(::KeyPath, ::$(layer_name)) = true function $(layer_name)( window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2) @@ -226,7 +226,7 @@ for layer_op in (:Max, :Mean, :LP) layer <: PoolingLayer end - Experimental.debug_leaf(::KeyPath, ::$(global_layer_name)) = true + Experimental.layer_map_leaf(::KeyPath, ::$(global_layer_name)) = true function $(global_layer_name)(; p=2) return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p)) @@ -247,7 +247,7 @@ for layer_op in (:Max, :Mean, :LP) layer <: PoolingLayer end - Experimental.debug_leaf(::KeyPath, ::$(adaptive_layer_name)) = true + Experimental.layer_map_leaf(::KeyPath, ::$(adaptive_layer_name)) = true function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2) return $(adaptive_layer_name)(PoolingLayer(