Skip to content

Commit

Permalink
remove children=(α,β) keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 2, 2024
1 parent 29e0d68 commit 2be9099
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 25 deletions.
25 changes: 3 additions & 22 deletions src/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
@layer Dense
@layer :expand Chain
@layer BatchNorm trainable=(β,γ)
@layer Struct children=(α,β) trainable=(β,)
This macro replaces most uses of `@functor`. Its basic purpose is the same:
When you define a new layer, this tells Flux to explore inside it
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
Like `@functor`, this assumes your struct has the default constructor, to enable re-building.
Some keywords allow you to limit this exploration, instead of visiting all `fieldnames(T)`.
The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`.
Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes.
* If some fields look like parameters but should not be trained,
then `trainable` lets you specify which fields to include, while the rest are ignored.
* You can likewise add restrictions to Functors's `children` (although this is seldom a good idea),
equivalent to `@functor Struct (α,β)`. Any `trainable` limitation must then be a subset of `children`.
The macro also handles overloads of `show` for pretty printing.
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
Expand Down Expand Up @@ -69,21 +66,14 @@ macro layer(exs...)
# This function exists only for depwarns when you use @functor directly
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing))

i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :children, rest)
if isnothing(i) # then default like @functor Layer
push!(out.args, _macro_functor(esc(type)))
else
push!(out.args, _macro_functor(esc(type), rest[i].args[2]))
end
push!(out.args, _macro_functor(esc(type)))

for j in 1:length(rest)
j == i && continue
ex = rest[j]
Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex")

name = if ex.args[1] == :trainable
:(Optimisers.trainable)
elseif ex.args[1] == :functor
error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.")
else
error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.")
# @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
Expand Down Expand Up @@ -141,15 +131,6 @@ function _default_functor(::Type{T}, x) where {T}
namedtuple(x), spl(Base.typename(T).wrapper)
end
end

function _custom_functor(::Type{T}, x, ::Val{which}) where {T,which}
if false
# TODO write the @generated version. Or decide we don't care, or should forbid this?
else
remake(nt) = Base.typename(T).wrapper(map(f -> f in which ? getfield(nt, f) : getfield(x, f), fieldnames(T))...)
NamedTuple{which}(map(s -> getfield(x, s), which)), remake
end
end

function namedtuple(x::T) where T
F = fieldnames(T)
Expand Down
8 changes: 5 additions & 3 deletions test/layers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ end

m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6])
# Check that we can use the macro with a qualified type name, outside the defining module:
Flux.@layer :expand MacroTest.TwoThirds children=(:a,:c) trainable=(:a) # documented as (a,c) but allow quotes
Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes

@test Functors.children(m23) == (a = [1 2], c = [5 6])
m23re = Functors.functor(m23)[2]((a = [10 20], c = [50 60]))
m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60]))
@test m23re isa MacroTest.TwoThirds
@test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60])

@test Optimisers.trainable(m23) == (a = [1 2],)

@test_throws LoadError @eval Flux.@layer :zzz MacroTest.TwoThirds
@test_throws LoadError @eval Flux.@layer MacroTest.TwoThirds chidren=(a, b)
end

0 comments on commit 2be9099

Please sign in to comment.