Skip to content

Commit

Permalink
fix float(::Dual) and add float(::Type{<:Dual}) (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj authored Jul 28, 2021
1 parent 54fe5d0 commit 102ee4d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero
Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
Base.convert(::Type{D}, d::D) where {D<:Dual} = d

Base.float(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d)
Base.AbstractFloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d)
Base.float(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,float(V),N}
Base.float(d::Dual) = convert(float(typeof(d)), d)

###################################
# General Mathematical Operations #
Expand Down
9 changes: 8 additions & 1 deletion test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ end
@test length(UnitRange(Dual(1.5), Dual(3.5))) == 3
@test length(UnitRange(Dual(1.5,1), Dual(3.5,3))) == 3
end

if VERSION >= v"1.6.0-rc1"
@testset "@printf" begin
for T in (Float16, Float32, Float64, BigFloat)
Expand All @@ -528,4 +528,11 @@ if VERSION >= v"1.6.0-rc1"
end
end

@testset "float" begin # issue #492
@test float(Dual{Nothing, Int, 2}) === Dual{Nothing, Float64, 2}
@test float(Dual(1)) isa Dual{Nothing, Float64, 0}
@test value.(float.(Dual.(1:4, 2:5, 3:6))) isa Vector{Float64}
@test ForwardDiff.derivative(float, 1)::Float64 === 1.0
end

end # module

0 comments on commit 102ee4d

Please sign in to comment.