Skip to content

Commit

Permalink
Add a macro to opt-in to fancy printing, and to everything else (#1932)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Mar 5, 2024
1 parent da11bf2 commit 5e80211
Show file tree
Hide file tree
Showing 15 changed files with 350 additions and 71 deletions.
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.14.13
* New macro `Flux.@layer` which should be used in place of `@functor`.
This also adds `show` methods for pretty printing.

## v0.14.12
* New `SignDecay` optimiser, like `` WeightNorm` but for L1 norm.

## v0.14.0 (July 2023)
* Flux now requires julia v1.9 or later.
* CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`.
Expand Down Expand Up @@ -51,6 +58,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl

## v0.13.6
* Use the package [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) instead of having the same code here.
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
Expand Down
25 changes: 16 additions & 9 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ function (m::CustomModel)(x)
return m.chain(x) + x
end

# Call @functor to allow for training. Described below in more detail.
Flux.@functor CustomModel
# Call @layer to allow for training. Described below in more detail.
Flux.@layer CustomModel
```

You can then use the model like:
Expand All @@ -39,15 +39,15 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi
By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function:

```julia-repl
julia> Flux.@functor Affine
julia> @layer Affine
julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9])
Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0])
julia> Flux.params(a) # default behavior
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]])
julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name
julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name
julia> Flux.params(a)
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]])
Expand All @@ -67,7 +67,14 @@ julia> Flux.params(Affine(true, [10, 11, 12.0]))
Params([])
```

It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired).
The exact same method of `trainable` can also be defined using the macro, for convenience:

```julia
Flux.@layer Affine trainable=(W,)
```

There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument.


## Freezing Layer Parameters

Expand Down Expand Up @@ -135,9 +142,9 @@ Join(combine, paths...) = Join(combine, paths)
```
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.

The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
```julia
Flux.@functor Join
Flux.@layer Join
```

Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results.
Expand Down Expand Up @@ -194,7 +201,7 @@ model(xs)

Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.

We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass.
We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass.
```julia
using Flux
using CUDA
Expand All @@ -206,7 +213,7 @@ end

Split(paths...) = Split(paths)

Flux.@functor Split
Flux.@layer Split

(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)
```
Expand Down
9 changes: 7 additions & 2 deletions docs/src/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ m(5) # => 26

There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@functor`](@ref Functors.@functor) macro:

```
Flux.@functor Affine
```julia
Flux.@layer Affine
```

Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias):
Expand All @@ -272,3 +272,8 @@ end
Affine(3 => 1, bias=false, init=ones) |> gpu
```

```@docs
Flux.@layer
Flux.create_bias
```
5 changes: 4 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using MacroTools: @forward

@reexport using NNlib
using MLUtils
const stack = MLUtils.stack # now exported by Base
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
using Optimisers: freeze!, thaw!, adjust!
using Random: default_rng
Expand Down Expand Up @@ -69,14 +70,16 @@ include("functor.jl")
# Pirate error to catch a common mistake.
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")

include("layers/show.jl")
include("layers/macro.jl")

include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/attention.jl")
include("layers/show.jl")

include("loading.jl")

Expand Down
1 change: 1 addition & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ function params!(p::Params, x, seen = IdSet())
elseif x in seen
nothing
else
_check_new_macro(x) # complains if you used @functor not @layer
push!(seen, x)
for child in trainable(x)
params!(p, child, seen)
Expand Down
44 changes: 41 additions & 3 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
out_proj::P2
end

@functor MultiHeadAttention
@layer MultiHeadAttention

function MultiHeadAttention(dims;
nheads::Int = 8,
Expand All @@ -83,8 +83,8 @@ function MultiHeadAttention(dims;
dropout_prob = 0.0)

dims = normalize_mha_dims(dims)
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)"))
dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)"))
q_proj = Dense(dims.q_in => dims.qk; bias, init)
k_proj = Dense(dims.k_in => dims.qk; bias, init)
v_proj = Dense(dims.v_in => dims.v; bias, init)
Expand Down Expand Up @@ -131,3 +131,41 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
# [α] = [kv_len, q_len, nheads, batch_size]
return x, α
end

function Base.show(io::IO, mha::MultiHeadAttention)
qk, q_in = size(mha.q_proj.weight)
qk, k_in = size(mha.k_proj.weight)
v, v_in = size(mha.v_proj.weight)
out, v = size(mha.out_proj.weight)
# @show q_in, k_in, v_in, qk, v, out
print(io, "MultiHeadAttention(")
if q_in == k_in == v_in == qk == v == out
print(io, q_in)
elseif q_in == k_in == v_in && qk == v
print(io, q_in, " => ", qk, " => ", out)
elseif q_in == k_in == v_in
print(io, q_in, " => (", qk, ", ", v,") => ", out)
else
print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out)
end
print(io, "; nheads=", mha.nheads)
if mha.q_proj.bias !== false
print(io, ", bias=true")
end
if mha.attn_drop.p != 0
print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this?
end
print(io, ")")
end


#=
# Test cases for printing:
MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1)
MultiHeadAttention(3 => (6, 7) => 8; nheads=1)
MultiHeadAttention(3 => 6 => 8; nheads=1)
MultiHeadAttention(8; bias=true)
=#
18 changes: 9 additions & 9 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@functor Chain
@layer :expand Chain # the + opts-in to container-style pretty-printing

(c::Chain)(x) = _applychain(c.layers, x)

Expand Down Expand Up @@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
Dense(init(out, in), bias, σ)
end

@functor Dense
@layer Dense

function (a::Dense)(x::AbstractVecOrMat)
_size_check(a, x, 1 => size(a.weight, 2))
Expand Down Expand Up @@ -251,7 +251,7 @@ end
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])

@functor Scale
@layer Scale

function (a::Scale)(x::AbstractArray)
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
Expand Down Expand Up @@ -306,7 +306,7 @@ end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@functor Maxout
@layer :expand Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
Expand Down Expand Up @@ -353,7 +353,7 @@ struct SkipConnection{T,F}
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@functor SkipConnection
@layer :expand SkipConnection

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
Expand Down Expand Up @@ -423,7 +423,7 @@ struct Bilinear{F,A,B}
end
end

@functor Bilinear
@layer Bilinear

function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity;
bias = true, init = glorot_uniform)
Expand Down Expand Up @@ -522,7 +522,7 @@ function Parallel(connection; kw...)
Parallel(connection, layers)
end

@functor Parallel
@layer :expand Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
Expand Down Expand Up @@ -643,7 +643,7 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@functor PairwiseFusion
@layer :expand PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Expand Down Expand Up @@ -701,7 +701,7 @@ struct Embedding{W<:AbstractMatrix}
weight::W
end

@functor Embedding
@layer Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

Expand Down
6 changes: 3 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init(filter..., cin÷groups, cout)
end

@functor Conv
@layer Conv

conv_dims(c::Conv, x::AbstractArray) =
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
Expand Down Expand Up @@ -309,7 +309,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
end

@functor ConvTranspose
@layer ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
Expand Down Expand Up @@ -460,7 +460,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden
return CrossCor(weight, bias, σ; stride, pad, dilation)
end

@functor CrossCor
@layer CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
Expand Down
Loading

0 comments on commit 5e80211

Please sign in to comment.