diff --git a/src/kb.jl b/src/kb.jl index e84f462..3abefc4 100644 --- a/src/kb.jl +++ b/src/kb.jl @@ -74,17 +74,13 @@ function kbsolve(f_vert, f_diag, u0, (t0, tmax); # Predictor u_next = predict!(state.t, cache) - for t′ in 1:t - foreach((u,u′) -> u[t,t′] = u′, state.u, u_next[t′]) - update_time(state.t, t, t′) - end + foreach((u,u′) -> u[t,1:t] = u′, state.u, u_next) + foreach(t′ -> update_time(state.t, t, t′), 1:t) # Corrector u_next = correct!((f(t′) for t′ in 1:t), cache) - for t′ in 1:t - foreach((u,u′) -> u[t,t′] = u′, state.u, u_next[t′]) - update_time(state.t, t, t′) - end + foreach((u,u′) -> u[t,1:t] = u′, state.u, u_next) + foreach(t′ -> update_time(state.t, t, t′), 1:t) # Calculate error and, if the step is accepted, adjust order and add a new cache entry adjust_order!(t′ -> f_vert(state.u, state.t, t′, t), (f(t′) for t′ in 1:t), state, cache, opts.kmax, opts.atol, opts.rtol) diff --git a/src/utils.jl b/src/utils.jl index 34a4b1b..2429782 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,3 +11,47 @@ _last2() = throw(ArgumentError("Cannot call last2 on an empty tuple.")) _last2(v) = throw(ArgumentError("Cannot call last2 on 1-element tuple.")) @inline _last2(v1,v2) = (v1, v2) @inline _last2(v, t...) = _last2(t...) + +# From https://github.com/JuliaLang/julia/pull/33515 +function unzip(itrs) + n = Base.haslength(itrs) ? length(itrs) : nothing + outer = iterate(itrs) + outer === nothing && return () + vals, state = outer + vecs = ntuple(length(vals)) do i + x = vals[i] + v = Vector{typeof(x)}(undef, something(n, 1)) + @inbounds v[1] = x + return v + end + unzip_rest(vecs, typeof(vals), n isa Int ? 1 : nothing, itrs, state) +end + +function unzip_rest(vecs, eltypes, i, itrs, state) + while true + i isa Int && (i += 1) + outer = iterate(itrs, state) + outer === nothing && return vecs + itr, state = outer + vals = Tuple(itr) + if vals isa eltypes + for (v, x) in zip(vecs, vals) + if i isa Int + @inbounds v[i] = x + else + push!(v, x) + end + end + else + vecs′ = map(vecs, vals) do v, x + T = Base.promote_typejoin(eltype(v), typeof(x)) + v′ = Vector{T}(undef, length(v) + !(i isa Int)) + copyto!(v′, v) + @inbounds v′[something(i, end)] = x + return v′ + end + eltypes′ = Tuple{map(eltype, vecs′)...} + return unzip_rest(Tuple(vecs′), eltypes′, i, itrs, state) + end + end +end \ No newline at end of file diff --git a/src/vcabm.jl b/src/vcabm.jl index e9beefa..f704bf0 100644 --- a/src/vcabm.jl +++ b/src/vcabm.jl @@ -39,8 +39,8 @@ mutable struct VCABMCache{T,U} error_k::T function VCABMCache{T}(kmax, u_prev, f_prev) where {T} - u_prev = VectorOfArray(u_prev) - f_prev = VectorOfArray(f_prev) + u_prev = VectorOfArray(collect(unzip(u_prev))) + f_prev = VectorOfArray(collect(unzip(f_prev))) ϕ_n = [zero.(f_prev) for _ in 1:kmax+1] ϕstar_nm = [zero.(f_prev) for _ in 1:kmax+1] @@ -58,16 +58,19 @@ function extend_cache!(f_vert, times, cache, kmax) t = length(times) # Insert new terms for `f_vert` at `(t,t)` - insert!(f_prev.u, t, f_vert(t)) - insert!(u_prev.u, t, copy.(u_prev[t])) - insert!(u_next.u, t, zero.(u_prev[t])) - insert!(u_erro.u, t, zero.(u_erro[t])) + f = f_vert(t) + for i in eachindex(f) + insert!(f_prev.u[i], t, f[i]) + insert!(u_prev.u[i], t, copy(u_prev.u[i][t])) + insert!(u_next.u[i], t, zero(u_prev.u[i][t])) + insert!(u_erro.u[i], t, zero(u_erro.u[i][t])) + end # And calculate the ϕs - _ϕ_n = [zero.(f_prev[t]) for _ in 1:kmax+1] - _ϕ_np1 = [zero.(f_prev[t]) for _ in 1:kmax+2] - _ϕstar_n = [zero.(f_prev[t]) for _ in 1:kmax+1] - _ϕstar_nm1 = [zero.(f_prev[t]) for _ in 1:kmax+1] + _ϕ_n = [zero(f) for _ in 1:kmax+1] + _ϕ_np1 = [zero(f) for _ in 1:kmax+2] + _ϕstar_n = [zero(f) for _ in 1:kmax+1] + _ϕstar_nm1 = [zero(f) for _ in 1:kmax+1] t0 = max(1, t - k) for t′ in (t0+1):t @@ -79,10 +82,12 @@ function extend_cache!(f_vert, times, cache, kmax) _ϕstar_nm1, _ϕstar_n = _ϕstar_n, _ϕstar_nm1 end - foreach((ϕ, ϕ′) -> insert!(ϕ.u, t, ϕ′), ϕ_n, _ϕ_n) - foreach((ϕ, ϕ′) -> insert!(ϕ.u, t, ϕ′), ϕ_np1, _ϕ_np1) - foreach((ϕ, ϕ′) -> insert!(ϕ.u, t, ϕ′), ϕstar_n, _ϕstar_n) - foreach((ϕ, ϕ′) -> insert!(ϕ.u, t, ϕ′), ϕstar_nm1, _ϕstar_nm1) + for i in eachindex(f) + foreach((ϕ, ϕ′) -> insert!(ϕ.u[i], t, ϕ′[i]), ϕ_n, _ϕ_n) + foreach((ϕ, ϕ′) -> insert!(ϕ.u[i], t, ϕ′[i]), ϕ_np1, _ϕ_np1) + foreach((ϕ, ϕ′) -> insert!(ϕ.u[i], t, ϕ′[i]), ϕstar_n, _ϕstar_n) + foreach((ϕ, ϕ′) -> insert!(ϕ.u[i], t, ϕ′[i]), ϕstar_nm1, _ϕstar_nm1) + end end # Explicit Adams: Section III.5 Eq. (5.5) @@ -102,7 +107,7 @@ end function correct!(du, cache) @unpack u_next,g,ϕ_np1,ϕstar_n,k = cache @inbounds begin - ϕ_np1!(cache, VectorOfArray(collect(du)), k+1) + ϕ_np1!(cache, VectorOfArray(collect(unzip(du))), k+1) @. u_next = muladd(g[k], ϕ_np1[k], u_next) end u_next @@ -131,7 +136,7 @@ function adjust_order!(f_vert, f, state, cache, kmax, atol, rtol) return end - cache.f_prev = VectorOfArray([f...]) + cache.f_prev = VectorOfArray(collect(unzip(f))) if length(state.t)<=5 || k<3 cache.k = min(k+1, 3, kmax)