Skip to content

Commit

Permalink
set expand option as default for @layer (#2532)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 3, 2024
1 parent e2b3f06 commit e2f58a8
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 80 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]

model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)

optim = Flux.setup(Adam(), model)
opt_state = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state)
end

plot(x -> 2x-x^3, -2, 2, legend=false)
Expand Down
7 changes: 6 additions & 1 deletion docs/src/guide/saving.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ julia> Flux.@layer MyModel
julia> MyModel() = MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)));
julia> model = MyModel()
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters
MyModel(
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
),
) # Total: 4 arrays, 67 parameters, 484 bytes.
julia> model_state = Flux.state(model);
Expand Down
5 changes: 3 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using MacroTools, Reexport, ProgressLogging, SpecialFunctions
using MacroTools: @forward

@reexport using NNlib
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
using MLUtils

using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update!
Expand All @@ -27,7 +28,7 @@ export gradient
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
XLADevice,
# get_device, # we define get_device here for retrocompatibility
# gpu_backend!, # have to define here due to https://github.com/JuliaPackaging/Preferences.jl/issues/39
gpu_backend!,
get_device_type,
DeviceIterator

Expand Down Expand Up @@ -118,7 +119,7 @@ include("losses/Losses.jl")
using .Losses

include("devices.jl")
export get_device, gpu_backend!
export get_device

# Distributed Training
include("distributed/backend.jl")
Expand Down
12 changes: 0 additions & 12 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@ julia> m.bias
"""
cpu(x) = cpu_device()(x)

# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089
ChainRulesCore.@non_differentiable cpu_device()


# Remove when
# https://github.com/JuliaPackaging/Preferences.jl/issues/39
# is resolved
function gpu_backend!(backend::String)
@set_preferences!("gpu_backend" => backend)
MLDataDevices.gpu_backend!(backend)
end

"""
gpu(m)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
out_proj::P2
end

@layer MultiHeadAttention
@layer :noexpand MultiHeadAttention

function MultiHeadAttention(dims;
nheads::Int = 8,
Expand Down
10 changes: 5 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@layer :expand Chain # the option :expand opts-in to container-style pretty-printing
@layer Chain

(c::Chain)(x) = _applychain(c.layers, x)
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))
Expand Down Expand Up @@ -334,7 +334,7 @@ end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@layer :expand Maxout
@layer Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
Expand Down Expand Up @@ -381,7 +381,7 @@ struct SkipConnection{T,F}
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@layer :expand SkipConnection
@layer SkipConnection

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
Expand Down Expand Up @@ -575,7 +575,7 @@ end
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))

@layer :expand Parallel
@layer Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument

Expand Down Expand Up @@ -705,7 +705,7 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@layer :expand PairwiseFusion
@layer PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Expand Down
1 change: 0 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using NNlib: conv, ∇conv_data, depthwiseconv, output_size

# pad dims of x with dims of y until ndims(x) == ndims(y)
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
Expand Down
44 changes: 28 additions & 16 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@

"""
@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
@layer [showtype] MyModel [trainable=(field1,...)]
This macro adds convenience functionality to a custom type to serve
as a neural network layer, module, or entire model.
as a neural network layer, as a module, or as an entire model.
The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`.
The optional keyword `trainable` allows you to specify which fields of your model can be trained,
instead of assuming all `fieldnames(MyModel)` to trainable.
Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
This can be also be done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type.
The macro also handles overloads of the 3-arg `show(::IO, ::MIME"text/plain", ::MyModel)` for pretty printing.
The optional argument `showtype` can take any of the following values:
The macro also handles overloads of `show` for pretty printing.
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
* To disable all `show` overloads, there is an `:ignore` option too.
- `:expand` (default): This will expand the representation of container types like `Chain`,
while maintaining a compat representation of types like `Dense` containing only arrays.
- `:noexpand`: This is to be used in case your type contains other layers but you want to keep the representation simple.
- `:ignore`: To opt out of the pretty printing.
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
You probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this.
Note that re-running the macro with different options may not remove all methods, you will need to restart.
Expand All @@ -26,16 +29,22 @@ julia> struct Trio; a; b; c end
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
julia> Flux.@layer :expand Trio
julia> Flux.@layer Trio
julia> tri # now the layer is printed like Chain
Trio(
Dense(2 => 1, tanh), # 3 parameters
Dense(1 => 1; bias=false), # 1 parameters
Dropout(0.4),
) # Total: 3 arrays, 4 parameters, 240 bytes.
```
julia> Flux.@layer :noexpand Trio trainable=(a,b)
julia> tri # now the layer is printed compactly
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) # 4 parameters
julia> opt_state = Flux.setup(Adam(), tri); # `c` is not in the optimizer state
```
"""
macro layer(exs...)
_layer_macro(exs...)
Expand All @@ -46,14 +55,17 @@ function _layer_macro(exs...)

