diff --git a/Project.toml b/Project.toml index 44fad15c..3de515f0 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10" KernelAbstractions = "0.9" LinearSolve = "1.15, 2" -Metal = "0.4" +Metal = "0.5" MuladdMacro = "0.2" Parameters = "0.12" RecursiveArrayTools = "2" diff --git a/src/DiffEqGPU.jl b/src/DiffEqGPU.jl index 0305719a..79613416 100644 --- a/src/DiffEqGPU.jl +++ b/src/DiffEqGPU.jl @@ -50,6 +50,7 @@ include("ensemblegpukernel/integrators/stiff/interpolants.jl") include("ensemblegpukernel/integrators/nonstiff/interpolants.jl") include("ensemblegpukernel/nlsolve/type.jl") include("ensemblegpukernel/nlsolve/utils.jl") +include("ensemblegpukernel/kernels.jl") include("ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl") include("ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl") diff --git a/src/ensemblegpuarray/problem_generation.jl b/src/ensemblegpuarray/problem_generation.jl index d78084e3..9075cae2 100644 --- a/src/ensemblegpuarray/problem_generation.jl +++ b/src/ensemblegpuarray/problem_generation.jl @@ -1,4 +1,8 @@ -function generate_problem(prob::SciMLBase.AbstractODEProblem, u0, p, jac_prototype, colorvec) +function generate_problem(prob::SciMLBase.AbstractODEProblem, + u0, + p, + jac_prototype, + colorvec) _f = let f = prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop function (du, u, p, t) version = get_backend(u) diff --git a/src/ensemblegpukernel/kernels.jl b/src/ensemblegpukernel/kernels.jl new file mode 100644 index 00000000..0e4705b0 --- /dev/null +++ b/src/ensemblegpukernel/kernels.jl @@ -0,0 +1,114 @@ + +# saveat is just a bool here: +# true: ts is a vector of timestamps to read from +# false: each ODE has its own timestamps, so ts is a vector to write to +@kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback, + tstops, nsteps, + saveat, ::Val{save_everystep}) where {save_everystep} + i = @index(Global, Linear) + + # get the actual problem for this thread + prob = @inbounds probs[i] + + # get the input/output arrays for this thread + ts = @inbounds view(_ts, :, i) + us = @inbounds view(_us, :, i) + + _saveat = get(prob.kwargs, :saveat, nothing) + + saveat = _saveat === nothing ? saveat : _saveat + + integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops, + callback, save_everystep, saveat) + + u0 = prob.u0 + tspan = prob.tspan + + integ.cur_t = 0 + if saveat !== nothing + integ.cur_t = 1 + if prob.tspan[1] == saveat[1] + integ.cur_t += 1 + @inbounds us[1] = u0 + end + else + @inbounds ts[integ.step_idx] = prob.tspan[1] + @inbounds us[integ.step_idx] = prob.u0 + end + + integ.step_idx += 1 + # FSAL + while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated + saved_in_cb = step!(integ, ts, us) + !saved_in_cb && savevalues!(integ, ts, us) + end + if integ.t > tspan[2] && saveat === nothing + ## Intepolate to tf + @inbounds us[end] = integ(tspan[2]) + @inbounds ts[end] = tspan[2] + end + + if saveat === nothing && !save_everystep + @inbounds us[2] = integ.u + @inbounds ts[2] = integ.t + end +end + +@kernel function ode_asolve_kernel(@Const(probs), alg, _us, _ts, dt, callback, tstops, + abstol, reltol, + saveat, + ::Val{save_everystep}) where {save_everystep} + i = @index(Global, Linear) + + # get the actual problem for this thread + prob = @inbounds probs[i] + # get the input/output arrays for this thread + ts = @inbounds view(_ts, :, i) + us = @inbounds view(_us, :, i) + # TODO: optimize contiguous view to return a CuDeviceArray + + _saveat = get(prob.kwargs, :saveat, nothing) + + saveat = _saveat === nothing ? saveat : _saveat + + u0 = prob.u0 + tspan = prob.tspan + f = prob.f + p = prob.p + + t = tspan[1] + tf = prob.tspan[2] + + integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt, + prob.p, + abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback, + saveat) + + integ.cur_t = 0 + if saveat !== nothing + integ.cur_t = 1 + if tspan[1] == saveat[1] + integ.cur_t += 1 + @inbounds us[1] = u0 + end + else + @inbounds ts[1] = tspan[1] + @inbounds us[1] = u0 + end + + while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated + saved_in_cb = step!(integ, ts, us) + !saved_in_cb && savevalues!(integ, ts, us) + end + + if integ.t > tspan[2] && saveat === nothing + ## Intepolate to tf + @inbounds us[end] = integ(tspan[2]) + @inbounds ts[end] = tspan[2] + end + + if saveat === nothing && !save_everystep + @inbounds us[2] = integ.u + @inbounds ts[2] = integ.t + end +end diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index 0bb72034..5c5485ad 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -35,7 +35,6 @@ function vectorized_solve(probs, prob::ODEProblem, alg; nsteps = length(timeseries) prob = convert(ImmutableODEProblem, prob) - dt = convert(eltype(prob.tspan), dt) if saveat === nothing @@ -52,7 +51,29 @@ function vectorized_solve(probs, prob::ODEProblem, alg; fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (len, length(probs))) else - saveat = adapt(backend, saveat) + saveat = if saveat isa AbstractRange + _saveat = range(convert(eltype(prob.tspan), first(saveat)), + convert(eltype(prob.tspan), last(saveat)), + length = length(saveat)) + convert(StepRangeLen{ + eltype(_saveat), + eltype(_saveat), + eltype(_saveat), + eltype(_saveat) === Float32 ? Int32 : Int64, + }, + _saveat) + elseif saveat isa AbstractVector + adapt(backend, convert.(eltype(prob.tspan), saveat)) + else + _saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end] + convert(StepRangeLen{ + eltype(_saveat), + eltype(_saveat), + eltype(_saveat), + eltype(_saveat) === Float32 ? Int32 : Int64, + }, + _saveat) + end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs))) @@ -99,7 +120,15 @@ function vectorized_solve(probs, prob::SDEProblem, alg; fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (len, length(probs))) else - saveat = adapt(backend, saveat) + saveat = if saveat isa AbstractRange + range(convert(eltype(prob.tspan), first(saveat)), + convert(eltype(prob.tspan), last(saveat)), + length = length(saveat)) + elseif saveat isa AbstractVector + convert.(eltype(prob.tspan), adapt(backend, saveat)) + else + prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end] + end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs))) @@ -176,7 +205,15 @@ function vectorized_asolve(probs, prob::ODEProblem, alg; fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (len, length(probs))) else - saveat = adapt(backend, saveat) + saveat = if saveat isa AbstractRange + range(convert(eltype(prob.tspan), first(saveat)), + convert(eltype(prob.tspan), last(saveat)), + length = length(saveat)) + elseif saveat isa AbstractVector + adapt(backend, convert.(eltype(prob.tspan), saveat)) + else + prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end] + end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs))) @@ -211,117 +248,3 @@ function vectorized_asolve(probs, prob::SDEProblem, alg; kwargs...) error("Adaptive time-stepping is not supported yet with GPUEM.") end - -# saveat is just a bool here: -# true: ts is a vector of timestamps to read from -# false: each ODE has its own timestamps, so ts is a vector to write to -@kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback, - tstops, nsteps, - saveat, ::Val{save_everystep}) where {save_everystep} - i = @index(Global, Linear) - - # get the actual problem for this thread - prob = @inbounds probs[i] - - # get the input/output arrays for this thread - ts = @inbounds view(_ts, :, i) - us = @inbounds view(_us, :, i) - - _saveat = get(prob.kwargs, :saveat, nothing) - - saveat = _saveat === nothing ? saveat : _saveat - - integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops, - callback, save_everystep, saveat) - - u0 = prob.u0 - tspan = prob.tspan - - integ.cur_t = 0 - if saveat !== nothing - integ.cur_t = 1 - if prob.tspan[1] == saveat[1] - integ.cur_t += 1 - @inbounds us[1] = u0 - end - else - @inbounds ts[integ.step_idx] = prob.tspan[1] - @inbounds us[integ.step_idx] = prob.u0 - end - - integ.step_idx += 1 - # FSAL - while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated - saved_in_cb = step!(integ, ts, us) - !saved_in_cb && savevalues!(integ, ts, us) - end - if integ.t > tspan[2] && saveat === nothing - ## Intepolate to tf - @inbounds us[end] = integ(tspan[2]) - @inbounds ts[end] = tspan[2] - end - - if saveat === nothing && !save_everystep - @inbounds us[2] = integ.u - @inbounds ts[2] = integ.t - end -end - -@kernel function ode_asolve_kernel(probs, alg, _us, _ts, dt, callback, tstops, - abstol, reltol, - saveat, - ::Val{save_everystep}) where {save_everystep} - i = @index(Global, Linear) - - # get the actual problem for this thread - prob = @inbounds probs[i] - # get the input/output arrays for this thread - ts = @inbounds view(_ts, :, i) - us = @inbounds view(_us, :, i) - # TODO: optimize contiguous view to return a CuDeviceArray - - _saveat = get(prob.kwargs, :saveat, nothing) - - saveat = _saveat === nothing ? saveat : _saveat - - u0 = prob.u0 - tspan = prob.tspan - f = prob.f - p = prob.p - - t = tspan[1] - tf = prob.tspan[2] - - integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt, - prob.p, - abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback, - saveat) - - integ.cur_t = 0 - if saveat !== nothing - integ.cur_t = 1 - if tspan[1] == saveat[1] - integ.cur_t += 1 - @inbounds us[1] = u0 - end - else - @inbounds ts[1] = tspan[1] - @inbounds us[1] = u0 - end - - while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated - saved_in_cb = step!(integ, ts, us) - !saved_in_cb && savevalues!(integ, ts, us) - end - - if integ.t > tspan[2] && saveat === nothing - ## Intepolate to tf - @inbounds us[end] = integ(tspan[2]) - @inbounds ts[end] = tspan[2] - end - - if saveat === nothing && !save_everystep - @inbounds us[2] = integ.u - @inbounds ts[2] = integ.t - end -end diff --git a/src/kernels.jl b/src/kernels.jl deleted file mode 100644 index c26566b0..00000000 --- a/src/kernels.jl +++ /dev/null @@ -1,228 +0,0 @@ -@kernel function ode_solve_kernel(@Const(ps), alg, _us, _ts, dt, callback, - tstops, nsteps, - saveat, ::Val{save_everystep}) where {save_everystep} - i = @index(Global, Linear) - - # get the actual parameter for this thread - # p_i = @inbounds view(ps, i) - # @KernelAbstractions.print("Here") - - # prob = remake(prob_1; p = p_i) - - # get the input/output arrays for this thread - # ts = @inbounds view(_ts, :, i) - # us = @inbounds view(_us, :, i) - - # _saveat = get(prob.kwargs, :saveat, nothing) - - # saveat = _saveat === nothing ? saveat : _saveat - - # integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops, - # callback, save_everystep, saveat) - - # u0 = prob.u0 - # tspan = prob.tspan - - # integ.cur_t = 0 - # if saveat !== nothing - # integ.cur_t = 1 - # if prob.tspan[1] == saveat[1] - # integ.cur_t += 1 - # @inbounds us[1] = u0 - # end - # else - # @inbounds ts[integ.step_idx] = prob.tspan[1] - # @inbounds us[integ.step_idx] = prob.u0 - # end - - # integ.step_idx += 1 - # # FSAL - # while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated - # saved_in_cb = step!(integ, ts, us) - # !saved_in_cb && savevalues!(integ, ts, us) - # end - # if integ.t > tspan[2] && saveat === nothing - # ## Intepolate to tf - # @inbounds us[end] = integ(tspan[2]) - # @inbounds ts[end] = tspan[2] - # end - - # if saveat === nothing && !save_everystep - # @inbounds us[2] = integ.u - # @inbounds ts[2] = integ.t - # end -end - -# saveat is just a bool here: -# true: ts is a vector of timestamps to read from -# false: each ODE has its own timestamps, so ts is a vector to write to -@kernel function ode_solve_kernel(prob, @Const(probs), alg, _us, _ts, dt, callback, - tstops, nsteps, - saveat, ::Val{save_everystep}) where {save_everystep} - i = @index(Global, Linear) - - # get the actual problem for this thread - prob = @inbounds probs[i] - - # get the input/output arrays for this thread - ts = @inbounds view(_ts, :, i) - us = @inbounds view(_us, :, i) - - _saveat = get(prob.kwargs, :saveat, nothing) - - saveat = _saveat === nothing ? saveat : _saveat - - integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops, - callback, save_everystep, saveat) - - u0 = prob.u0 - tspan = prob.tspan - - integ.cur_t = 0 - if saveat !== nothing - integ.cur_t = 1 - if prob.tspan[1] == saveat[1] - integ.cur_t += 1 - @inbounds us[1] = u0 - end - else - @inbounds ts[integ.step_idx] = prob.tspan[1] - @inbounds us[integ.step_idx] = prob.u0 - end - - integ.step_idx += 1 - # FSAL - while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated - saved_in_cb = step!(integ, ts, us) - !saved_in_cb && savevalues!(integ, ts, us) - end - if integ.t > tspan[2] && saveat === nothing - ## Intepolate to tf - @inbounds us[end] = integ(tspan[2]) - @inbounds ts[end] = tspan[2] - end - - if saveat === nothing && !save_everystep - @inbounds us[2] = integ.u - @inbounds ts[2] = integ.t - end -end - -@kernel function ode_asolve_kernel(@Const(ps), prob, alg, _us, _ts, dt, callback, tstops, - abstol, reltol, - saveat, - ::Val{save_everystep}) where {save_everystep} - i = @index(Global, Linear) - - # get the actual problem for this thread - p_i = @inbounds ps[i] - - # get the input/output arrays for this thread - ts = @inbounds view(_ts, :, i) - us = @inbounds view(_us, :, i) - # TODO: optimize contiguous view to return a CuDeviceArray - - _saveat = get(prob.kwargs, :saveat, nothing) - - saveat = _saveat === nothing ? saveat : _saveat - - u0 = prob.u0 - tspan = prob.tspan - f = prob.f - p = prob.p - - t = tspan[1] - tf = prob.tspan[2] - - integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt, - prob.p, - abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback, - saveat) - - integ.cur_t = 0 - if saveat !== nothing - integ.cur_t = 1 - if tspan[1] == saveat[1] - integ.cur_t += 1 - @inbounds us[1] = u0 - end - else - @inbounds ts[1] = tspan[1] - @inbounds us[1] = u0 - end - - while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated - saved_in_cb = step!(integ, ts, us) - !saved_in_cb && savevalues!(integ, ts, us) - end - - if integ.t > tspan[2] && saveat === nothing - ## Intepolate to tf - @inbounds us[end] = integ(tspan[2]) - @inbounds ts[end] = tspan[2] - end - - if saveat === nothing && !save_everystep - @inbounds us[2] = integ.u - @inbounds ts[2] = integ.t - end -end - -@kernel function ode_asolve_kernel(@Const(probs), alg, _us, _ts, dt, callback, tstops, - abstol, reltol, - saveat, - ::Val{save_everystep}) where {save_everystep} - i = @index(Global, Linear) - - # get the actual problem for this thread - prob = @inbounds probs[i] - # get the input/output arrays for this thread - ts = @inbounds view(_ts, :, i) - us = @inbounds view(_us, :, i) - # TODO: optimize contiguous view to return a CuDeviceArray - - _saveat = get(prob.kwargs, :saveat, nothing) - - saveat = _saveat === nothing ? saveat : _saveat - - u0 = prob.u0 - tspan = prob.tspan - f = prob.f - p = prob.p - - t = tspan[1] - tf = prob.tspan[2] - - integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt, - prob.p, - abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback, - saveat) - - integ.cur_t = 0 - if saveat !== nothing - integ.cur_t = 1 - if tspan[1] == saveat[1] - integ.cur_t += 1 - @inbounds us[1] = u0 - end - else - @inbounds ts[1] = tspan[1] - @inbounds us[1] = u0 - end - - while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated - saved_in_cb = step!(integ, ts, us) - !saved_in_cb && savevalues!(integ, ts, us) - end - - if integ.t > tspan[2] && saveat === nothing - ## Intepolate to tf - @inbounds us[end] = integ(tspan[2]) - @inbounds ts[end] = tspan[2] - end - - if saveat === nothing && !save_everystep - @inbounds us[2] = integ.u - @inbounds ts[2] = integ.t - end -end diff --git a/test/Project.toml b/test/Project.toml index 369c4dfb..38217d47 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,13 +1,13 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/test/gpu_kernel_de/conversions.jl b/test/gpu_kernel_de/conversions.jl index 0c616909..d524a929 100644 --- a/test/gpu_kernel_de/conversions.jl +++ b/test/gpu_kernel_de/conversions.jl @@ -1,4 +1,4 @@ -using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra +using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra, Test include("../utils.jl") function lorenz(u, p, t) @@ -17,6 +17,31 @@ p = [10.0f0, 28.0f0, 8 / 3.0f0] prob = ODEProblem{false}(lorenz, u0, tspan, p) prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + +## Don't test the problems in which GPUs don't support FP64 completely yet +## Creating StepRangeLen causes some param types to be FP64 inferred by `float` function +if ENV["GROUP"] ∉ ("Metal", "oneAPI") + @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:10)[1].t == Float32.(1:10) + + @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:0.1:10)[1].t == 1.0f0:0.1f0:10.0f0 + + @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:(1.0f0):10)[1].t == 1:1.0f0:10 + + @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1.0)[1].t == 0.0f0:1.0f0:10.0f0 +end + +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = [1.0f0, 5.0f0, 10.0f0])[1].t == [1.0f0, 5.0f0, 10.0f0] + +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = 1.0f0); \ No newline at end of file + saveat = [1.0, 5.0, 10.0])[1].t == [1.0f0, 5.0f0, 10.0f0]