From b9bbdc42356dd735f80263ca5e145c2abfe94be2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 14 Feb 2022 17:54:40 -0500 Subject: [PATCH] add total --- docs/src/api.md | 4 ++- src/Optimisers.jl | 2 +- src/destructure.jl | 73 +++++++++++++++++++++++++++++++++++++++++---- test/destructure.jl | 23 ++++++++++++++ 4 files changed, 95 insertions(+), 7 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 4a9d62bb..9151d770 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -52,11 +52,13 @@ Optimisers.trainable Optimisers.isnumeric ``` -Such restrictions are also obeyed by this function for flattening a model: +Such restrictions are also obeyed by this function for flattening a model, +and one for applying a function to every parameter: ```@docs Optimisers.destructure Optimisers.Restructure +Optimisers.total ``` ## Rule Definition diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 9254bb77..ea821723 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -9,7 +9,7 @@ export AbstractRule include("adjust.jl") include("destructure.jl") -export destructure +export destructure, total include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, diff --git a/src/destructure.jl b/src/destructure.jl index 3b21d918..da4dd823 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -1,6 +1,6 @@ -using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk -const NoT = NoTangent() +using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad +const NoT = ChainRulesCore.NoTangent() """ destructure(model) -> vector, reconstructor @@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT y = _trainmap(f, ch, _trainable(x), au) y isa Tuple{} && return NoT p = ProjectTo(x) - if p isa ProjectTo # e.g. Array, NamedTuple - p(y) - else # p === identity for unknown structs + # if p isa ProjectTo # e.g. Array, NamedTuple + # p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538 + if x isa Union{Number, AbstractArray} # these don't use Tangent + ProjectTo(x)(unthunk(y)) + else Tangent{typeof(x), typeof(y)}(y) end end @@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn)) @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3 nothing, _ -> (NoT,) end + +""" + total(f, model) + +Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in +the model, and returns the sum. Differentiable. Counts shared weights once. + +# Examples +```jldoctest +julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7)); + +julia> total(sum, m) +12.0 + +julia> total(norm, m) +10.0 + +julia> total(length, m) == length(destructure(m)[1]) +true +``` +""" +function total(f, x) + values = [] + fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) + sum(values) +end + +function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x) + z, backs = _total_hobbit(config, f, x) + total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...) + z, total_back +end + +function _total_hobbit(config::RuleConfig, f, x) + values = [] + backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y + val, back = rrule_via_ad(config, f, y) + push!(values, val) + back + end + sum(values), backs +end + +function _total_grad(dz, x, backs) + dfs = [] + dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b + df, dy = b(dz) + push!(dfs, df) + dy + end + sum(dfs), dx +end + +function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs) + @warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3 + function grad_back((df, dx)) + df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!" + (NoT, total(dx), NoT, NoT) + end + _total_grad(dz, x, backs), grad_back +end diff --git a/test/destructure.jl b/test/destructure.jl index 232a9001..c656444e 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -296,3 +296,26 @@ tmp1 y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4)) @test bk(1.0) == (nothing,) end + +@testset "total" begin + @test total(sum, m1) == sum(1:3) + @test total(prod, m2) == prod(1:3) + prod(4:6) + @test total(sum, m3) == sum(1:6) + @test total(sum, m4) == sum(1:6) # shared only counts once + @test total(sum, m6) == 6 + 4 + im + + @test gradient(m -> total(sum, m), m1) == ([1,1,1],) + @test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1]) + @test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1]) + g6 = gradient(m -> abs2(total(sum, m)), m6)[1] + @test g6.a isa Vector{Float64} + + @test gradient(λ -> total(x -> sum(x.*λ), m3), 1.0) == (21.0,) + @test gradient(λ -> total(x -> sum(x.*λ), m4), 1.0) == (21.0,) + + @testset "second derivatives" begin + f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0])) + @test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) ≈ Zygote.hessian_dual(f3, [1,2,3.0]) + # typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple... + end +end