Skip to content

Commit

Permalink
WIP: add Enzyme support for fastpow
Browse files Browse the repository at this point in the history
Straightforward since fastpow is simply ^. Still needs:

- [ ] Tests
- [ ] Generalize to batchduplicated
  • Loading branch information
ChrisRackauckas committed Aug 17, 2024
1 parent 84cbb9d commit 8c7b8f1
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqBaseEnzymeExt

using DiffEqBase
import DiffEqBase: value
import DiffEqBase: value, fastpow
using Enzyme
import Enzyme: Const
using ChainRulesCore
Expand Down Expand Up @@ -53,4 +53,38 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
return ntuple(_ -> nothing, Val(length(args) + 4))
end

function EnzymeRules.forward(func::Const{typeof(fastpow)},
RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,
BatchDuplicated,BatchDuplicatedNoNeed}},
_x::Annotation, _y::Annotation)
x = _x.val
y = _y.val
ret = func.val(x.val, y.val)
dxval = x.dval * y * (fastpow(x,y - 1))
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
return Duplicated(ret, dxval + dyval)
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1},
func::Const{typeof(fastpow)},
::Type{<:Active},
x::Active, x::Active)
if EnzymeRules.needs_primal(config)
primal = func.val(x.val, y.val)
else
primal = nothing
end
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},
func::Const{DiffEqBase.fastpow}, dret, tape::Nothing,
_x, _y)
x = _x.val
y = _y.val
dxval = x.dval * y * (fastpow(x,y - 1))
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
return (dxval, dyval)
end

end

0 comments on commit 8c7b8f1

Please sign in to comment.