Skip to content

Commit

Permalink
Merge pull request #549 from JuliaDiff/kc/arit
Browse files Browse the repository at this point in the history
define ArithmeticStyle for Dual
  • Loading branch information
ChrisRackauckas authored Oct 3, 2021
2 parents bb85ea2 + ff251a0 commit 49bdd7a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ struct Dual{T,V,N} <: Real
end
end

##########
# Traits #
##########
Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V)

##############
# Exceptions #
##############
Expand Down
7 changes: 7 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,11 @@ end
@test_throws DimensionMismatch ForwardDiff.gradient(identity, fill(2pi, 10^6)) # chunk_mode_gradient
end

@testset "ArithmeticStyle" begin
function f(p)
sum(collect(0.0:p[1]:p[2]))
end
@test ForwardDiff.gradient(f, [0.2,25.0]) == [7875.0, 0.0]
end

end # module

0 comments on commit 49bdd7a

Please sign in to comment.