From e69bd98a045d4dd83f1d4828effa7ed1cba1bda2 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 16 Nov 2024 11:02:56 +0100 Subject: [PATCH] cleanup --- src/deprecations.jl | 3 +-- src/layers/macro.jl | 26 +++++++++++++------------- src/train.jl | 13 ++++++++++++- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index 2ee56f05e7..1b85345f5d 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -81,8 +81,7 @@ function reset!(x) return x end - -function params!(p::Params, x, seen = IdSet()) +function params!(p::Zygote.Params, x, seen = IdSet()) if x isa AbstractArray{<:Number} && Functors.isleaf(x) return push!(p, x) elseif x in seen diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 065774602a..9f9d0435ec 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -117,19 +117,19 @@ end _macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma function _default_functor(::Type{T}, x) where {T} - if @generated - F = fieldnames(T) - args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) - C = Base.typename(T).wrapper # constructor - # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) - recon = :(Base.splat($C)) - :((NamedTuple{$F}(($(args...),)), $recon)) - else - # Getting this parameterless type takes about 2μs, every time: - # spl = VERSION > v"1.9-" ? Splat : Base.splat - spl = Base.splat - namedtuple(x), spl(Base.typename(T).wrapper) - end + if @generated + F = fieldnames(T) + args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) + C = Base.typename(T).wrapper # constructor + # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + recon = :(Base.splat($C)) + :((NamedTuple{$F}(($(args...),)), $recon)) + else + # Getting this parameterless type takes about 2μs, every time: + # spl = VERSION > v"1.9-" ? Splat : Base.splat + spl = Base.splat + namedtuple(x), spl(Base.typename(T).wrapper) + end end function namedtuple(x::T) where T diff --git a/src/train.jl b/src/train.jl index 2815c15762..bf52256c53 100644 --- a/src/train.jl +++ b/src/train.jl @@ -3,7 +3,7 @@ module Train using LinearAlgebra using Optimisers: Optimisers using Functors: fmap, fmapstructure -using ..Flux: Flux # used only in docstring +using ..Flux: Flux using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote @@ -133,4 +133,15 @@ function _rule_to_state(model, rule::Optimisers.AbstractRule) state end + +# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1] +# Can't catch every case, but can catch many simple Flux models: + +function Optimisers.update!(opt, model::Flux.Chain, grads::Tuple) + # Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent + @warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone, + not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.""" + return Optimisers.update!(opt, model, grads[1]) +end + end # module Train