Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set expand option as default for @layer #2532

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]

model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)

optim = Flux.setup(Adam(), model)
opt_state = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state)
end

plot(x -> 2x-x^3, -2, 2, legend=false)
Expand Down
7 changes: 6 additions & 1 deletion docs/src/guide/saving.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ julia> Flux.@layer MyModel
julia> MyModel() = MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)));

julia> model = MyModel()
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters
MyModel(
Chain(
Dense(10 => 5, relu), # 55 parameters
Dense(5 => 2), # 12 parameters
),
) # Total: 4 arrays, 67 parameters, 484 bytes.

julia> model_state = Flux.state(model);

Expand Down
5 changes: 3 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using MacroTools, Reexport, ProgressLogging, SpecialFunctions
using MacroTools: @forward

@reexport using NNlib
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
using MLUtils

using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update!
Expand All @@ -27,7 +28,7 @@ export gradient
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
XLADevice,
# get_device, # we define get_device here for retrocompatibility
# gpu_backend!, # have to define here due to https://github.com/JuliaPackaging/Preferences.jl/issues/39
gpu_backend!,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change: with v0.15 we don't need to define gpu_backend!, we can just reexport the one from MLDataDevices

get_device_type,
DeviceIterator

Expand Down Expand Up @@ -118,7 +119,7 @@ include("losses/Losses.jl")
using .Losses

include("devices.jl")
export get_device, gpu_backend!
export get_device

# Distributed Training
include("distributed/backend.jl")
Expand Down
12 changes: 0 additions & 12 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@ julia> m.bias
"""
cpu(x) = cpu_device()(x)

# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089
ChainRulesCore.@non_differentiable cpu_device()


# Remove when
# https://github.com/JuliaPackaging/Preferences.jl/issues/39
# is resolved
function gpu_backend!(backend::String)
@set_preferences!("gpu_backend" => backend)
MLDataDevices.gpu_backend!(backend)
end

"""
gpu(m)

Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
out_proj::P2
end

@layer MultiHeadAttention
@layer :noexpand MultiHeadAttention

function MultiHeadAttention(dims;
nheads::Int = 8,
Expand Down
10 changes: 5 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@layer :expand Chain # the option :expand opts-in to container-style pretty-printing
@layer Chain

(c::Chain)(x) = _applychain(c.layers, x)
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))
Expand Down Expand Up @@ -334,7 +334,7 @@ end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@layer :expand Maxout
@layer Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
Expand Down Expand Up @@ -381,7 +381,7 @@ struct SkipConnection{T,F}
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@layer :expand SkipConnection
@layer SkipConnection

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
Expand Down Expand Up @@ -575,7 +575,7 @@ end
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))

@layer :expand Parallel
@layer Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument

Expand Down Expand Up @@ -705,7 +705,7 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@layer :expand PairwiseFusion
@layer PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Expand Down
1 change: 0 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using NNlib: conv, ∇conv_data, depthwiseconv, output_size

# pad dims of x with dims of y until ndims(x) == ndims(y)
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)
Expand Down
44 changes: 28 additions & 16 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@

"""
@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)

@layer [showtype] MyModel [trainable=(field1,...)]

This macro adds convenience functionality to a custom type to serve
as a neural network layer, module, or entire model.
as a neural network layer, as a module, or as an entire model.

The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`.
The optional keyword `trainable` allows you to specify which fields of your model can be trained,
instead of assuming all `fieldnames(MyModel)` to trainable.
Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
This can be also be done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type.

The macro also handles overloads of the 3-arg `show(::IO, ::MIME"text/plain", ::MyModel)` for pretty printing.
The optional argument `showtype` can take any of the following values:

The macro also handles overloads of `show` for pretty printing.
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
* To disable all `show` overloads, there is an `:ignore` option too.
- `:expand` (default): This will expand the representation of container types like `Chain`,
while maintaining a compat representation of types like `Dense` containing only arrays.
- `:noexpand`: This is to be used in case your type contains other layers but you want to keep the representation simple.
- `:ignore`: To opt out of the pretty printing.

(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
You probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this.

Note that re-running the macro with different options may not remove all methods, you will need to restart.

Expand All @@ -26,16 +29,22 @@ julia> struct Trio; a; b; c end
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))

julia> Flux.@layer :expand Trio
julia> Flux.@layer Trio

julia> tri # now the layer is printed like Chain
Trio(
Dense(2 => 1, tanh), # 3 parameters
Dense(1 => 1; bias=false), # 1 parameters
Dropout(0.4),
) # Total: 3 arrays, 4 parameters, 240 bytes.
```

julia> Flux.@layer :noexpand Trio trainable=(a,b)

julia> tri # now the layer is printed compactly
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) # 4 parameters

julia> opt_state = Flux.setup(Adam(), tri); # `c` is not in the optimizer state
```
"""
macro layer(exs...)
_layer_macro(exs...)
Expand All @@ -46,14 +55,17 @@ function _layer_macro(exs...)

