Skip to content

Commit

Permalink
Fix the solution sampling code
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Jul 10, 2024
1 parent 2f2f6ad commit 17ad4fc
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/solution_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ end
function sample(sol::ProbODESolution, n::Int=1)
@unpack d, q = sol.cache
sample_path = sample_states(sol, n)
return sample_path[:, 1:q+1:d*(q+1), :]
ys = stack(map(x -> (sol.cache.SolProj * x')', eachslice(sample_path; dims=3)))
return ys
end
function sample_states(ts, xs, diffusions, difftimes, cache, n::Int=1)
@assert length(diffusions) + 1 == length(difftimes)
Expand Down Expand Up @@ -81,5 +82,6 @@ end
function dense_sample(sol::ProbODESolution, n::Int=1; density=1000)
samples, times = dense_sample_states(sol, n; density=density)
@unpack d, q = sol.cache
return samples[:, 1:q+1:d*(q+1), :], times
ys = stack(map(x -> (sol.cache.SolProj * x')', eachslice(samples; dims=3)))
return ys, times
end

0 comments on commit 17ad4fc

Please sign in to comment.