diff --git a/NEWS.md b/NEWS.md index ac8883a091..68d36fdc34 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,13 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.13 +* New macro `Flux.@layer` which should be used in place of `@functor`. + This also adds `show` methods for pretty printing. + +## v0.14.12 +* New `SignDecay` optimiser, like `` WeightNorm` but for L1 norm. + ## v0.14.0 (July 2023) * Flux now requires julia v1.9 or later. * CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`. @@ -51,6 +58,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl ## v0.13.6 * Use the package [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) instead of having the same code here. +* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index b7161b8c59..ab045d96be 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -18,8 +18,8 @@ function (m::CustomModel)(x) return m.chain(x) + x end -# Call @functor to allow for training. Described below in more detail. -Flux.@functor CustomModel +# Call @layer to allow for training. Described below in more detail. +Flux.@layer CustomModel ``` You can then use the model like: @@ -39,7 +39,7 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function: ```julia-repl -julia> Flux.@functor Affine +julia> @layer Affine julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]) Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) @@ -47,7 +47,7 @@ Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) julia> Flux.params(a) # default behavior Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]]) -julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name +julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name julia> Flux.params(a) Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]]) @@ -67,7 +67,14 @@ julia> Flux.params(Affine(true, [10, 11, 12.0])) Params([]) ``` -It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired). +The exact same method of `trainable` can also be defined using the macro, for convenience: + +```julia +Flux.@layer Affine trainable=(W,) +``` + +There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. + ## Freezing Layer Parameters @@ -135,9 +142,9 @@ Join(combine, paths...) = Join(combine, paths) ``` Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field. -The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. +The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. ```julia -Flux.@functor Join +Flux.@layer Join ``` Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results. @@ -194,7 +201,7 @@ model(xs) Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs. -We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass. +We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass. ```julia using Flux using CUDA @@ -206,7 +213,7 @@ end Split(paths...) = Split(paths) -Flux.@functor Split +Flux.@layer Split (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) ``` diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index ca95dc747d..fb0f2d5488 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -257,8 +257,8 @@ m(5) # => 26 There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@functor`](@ref Functors.@functor) macro: -``` -Flux.@functor Affine +```julia +Flux.@layer Affine ``` Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias): @@ -272,3 +272,8 @@ end Affine(3 => 1, bias=false, init=ones) |> gpu ``` + +```@docs +Flux.@layer +Flux.create_bias +``` diff --git a/src/Flux.jl b/src/Flux.jl index d3ca611dbd..5675f7c10f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,6 +9,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils +const stack = MLUtils.stack # now exported by Base import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions using Optimisers: freeze!, thaw!, adjust! using Random: default_rng @@ -69,6 +70,9 @@ include("functor.jl") # Pirate error to catch a common mistake. Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.") +include("layers/show.jl") +include("layers/macro.jl") + include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") @@ -76,7 +80,6 @@ include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") include("layers/attention.jl") -include("layers/show.jl") include("loading.jl") diff --git a/src/functor.jl b/src/functor.jl index 8215b92863..34fe52db35 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -81,6 +81,7 @@ function params!(p::Params, x, seen = IdSet()) elseif x in seen nothing else + _check_new_macro(x) # complains if you used @functor not @layer push!(seen, x) for child in trainable(x) params!(p, child, seen) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 3701be2bb0..d4a33283d9 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2} out_proj::P2 end -@functor MultiHeadAttention +@layer MultiHeadAttention function MultiHeadAttention(dims; nheads::Int = 8, @@ -83,8 +83,8 @@ function MultiHeadAttention(dims; dropout_prob = 0.0) dims = normalize_mha_dims(dims) - @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" - @assert dims.v % nheads == 0 "v_dim should be divisible by nheads" + dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)")) + dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)")) q_proj = Dense(dims.q_in => dims.qk; bias, init) k_proj = Dense(dims.k_in => dims.qk; bias, init) v_proj = Dense(dims.v_in => dims.v; bias, init) @@ -131,3 +131,41 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, # [α] = [kv_len, q_len, nheads, batch_size] return x, α end + +function Base.show(io::IO, mha::MultiHeadAttention) + qk, q_in = size(mha.q_proj.weight) + qk, k_in = size(mha.k_proj.weight) + v, v_in = size(mha.v_proj.weight) + out, v = size(mha.out_proj.weight) + # @show q_in, k_in, v_in, qk, v, out + print(io, "MultiHeadAttention(") + if q_in == k_in == v_in == qk == v == out + print(io, q_in) + elseif q_in == k_in == v_in && qk == v + print(io, q_in, " => ", qk, " => ", out) + elseif q_in == k_in == v_in + print(io, q_in, " => (", qk, ", ", v,") => ", out) + else + print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out) + end + print(io, "; nheads=", mha.nheads) + if mha.q_proj.bias !== false + print(io, ", bias=true") + end + if mha.attn_drop.p != 0 + print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this? + end + print(io, ")") +end + + +#= + +# Test cases for printing: + +MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => 6 => 8; nheads=1) +MultiHeadAttention(8; bias=true) + +=# diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b7027f5007..018b19b31d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -46,7 +46,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@functor Chain +@layer :expand Chain # the + opts-in to container-style pretty-printing (c::Chain)(x) = _applychain(c.layers, x) @@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity; Dense(init(out, in), bias, σ) end -@functor Dense +@layer Dense function (a::Dense)(x::AbstractVecOrMat) _size_check(a, x, 1 => size(a.weight, 2)) @@ -251,7 +251,7 @@ end Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act) Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end]) -@functor Scale +@layer Scale function (a::Scale)(x::AbstractArray) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc @@ -306,7 +306,7 @@ end Maxout(layers...) = Maxout(layers) Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) -@functor Maxout +@layer :expand Maxout function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, @@ -353,7 +353,7 @@ struct SkipConnection{T,F} connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end -@functor SkipConnection +@layer :expand SkipConnection function (skip::SkipConnection)(input) skip.connection(skip.layers(input), input) @@ -423,7 +423,7 @@ struct Bilinear{F,A,B} end end -@functor Bilinear +@layer Bilinear function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity; bias = true, init = glorot_uniform) @@ -522,7 +522,7 @@ function Parallel(connection; kw...) Parallel(connection, layers) end -@functor Parallel +@layer :expand Parallel (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs::Tuple) = m(xs...) @@ -643,7 +643,7 @@ end end applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x) -@functor PairwiseFusion +@layer :expand PairwiseFusion Base.getindex(m::PairwiseFusion, i) = m.layers[i] Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) @@ -701,7 +701,7 @@ struct Embedding{W<:AbstractMatrix} weight::W end -@functor Embedding +@layer Embedding Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ca275d4a16..4e6044dcfb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init(filter..., cin÷groups, cout) end -@functor Conv +@layer Conv conv_dims(c::Conv, x::AbstractArray) = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) @@ -309,7 +309,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) end -@functor ConvTranspose +@layer ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... @@ -460,7 +460,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden return CrossCor(weight, bias, σ; stride, pad, dilation) end -@functor CrossCor +@layer CrossCor function crosscor(x, w, ddims::DenseConvDims) ddims = DenseConvDims(ddims, F=true) diff --git a/src/layers/macro.jl b/src/layers/macro.jl new file mode 100644 index 0000000000..9e770add87 --- /dev/null +++ b/src/layers/macro.jl @@ -0,0 +1,156 @@ + +""" + @layer Dense + @layer :expand Chain + @layer BatchNorm trainable=(β,γ) + +This macro replaces most uses of `@functor`. Its basic purpose is the same: +When you define a new layer, this tells Flux to explore inside it +to see the parameters it trains, and also to move them to the GPU, change precision, etc. +Like `@functor`, this assumes your struct has the default constructor, to enable re-building. + +The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. +Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. +* If some fields look like parameters but should not be trained, + then `trainable` lets you specify which fields to include, while the rest are ignored. + +The macro also handles overloads of `show` for pretty printing. +* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. +* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents. +* To disable all `show` overloads, there is an `:ignore` option too. + +(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) + +Note that re-running the macro with different options may not overwrite all methods, you will need to restart. + +# Example +```jldoctest +julia> struct Trio; a; b; c end + +julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4)) +Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) + +julia> Flux.destructure(tri) # parameters are not yet visible to Flux +(Bool[], Restructure(Trio, ..., 0)) + +julia> Flux.@layer :expand Trio + +julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too +([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4)) + +julia> tri # and layer is printed like Chain +Trio( + Dense(2 => 1, tanh), # 3 parameters + Dense(1 => 1; bias=false), # 1 parameters + Dropout(0.4), +) # Total: 3 arrays, 4 parameters, 224 bytes. +``` + +""" +macro layer(exs...) + out = quote end + + # These functions are defined in show.jl, and each return an expression overloading Base.show + type, rest... = if exs[1] == QuoteNode(:expand) + push!(out.args, _macro_big_show(esc(exs[2]))) + exs[2:end] + elseif exs[1] == QuoteNode(:ignore) + exs[2:end] + elseif exs[1] isa QuoteNode + error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)") + else + push!(out.args, _macro_layer_show(esc(exs[1]))) + exs + end + + # This function exists only for depwarns when you use @functor directly + push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) + + push!(out.args, _macro_functor(esc(type))) + + for j in 1:length(rest) + ex = rest[j] + Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex") + + name = if ex.args[1] == :trainable + :(Optimisers.trainable) + else + error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.") + # @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1 + # esc(ex.args[1]) + end + push!(out.args, _macro_trainable(esc(type), name, ex.args[2])) + end + + out +end + +# Temporary depwarn function, called within `params`, is also called by `show`. + +function _check_new_macro(x::T) where T + Functors.isleaf(x) && return + Base.depwarn("This type should probably now use `Flux.@layer` instead of `@functor`: $T", Symbol("@functor")) +end +_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users +_check_new_macro(::NamedTuple) = nothing +_check_new_macro(::AbstractArray) = nothing +_check_new_macro(::Ref) = nothing + +# @layer's code for Functors & Adapt +# Unlike @functor, _default_functor doesn't need to eval anything + +function _macro_functor(type) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x) + Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) + end +end + +function _macro_functor(type, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols)) + Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) + end +end +_macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma + +function _default_functor(::Type{T}, x) where {T} + if @generated + F = fieldnames(T) + args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) + C = Base.typename(T).wrapper # constructor + # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + recon = :(Base.splat($C)) + :((NamedTuple{$F}(($(args...),)), $recon)) + else + # Getting this parameterless type takes about 2μs, every time: + # spl = VERSION > v"1.9-" ? Splat : Base.splat + spl = Base.splat + namedtuple(x), spl(Base.typename(T).wrapper) + end +end + +function namedtuple(x::T) where T + F = fieldnames(T) + NamedTuple{F}(map(sy -> getfield(x, sy), F)) +end + +# @layer's code for Optimisers.trainable, and perhaps anything else, +# with the pattern that keywords mean function names & what fields they pick. + +function _macro_trainable(type, fun, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quoted = map(QuoteNode, symbols) + gets = [:(getfield(x, $f)) for f in quoted] + quote + $fun(x::$type) = NamedTuple{$symbols}(($(gets...),)) + end +end +_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma + +_noquotenode(s::Symbol) = s +_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y) +_noquotenode(ex) = error("expected a symbol here, as a field name, but got $ex") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1c8fbff5a1..c0a86c8796 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -78,8 +78,7 @@ function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = d Dropout(p, dims, active, rng) end -@functor Dropout -trainable(a::Dropout) = (;) +@layer Dropout trainable=() (a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims) @@ -131,8 +130,7 @@ function AlphaDropout(p; rng = default_rng(), active::Union{Bool,Nothing} = noth AlphaDropout(p, active, rng) end -@functor AlphaDropout -trainable(a::AlphaDropout) = (;) +@layer AlphaDropout trainable=() function (a::AlphaDropout)(x::AbstractArray{T}) where T _isactive(a, x) || return x @@ -151,6 +149,8 @@ end testmode!(m::AlphaDropout, mode=true) = (m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m) +Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")") + """ LayerNorm(size..., λ=identity; affine=true, eps=1f-5) @@ -199,7 +199,7 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@functor LayerNorm +@layer LayerNorm function (a::LayerNorm)(x::AbstractArray) ChainRulesCore.@ignore_derivatives if a.diag isa Scale @@ -343,8 +343,7 @@ function BatchNorm(chs::Int, λ=identity; active, chs) end -@functor BatchNorm -trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) +@layer BatchNorm trainable=(β,γ) function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(BN, x, N-1 => BN.chs) @@ -437,8 +436,7 @@ function InstanceNorm(chs::Int, λ=identity; active, chs) end -@functor InstanceNorm -trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) +@layer InstanceNorm trainable=(β,γ) function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(l, x, N-1 => l.chs) @@ -517,8 +515,7 @@ mutable struct GroupNorm{F,V,N} chs::Int # number of channels end -@functor GroupNorm -trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) +@layer GroupNorm trainable=(β,γ) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 375ff43d52..f55ebb1741 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -135,8 +135,7 @@ function (m::Recur)(x) return y end -@functor Recur -trainable(a::Recur) = (; cell = a.cell) +@layer :expand Recur trainable=(cell,) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -209,7 +208,7 @@ function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where return h, reshape_cell_output(h, x) end -@functor RNNCell +@layer RNNCell # state0 is trainable, see issue 807 about this. function Base.show(io::IO, l::RNNCell) print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) @@ -318,7 +317,7 @@ function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::AbstractV return (h′, c′), reshape_cell_output(h′, x) end -@functor LSTMCell +@layer LSTMCell Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") @@ -391,7 +390,7 @@ function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where { return h′, reshape_cell_output(h′, x) end -@functor GRUCell +@layer GRUCell Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") @@ -461,7 +460,7 @@ function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) wh return h′, reshape_cell_output(h′, x) end -@functor GRUv3Cell +@layer GRUv3Cell Base.show(io::IO, l::GRUv3Cell) = print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") diff --git a/src/layers/show.jl b/src/layers/show.jl index 0ae14dd9ee..a03ddf3754 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,15 +1,21 @@ +@nospecialize # just for this file, for startup time -for T in [ - :Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - _big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - _layer_show(io, x) - else - show(io, x) +# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression: +function _macro_big_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end end + + # Don't show Chain(Tuple(...)), always splat that. And ignore Recur's non-trainable state: + Flux._show_children(x::$ex) = _flat_children(trainable(x)) end end @@ -17,6 +23,8 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) if all(_show_leaflike, children) + # This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids, + # but once all layers use @layer, they stop the recursion by defining a method for _big_show. _layer_show(io, obj, indent, name) else println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) @@ -49,25 +57,32 @@ _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: # note the covariance of tuple, using <:T causes warning or error _show_leaflike(::Tuple{Vararg{Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{AbstractArray}}) = true # e.g. parameters of LSTMcell -_show_leaflike(::Scale) = true # appears inside LayerNorm _show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays -_show_children(x) = trainable(x) # except for layers which hide their Tuple: -_show_children(c::Chain) = c.layers -_show_children(m::Maxout) = m.layers -_show_children(p::Parallel) = (p.connection, p.layers...) -_show_children(f::PairwiseFusion) = (f.connection, f.layers...) - -for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) +_show_children(x) = trainable(x) +# This used to have methods for Chain, Maxout, Parallel, PairwiseFusion. Now @layer instead +# writes a method to use this function. It flattens the Tuple within Chain etc. +# (The remaining special cases are for printing of layer names when a NamedTuple, above.) +function _flat_children(x) + alpha = map(f -> getfield(x, f), fieldnames(typeof(x))) + beta = map(y -> y isa Union{Tuple, NamedTuple} ? y : (y,), alpha) + gamma = ((beta...)...,) +end + +# This is called by @layer, on layers which should be treated like Dense, and returns an expression: +function _macro_layer_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end end + + # Exit from _big_show recursion: + Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name) end end @@ -126,6 +141,8 @@ function _nan_show(io::IO, x) end end +@specialize # un-does @nospecialze at the top of this file + _any(f, xs::AbstractArray{<:Number}) = any(f, xs) # _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) _any(f, xs) = any(x -> _any(f, x), xs) diff --git a/test/layers/macro.jl b/test/layers/macro.jl new file mode 100644 index 0000000000..e41d5a2240 --- /dev/null +++ b/test/layers/macro.jl @@ -0,0 +1,47 @@ +using Flux, Functors, Optimisers + +module MacroTest + using Flux: @layer + + struct Duo{T,S}; x::T; y::S; end + @layer :expand Duo + + struct Trio; a; b; c end + # @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget + @layer Trio trainable=(a,b) # defining a method for test is made an error, for now + + struct TwoThirds; a; b; c; end +end + +@testset "@layer macro" begin + @test !isdefined(MacroTest, :Flux) # That's why the module, to check scope + + m2 = MacroTest.Duo(Dense(2=>2), Chain(Flux.Scale(2), Dropout(0.2))) + + @test Functors.children(m2) isa NamedTuple{(:x, :y)} + @test length(Optimisers.destructure(m2)[1]) == 10 + + m3 = MacroTest.Trio([1.0], [2.0], [3.0]) + + @test Functors.children(m3) isa NamedTuple{(:a, :b, :c)} + @test fmap(zero, m3) isa MacroTest.Trio + + @test Optimisers.trainable(m3) isa NamedTuple{(:a, :b)} + @test Optimisers.destructure(m3)[1] == [1, 2] + + # @test MacroTest.test(m3) == (c = [3.0],) # removed, for now + + m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6]) + # Check that we can use the macro with a qualified type name, outside the defining module: + Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes + + m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60])) + @test m23re isa MacroTest.TwoThirds + @test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) + + @test Optimisers.trainable(m23) == (a = [1 2],) + + @test_throws LoadError @eval Flux.@layer :zzz MacroTest.TwoThirds + @test_throws LoadError @eval Flux.@layer MacroTest.TwoThirds chidren=(a, b) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 94e0c466e6..8dca6becdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,7 @@ Random.seed!(0) include("layers/conv.jl") include("layers/upsample.jl") include("layers/show.jl") + include("layers/macro.jl") end @testset "outputsize" begin diff --git a/test/utils.jl b/test/utils.jl index 620a4d40b4..e175eb1f5b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -616,7 +616,7 @@ end a::A b::A end - Flux.@functor Model + Flux.@layer Model (m::Model)(x) = m.a(x) .+ m.b(x) d = Dense(1, 1)