Skip to content

Commit

Permalink
try out Ref for withgradient
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 23, 2024
1 parent 0bef0d6 commit ecef1f0
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,33 @@ function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::B
tape, result, shadow_result = forward(Const(f), args...)
reverse(Const(f), args..., _sensitivity(result), tape)

# Take III, it may be more efficient to have the function write the loss into Ref(0.0), seed Duplicated(that, Ref(1.0))?
# Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
# Some cases work, but Flux.withgradient(m -> m(3), Duplicated(model)) does not.
# dup_loss = DuplicatedNoNeed(Ref(0.0), Ref(1.0))
# result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...)

(; val = result, grad = map(_grad_or_nothing, args))
end

_sensitivity(y::Real) = one(y)
_sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...)
_sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys)))
@inline _sensitivity(y::Real) = one(y)
@inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...)
@inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys)))
_sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,

Check warning on line 62 in ext/FluxEnzymeExt/FluxEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FluxEnzymeExt/FluxEnzymeExt.jl#L62

Added line #L62 was not covered by tests
or else a Tuple or NamedTuple whose first element is a real number.""")

# function _ref_loss!(out::Ref, f, args...) # for Take III above
# val = f(args...)
# @show val
# out[] = _get_loss(val) # saves loss by mutation
# val # returns the whole thing
# end

# @inline _get_loss(y::Real) = y
# @inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1]
# @inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1]
# _get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
# or else a Tuple or NamedTuple whose first element is a real number.""")

### Flux.Train, for train!

_applyloss(loss, model, d...) = loss(model, d...)
Expand Down

0 comments on commit ecef1f0

Please sign in to comment.