Skip to content

Commit

Permalink
add total
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 14, 2022
1 parent 1e34fa2 commit fffd297
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ optimiser to act on all suitable fields. To restrict this, define `trainable`:
Optimisers.trainable
```

Such restrictions are also obeyed by this function for flattening a model:
Such restrictions are also obeyed by this function for flattening a model,
and one for applying a function to every parameter:

```@docs
Optimisers.destructure
Optimisers.Restructure
Optimisers.total
```

## Rule Definition
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra
include("interface.jl")

include("destructure.jl")
export destructure, total, total2
export destructure, total

include("rules.jl")
export Descent, ADAM, Momentum, Nesterov, RMSProp,
Expand Down
73 changes: 68 additions & 5 deletions src/destructure.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
const NoT = ChainRulesCore.NoTangent()

"""
destructure(model) -> vector, reconstructor
Expand Down Expand Up @@ -107,9 +107,11 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
# if p isa ProjectTo # e.g. Array, NamedTuple
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
if x isa Union{Number, AbstractArray} # these don't use Tangent
ProjectTo(x)(unthunk(y))
else
Tangent{typeof(x), typeof(y)}(y)
end
end
Expand Down Expand Up @@ -150,3 +152,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end

"""
total(f, model)
Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
the model, and returns the sum. Differentiable. Counts shared weights once.
# Examples
```jldoctest
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
julia> total(sum, m)
12.0
julia> total(norm, m)
10.0
julia> total(length, m) == length(destructure(m)[1])
true
```
"""
function total(f, x)
values = []
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
sum(values)
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
z, backs = _total_hobbit(config, f, x)
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
z, total_back
end

function _total_hobbit(config::RuleConfig, f, x)
values = []
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
val, back = rrule_via_ad(config, f, y)
push!(values, val)
back
end
sum(values), backs
end

function _total_grad(dz, x, backs)
dfs = []
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
df, dy = b(dz)
push!(dfs, df)
dy
end
sum(dfs), dx
end

function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
function grad_back((df, dx))
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
(NoT, total(dx), NoT, NoT)
end
_total_grad(dz, x, backs), grad_back
end
23 changes: 23 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,26 @@ end
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
end

@testset "total" begin
@test total(sum, m1) == sum(1:3)
@test total(prod, m2) == prod(1:3) + prod(4:6)
@test total(sum, m3) == sum(1:6)
@test total(sum, m4) == sum(1:6) # shared only counts once
@test total(sum, m6) == 6 + 4 + im

@test gradient(m -> total(sum, m), m1) == ([1,1,1],)
@test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
@test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
@test g6.a isa Vector{Float64}

@test gradient-> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
@test gradient-> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)

@testset "second derivatives" begin
f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
@test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) Zygote.hessian_dual(f3, [1,2,3.0])
# typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
end
end

0 comments on commit fffd297

Please sign in to comment.