# These functions are defined in show.jl, and each return an expression overloading Base.show
type, rest... = if exs[1] == QuoteNode(:expand)
push!(out.args, _macro_big_show(esc(exs[2])))
push!(out.args, _macro_big_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:noexpand)
push!(out.args, _macro_layer_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:ignore)
exs[2:end]
elseif exs[1] isa QuoteNode
error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)")
error("`@layer` accepts only the options `:ignore`, `:noexpand`, and `:expand` before the layer type (to control `show`).")
else
push!(out.args, _macro_layer_show(esc(exs[1])))
push!(out.args, _macro_big_show(esc(exs[1])))
exs
end

Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@layer LayerNorm
@layer :noexpand LayerNorm

function (a::LayerNorm)(x::AbstractArray)
ChainRulesCore.@ignore_derivatives if a.diag isa Scale
Expand Down
12 changes: 6 additions & 6 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ struct Model
h0::AbstractVector
end
Flux.@layer :expand Model
Flux.@layer Model
(m::Model)(x) = m.rnn(x, m.h0)
Expand All @@ -169,7 +169,7 @@ struct RNN{M}
cell::M
end

@layer :expand RNN
@layer RNN

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
Expand Down Expand Up @@ -344,7 +344,7 @@ struct Model
c0::AbstractVector
end
Flux.@layer :expand Model
Flux.@layer Model
(m::Model)(x) = m.lstm(x, (m.h0, m.c0))
Expand All @@ -359,7 +359,7 @@ struct LSTM{M}
cell::M
end

@layer :expand LSTM
@layer LSTM

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -531,7 +531,7 @@ struct GRU{M}
cell::M
end

@layer :expand GRU
@layer GRU

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -669,7 +669,7 @@ struct GRUv3{M}
cell::M
end

@layer :expand GRUv3
@layer GRUv3

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
Expand Down
59 changes: 36 additions & 23 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@nospecialize # just for this file, for startup time

# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression:
# This is called by @layer and returns an expression:
function _macro_big_show(ex)
quote
# Entry point:
Expand Down Expand Up @@ -83,7 +83,7 @@ function _flat_children(x)
gamma = ((beta...)...,)
end

# This is called by @layer, on layers which should be treated like Dense, and returns an expression:
# This is called by @layer :noexpand, on layers which should be treated like Dense, and returns an expression:
function _macro_layer_show(ex)
quote
# Entry point:
Expand Down Expand Up @@ -176,40 +176,53 @@ _all(f, xs) = !_any(!f, xs)

#=
julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
julia> struct Tmp2; x; y; end;
# Before, notice Array(), NamedTuple(), and values
julia> t = Tmp2([Dense(2,3), randn(3,4)'], (x=1:4, y=Dense(3,4), z=rand(3)))
Tmp2(Any[Dense(2 => 3), [-0.559390071462934 -0.6357914190386781 -0.8516823037180543; -2.187495592853204 -0.6807254521505784 -1.2334639710489697; -0.12790952072543338 -1.4672700459421741 1.3687526519721238; 0.5232171922680576 -1.012045481192333 1.4953790632112915]], (x = 1:4, y = Dense(3 => 4), z = [0.29222096031585143, 0.6562195256556428, 0.9741896713499167]))
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
julia> Chain(t)
Chain(
Tmp2(
Array(
[
Dense(2 => 3), # 9 parameters
[0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
),
NamedTuple(
1:3, # 3 parameters
Dense(3 => 4), # 16 parameters
[0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
(;
x = 4-element UnitRange{Int64},
y = Dense(3 => 4), # 16 parameters
z = 3-element Vector{Float64}, # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.
# After, (; x=, y=, z=) and "3-element Array"
julia> Flux.@layer Tmp2
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
julia> t
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
4-element UnitRange{Int64},
Dense(3 => 4), # 16 parameters
3-element Vector{Float64}, # 3 parameters
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.
julia> Chain(t)
Chain(
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint, # 12 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
(;
x = 3-element UnitRange, # 3 parameters
y = Dense(3 => 4), # 16 parameters
z = 3-element Array, # 3 parameters
),
4-element UnitRange{Int64},
Dense(3 => 4), # 16 parameters
3-element Vector{Float64}, # 3 parameters
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.
=#
Loading

0 comments on commit e2f58a8

Please sign in to comment.