From ff251a017c5279cc8d7c139c0757238561c9a7aa Mon Sep 17 00:00:00 2001 From: Kristoffer Date: Sun, 3 Oct 2021 12:12:25 +0200 Subject: [PATCH] define ArithmeticStyle for Dual --- src/dual.jl | 5 +++++ test/GradientTest.jl | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/dual.jl b/src/dual.jl index 580b5197..9c9b0d82 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -20,6 +20,11 @@ struct Dual{T,V,N} <: Real end end +########## +# Traits # +########## +Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V) + ############## # Exceptions # ############## diff --git a/test/GradientTest.jl b/test/GradientTest.jl index ac949944..707fcd68 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -162,4 +162,11 @@ end @test_throws DimensionMismatch ForwardDiff.gradient(identity, fill(2pi, 10^6)) # chunk_mode_gradient end +@testset "ArithmeticStyle" begin + function f(p) + sum(collect(0.0:p[1]:p[2])) + end + @test ForwardDiff.gradient(f, [0.2,25.0]) == [7875.0, 0.0] +end + end # module