Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set expand option as default for @layer #2532

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]

model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)

optim = Flux.setup(Adam(), model)
opt_state = Flux.setup(Adam(), model)
for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, opt_state)
end

plot(x -> 2x-x^3, -2, 2, legend=false)
Expand Down
4 changes: 2 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export gradient
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice,
XLADevice,
# get_device, # we define get_device here for retrocompatibility
# gpu_backend!, # have to define here due to https://github.com/JuliaPackaging/Preferences.jl/issues/39
gpu_backend!,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change: with v0.15 we don't need to define gpu_backend!, we can just reexport the one from MLDataDevices

get_device_type,
DeviceIterator

Expand Down Expand Up @@ -118,7 +118,7 @@ include("losses/Losses.jl")
using .Losses

include("devices.jl")
export get_device, gpu_backend!
export get_device

# Distributed Training
include("distributed/backend.jl")
Expand Down
12 changes: 0 additions & 12 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@ julia> m.bias
"""
cpu(x) = cpu_device()(x)

# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089
ChainRulesCore.@non_differentiable cpu_device()


# Remove when
# https://github.com/JuliaPackaging/Preferences.jl/issues/39
# is resolved
function gpu_backend!(backend::String)
@set_preferences!("gpu_backend" => backend)
MLDataDevices.gpu_backend!(backend)
end

"""
gpu(m)

Expand Down
10 changes: 5 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@layer :expand Chain # the option :expand opts-in to container-style pretty-printing
@layer Chain

(c::Chain)(x) = _applychain(c.layers, x)
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))
Expand Down Expand Up @@ -334,7 +334,7 @@ end
Maxout(layers...) = Maxout(layers)
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)

@layer :expand Maxout
@layer Maxout

function (mo::Maxout)(input::AbstractArray)
# Perhaps surprisingly, pairwise max broadcast is often faster,
Expand Down Expand Up @@ -381,7 +381,7 @@ struct SkipConnection{T,F}
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
end

@layer :expand SkipConnection
@layer SkipConnection

function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
Expand Down Expand Up @@ -575,7 +575,7 @@ end
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))

@layer :expand Parallel
@layer Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument

Expand Down Expand Up @@ -705,7 +705,7 @@ end
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@layer :expand PairwiseFusion
@layer PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Expand Down
30 changes: 15 additions & 15 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@

