From 33583955d16c98e3be14055de238e5adc9b341a1 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 15 Jun 2024 13:58:33 -0400 Subject: [PATCH 1/3] Fix solutions nested in structs for forwarddiff Fully fixes https://github.com/SciML/DiffEqBase.jl/issues/1021 --- src/forwarddiff.jl | 6 ++++++ test/forwarddiff_dual_detection.jl | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 252246492..17bdaab0f 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -181,6 +181,12 @@ function anyeltypedual(x::Type{T}, ForwardDiff.AbstractConfig} Any end + +function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T,N}}, + ::Type{Val{counter}} = Val{0}) where {T, N, counter} + anyeltypedual(T) +end + function anyeltypedual(x::ForwardDiff.DiffResults.DiffResult, ::Type{Val{counter}} = Val{0}) where {counter} Any diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 132b24556..1313424a4 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -335,3 +335,13 @@ ode = ODEProblem(f, [0.0, 0.0], (0, 1)) @inferred DiffEqBase.anyeltypedual(ode) ode = NonlinearProblem(f, [0.0, 0.0], (0, 1)) @inferred DiffEqBase.anyeltypedual(ode) + +# Issue https://github.com/SciML/DiffEqBase.jl/issues/1021 +f(u, p, t) = 1.01*u +struct Foo{T}; sol::T; end +u0 = 1/2 +tspan = (0.0, 1.0) +prob = ODEProblem{false}(f, u0, tspan) +sol = solve(prob, Tsit5()) +foo = Foo(sol) +DiffEqBase.anyeltypedual((;x=foo)) From 862951db1caa8cac39cbcf885645f15a0d4e9621 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 15 Jun 2024 20:40:31 -0400 Subject: [PATCH 2/3] Update test/forwarddiff_dual_detection.jl --- test/forwarddiff_dual_detection.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 1313424a4..1bb712a69 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -342,6 +342,5 @@ struct Foo{T}; sol::T; end u0 = 1/2 tspan = (0.0, 1.0) prob = ODEProblem{false}(f, u0, tspan) -sol = solve(prob, Tsit5()) -foo = Foo(sol) +foo = SciMLBase.build_solution(prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0,u0], [0.0,1.0]) DiffEqBase.anyeltypedual((;x=foo)) From 4f0c45daa5cde60aef87b8785ef4de9793e10fa7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 15 Jun 2024 22:40:19 -0400 Subject: [PATCH 3/3] format --- src/forwarddiff.jl | 4 ++-- test/forwarddiff_dual_detection.jl | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 17bdaab0f..f63cce1d8 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -182,8 +182,8 @@ function anyeltypedual(x::Type{T}, Any end -function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T,N}}, - ::Type{Val{counter}} = Val{0}) where {T, N, counter} +function anyeltypedual(::Type{<:AbstractTimeseriesSolution{T, N}}, + ::Type{Val{counter}} = Val{0}) where {T, N, counter} anyeltypedual(T) end diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 1bb712a69..02bcd2c8b 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -337,10 +337,13 @@ ode = NonlinearProblem(f, [0.0, 0.0], (0, 1)) @inferred DiffEqBase.anyeltypedual(ode) # Issue https://github.com/SciML/DiffEqBase.jl/issues/1021 -f(u, p, t) = 1.01*u -struct Foo{T}; sol::T; end -u0 = 1/2 +f(u, p, t) = 1.01 * u +struct Foo{T} + sol::T +end +u0 = 1 / 2 tspan = (0.0, 1.0) prob = ODEProblem{false}(f, u0, tspan) -foo = SciMLBase.build_solution(prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0,u0], [0.0,1.0]) -DiffEqBase.anyeltypedual((;x=foo)) +foo = SciMLBase.build_solution( + prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0]) +DiffEqBase.anyeltypedual((; x = foo))