Skip to content
This repository has been archived by the owner on Mar 18, 2023. It is now read-only.

add common interface bindings and tests #28

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ julia 0.5-
Polynomials
ForwardDiff
Compat 0.4.1
DiffEqBase
3 changes: 3 additions & 0 deletions src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ module ODE

using Polynomials
using Compat
using DiffEqBase
import DiffEqBase: solve
import Compat.String
using ForwardDiff

Expand Down Expand Up @@ -48,6 +50,7 @@ include("integrators/rosenbrock.jl")

# User interface to solvers
include("top-interface.jl")
include("common.jl")


"""
Expand Down
161 changes: 161 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
abstract ODEIterAlgorithm <: AbstractODEAlgorithm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Four spaces indent, if not too much hassle.

immutable feuler <: ODEIterAlgorithm end
immutable rk23 <: ODEIterAlgorithm end
immutable rk45 <: ODEIterAlgorithm end
immutable feh78 <: ODEIterAlgorithm end
immutable ModifiedRosenbrock <: ODEIterAlgorithm end
immutable midpoint <: ODEIterAlgorithm end
immutable heun <: ODEIterAlgorithm end
immutable rk4 <: ODEIterAlgorithm end
immutable feh45 <: ODEIterAlgorithm end

typealias KW Dict{Symbol,Any}

function solve{uType,tType,isinplace,algType<:ODEIterAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F},
alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true,
saveat=[],callback=()->nothing,timeseries_errors=true,dense_errors=false,
kwargs...)
tspan = prob.tspan

if tspan[end]-tspan[1]<tType(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto (SciML/ODE.jl#119 (comment))

What about the sort(unique(...)) thing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the sort(unique(...)) thing?

That's for "saveat" style chosen outputs. The things I tried didn't work, so I just don't have saveat implemented at all in this.

error("final time must be greater than starting time. Aborting.")
end
atomloaded = isdefined(Main,:Atom)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment Check if Juno IDE is used

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, this snuck in there. It actually isn't needed.

o = KW(kwargs)
t = tspan[1]
u0 = prob.u0
o[:T] = tspan[end]

if typeof(u0) <: Number
u = [u0]
else
u = deepcopy(u0)
end

sizeu = size(u)

opts = buildOptions(o,ODEJL_OPTION_LIST,ODEJL_ALIASES,ODEJL_ALIASES_REVERSED)

if !isinplace && typeof(u)<:AbstractArray
f! = (t,u,du) -> (du[:] = prob.f(t,u))
else
f! = prob.f
end
ode = ODE.ExplicitODE(t,u,f!)
# adaptive==true ? FoA=:adaptive : FoA=:fixed #Currently limied to only adaptive
FoA = :adaptive
if typeof(alg) <: rk23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

algType <: rk23

solver = ODE.RKIntegrator{FoA,:rk23}
elseif typeof(alg) <: rk45
solver = ODE.RKIntegrator{FoA,:dopri5}
elseif typeof(alg) <: feh78
solver = ODE.RKIntegrator{FoA,:feh78}
elseif typeof(alg) <: ModifiedRosenbrock
solver = ODE.ModifiedRosenbrockIntegrator
elseif typeof(alg) <: feuler
solver = ODE.RKIntegratorFixed{:feuler}
elseif typeof(alg) <: midpoint
solver = ODE.RKIntegratorFixed{:midpoint}
elseif typeof(alg) <: heun
solver = ODE.RKIntegratorFixed{:heun}
elseif typeof(alg) <: rk4
solver = ODE.RKIntegratorFixed{:rk4}
elseif typeof(alg) <: feh45
solver = ODE.RKIntegrator{FoA,:rk45}
end
out = ODE.solve(ode;solver=solver,opts...)
timeseries = out.y
ts = out.t
ks = out.dy
if length(out.y[1])==1
tmp = Vector{eltype(out.y[1])}(length(out.y))
tmp_dy = Vector{eltype(out.dy[1])}(length(out.dy))
for i in 1:length(out.y)
tmp[i] = out.y[i][1]
tmp_dy[i] = out.dy[i][1]
end
timeseries = tmp
ks = tmp_dy
end

saveat_idxs = find((x)->x∈saveat,ts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x ∈ saveat use spaces

t_nosaveat = view(ts,symdiff(1:length(ts),saveat_idxs))
u_nosaveat = view(timeseries,symdiff(1:length(ts),saveat_idxs))

if dense
interp = (tvals) -> common_interpolation(tvals,t_nosaveat,u_nosaveat,ks,alg,f!)
else
interp = (tvals) -> nothing
end

build_solution(prob,alg,ts,timeseries,
dense=dense,k=ks,interp=interp,
timeseries_errors = timeseries_errors,
dense_errors = dense_errors)
end

const ODEJL_OPTION_LIST = Set([:tout,:tstop,:reltol,:abstol,:minstep,:maxstep,:initstep,:norm,:maxiters,:isoutofdomain])
const ODEJL_ALIASES = Dict{Symbol,Symbol}(:minstep=>:dtmin,:maxstep=>:dtmax,:initstep=>:dt,:tstop=>:T,:maxiters=>:maxiters)
const ODEJL_ALIASES_REVERSED = Dict{Symbol,Symbol}([(v,k) for (k,v) in ODEJL_ALIASES])

function buildOptions(o,optionlist,aliases,aliases_reversed)
dict1 = Dict{Symbol,Any}([Pair(k,o[k]) for k in (keys(o) ∩ optionlist)])
dict2 = Dict([Pair(aliases_reversed[k],o[k]) for k in (keys(o) ∩ values(aliases))])
merge(dict1,dict2)
end

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the interpolation stuff go elsewhere as that is not interface specific?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed that sooner rather than later you might be replacing it with the interpolation functions that you already have internally (so that way it could also be specialized to the method), but wanted to make sure that the PR at least had the functionality it had before this code was moved.

common_interpolation(tvals,ts,timeseries,ks)

Get the value at tvals where the solution is known at the
times ts (sorted), with values timeseries and derivatives ks
"""
function common_interpolation(tvals,ts,timeseries,ks,alg,f)
idx = sortperm(tvals)
i = 2 # Start the search thinking it's between ts[1] and ts[2]
vals = Vector{eltype(timeseries)}(length(tvals))
for j in idx
t = tvals[j]
i = findfirst((x)->x>=t,ts[i:end])+i-1 # It's in the interval ts[i-1] to ts[i]
if ts[i] == t
vals[j] = timeseries[i]
elseif ts[i-1] == t # Can happen if it's the first value!
vals[j] = timeseries[i-1]
else
dt = ts[i] - ts[i-1]
Θ = (t-ts[i-1])/dt
vals[j] = common_interpolant(Θ,dt,timeseries[i-1],timeseries[i],ks[i-1],ks[i],alg)
end
end
vals
end

