Skip to content

Commit

Permalink
Fix the solution sampling code (#317)
Browse files Browse the repository at this point in the history
* Fix the solution sampling code

* Add a test which should catch such issues with sampling

* Run the sampling tests for both solvers
  • Loading branch information
nathanaelbosch authored Jul 12, 2024
1 parent be22a55 commit c526f7c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
6 changes: 4 additions & 2 deletions src/solution_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ end
function sample(sol::ProbODESolution, n::Int=1)
@unpack d, q = sol.cache
sample_path = sample_states(sol, n)
return sample_path[:, 1:q+1:d*(q+1), :]
ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3))...; dims=3)
return ys
end
function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1)
@assert length(diffusions) + 1 == length(difftimes)
Expand Down Expand Up @@ -81,5 +82,6 @@ end
function dense_sample(sol::ProbODESolution, n::Int=1; density=1000)
samples, times = dense_sample_states(sol, n; density=density)
@unpack d, q = sol.cache
return samples[:, 1:q+1:d*(q+1), :], times
ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3))...; dims=3)
return ys, times
end
74 changes: 48 additions & 26 deletions test/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,47 +85,69 @@ using ODEProblemLibrary: prob_ode_lotkavolterra

# Sampling
@testset "Solution Sampling" begin
if Alg == EK1
n_samples = 2
@testset "Discrete" begin
n_samples = 10

samples = ProbNumDiffEq.sample(sol, n_samples)

@test samples isa Array

m, n, o = size(samples)
@test m == length(sol)
@test_skip n == length(sol.u[1])
@test n == length(sol.u[1])
@test o == n_samples

# Dense sampling
dense_samples, dense_times = ProbNumDiffEq.dense_sample(sol, n_samples)
m, n, o = size(dense_samples)
@test m == length(dense_times)
@test_skip n == length(sol.u[1])
@test o == n_samples
us, es = stack(sol.u), stack(std.(sol.pu))
for (interval_width, (low, high)) in (
(1, (0.5, 0.8)),
(2, (0.8, 0.99)),
(3, (0.95, 1)),
(4, (0.99, 1)),
)
percent_in_interval =
sum(
(
sum(
abs.(us .- samples[:, :, i]') .<=
interval_width * es,
)
for i in 1:n_samples
)
) / (m * n * o)
@test low <= percent_in_interval <= high
end
end
end

@testset "Sampling states from the solution" begin
if Alg == EK1
n_samples = 2

samples = ProbNumDiffEq.sample_states(sol, n_samples)

@test samples isa Array

m, n, o = size(samples)
@test m == length(sol)
@test_skip n == length(sol.u[1]) * (sol.cache.q + 1)
@test o == n_samples

# Dense sampling
@testset "Dense" begin
n_samples = 10
dense_samples, dense_times =
ProbNumDiffEq.dense_sample_states(sol, n_samples)
ProbNumDiffEq.dense_sample(sol, n_samples)

m, n, o = size(dense_samples)
@test m == length(dense_times)
@test_skip n == length(sol.u[1]) * (sol.cache.q + 1)
@test n == length(sol.u[1])
@test o == n_samples

pu = sol(dense_times).u
us, es = stack(mean.(pu)), stack(std.(pu))
for (interval_width, (low, high)) in (
(1, (0.5, 0.8)),
(2, (0.8, 0.99)),
(3, (0.95, 1)),
(4, (0.99, 1)),
)
percent_in_interval =
sum(
(
sum(
abs.(us .- dense_samples[:, :, i]') .<=
interval_width * es,
)
for i in 1:n_samples
)
) / (m * n * o)
@test_broken low <= percent_in_interval <= high
end
end
end

Expand Down

0 comments on commit c526f7c

Please sign in to comment.