diff --git a/Project.toml b/Project.toml index e8679fbc8..1c1402c7c 100644 --- a/Project.toml +++ b/Project.toml @@ -70,7 +70,7 @@ DataStructures = "0.18" Distributions = "0.25" DocStringExtensions = "0.9" EnumX = "1" -Enzyme = "0.12" +Enzyme = "0.11.9, 0.12" EnzymeCore = "0.5, 0.6, 0.7" FastBroadcast = "0.2, 0.3" FastClosures = "0.3.2" diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 26631c30f..5b3e85d7c 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -265,26 +265,45 @@ end @inline promote_u0(::Nothing, p, t0) = nothing @inline function promote_u0(u0, p, t0) - if !(eltype(u0) <: ForwardDiff.Dual) - T = anyeltypedual(p) - T === Any && return u0 - if T <: ForwardDiff.Dual - return T.(u0) - end + Tu = eltype(u0) + if Tu <: ForwardDiff.Dual + return u0 + end + Tp = anyeltypedual(p) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(Tu, Tp, Tt) + return if Tcommon <: ForwardDiff.Dual + Tcommon.(u0) + else + u0 end - u0 end @inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0) - if !(real(eltype(u0)) <: ForwardDiff.Dual) - T = anyeltypedual(p) - T === Any && return u0 - if T <: ForwardDiff.Dual - Ts = promote_type(T, eltype(u0)) - return Ts.(u0) - end + Tu = real(eltype(u0)) + if Tu <: ForwardDiff.Dual + return u0 + end + Tp = anyeltypedual(p) + if Tp == Any + Tp = Tu + end + Tt = anyeltypedual(t0) + if Tt == Any + Tt = Tu + end + Tcommon = promote_type(eltype(u0), Tp, Tt) + return if real(Tcommon) <: ForwardDiff.Dual + Tcommon.(u0) + else + u0 end - u0 end function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, diff --git a/test/downstream/unwrapping.jl b/test/downstream/unwrapping.jl index 2ff69ff0a..eed2f060a 100644 --- a/test/downstream/unwrapping.jl +++ b/test/downstream/unwrapping.jl @@ -19,13 +19,6 @@ prob = ODEProblem(f, [x], tspan) integ = init(prob, Tsit5(), dt = 0.1) @test integ.f.f === f -tspan = (ForwardDiff.Dual(0.0, (0.01)), ForwardDiff.Dual(1.0, (0.01))) -prob = ODEProblem(f, [x], tspan) - -# Should not error during problem construction but should be unwrapped -integ = init(prob, Tsit5(), dt = 0.1) -@test integ.f.f === f - # Handle functional initial conditions prob = ODEProblem((dx, x, p, t) -> (dx .= 0), (p, t) -> zeros(2), (0, 10)) solve(prob, TRBDF2()) diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index fd3517af5..2a244864f 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -320,3 +320,11 @@ ow = OutsideWrapper(1.0, iw) @test !(DiffEqBase.anyeltypedual(ow) <: ForwardDiff.Dual) @inferred DiffEqBase.anyeltypedual(iw) @inferred DiffEqBase.anyeltypedual(ow) + +# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2717 +u0 = [1.0, 2.0, 3.0] +p = [1, 2] +t = ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}(1.0) +@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:ForwardDiff.Dual} +u0 = [1.0 + 1im, 2.0, 3.0] +@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}