Skip to content

Commit

Permalink
maybe this works?
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 27, 2024
1 parent 5c1650f commit 016cfe6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
31 changes: 15 additions & 16 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::B
# _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)

# Take II, using split mode.
forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
tape, result, shadow_result = forward(Const(f), args...)
reverse(Const(f), args..., _sensitivity(result), tape)
# forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
# tape, result, shadow_result = forward(Const(f), args...)
# reverse(Const(f), args..., _sensitivity(result), tape)

# Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
# Some cases work, but Flux.withgradient(m -> m(3), Duplicated(model)) does not.
# dup_loss = DuplicatedNoNeed(Ref(0.0), Ref(1.0))
dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
# result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
_, result = autodiff(ReverseWithPrimal, Const(_ref_loss!), Const, dup_loss, Const(f), args...)

(; val = result, grad = map(_grad_or_nothing, args))
end
Expand All @@ -87,18 +87,17 @@ end
_sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
or else a Tuple or NamedTuple whose first element is a real number.""")

# function _ref_loss!(out::Ref, f, args...) # for Take III above
# val = f(args...)
# @show val
# out[] = _get_loss(val) # saves loss by mutation
# val # returns the whole thing
# end
function _ref_loss!(out::Ref, f, args...) # for Take III above
val = f(args...)
out[] = _get_loss(val) # saves loss by mutation
val # returns the whole thing
end

# @inline _get_loss(y::Real) = y
# @inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1]
# @inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1]
# _get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
# or else a Tuple or NamedTuple whose first element is a real number.""")
@inline _get_loss(y::Real) = y
@inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1]
@inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1]
_get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
or else a Tuple or NamedTuple whose first element is a real number.""")

### Flux.Train, for train!

Expand Down
4 changes: 2 additions & 2 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,6 @@ end
@test Flux.gradient(|>, z, Duplicated(sum LayerNorm(3)))[1] [0.0, 0.0, 0.0]
@test Flux.gradient(|>, z, Const(sum LayerNorm(3)))[2] === nothing

@test_broken Flux.withgradient(sum LayerNorm(3), z).grad[1] [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
@test_broken Flux.withgradient(|>, z, Duplicated(sum LayerNorm(3))).grad[1] [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
@test Flux.withgradient(sum LayerNorm(3), z).grad[1] [0.0, 0.0, 0.0]
@test Flux.withgradient(|>, z, Duplicated(sum LayerNorm(3))).grad[1] [0.0, 0.0, 0.0]
end

0 comments on commit 016cfe6

Please sign in to comment.