From d159171c3b86912d8af3472dbc7c7a1f65eac6c4 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Wed, 10 Jul 2024 14:12:01 +0200 Subject: [PATCH 1/3] Fix the solution sampling code --- src/solution_sampling.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/solution_sampling.jl b/src/solution_sampling.jl index aad2d2f11..38db090fa 100644 --- a/src/solution_sampling.jl +++ b/src/solution_sampling.jl @@ -20,7 +20,8 @@ end function sample(sol::ProbODESolution, n::Int=1) @unpack d, q = sol.cache sample_path = sample_states(sol, n) - return sample_path[:, 1:q+1:d*(q+1), :] + ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3))...; dims=3) + return ys end function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1) @assert length(diffusions) + 1 == length(difftimes) @@ -81,5 +82,6 @@ end function dense_sample(sol::ProbODESolution, n::Int=1; density=1000) samples, times = dense_sample_states(sol, n; density=density) @unpack d, q = sol.cache - return samples[:, 1:q+1:d*(q+1), :], times + ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3))...; dims=3) + return ys, times end From 497562c89ad6ff12a6fd17598bf6f2857191666b Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Fri, 12 Jul 2024 09:23:28 +0200 Subject: [PATCH 2/3] Add a test which should catch such issues with sampling --- test/solution.jl | 55 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/test/solution.jl b/test/solution.jl index d389cc206..0015bb05e 100644 --- a/test/solution.jl +++ b/test/solution.jl @@ -84,9 +84,9 @@ using ODEProblemLibrary: prob_ode_lotkavolterra end # Sampling - @testset "Solution Sampling" begin - if Alg == EK1 - n_samples = 2 + if Alg == EK1 @testset "Solution Sampling" begin + @testset "Discrete" begin + n_samples = 10 samples = ProbNumDiffEq.sample(sol, n_samples) @@ -94,17 +94,56 @@ using ODEProblemLibrary: prob_ode_lotkavolterra m, n, o = size(samples) @test m == length(sol) - @test_skip n == length(sol.u[1]) + @test n == length(sol.u[1]) @test o == n_samples - # Dense sampling + us, es = stack(sol.u), stack(std.(sol.pu)) + for (interval_width, (low, high)) in ( + (1, (0.5, 0.8)), + (2, (0.8, 0.99)), + (3, (0.95, 1)), + (4, (0.99, 1)), + ) + percent_in_interval = sum( + (sum(abs.(us .- samples[:, :, i]') .<= interval_width * es) + for i in 1:n_samples) + ) / (m*n*o) + @test low <= percent_in_interval <= high + end + end + + @testset "Dense" begin + n_samples = 10 dense_samples, dense_times = ProbNumDiffEq.dense_sample(sol, n_samples) + m, n, o = size(dense_samples) @test m == length(dense_times) - @test_skip n == length(sol.u[1]) + @test n == length(sol.u[1]) @test o == n_samples + + pu = sol(dense_times).u + us, es = stack(mean.(pu)), stack(std.(pu)) + for (interval_width, (low, high)) in ( + (1, (0.5, 0.8)), + (2, (0.8, 0.99)), + (3, (0.95, 1)), + (4, (0.99, 1)), + ) + percent_in_interval = + sum( + ( + sum( + abs.(us .- dense_samples[:, :, i]') .<= + interval_width * es, + ) + for i in 1:n_samples + ) + ) / (m * n * o) + @test_broken low <= percent_in_interval <= high + end end end + end @testset "Sampling states from the solution" begin if Alg == EK1 @@ -116,7 +155,7 @@ using ODEProblemLibrary: prob_ode_lotkavolterra m, n, o = size(samples) @test m == length(sol) - @test_skip n == length(sol.u[1]) * (sol.cache.q + 1) + @test n == sol.cache.d * (sol.cache.q + 1) @test o == n_samples # Dense sampling @@ -124,7 +163,7 @@ using ODEProblemLibrary: prob_ode_lotkavolterra ProbNumDiffEq.dense_sample_states(sol, n_samples) m, n, o = size(dense_samples) @test m == length(dense_times) - @test_skip n == length(sol.u[1]) * (sol.cache.q + 1) + @test n == sol.cache.d * (sol.cache.q + 1) @test o == n_samples end end From 39d389c54709c27cbda92a33773a523405bd0012 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Fri, 12 Jul 2024 09:36:37 +0200 Subject: [PATCH 3/3] Run the sampling tests for both solvers --- test/solution.jl | 57 +++++++++++++++++------------------------------- 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/test/solution.jl b/test/solution.jl index 0015bb05e..a740dae05 100644 --- a/test/solution.jl +++ b/test/solution.jl @@ -84,7 +84,7 @@ using ODEProblemLibrary: prob_ode_lotkavolterra end # Sampling - if Alg == EK1 @testset "Solution Sampling" begin + @testset "Solution Sampling" begin @testset "Discrete" begin n_samples = 10 @@ -103,18 +103,25 @@ using ODEProblemLibrary: prob_ode_lotkavolterra (2, (0.8, 0.99)), (3, (0.95, 1)), (4, (0.99, 1)), - ) - percent_in_interval = sum( - (sum(abs.(us .- samples[:, :, i]') .<= interval_width * es) - for i in 1:n_samples) - ) / (m*n*o) + ) + percent_in_interval = + sum( + ( + sum( + abs.(us .- samples[:, :, i]') .<= + interval_width * es, + ) + for i in 1:n_samples + ) + ) / (m * n * o) @test low <= percent_in_interval <= high end end @testset "Dense" begin n_samples = 10 - dense_samples, dense_times = ProbNumDiffEq.dense_sample(sol, n_samples) + dense_samples, dense_times = + ProbNumDiffEq.dense_sample(sol, n_samples) m, n, o = size(dense_samples) @test m == length(dense_times) @@ -132,41 +139,17 @@ using ODEProblemLibrary: prob_ode_lotkavolterra percent_in_interval = sum( ( - sum( - abs.(us .- dense_samples[:, :, i]') .<= - interval_width * es, - ) - for i in 1:n_samples - ) + sum( + abs.(us .- dense_samples[:, :, i]') .<= + interval_width * es, + ) + for i in 1:n_samples + ) ) / (m * n * o) @test_broken low <= percent_in_interval <= high end end end - end - - @testset "Sampling states from the solution" begin - if Alg == EK1 - n_samples = 2 - - samples = ProbNumDiffEq.sample_states(sol, n_samples) - - @test samples isa Array - - m, n, o = size(samples) - @test m == length(sol) - @test n == sol.cache.d * (sol.cache.q + 1) - @test o == n_samples - - # Dense sampling - dense_samples, dense_times = - ProbNumDiffEq.dense_sample_states(sol, n_samples) - m, n, o = size(dense_samples) - @test m == length(dense_times) - @test n == sol.cache.d * (sol.cache.q + 1) - @test o == n_samples - end - end @testset "Plotting" begin @test_nowarn plot(sol)