Skip to content

Commit

Permalink
Add the mean solution tests again
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 3, 2024
1 parent 0dfb2ea commit b489594
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
18 changes: 15 additions & 3 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,21 @@ function mean(sol::ProbODESolution{T,N}) where {T,N}
sol.cache, sol.dense, sol.tslocation, sol.stats, sol.retcode, sol,
)
end
(sol::MeanProbODESolution)(t::Real, args...) = mean(sol.probsol(t, args...))
(sol::MeanProbODESolution)(t::AbstractVector, args...) =
DiffEqArray(sol.probsol(t, args...).u.μ, t)

function (sol::MeanProbODESolution)(
t::Number, ::Type{deriv}=Val{0}; idxs=nothing, continuity=:left) where {deriv}
return mean(sol.probsol(t, deriv; idxs, continuity))
end
function (sol::MeanProbODESolution)(
t::AbstractArray{<:Number}, ::Type{deriv}=Val{0}; idxs=nothing, continuity=:left,
) where {deriv}
return DiffEqArray(mean.(sol.probsol(t, deriv; idxs, continuity).u), t)
end
function (sol::MeanProbODESolution)(
v, t, ::Type{deriv}=Val{0}; idxs=nothing, continuity=:left) where {deriv}
return mean(sol.probsol(v, t, deriv; idxs, continuity))
end

DiffEqBase.calculate_solution_errors!(sol::ProbODESolution, args...; kwargs...) =
DiffEqBase.calculate_solution_errors!(mean(sol), args...; kwargs...)

Expand Down
11 changes: 7 additions & 4 deletions test/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ using ODEProblemLibrary: prob_ode_lotkavolterra

@testset "Mean Solution" begin
msol = mean(sol)
# @test_nowarn msol(prob.tspan[1])
# @test_nowarn msol(sol.t[1:2])
# @test_nowarn msol
# @test_nowarn plot(msol)
x = @test_nowarn msol(prob.tspan[1])
@test x isa AbstractArray
xs = @test_nowarn msol(sol.t[1:2])
@test xs isa ProbNumDiffEq.DiffEqArray
@test xs.u isa AbstractArray{<:AbstractArray}
@test_nowarn msol
@test_nowarn plot(msol)
end
end
end
Expand Down

0 comments on commit b489594

Please sign in to comment.