Skip to content

Commit

Permalink
Minimal unzip (still efectively in-place)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmeirinhos committed Apr 17, 2021
1 parent 712847a commit 2f9d066
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 24 deletions.
12 changes: 4 additions & 8 deletions src/kb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,13 @@ function kbsolve(f_vert, f_diag, u0, (t0, tmax);

# Predictor
u_next = predict!(state.t, cache)
for t′ in 1:t
foreach((u,u′) -> u[t,t′] = u′, state.u, u_next[t′])
update_time(state.t, t, t′)
end
foreach((u,u′) -> u[t,1:t] = u′, state.u, u_next)
foreach(t′ -> update_time(state.t, t, t′), 1:t)

# Corrector
u_next = correct!((f(t′) for t′ in 1:t), cache)
for t′ in 1:t
foreach((u,u′) -> u[t,t′] = u′, state.u, u_next[t′])
update_time(state.t, t, t′)
end
foreach((u,u′) -> u[t,1:t] = u′, state.u, u_next)
foreach(t′ -> update_time(state.t, t, t′), 1:t)

# Calculate error and, if the step is accepted, adjust order and add a new cache entry
adjust_order!(t′ -> f_vert(state.u, state.t, t′, t), (f(t′) for t′ in 1:t), state, cache, opts.kmax, opts.atol, opts.rtol)
Expand Down
44 changes: 44 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,47 @@ _last2() = throw(ArgumentError("Cannot call last2 on an empty tuple."))
_last2(v) = throw(ArgumentError("Cannot call last2 on 1-element tuple."))
@inline _last2(v1,v2) = (v1, v2)
@inline _last2(v, t...) = _last2(t...)

# From https://github.com/JuliaLang/julia/pull/33515
function unzip(itrs)
n = Base.haslength(itrs) ? length(itrs) : nothing
outer = iterate(itrs)
outer === nothing && return ()
vals, state = outer
vecs = ntuple(length(vals)) do i
x = vals[i]
v = Vector{typeof(x)}(undef, something(n, 1))
@inbounds v[1] = x
return v
end
unzip_rest(vecs, typeof(vals), n isa Int ? 1 : nothing, itrs, state)
end

function unzip_rest(vecs, eltypes, i, itrs, state)
while true
i isa Int && (i += 1)
outer = iterate(itrs, state)
outer === nothing && return vecs
itr, state = outer
vals = Tuple(itr)
if vals isa eltypes
for (v, x) in zip(vecs, vals)
if i isa Int
@inbounds v[i] = x
else
push!(v, x)
end
end
else
vecs′ = map(vecs, vals) do v, x
T = Base.promote_typejoin(eltype(v), typeof(x))
v′ = Vector{T}(undef, length(v) + !(i isa Int))
copyto!(v′, v)
@inbounds v′[something(i, end)] = x
return v′
end
eltypes′ = Tuple{map(eltype, vecs′)...}
return unzip_rest(Tuple(vecs′), eltypes′, i, itrs, state)
end
end
end
37 changes: 21 additions & 16 deletions src/vcabm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ mutable struct VCABMCache{T,U}
error_k::T

function VCABMCache{T}(kmax, u_prev, f_prev) where {T}
u_prev = VectorOfArray(u_prev)
f_prev = VectorOfArray(f_prev)
u_prev = VectorOfArray(collect(unzip(u_prev)))
f_prev = VectorOfArray(collect(unzip(f_prev)))

ϕ_n = [zero.(f_prev) for _ in 1:kmax+1]
ϕstar_nm = [zero.(f_prev) for _ in 1:kmax+1]
Expand All @@ -58,16 +58,19 @@ function extend_cache!(f_vert, times, cache, kmax)
t = length(times)

# Insert new terms for `f_vert` at `(t,t)`
insert!(f_prev.u, t, f_vert(t))
insert!(u_prev.u, t, copy.(u_prev[t]))
insert!(u_next.u, t, zero.(u_prev[t]))
insert!(u_erro.u, t, zero.(u_erro[t]))
f = f_vert(t)
for i in eachindex(f)
insert!(f_prev.u[i], t, f[i])
insert!(u_prev.u[i], t, copy(u_prev.u[i][t]))
insert!(u_next.u[i], t, zero(u_prev.u[i][t]))
insert!(u_erro.u[i], t, zero(u_erro.u[i][t]))
end

# And calculate the ϕs
_ϕ_n = [zero.(f_prev[t]) for _ in 1:kmax+1]
_ϕ_np1 = [zero.(f_prev[t]) for _ in 1:kmax+2]
_ϕstar_n = [zero.(f_prev[t]) for _ in 1:kmax+1]
_ϕstar_nm1 = [zero.(f_prev[t]) for _ in 1:kmax+1]
_ϕ_n = [zero(f) for _ in 1:kmax+1]
_ϕ_np1 = [zero(f) for _ in 1:kmax+2]
_ϕstar_n = [zero(f) for _ in 1:kmax+1]
_ϕstar_nm1 = [zero(f) for _ in 1:kmax+1]

t0 = max(1, t - k)
for t′ in (t0+1):t
Expand All @@ -79,10 +82,12 @@ function extend_cache!(f_vert, times, cache, kmax)
_ϕstar_nm1, _ϕstar_n = _ϕstar_n, _ϕstar_nm1
end

foreach((ϕ, ϕ′) -> insert!.u, t, ϕ′), ϕ_n, _ϕ_n)
foreach((ϕ, ϕ′) -> insert!.u, t, ϕ′), ϕ_np1, _ϕ_np1)
foreach((ϕ, ϕ′) -> insert!.u, t, ϕ′), ϕstar_n, _ϕstar_n)
foreach((ϕ, ϕ′) -> insert!.u, t, ϕ′), ϕstar_nm1, _ϕstar_nm1)
for i in eachindex(f)
foreach((ϕ, ϕ′) -> insert!.u[i], t, ϕ′[i]), ϕ_n, _ϕ_n)
foreach((ϕ, ϕ′) -> insert!.u[i], t, ϕ′[i]), ϕ_np1, _ϕ_np1)
foreach((ϕ, ϕ′) -> insert!.u[i], t, ϕ′[i]), ϕstar_n, _ϕstar_n)
foreach((ϕ, ϕ′) -> insert!.u[i], t, ϕ′[i]), ϕstar_nm1, _ϕstar_nm1)
end
end

# Explicit Adams: Section III.5 Eq. (5.5)
Expand All @@ -102,7 +107,7 @@ end
function correct!(du, cache)
@unpack u_next,g,ϕ_np1,ϕstar_n,k = cache
@inbounds begin
ϕ_np1!(cache, VectorOfArray(collect(du)), k+1)
ϕ_np1!(cache, VectorOfArray(collect(unzip(du))), k+1)
@. u_next = muladd(g[k], ϕ_np1[k], u_next)
end
u_next
Expand Down Expand Up @@ -131,7 +136,7 @@ function adjust_order!(f_vert, f, state, cache, kmax, atol, rtol)
return
end

cache.f_prev = VectorOfArray([f...])
cache.f_prev = VectorOfArray(collect(unzip(f)))

if length(state.t)<=5 || k<3
cache.k = min(k+1, 3, kmax)
Expand Down

0 comments on commit 2f9d066

Please sign in to comment.