diff --git a/src/solution_sampling.jl b/src/solution_sampling.jl index aad2d2f11..c3fea4ab0 100644 --- a/src/solution_sampling.jl +++ b/src/solution_sampling.jl @@ -20,7 +20,8 @@ end function sample(sol::ProbODESolution, n::Int=1) @unpack d, q = sol.cache sample_path = sample_states(sol, n) - return sample_path[:, 1:q+1:d*(q+1), :] + ys = stack(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3))) + return ys end function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1) @assert length(diffusions) + 1 == length(difftimes) @@ -81,5 +82,6 @@ end function dense_sample(sol::ProbODESolution, n::Int=1; density=1000) samples, times = dense_sample_states(sol, n; density=density) @unpack d, q = sol.cache - return samples[:, 1:q+1:d*(q+1), :], times + ys = stack(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3))) + return ys, times end