Skip to content

Commit

Permalink
fix: handle debug leafs with dispatch (#1115)
Browse files Browse the repository at this point in the history
* fix: handle debug leafs with dispatch

* test: add a test for pooling layers

* fix: pass in exclude to layer_map
  • Loading branch information
avik-pal authored Dec 3, 2024
1 parent 277f6ab commit b362324
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Compat: @compat
using ConcreteStructs: @concrete
using EnzymeCore: EnzymeRules
using FastClosures: @closure
using Functors: Functors, fmap
using Functors: Functors, KeyPath, fmap
using GPUArraysCore: @allowscalar
using Markdown: @doc_str
using NNlib: NNlib
Expand Down
3 changes: 2 additions & 1 deletion src/contrib/contrib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ using Static: StaticSymbol, StaticBool, True, known, static, dynamic

using ..Lux: Lux, Optional
using ..Utils: Utils, BoolType, SymbolType
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer, apply
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer,
AbstractLuxContainerLayer, apply

const CRC = ChainRulesCore

Expand Down
10 changes: 6 additions & 4 deletions src/contrib/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ Recurses into the `layer` and replaces the inner most non Container Layers with
See [`Lux.Experimental.DebugLayer`](@ref) for details about the Keyword Arguments.
"""
macro debug_mode(layer, kwargs...)
kws = esc.(kwargs)
return :($(fmap_with_path)(
(kp, l) -> DebugLayer(l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kws...)),
$(esc(layer))))
return esc(:(
$(fmap_with_path)(
(kp, l) -> $(DebugLayer)(
l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kwargs...)),
$(layer); exclude=$(layer_map_leaf))
))
end
8 changes: 7 additions & 1 deletion src/contrib/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ true
```
"""
function layer_map(f, l, ps, st)
return fmap_with_path(l, ps, st; walk=LayerWalkWithPath()) do kp, layer, ps_, st_
return fmap_with_path(
l, ps, st; walk=LayerWalkWithPath(), exclude=layer_map_leaf) do kp, layer, ps_, st_
return f(layer, ps_, st_, kp)
end
end
Expand Down Expand Up @@ -103,3 +104,8 @@ function perform_layer_map(recurse, kp, ps_children, st_children, layer_children

return layer_children_new, ps_children_new, st_children_new
end

layer_map_leaf(::KeyPath, ::AbstractLuxLayer) = true
layer_map_leaf(::KeyPath, ::AbstractLuxWrapperLayer) = false
layer_map_leaf(::KeyPath, ::AbstractLuxContainerLayer) = false
layer_map_leaf(::KeyPath, x) = Functors.isleaf(x)
28 changes: 14 additions & 14 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,24 @@ for layer_op in (:Max, :Mean, :LP)
layer <: PoolingLayer
end

Experimental.layer_map_leaf(::KeyPath, ::$(layer_name)) = true

function $(layer_name)(
window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2)
return $(layer_name)(PoolingLayer(static(:generic), static($(Meta.quot(op))),
window; stride, pad, dilation, p))
end

function Base.show(io::IO, m::$(layer_name))
kernel_size = m.layer.mode.kernel_size
(; mode, op) = m.layer
(; kernel_size, pad, stride, dilation) = mode
print(io, string($(Meta.quot(layer_name))), "($(kernel_size)")
pad = m.layer.mode.pad
all(==(0), pad) || print(io, ", pad=", PrettyPrinting.tuple_string(pad))
stride = m.layer.mode.stride
stride == kernel_size ||
print(io, ", stride=", PrettyPrinting.tuple_string(stride))
dilation = m.layer.mode.dilation
all(==(1), dilation) ||
print(io, ", dilation=", PrettyPrinting.tuple_string(dilation))
if $(Meta.quot(op)) == :lp
m.layer.op.p == 2 || print(io, ", p=", m.layer.op.p)
end
$(Meta.quot(op)) == :lp && (op.p == 2 || print(io, ", p=", op.p))
print(io, ")")
end

Expand All @@ -228,15 +226,16 @@ for layer_op in (:Max, :Mean, :LP)
layer <: PoolingLayer
end

Experimental.layer_map_leaf(::KeyPath, ::$(global_layer_name)) = true

function $(global_layer_name)(; p=2)
return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p))
end

function Base.show(io::IO, g::$(global_layer_name))
(; op) = g.layer
print(io, string($(Meta.quot(global_layer_name))), "(")
if $(Meta.quot(op)) == :lp
g.layer.op.p == 2 || print(io, ", p=", g.layer.op.p)
end
$(Meta.quot(op)) == :lp && (op.p == 2 || print(io, ", p=", op.p))
print(io, ")")
end

Expand All @@ -248,16 +247,17 @@ for layer_op in (:Max, :Mean, :LP)
layer <: PoolingLayer
end

Experimental.layer_map_leaf(::KeyPath, ::$(adaptive_layer_name)) = true

function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2)
return $(adaptive_layer_name)(PoolingLayer(
static(:adaptive), $(Meta.quot(op)), out_size; p))
end

function Base.show(io::IO, a::$(adaptive_layer_name))
print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size)
if $(Meta.quot(op)) == :lp
a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p)
end
(; mode, op) = a.layer
print(io, string($(Meta.quot(adaptive_layer_name))), "(", mode.out_size)
$(Meta.quot(op)) == :lp && (op.p == 2 || print(io, ", p=", op.p))
print(io, ")")
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ SimpleChains = "0.4.7"
StableRNGs = "1.0.2"
Static = "1"
StaticArrays = "1.9"
Statistics = "1.11.1"
Statistics = "1.10"
Test = "1.10"
Tracker = "0.2.36"
Zygote = "0.6.70"
17 changes: 17 additions & 0 deletions test/contrib/debug_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,20 @@ end
@test !any(isnan, gs.layer_3.bias)
end
end

@testitem "Debugging Tools: Issue #1068" setup=[SharedTestSetup] tags=[:misc] begin
model = Chain(
Conv((3, 3), 3 => 16, relu; stride=2),
MaxPool((2, 2)),
AdaptiveMaxPool((2, 2)),
GlobalMaxPool()
)

model_debug = Lux.Experimental.@debug_mode model
display(model_debug)

@test model_debug[1] isa Lux.Experimental.DebugLayer
@test model_debug[2] isa Lux.Experimental.DebugLayer
@test model_debug[3] isa Lux.Experimental.DebugLayer
@test model_debug[4] isa Lux.Experimental.DebugLayer
end

0 comments on commit b362324

Please sign in to comment.