Skip to content

Commit

Permalink
Merge pull request #306 from SciML/u/immutablefix
Browse files Browse the repository at this point in the history
Add immutable ODE Problem for GPU compilation
  • Loading branch information
ChrisRackauckas authored Oct 19, 2023
2 parents 6d1b08b + 5f62f6b commit 76f8431
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ MuladdMacro = "0.2"
Parameters = "0.12"
RecursiveArrayTools = "2"
Requires = "1.0"
SciMLBase = "1.26"
SciMLBase = "1.26, 2"
Setfield = "1"
SimpleDiffEq = "1"
StaticArrays = "1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/lower_level_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p)
## Building different problems for different parameters
probs = map(1:trajectories) do i
remake(prob, p = (@SVector rand(Float32, 3)) .* p)
DiffEqGPU.make_prob_compatible(remake(prob, p = (@SVector rand(Float32, 3)) .* p))
end
## Move the arrays to the GPU
Expand Down
2 changes: 2 additions & 0 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ include("ensemblegpukernel/tableaus/verner_tableaus.jl")
include("ensemblegpukernel/tableaus/rodas_tableaus.jl")
include("ensemblegpukernel/tableaus/kvaerno_tableaus.jl")

include("ensemblegpukernel/problems/ode_problems.jl")

include("utils.jl")
include("algorithms.jl")
include("solve.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/ensemblegpuarray/problem_generation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function generate_problem(prob::ODEProblem, u0, p, jac_prototype, colorvec)
function generate_problem(prob::SciMLBase.AbstractODEProblem, u0, p, jac_prototype, colorvec)
_f = let f = prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop
function (du, u, p, t)
version = get_backend(u)
Expand Down
105 changes: 105 additions & 0 deletions src/ensemblegpukernel/problems/ode_problems.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import SciMLBase: @add_kwonly, AbstractODEProblem, AbstractODEFunction,
FunctionWrapperSpecialize, StandardODEProblem, prepare_initial_state, promote_tspan,
warn_paramtype

struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <:
AbstractODEProblem{uType, tType, isinplace}
"""The ODE is `du = f(u,p,t)` for out-of-place and f(du,u,p,t) for in-place."""
f::F
"""The initial condition is `u(tspan[1]) = u0`."""
u0::uType
"""The solution `u(t)` will be computed for `tspan[1] ≤ t ≤ tspan[2]`."""
tspan::tType
"""Constant parameters to be supplied as the second argument of `f`."""
p::P
"""A callback to be applied to every solver which uses the problem."""
kwargs::K
"""An internal argument for storing traits about the solving process."""
problem_type::PT
@add_kwonly function ImmutableODEProblem{iip}(f::AbstractODEFunction{iip},
u0, tspan, p = NullParameters(),
problem_type = StandardODEProblem();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(_u0), typeof(_tspan),
isinplace(f), typeof(p), typeof(f),
typeof(kwargs),
typeof(problem_type)}(f,
_u0,
_tspan,
p,
kwargs,
problem_type)
end

"""
ImmutableODEProblem{isinplace}(f,u0,tspan,p=NullParameters(),callback=CallbackSet())
Define an ODE problem with the specified function.
`isinplace` optionally sets whether the function is inplace or not.
This is determined automatically, but not inferred.
"""
function ImmutableODEProblem{iip}(f,
u0,
tspan,
p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
end

@add_kwonly function ImmutableODEProblem{iip, recompile}(f, u0, tspan,
p = NullParameters();
kwargs...) where {iip, recompile}
ImmutableODEProblem{iip}(ODEFunction{iip, recompile}(f), u0, tspan, p; kwargs...)
end

function ImmutableODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan,
p = NullParameters();
kwargs...) where {iip}
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if iip
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
(_u0, _u0, p,
_tspan[1])))
else
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
(_u0, p,
_tspan[1])))
end
end
ImmutableODEProblem{iip}(ff, _u0, _tspan, p; kwargs...)
end
end

"""
ImmutableODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet())
Define an ODE problem from an [`ODEFunction`](@ref).
"""
function ImmutableODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
ImmutableODEProblem{isinplace(f)}(f, u0, tspan, args...; kwargs...)
end

function ImmutableODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
_u0 = prepare_initial_state(u0)
_tspan = promote_tspan(tspan)
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
end

function Base.convert(::Type{ImmutableODEProblem}, prob::T) where {T <: ODEProblem}
ImmutableODEProblem(prob.f,
prob.u0,
prob.tspan,
prob.p,
prob.problem_type;
prob.kwargs...)
end
29 changes: 20 additions & 9 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,21 @@ function batch_solve(ensembleprob, alg,
ensemblealg::Union{EnsembleArrayAlgorithm, EnsembleKernelAlgorithm}, I,
adaptive;
kwargs...)
if ensembleprob.safetycopy
probs = map(I) do i
ensembleprob.prob_func(deepcopy(ensembleprob.prob), i, 1)
end
else
probs = map(I) do i
ensembleprob.prob_func(ensembleprob.prob, i, 1)
end
end
@assert !isempty(I)
#@assert all(p->p.f === probs[1].f,probs)

if ensemblealg isa EnsembleGPUKernel
if ensembleprob.safetycopy
probs = map(I) do i
make_prob_compatible(ensembleprob.prob_func(deepcopy(ensembleprob.prob),
i,
1))
end
else
probs = map(I) do i
make_prob_compatible(ensembleprob.prob_func(ensembleprob.prob, i, 1))
end
end
# Using inner saveat requires all of them to be of same size,
# because the dimension of CuMatrix is decided by it.
# The columns of it are accessed at each thread.
Expand Down Expand Up @@ -192,6 +194,15 @@ function batch_solve(ensembleprob, alg,
error("We don't have solvers implemented for this algorithm yet")
end
else
if ensembleprob.safetycopy
probs = map(I) do i
ensembleprob.prob_func(deepcopy(ensembleprob.prob), i, 1)
end
else
probs = map(I) do i
ensembleprob.prob_func(ensembleprob.prob, i, 1)
end
end
u0 = reduce(hcat, Array(probs[i].u0) for i in 1:length(I))

if !all(Base.Fix2((prob1, prob2) -> isequal(prob1.tspan, prob2.tspan),
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ function diffeqgpunorm(u::AbstractArray{<:ForwardDiff.Dual}, t)
sqrt.(sum(abs2 ForwardDiff.value, u) ./ length(u))
end
diffeqgpunorm(u::ForwardDiff.Dual, t) = abs(ForwardDiff.value(u))

make_prob_compatible(prob) = prob

function make_prob_compatible(prob::T) where {T <: ODEProblem}
convert(ImmutableODEProblem, prob)
end
2 changes: 1 addition & 1 deletion test/lower_level_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ prob = ODEProblem{false}(func, u0, tspan, p)
## Building different problems for different parameters
batch = 1:trajectories
probs = map(batch) do i
remake(prob, p = (@SVector rand(Float32, 3)) .* p)
DiffEqGPU.make_prob_compatible(remake(prob, p = (@SVector rand(Float32, 3)) .* p))
end

## Move the arrays to the GPU
Expand Down

0 comments on commit 76f8431

Please sign in to comment.