Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some automated conversions in saveat #309

Merged
merged 7 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10"
KernelAbstractions = "0.9"
LinearSolve = "1.15, 2"
Metal = "0.4"
Metal = "0.5"
MuladdMacro = "0.2"
Parameters = "0.12"
RecursiveArrayTools = "2"
Expand Down
1 change: 1 addition & 0 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include("ensemblegpukernel/integrators/stiff/interpolants.jl")
include("ensemblegpukernel/integrators/nonstiff/interpolants.jl")
include("ensemblegpukernel/nlsolve/type.jl")
include("ensemblegpukernel/nlsolve/utils.jl")
include("ensemblegpukernel/kernels.jl")

include("ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl")
include("ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl")
Expand Down
6 changes: 5 additions & 1 deletion src/ensemblegpuarray/problem_generation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
function generate_problem(prob::SciMLBase.AbstractODEProblem, u0, p, jac_prototype, colorvec)
function generate_problem(prob::SciMLBase.AbstractODEProblem,
u0,
p,
jac_prototype,
colorvec)
_f = let f = prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop
function (du, u, p, t)
version = get_backend(u)
Expand Down
114 changes: 114 additions & 0 deletions src/ensemblegpukernel/kernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@

# saveat is just a bool here:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not allowed?

# true: ts is a vector of timestamps to read from
# false: each ODE has its own timestamps, so ts is a vector to write to
@kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback,
tstops, nsteps,
saveat, ::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]

# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops,
callback, save_everystep, saveat)

u0 = prob.u0
tspan = prob.tspan

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if prob.tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[integ.step_idx] = prob.tspan[1]
@inbounds us[integ.step_idx] = prob.u0
end

integ.step_idx += 1
# FSAL
while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end
if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end

@kernel function ode_asolve_kernel(@Const(probs), alg, _us, _ts, dt, callback, tstops,
abstol, reltol,
saveat,
::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]
# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)
# TODO: optimize contiguous view to return a CuDeviceArray

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p

t = tspan[1]
tf = prob.tspan[2]

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt,
prob.p,
abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback,
saveat)

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[1] = tspan[1]
@inbounds us[1] = u0
end

while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end

if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end
159 changes: 41 additions & 118 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
nsteps = length(timeseries)

prob = convert(ImmutableODEProblem, prob)

dt = convert(eltype(prob.tspan), dt)

if saveat === nothing
Expand All @@ -52,7 +51,29 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
_saveat = range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
convert(StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64,
},
_saveat)
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
else
_saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
convert(StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64,
},
_saveat)
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -99,7 +120,15 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
elseif saveat isa AbstractVector
convert.(eltype(prob.tspan), adapt(backend, saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -176,7 +205,15 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -211,117 +248,3 @@ function vectorized_asolve(probs, prob::SDEProblem, alg;
kwargs...)
error("Adaptive time-stepping is not supported yet with GPUEM.")
end

# saveat is just a bool here:
# true: ts is a vector of timestamps to read from
# false: each ODE has its own timestamps, so ts is a vector to write to
@kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback,
tstops, nsteps,
saveat, ::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]

# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops,
callback, save_everystep, saveat)

u0 = prob.u0
tspan = prob.tspan

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if prob.tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[integ.step_idx] = prob.tspan[1]
@inbounds us[integ.step_idx] = prob.u0
end

integ.step_idx += 1
# FSAL
while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end
if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end

@kernel function ode_asolve_kernel(probs, alg, _us, _ts, dt, callback, tstops,
abstol, reltol,
saveat,
::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]
# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)
# TODO: optimize contiguous view to return a CuDeviceArray

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p

t = tspan[1]
tf = prob.tspan[2]

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt,
prob.p,
abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback,
saveat)

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[1] = tspan[1]
@inbounds us[1] = u0
end

while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end

if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end
Loading