Skip to content

Commit

Permalink
wip: rules are hit
Browse files Browse the repository at this point in the history
  • Loading branch information
m-bossart committed Aug 24, 2024
1 parent 7c67680 commit 530d113
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,31 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
return ntuple(_ -> nothing, Val(length(args) + 4))
end

function EnzymeRules.forward(func::Const{typeof(fastpow)},
RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,
BatchDuplicated,BatchDuplicatedNoNeed}},
_x::Annotation, _y::Annotation)
function Enzyme.EnzymeRules.forward(func::Const{typeof(DiffEqBase.fastpow)},
RT::Type{<:Union{Duplicated, DuplicatedNoNeed}},
_x::Union{Const, Duplicated}, _y::Union{Const, Duplicated})
x = _x.val
y = _y.val
ret = func.val(x.val, y.val)
dxval = x.dval * y * (fastpow(x,y - 1))
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
return Duplicated(ret, dxval + dyval)
ret = func.val(x, y)
if !(_x isa Const)
dxval = _x.dval * y * (fastpow(x,y - 1))
else
dxval = make_zero(_x.val)
end
if !(_y isa Const)
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : _y.dval*(fastpow(x,y))*log(x)
else
dyval = make_zero(_y.val)
end
if RT <: DuplicatedNoNeed
return Float32(dxval + dyval)
else
return Duplicated(ret, Float32(dxval + dyval))
end
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1},
func::Const{typeof(fastpow)},
::Type{<:Active},
x::Active, x::Active)
function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1},
func::Const{typeof(fastpow)}, ::Type{<:Active}, x::Active, y::Active)
if EnzymeRules.needs_primal(config)
primal = func.val(x.val, y.val)
else
Expand All @@ -77,14 +86,13 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1},
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},
func::Const{DiffEqBase.fastpow}, dret, tape::Nothing,
_x, _y)
function EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
func::Const{typeof(DiffEqBase.fastpow)}, dret::Active, tape, _x::Active, _y::Active)
x = _x.val
y = _y.val
dxval = x.dval * y * (fastpow(x,y - 1))
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
dxval = y * (fastpow(x,y - 1))
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : (fastpow(x,y))*log(x)
return (dxval, dyval)
end

end
end

0 comments on commit 530d113

Please sign in to comment.