From 17ad4fcf61de6fe0fa8e5e1a095ba865ac4d25da Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Wed, 10 Jul 2024 14:12:01 +0200 Subject: [PATCH] Fix the solution sampling code --- src/solution_sampling.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/solution_sampling.jl b/src/solution_sampling.jl index aad2d2f11..c3fea4ab0 100644 --- a/src/solution_sampling.jl +++ b/src/solution_sampling.jl @@ -20,7 +20,8 @@ end function sample(sol::ProbODESolution, n::Int=1) @unpack d, q = sol.cache sample_path = sample_states(sol, n) - return sample_path[:, 1:q+1:d*(q+1), :] + ys = stack(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3))) + return ys end function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1) @assert length(diffusions) + 1 == length(difftimes) @@ -81,5 +82,6 @@ end function dense_sample(sol::ProbODESolution, n::Int=1; density=1000) samples, times = dense_sample_states(sol, n; density=density) @unpack d, q = sol.cache - return samples[:, 1:q+1:d*(q+1), :], times + ys = stack(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3))) + return ys, times end