Skip to content

Commit

Permalink
Merge pull request #1039 from AayushSabharwal/as/promote-u0-tspan
Browse files Browse the repository at this point in the history
fix: consider type of `t0` in `promote_u0`
  • Loading branch information
ChrisRackauckas authored Jun 6, 2024
2 parents 1559b16 + 114ad0b commit bc50a2b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ DataStructures = "0.18"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnumX = "1"
Enzyme = "0.12"
Enzyme = "0.11.9, 0.12"
EnzymeCore = "0.5, 0.6, 0.7"
FastBroadcast = "0.2, 0.3"
FastClosures = "0.3.2"
Expand Down
49 changes: 34 additions & 15 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,45 @@ end
@inline promote_u0(::Nothing, p, t0) = nothing

@inline function promote_u0(u0, p, t0)
if !(eltype(u0) <: ForwardDiff.Dual)
T = anyeltypedual(p)
T === Any && return u0
if T <: ForwardDiff.Dual
return T.(u0)
end
Tu = eltype(u0)
if Tu <: ForwardDiff.Dual
return u0
end
Tp = anyeltypedual(p)
if Tp == Any
Tp = Tu
end
Tt = anyeltypedual(t0)
if Tt == Any
Tt = Tu
end
Tcommon = promote_type(Tu, Tp, Tt)
return if Tcommon <: ForwardDiff.Dual
Tcommon.(u0)
else
u0
end
u0
end

@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
if !(real(eltype(u0)) <: ForwardDiff.Dual)
T = anyeltypedual(p)
T === Any && return u0
if T <: ForwardDiff.Dual
Ts = promote_type(T, eltype(u0))
return Ts.(u0)
end
Tu = real(eltype(u0))
if Tu <: ForwardDiff.Dual
return u0
end
Tp = anyeltypedual(p)
if Tp == Any
Tp = Tu
end
Tt = anyeltypedual(t0)
if Tt == Any
Tt = Tu
end
Tcommon = promote_type(eltype(u0), Tp, Tt)
return if real(Tcommon) <: ForwardDiff.Dual
Tcommon.(u0)
else
u0
end
u0
end

function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p,
Expand Down
7 changes: 0 additions & 7 deletions test/downstream/unwrapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ prob = ODEProblem(f, [x], tspan)
integ = init(prob, Tsit5(), dt = 0.1)
@test integ.f.f === f

tspan = (ForwardDiff.Dual(0.0, (0.01)), ForwardDiff.Dual(1.0, (0.01)))
prob = ODEProblem(f, [x], tspan)

# Should not error during problem construction but should be unwrapped
integ = init(prob, Tsit5(), dt = 0.1)
@test integ.f.f === f

# Handle functional initial conditions
prob = ODEProblem((dx, x, p, t) -> (dx .= 0), (p, t) -> zeros(2), (0, 10))
solve(prob, TRBDF2())
8 changes: 8 additions & 0 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,11 @@ ow = OutsideWrapper(1.0, iw)
@test !(DiffEqBase.anyeltypedual(ow) <: ForwardDiff.Dual)
@inferred DiffEqBase.anyeltypedual(iw)
@inferred DiffEqBase.anyeltypedual(ow)

# Issue https://github.com/SciML/ModelingToolkit.jl/issues/2717
u0 = [1.0, 2.0, 3.0]
p = [1, 2]
t = ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}(1.0)
@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:ForwardDiff.Dual}
u0 = [1.0 + 1im, 2.0, 3.0]
@test DiffEqBase.promote_u0(u0, p, t) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}

0 comments on commit bc50a2b

Please sign in to comment.