"""
common_interpolation(tval::Number,ts,timeseries,ks)

Get the value at tval where the solution is known at the
times ts (sorted), with values timeseries and derivatives ks
"""
function common_interpolation(tval::Number,ts,timeseries,ks,alg,f)
i = findfirst((x)->x>=tval,ts) # It's in the interval ts[i-1] to ts[i]
if ts[i] == tval
val = timeseries[i]
elseif ts[i-1] == tval # Can happen if it's the first value!
push!(vals,timeseries[i-1])
else
dt = ts[i] - ts[i-1]
Θ = (tval-ts[i-1])/dt
val = common_interpolant(Θ,dt,timeseries[i-1],timeseries[i],ks[i-1],ks[i],alg)
end
val
end

"""
Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190
"""
function common_interpolant(Θ,dt,y₀,y₁,k₀,k₁,alg) # Default interpolant is Hermite
(1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k₀ + Θ*dt*k₁)
end

export ODEIterAlgorithm, feuler, rk23, feh45, feh78, ModifiedRosenbrock,
midpoint, heun, rk4, rk45
1 change: 1 addition & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DiffEqProblemLibrary
66 changes: 66 additions & 0 deletions test/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using DiffEqProblemLibrary, DiffEqBase

prob = prob_ode_linear
dt=1/2^(4)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write as a loop:

for alg in subtypes(ODEIterAlgorithm)
    ...

sol =solve(prob,feuler();dt=dt)
#plot(sol,plot_analytic=true)

sol =solve(prob,rk23(),dt=dt)

sol =solve(prob,rk45(),dt=dt)

sol =solve(prob,feh78(),dt=dt)

sol =solve(prob,ModifiedRosenbrock(),dt=dt)

sol =solve(prob,midpoint(),dt=dt)

sol =solve(prob,heun(),dt=dt)

sol =solve(prob,rk4(),dt=dt)

sol =solve(prob,feh45(),dt=dt)

prob = prob_ode_2Dlinear

sol =solve(prob,feuler(),dt=dt)

sol =solve(prob,rk23(),dt=dt)

sol =solve(prob,rk45(),dt=dt)

sol =solve(prob,feh78(),dt=dt)

#sol =solve(prob,ModifiedRosenbrock(),dt=dt) #ODE.jl issues with 2D

sol =solve(prob,midpoint(),dt=dt)

sol =solve(prob,heun(),dt=dt)

sol =solve(prob,rk4(),dt=dt)

sol =solve(prob,feh45(),dt=dt)

#=
prob = prob_ode_bigfloat2Dlinear

sol =solve(prob,feuler(),dt=dt)
TEST_PLOT && plot(sol,plot_analytic=true)

sol =solve(prob,rk23(),dt=dt)

sol =solve(prob,rk4()5,dt=dt)

sol =solve(prob,feh78(),dt=dt)

#sol =solve(prob,dt=0,alg=:ode23s) #ODE.jl issues

sol =solve(prob,midpoint(),dt=dt)

sol =solve(prob,heun(),dt=dt)

sol =solve(prob,rk4(),dt=dt)

sol =solve(prob,feh45(),dt=dt)
=#
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ using Base.Test
@testset "ODE tests" begin
include("iterators.jl")
include("top-interface.jl")
include("common.jl")
end