diff --git a/src/dual.jl b/src/dual.jl index dd5758d4..421c1c10 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -380,8 +380,8 @@ Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V})) Base.convert(::Type{D}, d::D) where {D<:Dual} = d -Base.float(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d) -Base.AbstractFloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d) +Base.float(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,float(V),N} +Base.float(d::Dual) = convert(float(typeof(d)), d) ################################### # General Mathematical Operations # diff --git a/test/DualTest.jl b/test/DualTest.jl index 285fe3ab..637b7236 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -517,7 +517,7 @@ end @test length(UnitRange(Dual(1.5), Dual(3.5))) == 3 @test length(UnitRange(Dual(1.5,1), Dual(3.5,3))) == 3 end - + if VERSION >= v"1.6.0-rc1" @testset "@printf" begin for T in (Float16, Float32, Float64, BigFloat) @@ -528,4 +528,11 @@ if VERSION >= v"1.6.0-rc1" end end +@testset "float" begin # issue #492 + @test float(Dual{Nothing, Int, 2}) === Dual{Nothing, Float64, 2} + @test float(Dual(1)) isa Dual{Nothing, Float64, 0} + @test value.(float.(Dual.(1:4, 2:5, 3:6))) isa Vector{Float64} + @test ForwardDiff.derivative(float, 1)::Float64 === 1.0 +end + end # module