"""
@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
@layer MyModel
@layer MyModel trainable=(β,γ)

This macro adds convenience functionality to a custom type to serve
as a neural network layer, module, or entire model.
as a neural network layer, as a module, or as an entire model.

The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`.
The keyword `trainable` allows you to specify which fiels of you model can be trained,
instead of assuming all `fieldnames(MyModel)` to trainable.
Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
This can be also done by defining [`trainable(::MyModel)`](@ref Optimisers.trainable) for your type.

The macro also handles overloads of `show` for pretty printing.
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
* To disable all `show` overloads, there is an `:ignore` option too.

(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
The macro also handles overloads of `show` for pretty printing.
It adds methods to `show(::IO, ::MIME"text/plain", ::MyModel)` to treat your layer much like `Dense` or `Chain`.
To opt out of this, use `@layer :ignore MyModel`.
In case, you probably still want to define 2-arg `show(::IO, ::MyModel)`, the macro does not touch this.

Note that re-running the macro with different options may not remove all methods, you will need to restart.

Expand All @@ -26,7 +25,7 @@
julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4))
Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))

julia> Flux.@layer :expand Trio
julia> Flux.@layer Trio

julia> tri # now the layer is printed like Chain
Trio(
Expand All @@ -46,14 +45,15 @@

# These functions are defined in show.jl, and each return an expression overloading Base.show
type, rest... = if exs[1] == QuoteNode(:expand)
push!(out.args, _macro_big_show(esc(exs[2])))
@warn "The `:expand` option is deprecated, and will be removed in a future release. Use `@layer` without options instead." maxlog=1
push!(out.args, _macro_big_show(esc(exs[1])))

Check warning on line 49 in src/layers/macro.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/macro.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
exs[2:end]
elseif exs[1] == QuoteNode(:ignore)
exs[2:end]
elseif exs[1] isa QuoteNode
error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)")
error("`@layer` accepts only the option `:ignore` before the layer type (to control `show`).")

Check warning on line 54 in src/layers/macro.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/macro.jl#L54

Added line #L54 was not covered by tests
else
push!(out.args, _macro_layer_show(esc(exs[1])))
push!(out.args, _macro_big_show(esc(exs[1])))
exs
end

Expand Down
12 changes: 6 additions & 6 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ struct Model
h0::AbstractVector
end

Flux.@layer :expand Model
Flux.@layer Model

(m::Model)(x) = m.rnn(x, m.h0)

Expand All @@ -169,7 +169,7 @@ struct RNN{M}
cell::M
end

@layer :expand RNN
@layer RNN

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
Expand Down Expand Up @@ -344,7 +344,7 @@ struct Model
c0::AbstractVector
end

Flux.@layer :expand Model
Flux.@layer Model

(m::Model)(x) = m.lstm(x, (m.h0, m.c0))

Expand All @@ -359,7 +359,7 @@ struct LSTM{M}
cell::M
end

@layer :expand LSTM
@layer LSTM

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -531,7 +531,7 @@ struct GRU{M}
cell::M
end

@layer :expand GRU
@layer GRU

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
Expand Down Expand Up @@ -669,7 +669,7 @@ struct GRUv3{M}
cell::M
end

@layer :expand GRUv3
@layer GRUv3

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
Expand Down
74 changes: 35 additions & 39 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@nospecialize # just for this file, for startup time

# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression:
# This is called by @layer and returns an expression:
function _macro_big_show(ex)
quote
# Entry point:
Expand Down Expand Up @@ -83,23 +83,6 @@ function _flat_children(x)
gamma = ((beta...)...,)
end

# This is called by @layer, on layers which should be treated like Dense, and returns an expression:
function _macro_layer_show(ex)
quote
# Entry point:
function Base.show(io::IO, m::MIME"text/plain", x::$ex)
if !get(io, :compact, false)
_layer_show(io, x)
else
show(io, x)
end
end

# Exit from _big_show recursion:
Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name)
end
end

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * _layer_string(io, layer)
Expand Down Expand Up @@ -176,40 +159,53 @@ _all(f, xs) = !_any(!f, xs)

#=

julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
julia> struct Tmp2; x; y; end;

# Before, notice Array(), NamedTuple(), and values
julia> t = Tmp2([Dense(2,3), randn(3,4)'], (x=1:4, y=Dense(3,4), z=rand(3)))
Tmp2(Any[Dense(2 => 3), [-0.559390071462934 -0.6357914190386781 -0.8516823037180543; -2.187495592853204 -0.6807254521505784 -1.2334639710489697; -0.12790952072543338 -1.4672700459421741 1.3687526519721238; 0.5232171922680576 -1.012045481192333 1.4953790632112915]], (x = 1:4, y = Dense(3 => 4), z = [0.29222096031585143, 0.6562195256556428, 0.9741896713499167]))

julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
julia> Chain(t)
Chain(
Tmp2(
Array(
[
Dense(2 => 3), # 9 parameters
[0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
),
NamedTuple(
1:3, # 3 parameters
Dense(3 => 4), # 16 parameters
[0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
(;
x = 4-element UnitRange{Int64},
y = Dense(3 => 4), # 16 parameters
z = 3-element Vector{Float64}, # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.


# After, (; x=, y=, z=) and "3-element Array"
julia> Flux.@layer Tmp2

julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
julia> t
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
4-element UnitRange{Int64},
Dense(3 => 4), # 16 parameters
3-element Vector{Float64}, # 3 parameters
) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.

julia> Chain(t)
Chain(
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint, # 12 parameters
4×3 Adjoint{Float64,...}, # 12 parameters
],
(;
x = 3-element UnitRange, # 3 parameters
y = Dense(3 => 4), # 16 parameters
z = 3-element Array, # 3 parameters
),
4-element UnitRange{Int64},
Dense(3 => 4), # 16 parameters
3-element Vector{Float64}, # 3 parameters
),
) # Total: 7 arrays, 43 parameters, 644 bytes.

) # Total: 6 trainable arrays, 40 parameters,
# plus 1 non-trainable, 4 parameters, summarysize 620 bytes.
=#
8 changes: 4 additions & 4 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
h0::AbstractVector
end

Flux.@layer :expand ModelRNN
Flux.@layer ModelRNN

(m::ModelRNN)(x) = m.rnn(x, m.h0)

Expand Down Expand Up @@ -74,7 +74,7 @@ end
c0::AbstractVector
end

Flux.@layer :expand ModelLSTM
Flux.@layer ModelLSTM

(m::ModelLSTM)(x) = m.lstm(x, (m.h0, m.c0))

Expand Down Expand Up @@ -113,7 +113,7 @@ end
h0::AbstractVector
end

Flux.@layer :expand ModelGRU
Flux.@layer ModelGRU

(m::ModelGRU)(x) = m.gru(x, m.h0)

Expand Down Expand Up @@ -150,7 +150,7 @@ end
h0::AbstractVector
end

Flux.@layer :expand ModelGRUv3
Flux.@layer ModelGRUv3

(m::ModelGRUv3)(x) = m.gru(x, m.h0)

Expand Down
4 changes: 2 additions & 2 deletions test/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module MacroTest
using Flux: @layer

struct Duo{T,S}; x::T; y::S; end
@layer :expand Duo
@layer Duo

struct Trio; a; b; c end
# @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget
Expand Down Expand Up @@ -33,7 +33,7 @@ end

m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6])
# Check that we can use the macro with a qualified type name, outside the defining module:
Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes
Flux.@layer MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes

m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60]))
@test m23re isa MacroTest.TwoThirds
Expand Down
Loading
Loading