From 8c7b8f19744bb0d8aec7a7490a647cd9f759767a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 17 Aug 2024 11:34:12 -0400 Subject: [PATCH] WIP: add Enzyme support for fastpow Straightforward since fastpow is simply ^. Still needs: - [ ] Tests - [ ] Generalize to batchduplicated --- ext/DiffEqBaseEnzymeExt.jl | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index 2b2b7c001..fc9d2041b 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -1,7 +1,7 @@ module DiffEqBaseEnzymeExt using DiffEqBase -import DiffEqBase: value +import DiffEqBase: value, fastpow using Enzyme import Enzyme: Const using ChainRulesCore @@ -53,4 +53,38 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, return ntuple(_ -> nothing, Val(length(args) + 4)) end +function EnzymeRules.forward(func::Const{typeof(fastpow)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicated,BatchDuplicatedNoNeed}}, + _x::Annotation, _y::Annotation) + x = _x.val + y = _y.val + ret = func.val(x.val, y.val) + dxval = x.dval * y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) + return Duplicated(ret, dxval + dyval) +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, + func::Const{typeof(fastpow)}, + ::Type{<:Active}, + x::Active, x::Active) + if EnzymeRules.needs_primal(config) + primal = func.val(x.val, y.val) + else + primal = nothing + end + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, + func::Const{DiffEqBase.fastpow}, dret, tape::Nothing, + _x, _y) + x = _x.val + y = _y.val + dxval = x.dval * y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) + return (dxval, dyval) +end + end