# These functions are defined in show.jl, and each return an expression overloading Base.show
type, rest... = if exs[1] == QuoteNode(:expand)
push!(out.args, _macro_big_show(esc(exs[2])))
push!(out.args, _macro_big_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:noexpand)
push!(out.args, _macro_layer_show(esc(exs[2])))
exs[2:end]
elseif exs[1] == QuoteNode(:ignore)
exs[2:end]
elseif exs[1] isa QuoteNode
error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)")
error("`@layer` accepts only the options `:ignore`, `:noexpand`, and `:expand` before the layer type (to control `show`).")
else
push!(out.args, _macro_layer_show(esc(exs[1])))
push!(out.args, _macro_big_show(esc(exs[1])))
exs
end

Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ end
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)

@layer LayerNorm
@layer :noexpand LayerNorm

function (a::LayerNorm)(x::AbstractArray)
ChainRulesCore.@ignore_derivatives if a.diag isa Scale
Expand Down
12 changes: 6 additions & 6 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ struct Model
h0::AbstractVector
end

Flux.@layer :expand Model
Flux.@layer Model

(m::Model)(x) = m.rnn(x, m.h0)

Expand All @@ -169,7 +169,7 @@ struct RNN{M}
cell::M
end

@layer :expand RNN
@layer RNN

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
Expand Down Expand Up @@ -344,7 +344,7 @@ struct Model
c0::AbstractVector
end

Flux.@layer :expand Model
Flux.@layer Model

(m::Model)(x) = m.lstm(x, (m.h0, m.c0))

Expand All @@ -359,7 +359,7 @@ struct LSTM{M}
cell::M
end

@layer :expand LSTM
@layer LSTM

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -531,7 +531,7 @@ struct GRU{M}
cell::M
end

@layer :expand GRU
@layer GRU

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -669,7 +669,7 @@ struct GRUv3{M}
cell::M
end

@layer :expand GRUv3
@layer GRUv3

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
Expand Down
59 changes: 36 additions & 23 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@nospecialize # just for this file, for startup time

# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression:
# This is called by @layer and returns an expression:
function _macro_big_show(ex)
quote
# Entry point:
Expand Down Expand Up @@ -83,7 +83,7 @@ function _flat_children(x)
gamma = ((beta...)...,)
end

# This is called by @layer, on layers which should be treated like Dense, and returns an expression:
# This is called by @layer :noexpand, on layers which should be treated like Dense, and returns an expression:
function _macro_layer_show(ex)
quote
# Entry point:
Expand Down Expand Up @@ -176,40 +176,53 @@ _all(f, xs) = !_any(!f, xs)

#=

julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
julia> struct Tmp2; x; y; end;

# Before, notice Array(), NamedTuple(), and values
julia> t = Tmp2([Dense(2,3), randn(3,4)'], (x=1:4, y=Dense(3,4), z=rand(3)))
Tmp2(Any[Dense(2 => 3), [-0.559390071462934 -0.6357914190386781 -0.8516823037180543; -2.187495592853204 -0.6807254521505784 -1.2334639710489697; -0.12790952072543338 -1.4672700459421741 1.3687526519721238; 0.5232171922680576 -1.012045481192333 1.4953790632112915]], (x = 1:4, y = Dense(3 => 4), z = [0.29222096031585143, 0.6562195256556428, 0.9741896713499167]))

julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
julia> Chain(t)
Chain(
Tmp2(
Array(
[
Dense(2 => 3), # 9 parameters
[0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
),
NamedTuple(
1:3, # 3 parameters
Dense(3 => 4), # 16 parameters
[0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
(;
x = 4-element UnitRange{Int64},
y = Dense(3 => 4), # 16 parameters
z = 3-element Vector{Float64}, # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.


# After, (; x=, y=, z=) and "3-element Array"
julia> Flux.@layer Tmp2

julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
julia> t
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
4-element UnitRange{Int64},
Dense(3 => 4), # 16 parameters
3-element Vector{Float64}, # 3 parameters
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.

julia> Chain(t)
Chain(
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint, # 12 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
(;
x = 3-element UnitRange, # 3 parameters
y = Dense(3 => 4), # 16 parameters
z = 3-element Array, # 3 parameters
),
4-element UnitRange{Int64},
Dense(3 => 4), # 16 parameters
3-element Vector{Float64}, # 3 parameters
),
) # Total: 7 arrays, 43 parameters, 644 bytes.

) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.
=#
Loading
Loading