diff --git a/src/mprk.jl b/src/mprk.jl index 65594914..bfec0ded 100644 --- a/src/mprk.jl +++ b/src/mprk.jl @@ -12,10 +12,26 @@ p_prototype(u, f::ConservativePDSFunction) = zero(f.p_prototype) ##################################################################### # out-of-place for dense and static arrays function build_mprk_matrix(P, sigma, dt) - # M[i,i] = (sigma[i] + dt*sum_j P[j,i])/sigma[i] - # M[i,j] = -dt*P[i,j]/sigma[j] + # re-use the in-place version implemented below M = similar(P) - zeroM = zero(eltype(M)) + build_mprk_matrix!(M, P, sigma, dt) + + if P isa StaticArray + return SMatrix(M) + else + return M + end +end + +# in-place for dense arrays +function build_mprk_matrix!(M, P, sigma, dt) + # M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i] + # M[i,j] = -dt * P[i,j] / sigma[j] + # TODO: the performance of this can likely be improved + Base.require_one_based_indexing(M, P, sigma) + @assert size(M, 1) == size(M, 2) == size(P, 1) == size(P, 2) == length(sigma) + + zeroM = zero(eltype(P)) # Set sigma on diagonal @inbounds for i in eachindex(sigma) @@ -41,139 +57,39 @@ function build_mprk_matrix(P, sigma, dt) M[i, i] /= sigma[i] end - if P isa StaticArray - return SMatrix(M) - else - return M - end -end - -# in-place for dense arrays -function build_mprk_matrix!(M, a, P, D, sigma, dt) - # M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i] - # M[i,j] = -dt * P[i,j] / sigma[j] - # TODO: the performance of this can likely be improved - Base.require_one_based_indexing(M, P, D, sigma) - @assert size(M, 1) == size(M, 2) == size(P, 1) == size(P, 2) == length(D) == - length(sigma) - - for j in 1:length(sigma) - for i in 1:length(sigma) - if i == j - M[i, i] = 1 + dt * a * D[i] / sigma[i] - else - M[i, j] = -dt * a * P[i, j] / sigma[j] - end - end - end - - return M -end - -function build_mprk_matrix!(M, b1, P1, D1, b2, P2, D2, sigma, dt) - # M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i] - # M[i,j] = -dt * P[i,j] / sigma[j] - # TODO: the performance of this can likely be improved - Base.require_one_based_indexing(M, P1, D1, P2, D2, sigma) - @assert size(M, 1) == size(M, 2) == size(P1, 1) == size(P1, 2) == length(D1) == - size(P2, 1) == size(P2, 2) == length(D1) == length(sigma) - - for j in 1:length(sigma) - for i in 1:length(sigma) - if i == j - M[i, i] = 1 + dt * (b1 * D1[i] + b2 * D2[i]) / sigma[i] - else - M[i, j] = -dt * (b1 * P1[i, j] + b2 * P2[i, j]) / sigma[j] - end - end - end - - return M + return nothing end # optimized versions for Tridiagonal matrices -function build_mprk_matrix!(M::Tridiagonal, - a, P::Tridiagonal, D, - sigma, dt) +function build_mprk_matrix!(M::Tridiagonal, P::Tridiagonal, σ, dt) # M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i] # M[i,j] = -dt * P[i,j] / sigma[j] Base.require_one_based_indexing(M.dl, M.d, M.du, - P.dl, P.d, P.du, - D, sigma) + P.dl, P.d, P.du, σ) @assert length(M.dl) + 1 == length(M.d) == length(M.du) + 1 == - length(P.dl) + 1 == length(P.d) == length(P.du) + 1 == - length(D) == length(sigma) - - factor = a * dt + length(P.dl) + 1 == length(P.d) == length(P.du) + 1 == length(σ) - for i in eachindex(M.d, D, sigma) - M.d[i] = 1 + factor * D[i] / sigma[i] + for i in eachindex(M.d, σ) + M.d[i] = σ[i] end for i in eachindex(M.dl, P.dl) - M.dl[i] = -factor * P.dl[i] / sigma[i] + dtP = dt * P.dl[i] + M.dl[i] = -dtP / σ[i] + M.d[i] += dtP end - for i in eachindex(M.dl, P.dl) - M.du[i] = -factor * P.du[i] / sigma[i + 1] + for i in eachindex(M.du, P.du) + dtP = dt * P.du[i] + M.du[i] = -dtP / σ[i + 1] + M.d[i + 1] += dtP end - return M -end - -function build_mprk_matrix!(M::Tridiagonal, - b1, P1::Tridiagonal, D1, - b2, P2::Tridiagonal, D2, - sigma, dt) - # M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i] - # M[i,j] = -dt * P[i,j] / sigma[j] - Base.require_one_based_indexing(M.dl, M.d, M.du, - P1.dl, P1.d, P1.du, D1, - P2.dl, P2.d, P2.du, D2, - sigma) - @assert length(M.dl) + 1 == length(M.d) == length(M.du) + 1 == - length(P1.dl) + 1 == length(P1.d) == length(P1.du) + 1 == - length(D1) == - length(P2.dl) + 1 == length(P2.d) == length(P2.du) + 1 == - length(D2) == length(sigma) - - factor1 = b1 * dt - factor2 = b2 * dt - - for i in eachindex(M.d, D1, D2, sigma) - M.d[i] = 1 + (factor1 * D1[i] + factor2 * D2[i]) / sigma[i] - end - - for i in eachindex(M.dl, P1.dl, P2.dl) - M.dl[i] = -(factor1 * P1.dl[i] + factor2 * P2.dl[i]) / sigma[i] - end - - for i in eachindex(M.dl, P1.du, P2.du) - M.du[i] = -(factor1 * P1.du[i] + factor2 * P2.du[i]) / sigma[i + 1] + for i in eachindex(M.d, σ) + M.d[i] /= σ[i] end - return M -end - -##################################################################### -# Generic fallback (for dense arrays) -sum_destruction_terms!(D, P) = sum!(D', P) - -function sum_destruction_terms!(D, P::Tridiagonal) - Base.require_one_based_indexing(D, P.dl, P.d, P.du) - @assert length(D) == length(P.dl) + 1 == length(P.d) == length(P.du) + 1 - - let i = 1 - D[i] = P.d[i] + P.dl[i] - end - for i in 2:(length(D) - 1) - D[i] = P.du[i - 1] + P.d[i] + P.dl[i] - end - let i = lastindex(D) - D[i] = P.du[i - 1] + P.d[i] - end - - return D + return nothing end ##################################################################### @@ -291,6 +207,8 @@ function alg_cache(alg::MPE, u, rate_prototype, ::Type{uEltypeNoUnits}, weight = similar(u, uEltypeNoUnits) recursivefill!(weight, false) + # We use P to store the evaluation of the PDS + # as well as to store the system matrix of the linear system linprob = LinearProblem(P, _vec(linsolve_tmp); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -339,7 +257,7 @@ function perform_step!(integrator, cache::MPEConstantCache, repeat_step = false) @unpack small_constant = cache # Attention: Implementation assumes that the pds is conservative, - # i.e., P[i, i] == 0 for all i + # i.e., P[i, i] == 0 for all i # evaluate production matrix P = f.p(uprev, p, t) @@ -375,14 +293,17 @@ function perform_step!(integrator, cache::MPECache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator @unpack P, D, weight = cache + # We use P to store the last evaluation of the PDS + # as well as to store the system matrix of the linear system + # TODO: Shall we require the users to set unused entries to zero? fill!(P, zero(eltype(P))) f.p(P, uprev, p, t) # evaluate production terms - sum_destruction_terms!(D, P) # store destruction terms in D integrator.stats.nf += 1 - build_mprk_matrix!(P, 1, P, D, uprev, dt) + build_mprk_matrix!(P, P, uprev, dt) + # Same as linres = P \ uprev linres = dolinsolve(integrator, cache.linsolve; A = P, b = _vec(uprev), @@ -508,7 +429,6 @@ struct MPRK22Cache{uType, rateType, PType, tabType, Thread, F, uNoUnitsType} <: P2::PType D::uType D2::uType - M::PType σ::uType tab::tabType thread::Thread @@ -526,12 +446,14 @@ function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, tmp = zero(u) - M = p_prototype(u, f) + P2 = p_prototype(u, f) linsolve_tmp = zero(u) weight = similar(u, uEltypeNoUnits) recursivefill!(weight, false) - linprob = LinearProblem(M, _vec(linsolve_tmp); u0 = _vec(tmp)) + # We use P2 to store the last evaluation of the PDS + # as well as to store the system matrix of the linear system + linprob = LinearProblem(P2, _vec(linsolve_tmp); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, assumptions = LinearSolve.OperatorAssumptions(true)) @@ -540,10 +462,9 @@ function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, zero(rate_prototype), # k zero(rate_prototype), #fsalfirst p_prototype(u, f), # P - p_prototype(u, f), # P2 + P2, # P2 zero(u), # D zero(u), # D2 - M, zero(u), # σ tab, alg.thread, linsolve_tmp, linsolve, weight) @@ -661,42 +582,52 @@ end function perform_step!(integrator, cache::MPRK22Cache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator - @unpack tmp, atmp, P, P2, D, D2, M, σ, thread, weight = cache + @unpack tmp, atmp, P, P2, D, D2, σ, thread, weight = cache @unpack a21, b1, b2, c2, small_constant = cache.tab - uprev .= uprev .+ small_constant + # We use P2 to store the last evaluation of the PDS + # as well as to store the system matrix of the linear system f.p(P, uprev, p, t) # evaluate production terms - sum_destruction_terms!(D, P) # store destruction terms in D integrator.stats.nf += 1 + @.. broadcast=false P2=a21 * P + + # avoid division by zero due to zero Patankar weights + @.. broadcast=false σ=uprev + small_constant + + build_mprk_matrix!(P2, P2, σ, dt) - build_mprk_matrix!(M, a21, P, D, uprev, dt) - # Same as linres = M \ uprev + # Same as linres = P2 \ uprev linres = dolinsolve(integrator, cache.linsolve; - A = M, b = _vec(uprev), + A = P2, b = _vec(uprev), du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight) u .= linres integrator.stats.nsolve += 1 - u .= u .+ small_constant - - σ .= uprev .* (u ./ uprev) .^ (1 / a21) .+ small_constant + if isone(a21) + σ .= u + else + @.. broadcast=false σ=σ^(1 - 1 / a21) * u^(1 / a21) + end + @.. broadcast=false σ=σ + small_constant f.p(P2, u, p, t + a21 * dt) # evaluate production terms - sum_destruction_terms!(D, P) # store destruction terms in D - sum_destruction_terms!(D2, P2) # store destruction terms in D2 + integrator.stats.nf += 1 + + @.. broadcast=false P2=b1 * P + b2 * P2 + + build_mprk_matrix!(P2, P2, σ, dt) - build_mprk_matrix!(M, b1, P, D, b2, P2, D2, σ, dt) - # Same as linres = M \ uprev + # Same as linres = P2 \ uprev linres = dolinsolve(integrator, cache.linsolve; - A = M, b = _vec(uprev), + A = P2, b = _vec(uprev), du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight) u .= linres integrator.stats.nsolve += 1 - tmp .= u .- σ + @.. broadcast=false tmp=u - σ calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t, thread)