Skip to content

Commit

Permalink
Auto stash before checking out "origin/master"
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 20, 2023
1 parent 8bf5bb7 commit fc7b0f5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
19 changes: 15 additions & 4 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
nsteps = length(timeseries)

prob = convert(ImmutableODEProblem, prob)

dt = convert(eltype(prob.tspan), dt)

if saveat === nothing
Expand All @@ -52,7 +51,11 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat))
else
convert.(eltype(prob.tspan),adapt(backend, saveat))
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -99,7 +102,11 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat))
else
convert.(eltype(prob.tspan),adapt(backend, saveat))
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -176,7 +183,11 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat))
else
convert.(eltype(prob.tspan),adapt(backend, saveat))
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down
26 changes: 25 additions & 1 deletion test/gpu_kernel_de/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,28 @@ prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .*
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1.0f0);
saveat = 1.0f0);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1.0);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = [1f0, 5f0, 10f0]);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = [1.0, 5.0, 10.0]);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:10);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:0.1:10);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:(1f0):10);

0 comments on commit fc7b0f5

Please sign in to comment.