From 66c6210a30e87f4f76df80308f07083ef6d8aed5 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 19 Oct 2023 23:42:37 -0400 Subject: [PATCH] test conversion better --- src/ensemblegpukernel/lowerlevel_solve.jl | 12 +++++++--- test/gpu_kernel_de/conversions.jl | 29 ++++++++++------------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index 26dd7958..1b068d80 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -53,8 +53,10 @@ function vectorized_solve(probs, prob::ODEProblem, alg; else saveat = if saveat isa AbstractRange range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat)) - else + elseif saveat isa AbstractVector convert.(eltype(prob.tspan),adapt(backend, saveat)) + else + prob.tspan[1]:convert(eltype(prob.tspan),saveat):prob.tspan[end] end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) @@ -104,8 +106,10 @@ function vectorized_solve(probs, prob::SDEProblem, alg; else saveat = if saveat isa AbstractRange range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat)) - else + elseif saveat isa AbstractVector convert.(eltype(prob.tspan),adapt(backend, saveat)) + else + prob.tspan[1]:convert(eltype(prob.tspan),saveat):prob.tspan[end] end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) @@ -185,8 +189,10 @@ function vectorized_asolve(probs, prob::ODEProblem, alg; else saveat = if saveat isa AbstractRange range(convert(eltype(prob.tspan),first(saveat)), convert(eltype(prob.tspan),last(saveat)), length = length(saveat)) - else + elseif saveat isa AbstractVector convert.(eltype(prob.tspan),adapt(backend, saveat)) + else + prob.tspan[1]:convert(eltype(prob.tspan),saveat):prob.tspan[end] end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) diff --git a/test/gpu_kernel_de/conversions.jl b/test/gpu_kernel_de/conversions.jl index f5f8b195..815319a3 100644 --- a/test/gpu_kernel_de/conversions.jl +++ b/test/gpu_kernel_de/conversions.jl @@ -1,4 +1,4 @@ -using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra +using DiffEqGPU, OrdinaryDiffEq, StaticArrays, LinearAlgebra, Test include("../utils.jl") function lorenz(u, p, t) @@ -17,30 +17,27 @@ p = [10.0f0, 28.0f0, 8 / 3.0f0] prob = ODEProblem{false}(lorenz, u0, tspan, p) prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), - trajectories = 10_000, - saveat = 1.0f0); -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = 1.0); + saveat = 1.0)[1].t == 0f0:1f0:10f0 -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = [1f0, 5f0, 10f0]); + saveat = [1f0, 5f0, 10f0])[1].t == [1f0, 5f0, 10f0] -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = [1.0, 5.0, 10.0]); + saveat = [1.0, 5.0, 10.0])[1].t == [1f0, 5f0, 10f0] -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = 1:10); + saveat = 1:10)[1].t == Float32.(1:10) -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = 1:0.1:10); + saveat = 1:0.1:10)[1].t == 1:1f-1:10 -sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000, - saveat = 1:(1f0):10); \ No newline at end of file + saveat = 1:(1f0):10)[1].t == 1:1f0:10 \ No newline at end of file