Skip to content

Commit

Permalink
test conversion better
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 20, 2023
1 parent fc7b0f5 commit 66c6210
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
12 changes: 9 additions & 3 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
else
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat))
else
elseif saveat isa AbstractVector
convert.(eltype(prob.tspan),adapt(backend, saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan),saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
Expand Down Expand Up @@ -104,8 +106,10 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
else
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat))
else
elseif saveat isa AbstractVector
convert.(eltype(prob.tspan),adapt(backend, saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan),saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
Expand Down Expand Up @@ -185,8 +189,10 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
else
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat))
else
elseif saveat isa AbstractVector
convert.(eltype(prob.tspan),adapt(backend, saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan),saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
Expand Down
29 changes: 13 additions & 16 deletions test/gpu_kernel_de/conversions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra
using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra, Test
include("../utils.jl")

function lorenz(u, p, t)
Expand All @@ -17,30 +17,27 @@ p = [10.0f0, 28.0f0, 8 / 3.0f0]
prob = ODEProblem{false}(lorenz, u0, tspan, p)
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1.0f0);

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1.0);
saveat = 1.0)[1].t == 0f0:1f0:10f0

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = [1f0, 5f0, 10f0]);
saveat = [1f0, 5f0, 10f0])[1].t == [1f0, 5f0, 10f0]

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = [1.0, 5.0, 10.0]);
saveat = [1.0, 5.0, 10.0])[1].t == [1f0, 5f0, 10f0]

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:10);
saveat = 1:10)[1].t == Float32.(1:10)

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:0.1:10);
saveat = 1:0.1:10)[1].t == 1:1f-1:10

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:(1f0):10);
saveat = 1:(1f0):10)[1].t == 1:1f0:10

0 comments on commit 66c6210

Please sign in to comment.