diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index 37556a9309..0833c3b867 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -69,14 +69,14 @@ function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::B # _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) # Take II, using split mode. - forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) - tape, result, shadow_result = forward(Const(f), args...) - reverse(Const(f), args..., _sensitivity(result), tape) + # forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) + # tape, result, shadow_result = forward(Const(f), args...) + # reverse(Const(f), args..., _sensitivity(result), tape) # Take III, it may be more efficient to have the function write the loss into Ref(0.0)? - # Some cases work, but Flux.withgradient(m -> m(3), Duplicated(model)) does not. - # dup_loss = DuplicatedNoNeed(Ref(0.0), Ref(1.0)) + dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0)) # result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...) + _, result = autodiff(ReverseWithPrimal, Const(_ref_loss!), Const, dup_loss, Const(f), args...) (; val = result, grad = map(_grad_or_nothing, args)) end @@ -87,18 +87,17 @@ end _sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, or else a Tuple or NamedTuple whose first element is a real number.""") -# function _ref_loss!(out::Ref, f, args...) # for Take III above -# val = f(args...) -# @show val -# out[] = _get_loss(val) # saves loss by mutation -# val # returns the whole thing -# end +function _ref_loss!(out::Ref, f, args...) # for Take III above + val = f(args...) + out[] = _get_loss(val) # saves loss by mutation + val # returns the whole thing +end -# @inline _get_loss(y::Real) = y -# @inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1] -# @inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1] -# _get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, -# or else a Tuple or NamedTuple whose first element is a real number.""") +@inline _get_loss(y::Real) = y +@inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1] +@inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1] +_get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, + or else a Tuple or NamedTuple whose first element is a real number.""") ### Flux.Train, for train! diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 775cc41d68..2f0a6f47bf 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -216,6 +216,6 @@ end @test Flux.gradient(|>, z, Duplicated(sum ∘ LayerNorm(3)))[1] ≈ [0.0, 0.0, 0.0] @test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing - @test_broken Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any} - @test_broken Flux.withgradient(|>, z, Duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any} + @test Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] + @test Flux.withgradient(|>, z, Duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0] end