Skip to content

Commit

Permalink
Remove param conversion to FP32 StepRangeLen
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 committed Oct 22, 2023
1 parent 35f62b9 commit 0ed2b1a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 28 deletions.
18 changes: 2 additions & 16 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,27 +206,13 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = if saveat isa AbstractRange
_saveat = range(convert(eltype(prob.tspan), first(saveat)),
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
convert(StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64,
},
_saveat)
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
else
_saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
convert(StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64,
},
_saveat)
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
Expand Down
28 changes: 16 additions & 12 deletions test/gpu_kernel_de/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p)
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)

## Don't test the problems in which GPUs don't support FP64 completely yet
## Creating StepRangeLen causes some param types to be FP64 inferred by `float` function
if ENV["GROUP"] ("Metal", "oneAPI")
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:10)[1].t == Float32.(1:10)

@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:0.1:10)[1].t == StepRangeLen{Float32, Float32, Float32, Int32}(1.0f0, 0.1f0, 91)

@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:(1.0f0):10)[1].t == 1:1.0f0:10
end

@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1.0)[1].t == 0.0f0:1.0f0:10.0f0
Expand All @@ -29,15 +45,3 @@ monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = [1.0, 5.0, 10.0])[1].t == [1.0f0, 5.0f0, 10.0f0]

@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:10)[1].t == Float32.(1:10)

@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:0.1:10)[1].t == StepRangeLen{Float32, Float32, Float32, Int32}(1.0f0, 0.1f0, 91)

@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend),
trajectories = 10_000,
saveat = 1:(1.0f0):10)[1].t == 1:1.0f0:10

0 comments on commit 0ed2b1a

Please sign in to comment.