diff --git a/src/solution_sampling.jl b/src/solution_sampling.jl index aad2d2f11..38db090fa 100644 --- a/src/solution_sampling.jl +++ b/src/solution_sampling.jl @@ -20,7 +20,8 @@ end function sample(sol::ProbODESolution, n::Int=1) @unpack d, q = sol.cache sample_path = sample_states(sol, n) - return sample_path[:, 1:q+1:d*(q+1), :] + ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3))...; dims=3) + return ys end function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1) @assert length(diffusions) + 1 == length(difftimes) @@ -81,5 +82,6 @@ end function dense_sample(sol::ProbODESolution, n::Int=1; density=1000) samples, times = dense_sample_states(sol, n; density=density) @unpack d, q = sol.cache - return samples[:, 1:q+1:d*(q+1), :], times + ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3))...; dims=3) + return ys, times end diff --git a/test/solution.jl b/test/solution.jl index d389cc206..a740dae05 100644 --- a/test/solution.jl +++ b/test/solution.jl @@ -85,8 +85,8 @@ using ODEProblemLibrary: prob_ode_lotkavolterra # Sampling @testset "Solution Sampling" begin - if Alg == EK1 - n_samples = 2 + @testset "Discrete" begin + n_samples = 10 samples = ProbNumDiffEq.sample(sol, n_samples) @@ -94,38 +94,60 @@ using ODEProblemLibrary: prob_ode_lotkavolterra m, n, o = size(samples) @test m == length(sol) - @test_skip n == length(sol.u[1]) + @test n == length(sol.u[1]) @test o == n_samples - # Dense sampling - dense_samples, dense_times = ProbNumDiffEq.dense_sample(sol, n_samples) - m, n, o = size(dense_samples) - @test m == length(dense_times) - @test_skip n == length(sol.u[1]) - @test o == n_samples + us, es = stack(sol.u), stack(std.(sol.pu)) + for (interval_width, (low, high)) in ( + (1, (0.5, 0.8)), + (2, (0.8, 0.99)), + (3, (0.95, 1)), + (4, (0.99, 1)), + ) + percent_in_interval = + sum( + ( + sum( + abs.(us .- samples[:, :, i]') .<= + interval_width * es, + ) + for i in 1:n_samples + ) + ) / (m * n * o) + @test low <= percent_in_interval <= high + end end - end - - @testset "Sampling states from the solution" begin - if Alg == EK1 - n_samples = 2 - - samples = ProbNumDiffEq.sample_states(sol, n_samples) - - @test samples isa Array - m, n, o = size(samples) - @test m == length(sol) - @test_skip n == length(sol.u[1]) * (sol.cache.q + 1) - @test o == n_samples - - # Dense sampling + @testset "Dense" begin + n_samples = 10 dense_samples, dense_times = - ProbNumDiffEq.dense_sample_states(sol, n_samples) + ProbNumDiffEq.dense_sample(sol, n_samples) + m, n, o = size(dense_samples) @test m == length(dense_times) - @test_skip n == length(sol.u[1]) * (sol.cache.q + 1) + @test n == length(sol.u[1]) @test o == n_samples + + pu = sol(dense_times).u + us, es = stack(mean.(pu)), stack(std.(pu)) + for (interval_width, (low, high)) in ( + (1, (0.5, 0.8)), + (2, (0.8, 0.99)), + (3, (0.95, 1)), + (4, (0.99, 1)), + ) + percent_in_interval = + sum( + ( + sum( + abs.(us .- dense_samples[:, :, i]') .<= + interval_width * es, + ) + for i in 1:n_samples + ) + ) / (m * n * o) + @test_broken low <= percent_in_interval <= high + end end end