Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add total(f, model) to replace implicit sum(f, Flux.params(model)) #57

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ Optimisers.trainable
Optimisers.isnumeric
```

Such restrictions are also obeyed by this function for flattening a model:
Such restrictions are also obeyed by this function for flattening a model,
and one for applying a function to every parameter:

```@docs
Optimisers.destructure
Optimisers.Restructure
Optimisers.total
```

## Rule Definition
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export AbstractRule
include("adjust.jl")

include("destructure.jl")
export destructure
export destructure, total

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
Expand Down
73 changes: 68 additions & 5 deletions src/destructure.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad
const NoT = ChainRulesCore.NoTangent()

"""
destructure(model) -> vector, reconstructor
Expand Down Expand Up @@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
# if p isa ProjectTo # e.g. Array, NamedTuple
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
if x isa Union{Number, AbstractArray} # these don't use Tangent
ProjectTo(x)(unthunk(y))
else
Tangent{typeof(x), typeof(y)}(y)
Comment on lines -127 to 132
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is either a bug in earlier _Tangent_biwalk, or in ChainRulesCore.

end
end
Expand Down Expand Up @@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end

"""
total(f, model)

Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in
the model, and returns the sum. Differentiable. Counts shared weights once.

# Examples
```jldoctest
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));

julia> total(sum, m)
12.0

julia> total(norm, m)
10.0

julia> total(length, m) == length(destructure(m)[1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would solve FluxML/Flux.jl#2043 (as long as trainable parameters are what you want).

Or total(Base.summarysize, m) for bytes, total(_ -> 1, m) to count arrays.

true
```
"""
function total(f, x)
values = []
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
sum(values)
end
Comment on lines +200 to +204
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While Any[] doesn't seem great, this ends up about the same speed as my other idea:

const INIT = Base._InitialValue()

function total2(f, x; init = INIT)
  fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y
    val = f(y)
    init = init===INIT ? val : (init+val)
  end
  init
end

julia> @btime total(norm, $model)  # Resnet from the docs
  min 23.863 ms, mean 23.995 ms (1541 allocations, 130.06 KiB)
730.5533f0

julia> @btime total2(norm, $model)
  min 23.834 ms, mean 23.982 ms (1538 allocations, 128.17 KiB)
730.5533f0

julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));

julia> @btime total(norm, $m)
  min 1.750 μs, mean 1.846 μs (16 allocations, 752 bytes)
10.0

julia> @btime total2(norm, $m)
  min 1.675 μs, mean 1.769 μs (15 allocations, 640 bytes)
10.0


function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x)
z, backs = _total_hobbit(config, f, x)
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...)
z, total_back
end

function _total_hobbit(config::RuleConfig, f, x)
Copy link
Member

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brief comment about what this and _total_grad do would help. "hobbit" in particular is alien terminology for anyone who hasn't read a couple of specific issues on the ChainRules repo 😛. Is there something more concise than _total_value_and_inner_pullbacks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I rebased this but realised I have no memory of how it worked. Will revise or re-write.

values = []
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
val, back = rrule_via_ad(config, f, y)
push!(values, val)
back
end
sum(values), backs
end

function _total_grad(dz, x, backs)
dfs = []
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b
df, dy = b(dz)
push!(dfs, df)
dy
end
sum(dfs), dx
end

function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs)
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3
function grad_back((df, dx))
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!"
(NoT, total(dx), NoT, NoT)
end
_total_grad(dz, x, backs), grad_back
end
23 changes: 23 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,26 @@ tmp1
y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4))
@test bk(1.0) == (nothing,)
end

@testset "total" begin
@test total(sum, m1) == sum(1:3)
@test total(prod, m2) == prod(1:3) + prod(4:6)
@test total(sum, m3) == sum(1:6)
@test total(sum, m4) == sum(1:6) # shared only counts once
@test total(sum, m6) == 6 + 4 + im

@test gradient(m -> total(sum, m), m1) == ([1,1,1],)
@test gradient(m -> total(sum, m), m3)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
@test gradient(m -> total(sum, m), m4)[1] == (x = [1,1,1], y = nothing, z = [1,1,1])
g6 = gradient(m -> abs2(total(sum, m)), m6)[1]
@test g6.a isa Vector{Float64}

@test gradient(λ -> total(x -> sum(x.*λ), m3), 1.0) == (21.0,)
@test gradient(λ -> total(x -> sum(x.*λ), m4), 1.0) == (21.0,)

@testset "second derivatives" begin
f3 = v -> total(norm, (x=v, y=sin, z=[4,5,6.0]))
@test_broken Zygote.hessian_reverse(f3, [1,2,3.0]) ≈ Zygote.hessian_dual(f3, [1,2,3.0])
# typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple...
end
end
Loading