From 2be90998d33b1a582371105afc837d1c8840c400 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 1 Mar 2024 21:44:43 -0500 Subject: [PATCH] =?UTF-8?q?remove=20children=3D(=CE=B1,=CE=B2)=20keyword?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/layers/macro.jl | 25 +++---------------------- test/layers/macro.jl | 8 +++++--- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 02e5ef4540..9e770add87 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -3,19 +3,16 @@ @layer Dense @layer :expand Chain @layer BatchNorm trainable=(β,γ) - @layer Struct children=(α,β) trainable=(β,) This macro replaces most uses of `@functor`. Its basic purpose is the same: When you define a new layer, this tells Flux to explore inside it to see the parameters it trains, and also to move them to the GPU, change precision, etc. Like `@functor`, this assumes your struct has the default constructor, to enable re-building. -Some keywords allow you to limit this exploration, instead of visiting all `fieldnames(T)`. +The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. * If some fields look like parameters but should not be trained, then `trainable` lets you specify which fields to include, while the rest are ignored. -* You can likewise add restrictions to Functors's `children` (although this is seldom a good idea), - equivalent to `@functor Struct (α,β)`. Any `trainable` limitation must then be a subset of `children`. The macro also handles overloads of `show` for pretty printing. * By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. @@ -69,21 +66,14 @@ macro layer(exs...) # This function exists only for depwarns when you use @functor directly push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) - i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :children, rest) - if isnothing(i) # then default like @functor Layer - push!(out.args, _macro_functor(esc(type))) - else - push!(out.args, _macro_functor(esc(type), rest[i].args[2])) - end + push!(out.args, _macro_functor(esc(type))) + for j in 1:length(rest) - j == i && continue ex = rest[j] Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex") name = if ex.args[1] == :trainable :(Optimisers.trainable) - elseif ex.args[1] == :functor - error("Can't use `functor=(...)` as a keyword to `@layer`. Use `childen=(...)` to define a method for `functor`.") else error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.") # @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1 @@ -141,15 +131,6 @@ function _default_functor(::Type{T}, x) where {T} namedtuple(x), spl(Base.typename(T).wrapper) end end - -function _custom_functor(::Type{T}, x, ::Val{which}) where {T,which} - if false - # TODO write the @generated version. Or decide we don't care, or should forbid this? - else - remake(nt) = Base.typename(T).wrapper(map(f -> f in which ? getfield(nt, f) : getfield(x, f), fieldnames(T))...) - NamedTuple{which}(map(s -> getfield(x, s), which)), remake - end -end function namedtuple(x::T) where T F = fieldnames(T) diff --git a/test/layers/macro.jl b/test/layers/macro.jl index 1361a895f4..e41d5a2240 100644 --- a/test/layers/macro.jl +++ b/test/layers/macro.jl @@ -33,13 +33,15 @@ end m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6]) # Check that we can use the macro with a qualified type name, outside the defining module: - Flux.@layer :expand MacroTest.TwoThirds children=(:a,:c) trainable=(:a) # documented as (a,c) but allow quotes + Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes - @test Functors.children(m23) == (a = [1 2], c = [5 6]) - m23re = Functors.functor(m23)[2]((a = [10 20], c = [50 60])) + m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60])) @test m23re isa MacroTest.TwoThirds @test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) @test Optimisers.trainable(m23) == (a = [1 2],) + + @test_throws LoadError @eval Flux.@layer :zzz MacroTest.TwoThirds + @test_throws LoadError @eval Flux.@layer MacroTest.TwoThirds chidren=(a, b) end