From 8c95e153d214cb0afc7e3f1ac0e7799b93190c0b Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 19 Oct 2023 21:58:14 -0400 Subject: [PATCH] Automatically convert incompatible vector problems to svector Required for diffeqr and diffeqpy. Note that the `f` functions are okay when generated by MTK, since it generates a dispatch for SA -> SA, so therefore all that's needed is the u0 and p conversion for this to work downstream. --- src/algorithms.jl | 6 ++--- src/ensemblegpukernel/lowerlevel_solve.jl | 4 ++++ .../problems/ode_problems.jl | 10 +++++++-- test/gpu_kernel_de/conversions.jl | 22 +++++++++++++++++++ test/runtests.jl | 4 ++++ 5 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 test/gpu_kernel_de/conversions.jl diff --git a/src/algorithms.jl b/src/algorithms.jl index 6c34988f..dd330734 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -99,8 +99,8 @@ imparts some extra limitations on the use. ## Positional Arguments - `backend`: the KernelAbstractions backend for performing the computation. - - `cpu_offload`: the percentage of trajectories to offload to the CPU. Default is 0.2 or - 20% of trajectories. + - `cpu_offload`: the percentage of trajectories to offload to the CPU. Default is 0.0 or + 0% of trajectories. ## Limitations @@ -165,7 +165,7 @@ function EnsembleGPUArray(dev) end function EnsembleGPUKernel(dev) - EnsembleGPUKernel(dev, 0.2) + EnsembleGPUKernel(dev, 0.0) end function ChainRulesCore.rrule(::Type{<:EnsembleGPUArray}) diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index 03acae98..0bb72034 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -34,6 +34,8 @@ function vectorized_solve(probs, prob::ODEProblem, alg; timeseries = prob.tspan[1]:dt:prob.tspan[2] nsteps = length(timeseries) + prob = convert(ImmutableODEProblem, prob) + dt = convert(eltype(prob.tspan), dt) if saveat === nothing @@ -154,6 +156,8 @@ function vectorized_asolve(probs, prob::ODEProblem, alg; backend = get_backend(probs) backend = maybe_prefer_blocks(backend) + prob = convert(ImmutableODEProblem, prob) + dt = convert(eltype(prob.tspan), dt) abstol = convert(eltype(prob.tspan), abstol) reltol = convert(eltype(prob.tspan), reltol) diff --git a/src/ensemblegpukernel/problems/ode_problems.jl b/src/ensemblegpukernel/problems/ode_problems.jl index 628ca923..33d715ff 100644 --- a/src/ensemblegpukernel/problems/ode_problems.jl +++ b/src/ensemblegpukernel/problems/ode_problems.jl @@ -95,11 +95,17 @@ function ImmutableODEProblem(f, u0, tspan, p = NullParameters(); kwargs...) ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...) end +staticarray_itize(x) = x +staticarray_itize(x::Vector) = SVector{length(x)}(x) +staticarray_itize(x::SizedVector) = SVector{length(x)}(x) +staticarray_itize(x::Matrix) = SMatrix{size(x)...}(x) +staticarray_itize(x::SizedMatrix) = SMatrix{size(x)...}(x) + function Base.convert(::Type{ImmutableODEProblem}, prob::T) where {T <: ODEProblem} ImmutableODEProblem(prob.f, - prob.u0, + staticarray_itize(prob.u0), prob.tspan, - prob.p, + staticarray_itize(prob.p), prob.problem_type; prob.kwargs...) end diff --git a/test/gpu_kernel_de/conversions.jl b/test/gpu_kernel_de/conversions.jl new file mode 100644 index 00000000..0c616909 --- /dev/null +++ b/test/gpu_kernel_de/conversions.jl @@ -0,0 +1,22 @@ +using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra +include("../utils.jl") + +function lorenz(u, p, t) + σ = p[1] + ρ = p[2] + β = p[3] + du1 = σ * (u[2] - u[1]) + du2 = u[1] * (ρ - u[3]) - u[2] + du3 = u[1] * u[2] - β * u[3] + return SVector{3}(du1, du2, du3) +end + +u0 = [1.0f0; 0.0f0; 0.0f0] +tspan = (0.0f0, 10.0f0) +p = [10.0f0, 28.0f0, 8 / 3.0f0] +prob = ODEProblem{false}(lorenz, u0, tspan, p) +prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) +monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1.0f0); \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c3f4303f..175c36ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,6 +43,10 @@ end include("gpu_kernel_de/finite_diff.jl") end +@time @safetestset "GPU Kernelized Auto-Conversion tests" begin + include("gpu_kernel_de/conversions.jl") +end + if GROUP in SUPPORTS_LUFACT @time @safetestset "EnsembleGPUArray" begin include("ensemblegpuarray.jl")