-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add total(f, model)
to replace implicit sum(f, Flux.params(model))
#57
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
|
||
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk | ||
const NoT = NoTangent() | ||
using ChainRulesCore: ChainRulesCore, ProjectTo, unthunk, RuleConfig, HasReverseMode, rrule_via_ad | ||
const NoT = ChainRulesCore.NoTangent() | ||
|
||
""" | ||
destructure(model) -> vector, reconstructor | ||
|
@@ -124,9 +124,11 @@ function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT | |
y = _trainmap(f, ch, _trainable(x), au) | ||
y isa Tuple{} && return NoT | ||
p = ProjectTo(x) | ||
if p isa ProjectTo # e.g. Array, NamedTuple | ||
p(y) | ||
else # p === identity for unknown structs | ||
# if p isa ProjectTo # e.g. Array, NamedTuple | ||
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538 | ||
if x isa Union{Number, AbstractArray} # these don't use Tangent | ||
ProjectTo(x)(unthunk(y)) | ||
else | ||
Tangent{typeof(x), typeof(y)}(y) | ||
end | ||
end | ||
|
@@ -174,3 +176,64 @@ function ChainRulesCore.rrule(::typeof(_maybewarn)) | |
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3 | ||
nothing, _ -> (NoT,) | ||
end | ||
|
||
""" | ||
total(f, model) | ||
|
||
Applies `f` to every [`trainable`](@ref), [`isnumeric`](@ref) parameter in | ||
the model, and returns the sum. Differentiable. Counts shared weights once. | ||
|
||
# Examples | ||
```jldoctest | ||
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7)); | ||
|
||
julia> total(sum, m) | ||
12.0 | ||
|
||
julia> total(norm, m) | ||
10.0 | ||
|
||
julia> total(length, m) == length(destructure(m)[1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would solve FluxML/Flux.jl#2043 (as long as trainable parameters are what you want). Or |
||
true | ||
``` | ||
""" | ||
function total(f, x) | ||
values = [] | ||
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sum(values) | ||
end | ||
Comment on lines
+200
to
+204
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While
|
||
|
||
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(total), f, x) | ||
z, backs = _total_hobbit(config, f, x) | ||
total_back(dz) = (NoT, _total_grad(unthunk(dz), x, backs)...) | ||
z, total_back | ||
end | ||
|
||
function _total_hobbit(config::RuleConfig, f, x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A brief comment about what this and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed. I rebased this but realised I have no memory of how it worked. Will revise or re-write. |
||
values = [] | ||
backs = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y | ||
val, back = rrule_via_ad(config, f, y) | ||
push!(values, val) | ||
back | ||
end | ||
sum(values), backs | ||
end | ||
|
||
function _total_grad(dz, x, backs) | ||
dfs = [] | ||
dx = fmap(x, backs; exclude = isnumeric, walk = _Tangent_biwalk, prune = NoT) do y, b | ||
df, dy = b(dz) | ||
push!(dfs, df) | ||
dy | ||
end | ||
sum(dfs), dx | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(_total_grad), dz, x, backs) | ||
@warn "second derivatives of total(f, x) may not work yet, sorry!" maxlog=3 | ||
function grad_back((df, dx)) | ||
df isa Zero || @error "second derivatives of total(f, x) with respect to the function are wrong!" | ||
(NoT, total(dx), NoT, NoT) | ||
end | ||
_total_grad(dz, x, backs), grad_back | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is either a bug in earlier
_Tangent_biwalk
, or in ChainRulesCore.