diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index 0bb72034..26dd7958 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -35,7 +35,6 @@ function vectorized_solve(probs, prob::ODEProblem, alg; nsteps = length(timeseries) prob = convert(ImmutableODEProblem, prob) - dt = convert(eltype(prob.tspan), dt) if saveat === nothing @@ -52,7 +51,11 @@ function vectorized_solve(probs, prob::ODEProblem, alg; fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (len, length(probs))) else - saveat = adapt(backend, saveat) + saveat = if saveat isa AbstractRange + range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat)) + else + convert.(eltype(prob.tspan),adapt(backend, saveat)) + end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs))) @@ -99,7 +102,11 @@ function vectorized_solve(probs, prob::SDEProblem, alg; fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (len, length(probs))) else - saveat = adapt(backend, saveat) + saveat = if saveat isa AbstractRange + range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat)) + else + convert.(eltype(prob.tspan),adapt(backend, saveat)) + end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs))) @@ -176,7 +183,11 @@ function vectorized_asolve(probs, prob::ODEProblem, alg; fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (len, length(probs))) else - saveat = adapt(backend, saveat) + saveat = if saveat isa AbstractRange + range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat)) + else + convert.(eltype(prob.tspan),adapt(backend, saveat)) + end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs))) diff --git a/test/gpu_kernel_de/conversions.jl b/test/gpu_kernel_de/conversions.jl index 0c616909..f5f8b195 100644 --- a/test/gpu_kernel_de/conversions.jl +++ b/test/gpu_kernel_de/conversions.jl @@ -19,4 +19,28 @@ prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = 1.0f0); \ No newline at end of file + saveat = 1.0f0); + +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1.0); + +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = [1f0, 5f0, 10f0]); + +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = [1.0, 5.0, 10.0]); + +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:10); + +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:0.1:10); + +sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:(1f0):10); \ No newline at end of file