Skip to content

Commit

Permalink
add total
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 8, 2024
1 parent e60b71e commit b9bbdc4
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ Optimisers.trainable
Optimisers.isnumeric
```

Such restrictions are also obeyed by this function for flattening a model:
Such restrictions are also obeyed by this function for flattening a model,
and one for applying a function to every parameter:

```@docs
Optimisers.destructure
Optimisers.Restructure
Optimisers.total
```

## Rule Definition
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export AbstractRule
include("adjust.jl")

include("destructure.jl")
export destructure
export destructure, total

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
Expand Down
73 changes: 68 additions & 5 deletions src/destructure.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
const NoT = ChainRulesCore.NoTangent()

"""
destructure(model) -> vector, reconstructor
Expand Down Expand Up @@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
# if p isa ProjectTo # e.g. Array, NamedTuple
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
if x isa Union{Number, AbstractArray} # these don't use Tangent
ProjectTo(x)(unthunk(y))
else
Tangent{typeof(x), typeof(y)}(y)
end
end
Expand Down Expand Up @@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end

"""
total(f, model)
Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
the model, and returns the sum. Differentiable. Counts shared weights once.
# Examples
```jldoctest
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
julia> total(sum, m)
12.0
julia> total(norm, m)
10.0
julia> total(length, m) == length(destructure(m)[1])
true
```
"""
function total(f, x)
values = []
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
sum(values)
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
z, backs = _total_hobbit(config, f, x)
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
z, total_back
end

function _total_hobbit(config::RuleConfig, f, x)
values = []
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
val, back = rrule_via_ad(config, f, y)
push!(values, val)
back
end
sum(values), backs
end

function _total_grad(dz, x, backs)
dfs = []
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
df, dy = b(dz)
push!(dfs, df)
dy
end
sum(dfs), dx
end

function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
function grad_back((df, dx))
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
(NoT, total(dx), NoT, NoT)
end
_total_grad(dz, x, backs), grad_back
end
23 changes: 23 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,26 @@ tmp1
y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4))
@test bk(1.0) == (nothing,)
end

@testset "total" begin
@test total(sum, m1) == sum(1:3)
@test total(prod, m2) == prod(1:3) + prod(4:6)
@test total(sum, m3) == sum(1:6)
@test total(sum, m4) == sum(1:6) # shared only counts once
@test total(sum, m6) == 6 + 4 + im

@test gradient(m -> total(sum, m), m1) == ([1,1,1],)
@test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
@test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
@test g6.a isa Vector{Float64}

@test gradient-> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
@test gradient-> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)

@testset "second derivatives" begin
f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
@test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) Zygote.hessian_dual(f3, [1,2,3.0])
# typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
end
end

0 comments on commit b9bbdc4

Please sign in to comment.