From b9bbdc42356dd735f80263ca5e145c2abfe94be2 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 14 Feb 2022 17:54:40 -0500
Subject: [PATCH] add total

---
 docs/src/api.md     |  4 ++-
 src/Optimisers.jl   |  2 +-
 src/destructure.jl  | 73 +++++++++++++++++++++++++++++++++++++++++----
 test/destructure.jl | 23 ++++++++++++++
 4 files changed, 95 insertions(+), 7 deletions(-)

diff --git a/docs/src/api.md b/docs/src/api.md
index 4a9d62bb..9151d770 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -52,11 +52,13 @@ Optimisers.trainable
 Optimisers.isnumeric
 ```
 
-Such restrictions are also obeyed by this function for flattening a model:
+Such restrictions are also obeyed by this function for flattening a model,
+and one for applying a function to every parameter:
 
 ```@docs
 Optimisers.destructure
 Optimisers.Restructure
+Optimisers.total
 ```
 
 ## Rule Definition
diff --git a/src/Optimisers.jl b/src/Optimisers.jl
index 9254bb77..ea821723 100644
--- a/src/Optimisers.jl
+++ b/src/Optimisers.jl
@@ -9,7 +9,7 @@ export AbstractRule
 include("adjust.jl")
 
 include("destructure.jl")
-export destructure
+export destructure, total
 
 include("rules.jl")
 export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
diff --git a/src/destructure.jl b/src/destructure.jl
index 3b21d918..da4dd823 100644
--- a/src/destructure.jl
+++ b/src/destructure.jl
@@ -1,6 +1,6 @@
 
-using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
-const NoT = NoTangent()
+using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
+const NoT = ChainRulesCore.NoTangent()
 
 """
     destructure(model) -> vector, reconstructor
@@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux)  # use with prune = NoT
   y = _trainmap(f, ch, _trainable(x), au)
   y isa Tuple{} && return NoT
   p = ProjectTo(x)
-  if p isa ProjectTo  # e.g. Array, NamedTuple
-    p(y)
-  else  # p === identity for unknown structs
+  # if p isa ProjectTo  # e.g. Array, NamedTuple
+  #   p(y)  # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
+  if x isa Union{Number, AbstractArray}  # these don't use Tangent
+    ProjectTo(x)(unthunk(y))
+  else
     Tangent{typeof(x), typeof(y)}(y)
   end
 end
@@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
   @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
   nothing, _ -> (NoT,)
 end
+
+"""
+    total(f, model)
+
+Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
+the model, and returns the sum. Differentiable. Counts shared weights once. 
+
+# Examples
+```jldoctest
+julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
+
+julia> total(sum, m)
+12.0
+
+julia> total(norm, m)
+10.0
+
+julia> total(length, m) == length(destructure(m)[1])
+true
+```
+"""
+function total(f, x)
+  values = []
+  fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
+  sum(values)
+end
+
+function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
+  z, backs = _total_hobbit(config, f, x)
+  total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
+  z, total_back
+end
+
+function _total_hobbit(config::RuleConfig, f, x)
+  values = []
+  backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
+    val, back = rrule_via_ad(config, f, y)
+    push!(values, val)
+    back
+  end
+  sum(values), backs
+end
+
+function _total_grad(dz, x, backs)
+  dfs = []
+  dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
+    df, dy = b(dz)
+    push!(dfs, df)
+    dy
+  end
+  sum(dfs), dx
+end
+
+function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
+  @warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
+  function grad_back((df, dx))
+    df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
+    (NoT, total(dx), NoT, NoT)
+  end
+  _total_grad(dz, x, backs), grad_back
+end
diff --git a/test/destructure.jl b/test/destructure.jl
index 232a9001..c656444e 100644
--- a/test/destructure.jl
+++ b/test/destructure.jl
@@ -296,3 +296,26 @@ tmp1
     y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4))
     @test bk(1.0) == (nothing,)
 end
+
+@testset "total" begin
+  @test total(sum, m1) == sum(1:3)
+  @test total(prod, m2) == prod(1:3) + prod(4:6)
+  @test total(sum, m3) == sum(1:6)
+  @test total(sum, m4) == sum(1:6)  # shared only counts once
+  @test total(sum, m6) == 6 + 4 + im
+
+  @test gradient(m -> total(sum, m), m1) == ([1,1,1],)
+  @test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
+  @test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
+  g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
+  @test g6.a isa Vector{Float64}
+
+  @test gradient(λ -> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
+  @test gradient(λ -> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)
+
+  @testset "second derivatives" begin
+    f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
+    @test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) ≈ Zygote.hessian_dual(f3, [1,2,3.0])
+    # typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
+  end
+end