Skip to content

Commit

Permalink
Automatically convert incompatible vector problems to svector
Browse files Browse the repository at this point in the history
Required for diffeqr and diffeqpy. Note that the `f` functions are okay when generated by MTK, since it generates a dispatch for SA -> SA, so therefore all that's needed is the u0 and p conversion for this to work downstream.
  • Loading branch information
ChrisRackauckas committed Oct 20, 2023
1 parent 518d95b commit 8c95e15
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ imparts some extra limitations on the use.
## Positional Arguments
- `backend`: the KernelAbstractions backend for performing the computation.
- `cpu_offload`: the percentage of trajectories to offload to the CPU. Default is 0.2 or
20% of trajectories.
- `cpu_offload`: the percentage of trajectories to offload to the CPU. Default is 0.0 or
0% of trajectories.
## Limitations
Expand Down Expand Up @@ -165,7 +165,7 @@ function EnsembleGPUArray(dev)
end

function EnsembleGPUKernel(dev)
EnsembleGPUKernel(dev, 0.2)
EnsembleGPUKernel(dev, 0.0)
end

function ChainRulesCore.rrule(::Type{<:EnsembleGPUArray})
Expand Down
4 changes: 4 additions & 0 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
timeseries = prob.tspan[1]:dt:prob.tspan[2]
nsteps = length(timeseries)

prob = convert(ImmutableODEProblem, prob)

dt = convert(eltype(prob.tspan), dt)

if saveat === nothing
Expand Down Expand Up @@ -154,6 +156,8 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
backend = get_backend(probs)
backend = maybe_prefer_blocks(backend)

prob = convert(ImmutableODEProblem, prob)

dt = convert(eltype(prob.tspan), dt)
abstol = convert(eltype(prob.tspan), abstol)
reltol = convert(eltype(prob.tspan), reltol)
Expand Down
10 changes: 8 additions & 2 deletions src/ensemblegpukernel/problems/ode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,17 @@ function ImmutableODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
end

staticarray_itize(x) = x
staticarray_itize(x::Vector) = SVector{length(x)}(x)
staticarray_itize(x::SizedVector) = SVector{length(x)}(x)
staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x)
staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x)

function Base.convert(::Type{ImmutableODEProblem}, prob::T) where {T <: ODEProblem}
ImmutableODEProblem(prob.f,
prob.u0,
staticarray_itize(prob.u0),
prob.tspan,
prob.p,
staticarray_itize(prob.p),
prob.problem_type;
prob.kwargs...)
end
22 changes: 22 additions & 0 deletions test/gpu_kernel_de/conversions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra
include("../utils.jl")

function lorenz(u, p, t)
σ = p[1]
ρ = p[2]
β = p[3]
du1 = σ * (u[2] - u[1])
du2 = u[1] *- u[3]) - u[2]
du3 = u[1] * u[2] - β * u[3]
return SVector{3}(du1, du2, du3)
end

u0 = [1.0f0; 0.0f0; 0.0f0]
tspan = (0.0f0, 10.0f0)
p = [10.0f0, 28.0f0, 8 / 3.0f0]
prob = ODEProblem{false}(lorenz, u0, tspan, p)
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1.0f0);
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ end
include("gpu_kernel_de/finite_diff.jl")
end

@time @safetestset "GPU Kernelized Auto-Conversion tests" begin
include("gpu_kernel_de/conversions.jl")
end

if GROUP in SUPPORTS_LUFACT
@time @safetestset "EnsembleGPUArray" begin
include("ensemblegpuarray.jl")
Expand Down

0 comments on commit 8c95e15

Please sign in to comment.