From d159171c3b86912d8af3472dbc7c7a1f65eac6c4 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Wed, 10 Jul 2024 14:12:01 +0200 Subject: [PATCH] Fix the solution sampling code --- src/solution_sampling.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/solution_sampling.jl b/src/solution_sampling.jl index aad2d2f11..38db090fa 100644 --- a/src/solution_sampling.jl +++ b/src/solution_sampling.jl @@ -20,7 +20,8 @@ end function sample(sol::ProbODESolution, n::Int=1) @unpack d, q = sol.cache sample_path = sample_states(sol, n) - return sample_path[:, 1:q+1:d*(q+1), :] + ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3))...; dims=3) + return ys end function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1) @assert length(diffusions) + 1 == length(difftimes) @@ -81,5 +82,6 @@ end function dense_sample(sol::ProbODESolution, n::Int=1; density=1000) samples, times = dense_sample_states(sol, n; density=density) @unpack d, q = sol.cache - return samples[:, 1:q+1:d*(q+1), :], times + ys = cat(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3))...; dims=3) + return ys, times end