Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Fixed to allow using matrices as scalar like entity #102

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ os:
- osx
- linux
julia:
- 0.3
- release
- nightly
git:
Expand Down
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
julia 0.3
julia 0.4
Polynomials
Compat 0.4.1
2 changes: 1 addition & 1 deletion src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ end

# isoutofdomain takes the state and returns true if state is outside
# of the allowed domain. Used in adaptive step-control.
isoutofdomain(x) = isnan(x)
isoutofdomain(x) = any(isnan(x))

function make_consistent_types(fn, y0, tspan, btab::Tableau)
# There are a few types involved in a call to a ODE solver which
Expand Down
45 changes: 27 additions & 18 deletions src/runge_kutta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ immutable TableauRKExplicit{Name, S, T} <: Tableau{Name, S, T}
@assert istril(a)
@assert S==length(c)==size(a,1)==size(a,2)==size(b,2)
@assert size(b,1)==length(order)
@assert norm(sum(a,2)-c'',Inf)<1e-10 # consistency.
@assert norm(sum(a,2)-c'',Inf)<1e-10 # consistency.
new(order,a,b,c)
end
end
Expand Down Expand Up @@ -100,7 +100,7 @@ const bt_rk23 = TableauRKExplicit(:bogacki_shampine,(2,3), Rational{Int64},
2/9 1/3 4/9 0],
[7/24 1/4 1/3 1/8
2/9 1/3 4/9 0],
[0, 1//2, 3//4, 1]
[0, 1//2, 3//4, 1]
)

# Fehlberg https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method
Expand Down Expand Up @@ -163,17 +163,17 @@ ode2_heun(fn, y0, tspan) = oderk_fixed(fn, y0, tspan, bt_heun)
ode4(fn, y0, tspan) = oderk_fixed(fn, y0, tspan, bt_rk4)

function oderk_fixed(fn, y0, tspan, btab::TableauRKExplicit)
# Non-arrays y0 treat as scalar
fn_(t, y) = [fn(t, y[1])]
t,y = oderk_fixed(fn_, [y0], tspan, btab)
# For y0 which are scalar-like, wrap them in a vector:
fn_{T}(t, y::Vector{T}) = T[fn(t, y[1])]
t,y = oderk_fixed(fn_, typeof(y0)[y0], tspan, btab)
return t, vcat_nosplat(y)
end
function oderk_fixed{N,S}(fn, y0::AbstractVector, tspan,
btab_::TableauRKExplicit{N,S})
# TODO: instead of AbstractVector use a Holy-trait

# Needed interface:
# On components:
# On components:
# On y0 container: length, deepcopy, similar, setindex!
# On time container: getindex, convert. length

Expand Down Expand Up @@ -215,9 +215,9 @@ const ode45 = ode45_dp
ode78(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_feh78; kwargs...)

function oderk_adapt(fn, y0, tspan, btab::TableauRKExplicit; kwords...)
# For y0 which don't support indexing.
fn_ = (t, y) -> [fn(t, y[1])]
t,y = oderk_adapt(fn_, [y0], tspan, btab; kwords...)
# For y0 which are scalar-like, wrap them in a vector:
fn_{T}(t, y::Vector{T}) = T[fn(t, y[1])]
t,y = oderk_adapt(fn_, typeof(y0)[y0], tspan, btab; kwords...)
return t, vcat_nosplat(y)
end
function oderk_adapt{N,S}(fn, y0::AbstractVector, tspan, btab_::TableauRKExplicit{N,S};
Expand All @@ -233,7 +233,7 @@ function oderk_adapt{N,S}(fn, y0::AbstractVector, tspan, btab_::TableauRKExplici
# - note that the type of the components might change!
# On y0 container: length, similar, setindex!
# On time container: getindex, convert, length

# For y0 which support indexing. Currently y0<:AbstractVector but
# that could be relaxed with a Holy-trait.
!isadaptive(btab_) && error("Can only use this solver with an adaptive RK Butcher table")
Expand All @@ -254,11 +254,12 @@ function oderk_adapt{N,S}(fn, y0::AbstractVector, tspan, btab_::TableauRKExplici
# work arrays:
y = similar(y0, Eyf, dof) # y at time t
y[:] = y0
ytrial = similar(y0, Eyf, dof) # trial solution at time t+dt
yerr = similar(y0, Eyf, dof) # error of trial solution
ytrial = similar(y, dof) # trial solution at time t+dt
zero!(ytrial, y[1])
yerr = similar(y, dof) # error of trial solution
ks = Array(Ty, S)
# allocate!(ks, y0, dof) # no need to allocate as fn is not in-place
ytmp = similar(y0, Eyf, dof)
# allocate!(ks, y, dof) # no need to allocate as fn is not in-place
ytmp = similar(y, dof)

# output ys
nsteps_fixed = length(tspan) # these are always output
Expand Down Expand Up @@ -366,8 +367,8 @@ function rk_embedded_step!{N,S}(ytrial, yerr, ks, ytmp, y, fn, t, dt, dof, btab:
# On components: arithmetic, zero
# On y0 container: fill!, setindex!, getindex

fill!(ytrial, zero(eltype(ytrial)) )
fill!(yerr, zero(eltype(ytrial)) )
fill!(ytrial, zero(ytrial[1]) )
fill!(yerr, zero(ytrial[1]) )
for d=1:dof
ytrial[d] += btab.b[1,1]*ks[1][d]
yerr[d] += btab.b[2,1]*ks[1][d]
Expand Down Expand Up @@ -441,10 +442,18 @@ end

# Helper functions:
function allocate!{T}(vec::Vector{T}, y0, dof)
# Allocates all vectors inside a Vector{Vector} using the same
# kind of container as y0 has and element type eltype(eltype(vec)).
# Allocates and zeros all vectors inside a Vector{Vector} using
# the same kind of container as y0 has and element type
# eltype(eltype(vec)). Uses zero(y0[1]) to create zeros for each
# element.
for s=1:length(vec)
vec[s] = similar(y0, eltype(T), dof)
zero!(vec[s], y0[1])
end
end
function zero!(vec, zeroofthis)
for i in eachindex(vec)
vec[i] = zero(zeroofthis)
end
end
function index_or_push!(vec, i, val)
Expand Down
41 changes: 33 additions & 8 deletions test/interface-tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Here are tests which test what interface the solvers require.

nonconforming_solvers = [ODE.ode23s, ODE.ode4s_s, ODE.ode4s_kr]

################################################################################
# This is to test a scalar-like state variable
# This is to test a scalar-like state variable which is encoded in a custom type
# (due to @acroy: https://gist.github.com/acroy/28be4f2384d01f38e577)

import Base: +, -, *, /, .+, .-, .*, ./
Expand Down Expand Up @@ -33,6 +35,7 @@ Base.norm(y::CompSol) = norm(y::CompSol, 2.0)
### new for PR #68
Base.abs(y::CompSol) = norm(y, 2.) # TODO not needed anymore once https://github.com/JuliaLang/julia/pull/11043 is in current stable julia
Base.zero(::Type{CompSol}) = CompSol(complex(zeros(2,2)), 0., 0.)
Base.zero(::CompSol) = zero(CompSol)
ODE.isoutofdomain(y::CompSol) = any(isnan, vcat(y.rho[:], y.x, y.p))

# Because the new RK solvers wrap scalars in an array and because of
Expand All @@ -43,21 +46,18 @@ ODE.isoutofdomain(y::CompSol) = any(isnan, vcat(y.rho[:], y.x, y.p))
.*(s::Real, y1::CompSol) = y1*s
./(y1::CompSol, s::Real) = CompSol(y1.rho/s, y1.x/s, y1.p/s)


################################################################################

# define RHSs of differential equations
# delta, V and g are parameters
function rhs(t, y, delta, V, g)
H = [[-delta/2 V]; [V delta/2]]

rho_dot = -im*H*y.rho + im*y.rho*H
x_dot = y.p
p_dot = -y.x

return CompSol( rho_dot, x_dot, p_dot)
end

# inital conditons
rho0 = zeros(2,2);
rho0[1,1]=1.;
Expand All @@ -70,13 +70,38 @@ t,y1 = ODE.ode45((t,y)->rhs(t, y, delta0, V0, g0), y0, [0., endt]) # used as ref
print("Testing interface for scalar-like state... ")
for solver in solvers
# these only work with some Array-like interface defined:
if solver in [ODE.ode23s, ODE.ode4s_s, ODE.ode4s_kr]
if solver in nonconforming_solvers
continue
end
t,y2 = solver((t,y)->rhs(t, y, delta0, V0, g0), y0, linspace(0., endt, 500))
@test norm(y1[end]-y2[end])<0.1
end
println("ok.")

################################################################################
# Tests using a matrix as a scalar
# test due to @gersonjferreira https://github.com/JuliaLang/ODE.jl/issues/86#issuecomment-224889775

h = [5.0 1.0; -1.0 6.0]/10.0;
tspan=linspace(0,1,30);
dt=tspan[2]-tspan[1];

y0 = [0.0 1.0; 1.0 1.0];

function rhs(t, y)
return h*y
end

refsol = [ 0.173109 1.81331
1.81331 1.6402]
solver = 1
for solver in solvers
if solver in nonconforming_solvers
continue
end
t, y = solver(rhs, y0, tspan)
@test norm(y[end]-refsol) < 0.1
end

################################################################################
# TODO: test a vector-like state variable, i.e. one which can be indexed.