From 40f0ace7356a0415b46f9bc68b4dbe63c0197fe0 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Sun, 11 Dec 2016 22:25:37 -0800 Subject: [PATCH 1/4] add common interface bindings and tests --- REQUIRE | 1 + src/ODE.jl | 3 + src/common.jl | 161 +++++++++++++++++++++++++++++++++++++++++++++++ test/REQUIRE | 1 + test/common.jl | 66 +++++++++++++++++++ test/runtests.jl | 1 + 6 files changed, 233 insertions(+) create mode 100644 src/common.jl create mode 100644 test/REQUIRE create mode 100644 test/common.jl diff --git a/REQUIRE b/REQUIRE index ffc6c5b..88b7472 100644 --- a/REQUIRE +++ b/REQUIRE @@ -2,3 +2,4 @@ julia 0.5- Polynomials ForwardDiff Compat 0.4.1 +DiffEqBase diff --git a/src/ODE.jl b/src/ODE.jl index 73de9a6..eecde24 100644 --- a/src/ODE.jl +++ b/src/ODE.jl @@ -20,6 +20,8 @@ module ODE using Polynomials using Compat +using DiffEqBase +import DiffEqBase: solve import Compat.String using ForwardDiff @@ -48,6 +50,7 @@ include("integrators/rosenbrock.jl") # User interface to solvers include("top-interface.jl") +include("common.jl") """ diff --git a/src/common.jl b/src/common.jl new file mode 100644 index 0000000..e963036 --- /dev/null +++ b/src/common.jl @@ -0,0 +1,161 @@ +abstract ODEIterAlgorithm <: AbstractODEAlgorithm +immutable feuler <: ODEIterAlgorithm end +immutable rk23 <: ODEIterAlgorithm end +immutable rk45 <: ODEIterAlgorithm end +immutable feh78 <: ODEIterAlgorithm end +immutable ModifiedRosenbrock <: ODEIterAlgorithm end +immutable midpoint <: ODEIterAlgorithm end +immutable heun <: ODEIterAlgorithm end +immutable rk4 <: ODEIterAlgorithm end +immutable feh45 <: ODEIterAlgorithm end + +typealias KW Dict{Symbol,Any} + +function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F}, + alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true, + saveat=[],callback=()->nothing,timeseries_errors=true,dense_errors=false, + kwargs...) + tspan = prob.tspan + + if tspan[end]-tspan[1] (du[:] = prob.f(t,u)) + else + f! = prob.f + end + ode = ODE.ExplicitODE(t,u,f!) + # adaptive==true ? FoA=:adaptive : FoA=:fixed #Currently limied to only adaptive + FoA = :adaptive + if typeof(alg) <: rk23 + solver = ODE.RKIntegrator{FoA,:rk23} + elseif typeof(alg) <: rk45 + solver = ODE.RKIntegrator{FoA,:dopri5} + elseif typeof(alg) <: feh78 + solver = ODE.RKIntegrator{FoA,:feh78} + elseif typeof(alg) <: ModifiedRosenbrock + solver = ODE.ModifiedRosenbrockIntegrator + elseif typeof(alg) <: feuler + solver = ODE.RKIntegratorFixed{:feuler} + elseif typeof(alg) <: midpoint + solver = ODE.RKIntegratorFixed{:midpoint} + elseif typeof(alg) <: heun + solver = ODE.RKIntegratorFixed{:heun} + elseif typeof(alg) <: rk4 + solver = ODE.RKIntegratorFixed{:rk4} + elseif typeof(alg) <: feh45 + solver = ODE.RKIntegrator{FoA,:rk45} + end + out = ODE.solve(ode;solver=solver,opts...) + timeseries = out.y + ts = out.t + ks = out.dy + if length(out.y[1])==1 + tmp = Vector{eltype(out.y[1])}(length(out.y)) + tmp_dy = Vector{eltype(out.dy[1])}(length(out.dy)) + for i in 1:length(out.y) + tmp[i] = out.y[i][1] + tmp_dy[i] = out.dy[i][1] + end + timeseries = tmp + ks = tmp_dy + end + + saveat_idxs = find((x)->x∈saveat,ts) + t_nosaveat = view(ts,symdiff(1:length(ts),saveat_idxs)) + u_nosaveat = view(timeseries,symdiff(1:length(ts),saveat_idxs)) + + if dense + interp = (tvals) -> common_interpolation(tvals,t_nosaveat,u_nosaveat,ks,alg,f!) + else + interp = (tvals) -> nothing + end + + build_solution(prob,alg,ts,timeseries, + dense=dense,k=ks,interp=interp, + timeseries_errors = timeseries_errors, + dense_errors = dense_errors) +end + +const ODEJL_OPTION_LIST = Set([:tout,:tstop,:reltol,:abstol,:minstep,:maxstep,:initstep,:norm,:maxiters,:isoutofdomain]) +const ODEJL_ALIASES = Dict{Symbol,Symbol}(:minstep=>:dtmin,:maxstep=>:dtmax,:initstep=>:dt,:tstop=>:T,:maxiters=>:maxiters) +const ODEJL_ALIASES_REVERSED = Dict{Symbol,Symbol}([(v,k) for (k,v) in ODEJL_ALIASES]) + +function buildOptions(o,optionlist,aliases,aliases_reversed) + dict1 = Dict{Symbol,Any}([Pair(k,o[k]) for k in (keys(o) ∩ optionlist)]) + dict2 = Dict([Pair(aliases_reversed[k],o[k]) for k in (keys(o) ∩ values(aliases))]) + merge(dict1,dict2) +end + +""" +common_interpolation(tvals,ts,timeseries,ks) + +Get the value at tvals where the solution is known at the +times ts (sorted), with values timeseries and derivatives ks +""" +function common_interpolation(tvals,ts,timeseries,ks,alg,f) + idx = sortperm(tvals) + i = 2 # Start the search thinking it's between ts[1] and ts[2] + vals = Vector{eltype(timeseries)}(length(tvals)) + for j in idx + t = tvals[j] + i = findfirst((x)->x>=t,ts[i:end])+i-1 # It's in the interval ts[i-1] to ts[i] + if ts[i] == t + vals[j] = timeseries[i] + elseif ts[i-1] == t # Can happen if it's the first value! + vals[j] = timeseries[i-1] + else + dt = ts[i] - ts[i-1] + Θ = (t-ts[i-1])/dt + vals[j] = common_interpolant(Θ,dt,timeseries[i-1],timeseries[i],ks[i-1],ks[i],alg) + end + end + vals +end + +""" +common_interpolation(tval::Number,ts,timeseries,ks) + +Get the value at tval where the solution is known at the +times ts (sorted), with values timeseries and derivatives ks +""" +function common_interpolation(tval::Number,ts,timeseries,ks,alg,f) + i = findfirst((x)->x>=tval,ts) # It's in the interval ts[i-1] to ts[i] + if ts[i] == tval + val = timeseries[i] + elseif ts[i-1] == tval # Can happen if it's the first value! + push!(vals,timeseries[i-1]) + else + dt = ts[i] - ts[i-1] + Θ = (tval-ts[i-1])/dt + val = common_interpolant(Θ,dt,timeseries[i-1],timeseries[i],ks[i-1],ks[i],alg) + end + val +end + +""" +Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190 +""" +function common_interpolant(Θ,dt,y₀,y₁,k₀,k₁,alg) # Default interpolant is Hermite + (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k₀ + Θ*dt*k₁) +end + +export ODEIterAlgorithm, feuler, rk23, feh45, feh78, ModifiedRosenbrock, + midpoint, heun, rk4, rk45 diff --git a/test/REQUIRE b/test/REQUIRE new file mode 100644 index 0000000..62544ef --- /dev/null +++ b/test/REQUIRE @@ -0,0 +1 @@ +DiffEqProblemLibrary diff --git a/test/common.jl b/test/common.jl new file mode 100644 index 0000000..ea179fd --- /dev/null +++ b/test/common.jl @@ -0,0 +1,66 @@ +using DiffEqProblemLibrary, DiffEqBase + +prob = prob_ode_linear +dt=1/2^(4) + +sol =solve(prob,feuler();dt=dt) +#plot(sol,plot_analytic=true) + +sol =solve(prob,rk23(),dt=dt) + +sol =solve(prob,rk45(),dt=dt) + +sol =solve(prob,feh78(),dt=dt) + +sol =solve(prob,ModifiedRosenbrock(),dt=dt) + +sol =solve(prob,midpoint(),dt=dt) + +sol =solve(prob,heun(),dt=dt) + +sol =solve(prob,rk4(),dt=dt) + +sol =solve(prob,feh45(),dt=dt) + +prob = prob_ode_2Dlinear + +sol =solve(prob,feuler(),dt=dt) + +sol =solve(prob,rk23(),dt=dt) + +sol =solve(prob,rk45(),dt=dt) + +sol =solve(prob,feh78(),dt=dt) + +#sol =solve(prob,ModifiedRosenbrock(),dt=dt) #ODE.jl issues with 2D + +sol =solve(prob,midpoint(),dt=dt) + +sol =solve(prob,heun(),dt=dt) + +sol =solve(prob,rk4(),dt=dt) + +sol =solve(prob,feh45(),dt=dt) + +#= +prob = prob_ode_bigfloat2Dlinear + +sol =solve(prob,feuler(),dt=dt) +TEST_PLOT && plot(sol,plot_analytic=true) + +sol =solve(prob,rk23(),dt=dt) + +sol =solve(prob,rk4()5,dt=dt) + +sol =solve(prob,feh78(),dt=dt) + +#sol =solve(prob,dt=0,alg=:ode23s) #ODE.jl issues + +sol =solve(prob,midpoint(),dt=dt) + +sol =solve(prob,heun(),dt=dt) + +sol =solve(prob,rk4(),dt=dt) + +sol =solve(prob,feh45(),dt=dt) +=# diff --git a/test/runtests.jl b/test/runtests.jl index 27be75b..ee108c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,4 +5,5 @@ using Base.Test @testset "ODE tests" begin include("iterators.jl") include("top-interface.jl") + include("common.jl") end From ceb0632ba47c11b0ad766efa9594d4c8969fb6e1 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 28 Dec 2016 00:43:36 -0800 Subject: [PATCH 2/4] update to review --- src/common.jl | 147 ++++++++++++++++++++++++------------------------- test/common.jl | 70 +++++------------------ 2 files changed, 87 insertions(+), 130 deletions(-) diff --git a/src/common.jl b/src/common.jl index e963036..188da74 100644 --- a/src/common.jl +++ b/src/common.jl @@ -12,86 +12,83 @@ immutable feh45 <: ODEIterAlgorithm end typealias KW Dict{Symbol,Any} function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F}, - alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true, - saveat=[],callback=()->nothing,timeseries_errors=true,dense_errors=false, - kwargs...) - tspan = prob.tspan + alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true, + saveat=[],callback=()->nothing,timeseries_errors=true,dense_errors=false, + kwargs...) - if tspan[end]-tspan[1] (du[:] = prob.f(t,u)) - else - f! = prob.f - end - ode = ODE.ExplicitODE(t,u,f!) - # adaptive==true ? FoA=:adaptive : FoA=:fixed #Currently limied to only adaptive - FoA = :adaptive - if typeof(alg) <: rk23 - solver = ODE.RKIntegrator{FoA,:rk23} - elseif typeof(alg) <: rk45 - solver = ODE.RKIntegrator{FoA,:dopri5} - elseif typeof(alg) <: feh78 - solver = ODE.RKIntegrator{FoA,:feh78} - elseif typeof(alg) <: ModifiedRosenbrock - solver = ODE.ModifiedRosenbrockIntegrator - elseif typeof(alg) <: feuler - solver = ODE.RKIntegratorFixed{:feuler} - elseif typeof(alg) <: midpoint - solver = ODE.RKIntegratorFixed{:midpoint} - elseif typeof(alg) <: heun - solver = ODE.RKIntegratorFixed{:heun} - elseif typeof(alg) <: rk4 - solver = ODE.RKIntegratorFixed{:rk4} - elseif typeof(alg) <: feh45 - solver = ODE.RKIntegrator{FoA,:rk45} - end - out = ODE.solve(ode;solver=solver,opts...) - timeseries = out.y - ts = out.t - ks = out.dy - if length(out.y[1])==1 - tmp = Vector{eltype(out.y[1])}(length(out.y)) - tmp_dy = Vector{eltype(out.dy[1])}(length(out.dy)) - for i in 1:length(out.y) - tmp[i] = out.y[i][1] - tmp_dy[i] = out.dy[i][1] + sizeu = size(u) + + opts = buildOptions(o,ODEJL_OPTION_LIST,ODEJL_ALIASES,ODEJL_ALIASES_REVERSED) + + if !isinplace && typeof(u)<:AbstractArray + f! = (t,u,du) -> (du[:] = prob.f(t,u)) + else + f! = prob.f + end + ode = ODE.ExplicitODE(t,u,f!) + # adaptive==true ? FoA=:adaptive : FoA=:fixed #Currently limied to only adaptive + FoA = :adaptive + if algType <: rk23 + solver = ODE.RKIntegrator{FoA,:rk23} + elseif algType <: rk45 + solver = ODE.RKIntegrator{FoA,:dopri5} + elseif algType <: feh78 + solver = ODE.RKIntegrator{FoA,:feh78} + elseif algType <: ModifiedRosenbrock + solver = ODE.ModifiedRosenbrockIntegrator + elseif algType <: feuler + solver = ODE.RKIntegratorFixed{:feuler} + elseif algType <: midpoint + solver = ODE.RKIntegratorFixed{:midpoint} + elseif algType <: heun + solver = ODE.RKIntegratorFixed{:heun} + elseif algType <: rk4 + solver = ODE.RKIntegratorFixed{:rk4} + elseif algType <: feh45 + solver = ODE.RKIntegrator{FoA,:rk45} + end + out = ODE.solve(ode;solver=solver,opts...) + timeseries = out.y + ts = out.t + ks = out.dy + if length(out.y[1])==1 + tmp = Vector{eltype(out.y[1])}(length(out.y)) + tmp_dy = Vector{eltype(out.dy[1])}(length(out.dy)) + for i in 1:length(out.y) + tmp[i] = out.y[i][1] + tmp_dy[i] = out.dy[i][1] + end + timeseries = tmp + ks = tmp_dy end - timeseries = tmp - ks = tmp_dy - end - saveat_idxs = find((x)->x∈saveat,ts) - t_nosaveat = view(ts,symdiff(1:length(ts),saveat_idxs)) - u_nosaveat = view(timeseries,symdiff(1:length(ts),saveat_idxs)) + saveat_idxs = find((x)-> x ∈ saveat,ts) + t_nosaveat = view(ts,symdiff(1:length(ts),saveat_idxs)) + u_nosaveat = view(timeseries,symdiff(1:length(ts),saveat_idxs)) - if dense - interp = (tvals) -> common_interpolation(tvals,t_nosaveat,u_nosaveat,ks,alg,f!) - else - interp = (tvals) -> nothing - end + if dense + interp = (tvals) -> common_interpolation(tvals,t_nosaveat,u_nosaveat,ks,alg,f!) + else + interp = (tvals) -> nothing + end - build_solution(prob,alg,ts,timeseries, - dense=dense,k=ks,interp=interp, - timeseries_errors = timeseries_errors, - dense_errors = dense_errors) + build_solution(prob,alg,ts,timeseries, + dense=dense,k=ks,interp=interp, + timeseries_errors = timeseries_errors, + dense_errors = dense_errors) end const ODEJL_OPTION_LIST = Set([:tout,:tstop,:reltol,:abstol,:minstep,:maxstep,:initstep,:norm,:maxiters,:isoutofdomain]) @@ -99,9 +96,9 @@ const ODEJL_ALIASES = Dict{Symbol,Symbol}(:minstep=>:dtmin,:maxstep=>:dtmax,:ini const ODEJL_ALIASES_REVERSED = Dict{Symbol,Symbol}([(v,k) for (k,v) in ODEJL_ALIASES]) function buildOptions(o,optionlist,aliases,aliases_reversed) - dict1 = Dict{Symbol,Any}([Pair(k,o[k]) for k in (keys(o) ∩ optionlist)]) - dict2 = Dict([Pair(aliases_reversed[k],o[k]) for k in (keys(o) ∩ values(aliases))]) - merge(dict1,dict2) + dict1 = Dict{Symbol,Any}([Pair(k,o[k]) for k in (keys(o) ∩ optionlist)]) + dict2 = Dict([Pair(aliases_reversed[k],o[k]) for k in (keys(o) ∩ values(aliases))]) + merge(dict1,dict2) end """ diff --git a/test/common.jl b/test/common.jl index ea179fd..df4e4b8 100644 --- a/test/common.jl +++ b/test/common.jl @@ -1,66 +1,26 @@ -using DiffEqProblemLibrary, DiffEqBase +using DiffEqProblemLibrary, DiffEqBase, ODE prob = prob_ode_linear dt=1/2^(4) -sol =solve(prob,feuler();dt=dt) -#plot(sol,plot_analytic=true) +algs = [feuler(),rk23(),rk45(),feh78(),ModifiedRosenbrock(),midpoint(),heun(),rk4(),feh45()] -sol =solve(prob,rk23(),dt=dt) - -sol =solve(prob,rk45(),dt=dt) - -sol =solve(prob,feh78(),dt=dt) - -sol =solve(prob,ModifiedRosenbrock(),dt=dt) - -sol =solve(prob,midpoint(),dt=dt) - -sol =solve(prob,heun(),dt=dt) - -sol =solve(prob,rk4(),dt=dt) - -sol =solve(prob,feh45(),dt=dt) +for alg in algs + sol = solve(prob,alg;dt=dt) +end prob = prob_ode_2Dlinear -sol =solve(prob,feuler(),dt=dt) - -sol =solve(prob,rk23(),dt=dt) - -sol =solve(prob,rk45(),dt=dt) - -sol =solve(prob,feh78(),dt=dt) - -#sol =solve(prob,ModifiedRosenbrock(),dt=dt) #ODE.jl issues with 2D - -sol =solve(prob,midpoint(),dt=dt) +for alg in algs + if alg != ModifiedRosenbrock() #ODE.jl issues with 2D + sol = solve(prob,alg;dt=dt) + end +end -sol =solve(prob,heun(),dt=dt) - -sol =solve(prob,rk4(),dt=dt) - -sol =solve(prob,feh45(),dt=dt) - -#= prob = prob_ode_bigfloat2Dlinear -sol =solve(prob,feuler(),dt=dt) -TEST_PLOT && plot(sol,plot_analytic=true) - -sol =solve(prob,rk23(),dt=dt) - -sol =solve(prob,rk4()5,dt=dt) - -sol =solve(prob,feh78(),dt=dt) - -#sol =solve(prob,dt=0,alg=:ode23s) #ODE.jl issues - -sol =solve(prob,midpoint(),dt=dt) - -sol =solve(prob,heun(),dt=dt) - -sol =solve(prob,rk4(),dt=dt) - -sol =solve(prob,feh45(),dt=dt) -=# +for alg in algs + if alg != ModifiedRosenbrock() #ODE.jl issues with 2D + sol = solve(prob,alg;dt=dt) + end +end From 6588b229b5f90128d1c9780d831e2a73dda4cb21 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 28 Dec 2016 00:55:44 -0800 Subject: [PATCH 3/4] some cleanup --- src/common.jl | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/common.jl b/src/common.jl index 188da74..7332c42 100644 --- a/src/common.jl +++ b/src/common.jl @@ -9,17 +9,14 @@ immutable heun <: ODEIterAlgorithm end immutable rk4 <: ODEIterAlgorithm end immutable feh45 <: ODEIterAlgorithm end -typealias KW Dict{Symbol,Any} - -function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F}, - alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true, - saveat=[],callback=()->nothing,timeseries_errors=true,dense_errors=false, - kwargs...) +function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}( + prob::AbstractODEProblem{uType,tType,isinplace,F}, + alg::algType,timeseries=[],ts=[],ks=[];dense=true, + timeseries_errors=true,dense_errors=false,kwargs...) tspan = prob.tspan - o = KW(kwargs) - t = tspan[1] + o = Dict{Symbol,Any}(kwargs) u0 = prob.u0 o[:T] = tspan[end] @@ -38,7 +35,7 @@ function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::Abstract else f! = prob.f end - ode = ODE.ExplicitODE(t,u,f!) + ode = ODE.ExplicitODE(tspan[1],u,f!) # adaptive==true ? FoA=:adaptive : FoA=:fixed #Currently limied to only adaptive FoA = :adaptive if algType <: rk23 @@ -61,9 +58,9 @@ function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::Abstract solver = ODE.RKIntegrator{FoA,:rk45} end out = ODE.solve(ode;solver=solver,opts...) - timeseries = out.y - ts = out.t - ks = out.dy + y = out.y + t = out.t + dy = out.dy if length(out.y[1])==1 tmp = Vector{eltype(out.y[1])}(length(out.y)) tmp_dy = Vector{eltype(out.dy[1])}(length(out.dy)) @@ -71,22 +68,22 @@ function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::Abstract tmp[i] = out.y[i][1] tmp_dy[i] = out.dy[i][1] end - timeseries = tmp - ks = tmp_dy + y = tmp + dy = tmp_dy end - saveat_idxs = find((x)-> x ∈ saveat,ts) - t_nosaveat = view(ts,symdiff(1:length(ts),saveat_idxs)) - u_nosaveat = view(timeseries,symdiff(1:length(ts),saveat_idxs)) + #saveat_idxs = find((x)-> x ∈ saveat,ts) + #t_nosaveat = view(ts,symdiff(1:length(ts),saveat_idxs)) + #u_nosaveat = view(timeseries,symdiff(1:length(ts),saveat_idxs)) if dense - interp = (tvals) -> common_interpolation(tvals,t_nosaveat,u_nosaveat,ks,alg,f!) + interp = (tvals) -> common_interpolation(tvals,t,y,dy,alg,f!) else interp = (tvals) -> nothing end - build_solution(prob,alg,ts,timeseries, - dense=dense,k=ks,interp=interp, + build_solution(prob,alg,t,y, + dense=dense,k=dy,interp=interp, timeseries_errors = timeseries_errors, dense_errors = dense_errors) end From 12dc7edd0c56dcf28c97635d5d1d6e2371756b9a Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 28 Dec 2016 00:58:17 -0800 Subject: [PATCH 4/4] remove size save --- src/common.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/common.jl b/src/common.jl index 7332c42..d1f7eb2 100644 --- a/src/common.jl +++ b/src/common.jl @@ -26,8 +26,6 @@ function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}( u = deepcopy(u0) end - sizeu = size(u) - opts = buildOptions(o,ODEJL_OPTION_LIST,ODEJL_ALIASES,ODEJL_ALIASES_REVERSED) if !isinplace && typeof(u)<:AbstractArray