diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index aaa7846d2..ae3f92e8b 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -1,7 +1,7 @@ module DiffEqBaseEnzymeExt using DiffEqBase -import DiffEqBase: value +import DiffEqBase: value, fastpow using Enzyme import Enzyme: Const using ChainRulesCore @@ -53,4 +53,6 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1} return ntuple(_ -> nothing, Val(length(args) + 4)) end -end +Enzyme.Compiler.known_ops[typeof(DiffEqBase.fastpow)] = (:pow, 2, nothing) + +end \ No newline at end of file diff --git a/src/fastpow.jl b/src/fastpow.jl index b66ee2ae2..36fdb53f4 100644 --- a/src/fastpow.jl +++ b/src/fastpow.jl @@ -51,60 +51,20 @@ const EXP2FT = (Float32(0x1.6a09e667f3bcdp-1), Float32(0x1.3dea64c123422p+0), Float32(0x1.4bfdad5362a27p+0), Float32(0x1.5ab07dd485429p+0)) -@inline function _exp2(x::Float32) - TBLBITS = UInt32(4) - TBLSIZE = UInt32(1 << TBLBITS) - redux = Float32(0x1.8p23) / TBLSIZE - P1 = Float32(0x1.62e430p-1) - P2 = Float32(0x1.ebfbe0p-3) - P3 = Float32(0x1.c6b348p-5) - P4 = Float32(0x1.3b2c9cp-7) - - # Reduce x, computing z, i0, and k. - t::Float32 = x + redux - i0 = reinterpret(UInt32, t) - i0 += TBLSIZE ÷ UInt32(2) - k::UInt32 = unsafe_trunc(UInt32, (i0 >> TBLBITS) << 20) - i0 &= TBLSIZE - UInt32(1) - t -= redux - z = x - t - twopk = Float32(reinterpret(Float64, UInt64(0x3ff00000 + k) << 32)) - - # Compute r = exp2(y) = exp2ft[i0] * p(z). - tv = EXP2FT[i0 + UInt32(1)] - u = tv * z - tv = tv + u * (P1 + z * P2) + u * (z * z) * (P3 + z * P4) - - # Scale by 2**(k>>20) - return tv * twopk -end - -if VERSION < v"1.7.0" - """ - fastpow(x::Real, y::Real) -> Float32 - """ - @inline function fastpow(x::Real, y::Real) - if iszero(x) - return 0.0f0 - elseif isinf(x) && isinf(y) - return Float32(Inf) - else - return _exp2(convert(Float32, y) * fastlog2(convert(Float32, x))) - end - end -else - """ - fastpow(x::Real, y::Real) -> Float32 - """ - @inline function fastpow(x::Real, y::Real) - if iszero(x) - return 0.0f0 - elseif isinf(x) && isinf(y) - return Float32(Inf) - else - return @fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))) - end +""" + fastpow(x::T, y::T) where {T} -> float(T) + Trips through Float32 for performance. +""" +@inline function fastpow(x::T, y::T) where {T} + outT = float(T) + if iszero(x) + return zero(outT) + elseif isinf(x) && isinf(y) + return convert(outT,Inf) + else + return convert(outT,@fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))) end end + @inline fastpow(x, y) = x^y diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index ef763267f..49c49e10d 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -3,6 +3,7 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" diff --git a/test/downstream/enzyme.jl b/test/downstream/enzyme.jl new file mode 100644 index 000000000..7c820945d --- /dev/null +++ b/test/downstream/enzyme.jl @@ -0,0 +1,23 @@ +using Enzyme, EnzymeTestUtils +using DiffEqBase: fastlog2, fastpow +using Test + +@testset "Fast pow - Enzyme forward rule" begin + @testset for RT in (Duplicated, DuplicatedNoNeed), + Tx in (Const, Duplicated), + Ty in (Const, Duplicated) + x = 3.0 + y = 2.0 + test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005) + end +end + +@testset "Fast pow - Enzyme reverse rule" begin + @testset for RT in (Active,), + Tx in (Active,), + Ty in (Active,) + x = 2.0 + y = 3.0 + test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001) + end +end \ No newline at end of file diff --git a/test/fastpow.jl b/test/fastpow.jl index 96b588aa1..0a39bc8e1 100644 --- a/test/fastpow.jl +++ b/test/fastpow.jl @@ -1,4 +1,4 @@ -using DiffEqBase: fastlog2, _exp2, fastpow +using DiffEqBase: fastlog2, fastpow using Test @testset "Fast log2" begin @@ -7,15 +7,9 @@ using Test end end -@testset "Exp2" begin - for x in -100:0.01:3 - @test exp2(x)≈_exp2(Float32(x)) atol=1e-6 - end -end - @testset "Fast pow" begin - @test fastpow(1, 1) isa Float32 - @test fastpow(1.0, 1.0) isa Float32 + @test fastpow(1, 1) isa Float64 + @test fastpow(1.0, 1.0) isa Float64 errors = [abs(^(x, y) - fastpow(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5] @test maximum(errors) < 1e-4 -end +end \ No newline at end of file