From 530d113f69e6f8b69865607e2951f798e8c908eb Mon Sep 17 00:00:00 2001 From: Matt Bossart Date: Sat, 24 Aug 2024 16:16:48 -0600 Subject: [PATCH] wip: rules are hit --- ext/DiffEqBaseEnzymeExt.jl | 44 ++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index fc9d2041b..43e6030c3 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -53,22 +53,31 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, return ntuple(_ -> nothing, Val(length(args) + 4)) end -function EnzymeRules.forward(func::Const{typeof(fastpow)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicated,BatchDuplicatedNoNeed}}, - _x::Annotation, _y::Annotation) +function Enzyme.EnzymeRules.forward(func::Const{typeof(DiffEqBase.fastpow)}, + RT::Type{<:Union{Duplicated, DuplicatedNoNeed}}, + _x::Union{Const, Duplicated}, _y::Union{Const, Duplicated}) x = _x.val y = _y.val - ret = func.val(x.val, y.val) - dxval = x.dval * y * (fastpow(x,y - 1)) - dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) - return Duplicated(ret, dxval + dyval) + ret = func.val(x, y) + if !(_x isa Const) + dxval = _x.dval * y * (fastpow(x,y - 1)) + else + dxval = make_zero(_x.val) + end + if !(_y isa Const) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : _y.dval*(fastpow(x,y))*log(x) + else + dyval = make_zero(_y.val) + end + if RT <: DuplicatedNoNeed + return Float32(dxval + dyval) + else + return Duplicated(ret, Float32(dxval + dyval)) + end end -function EnzymeRules.augmented_primal(config::ConfigWidth{1}, - func::Const{typeof(fastpow)}, - ::Type{<:Active}, - x::Active, x::Active) +function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(fastpow)}, ::Type{<:Active}, x::Active, y::Active) if EnzymeRules.needs_primal(config) primal = func.val(x.val, y.val) else @@ -77,14 +86,13 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, - func::Const{DiffEqBase.fastpow}, dret, tape::Nothing, - _x, _y) +function EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(DiffEqBase.fastpow)}, dret::Active, tape, _x::Active, _y::Active) x = _x.val y = _y.val - dxval = x.dval * y * (fastpow(x,y - 1)) - dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) + dxval = y * (fastpow(x,y - 1)) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : (fastpow(x,y))*log(x) return (dxval, dyval) end -end +end \ No newline at end of file