Skip to content

Commit

Permalink
Merge pull request #839 from SciML/ensemble_indexing
Browse files Browse the repository at this point in the history
Simplify ensemble indexing
  • Loading branch information
ChrisRackauckas authored Oct 30, 2024
2 parents c341819 + 5e315da commit 1180779
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ PyCall = "1.96"
PythonCall = "0.9.15"
RCall = "0.14.0"
RecipesBase = "1.3.4"
RecursiveArrayTools = "3.26.0"
RecursiveArrayTools = "3.27.2"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
Expand Down
14 changes: 0 additions & 14 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,6 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i::Integer)
return x.u[s].u[i]
end

Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
return x.u[s][i2, i3, idxs...]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x.u]
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
[s(args...; kwargs...) for s in sol]
end
Expand Down
18 changes: 14 additions & 4 deletions test/downstream/ensemble_diffeq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
using OrdinaryDiffEq
using OrdinaryDiffEq, Test

prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0))
A = [1 2
3 4]
prob = ODEProblem((u, p, t) -> A*u, ones(2,2), (0.0, 1.0))
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
remake(prob, u0 = i * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10, saveat=0.01)
@test sim isa EnsembleSolution
@test size(sim[1,:,:,:]) == (2,101,10)
@test size(sim[:,1,:,:]) == (2,101,10)
@test size(sim[:,:,1,:]) == (2,2,10)
@test size(sim[:,:,:,1]) == (2,2,101)
@test Array(sim)[1,:,:,:] == sim[1,:,:,:]
@test Array(sim)[:,1,:,:] == sim[:,1,:,:]
@test Array(sim)[:,:,1,:] == sim[:,:,1,:]
@test Array(sim)[:,:,:,1] == sim[:,:,:,1]
9 changes: 4 additions & 5 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0))
ensemble_prob = EnsembleProblem([prob1, prob2, prob3])
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
for i in 1:3
@test sol[x, :][i] == sol.u[i][x]
@test sol[y, :][i] == sol.u[i][y]
@test sol[1,:,i] == sol.u[i][x]
@test sol[2,:,i] == sol.u[i][y]
end
# Ensemble is a recursive array
@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] == first.(sol[x, :])
# TODO: fix the interpolation
@test only.(sol(1.0, idxs = [x])) last.(sol[x, :])
@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :]
@test only.(sol(1.0, idxs = [x])) [sol[i][1, end] for i in 1:3]

0 comments on commit 1180779

Please sign in to comment.