Skip to content

Commit

Permalink
Avoid too many build_mprk_matrix! methods (#77)
Browse files Browse the repository at this point in the history
* Avoid too many build_mprk_matrix! methods

-WORK IN PROGRESS

- MPRK22_old uses old implementation

- MPRK22 now uses a single method of build_mprk_matrix! per data type

* removed old implementation of build_mprk_matrix!

* removed one more method of build_mprk_matrix!

* use build_mprk_matrix! within build_mprk_matrix

* Removed extra allocation of memory for linprob.A from MPRK22 (inplace).

- Now we use P2, which is also used to store an evaluation of the PDS

- The implementation is now analogous to MPE.

* Added comments about the usage of P in the code of MPE.

* moved comments about usage of P in MPE and P2 in MPRK22 to correct position
  • Loading branch information
SKopecz authored May 24, 2024
1 parent 7c66a3c commit 902e118
Showing 1 changed file with 73 additions and 142 deletions.
215 changes: 73 additions & 142 deletions src/mprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,26 @@ p_prototype(u, f::ConservativePDSFunction) = zero(f.p_prototype)
#####################################################################
# out-of-place for dense and static arrays
function build_mprk_matrix(P, sigma, dt)
# M[i,i] = (sigma[i] + dt*sum_j P[j,i])/sigma[i]
# M[i,j] = -dt*P[i,j]/sigma[j]
# re-use the in-place version implemented below
M = similar(P)
zeroM = zero(eltype(M))
build_mprk_matrix!(M, P, sigma, dt)

if P isa StaticArray
return SMatrix(M)
else
return M
end
end

# in-place for dense arrays
function build_mprk_matrix!(M, P, sigma, dt)
# M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i]
# M[i,j] = -dt * P[i,j] / sigma[j]
# TODO: the performance of this can likely be improved
Base.require_one_based_indexing(M, P, sigma)
@assert size(M, 1) == size(M, 2) == size(P, 1) == size(P, 2) == length(sigma)

zeroM = zero(eltype(P))

# Set sigma on diagonal
@inbounds for i in eachindex(sigma)
Expand All @@ -41,139 +57,39 @@ function build_mprk_matrix(P, sigma, dt)
M[i, i] /= sigma[i]
end

if P isa StaticArray
return SMatrix(M)
else
return M
end
end

# in-place for dense arrays
function build_mprk_matrix!(M, a, P, D, sigma, dt)
# M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i]
# M[i,j] = -dt * P[i,j] / sigma[j]
# TODO: the performance of this can likely be improved
Base.require_one_based_indexing(M, P, D, sigma)
@assert size(M, 1) == size(M, 2) == size(P, 1) == size(P, 2) == length(D) ==
length(sigma)

for j in 1:length(sigma)
for i in 1:length(sigma)
if i == j
M[i, i] = 1 + dt * a * D[i] / sigma[i]
else
M[i, j] = -dt * a * P[i, j] / sigma[j]
end
end
end

return M
end

function build_mprk_matrix!(M, b1, P1, D1, b2, P2, D2, sigma, dt)
# M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i]
# M[i,j] = -dt * P[i,j] / sigma[j]
# TODO: the performance of this can likely be improved
Base.require_one_based_indexing(M, P1, D1, P2, D2, sigma)
@assert size(M, 1) == size(M, 2) == size(P1, 1) == size(P1, 2) == length(D1) ==
size(P2, 1) == size(P2, 2) == length(D1) == length(sigma)

for j in 1:length(sigma)
for i in 1:length(sigma)
if i == j
M[i, i] = 1 + dt * (b1 * D1[i] + b2 * D2[i]) / sigma[i]
else
M[i, j] = -dt * (b1 * P1[i, j] + b2 * P2[i, j]) / sigma[j]
end
end
end

return M
return nothing
end

# optimized versions for Tridiagonal matrices
function build_mprk_matrix!(M::Tridiagonal,
a, P::Tridiagonal, D,
sigma, dt)
function build_mprk_matrix!(M::Tridiagonal, P::Tridiagonal, σ, dt)
# M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i]
# M[i,j] = -dt * P[i,j] / sigma[j]
Base.require_one_based_indexing(M.dl, M.d, M.du,
P.dl, P.d, P.du,
D, sigma)
P.dl, P.d, P.du, σ)
@assert length(M.dl) + 1 == length(M.d) == length(M.du) + 1 ==
length(P.dl) + 1 == length(P.d) == length(P.du) + 1 ==
length(D) == length(sigma)

factor = a * dt
length(P.dl) + 1 == length(P.d) == length(P.du) + 1 == length(σ)

for i in eachindex(M.d, D, sigma)
M.d[i] = 1 + factor * D[i] / sigma[i]
for i in eachindex(M.d, σ)
M.d[i] = σ[i]
end

for i in eachindex(M.dl, P.dl)
M.dl[i] = -factor * P.dl[i] / sigma[i]
dtP = dt * P.dl[i]
M.dl[i] = -dtP / σ[i]
M.d[i] += dtP
end

for i in eachindex(M.dl, P.dl)
M.du[i] = -factor * P.du[i] / sigma[i + 1]
for i in eachindex(M.du, P.du)
dtP = dt * P.du[i]
M.du[i] = -dtP / σ[i + 1]
M.d[i + 1] += dtP
end

return M
end

function build_mprk_matrix!(M::Tridiagonal,
b1, P1::Tridiagonal, D1,
b2, P2::Tridiagonal, D2,
sigma, dt)
# M[i,i] = (sigma[i] + dt * sum_j P[j,i]) / sigma[i]
# M[i,j] = -dt * P[i,j] / sigma[j]
Base.require_one_based_indexing(M.dl, M.d, M.du,
P1.dl, P1.d, P1.du, D1,
P2.dl, P2.d, P2.du, D2,
sigma)
@assert length(M.dl) + 1 == length(M.d) == length(M.du) + 1 ==
length(P1.dl) + 1 == length(P1.d) == length(P1.du) + 1 ==
length(D1) ==
length(P2.dl) + 1 == length(P2.d) == length(P2.du) + 1 ==
length(D2) == length(sigma)

factor1 = b1 * dt
factor2 = b2 * dt

for i in eachindex(M.d, D1, D2, sigma)
M.d[i] = 1 + (factor1 * D1[i] + factor2 * D2[i]) / sigma[i]
end

for i in eachindex(M.dl, P1.dl, P2.dl)
M.dl[i] = -(factor1 * P1.dl[i] + factor2 * P2.dl[i]) / sigma[i]
end

for i in eachindex(M.dl, P1.du, P2.du)
M.du[i] = -(factor1 * P1.du[i] + factor2 * P2.du[i]) / sigma[i + 1]
for i in eachindex(M.d, σ)
M.d[i] /= σ[i]
end

return M
end

#####################################################################
# Generic fallback (for dense arrays)
sum_destruction_terms!(D, P) = sum!(D', P)

function sum_destruction_terms!(D, P::Tridiagonal)
Base.require_one_based_indexing(D, P.dl, P.d, P.du)
@assert length(D) == length(P.dl) + 1 == length(P.d) == length(P.du) + 1

let i = 1
D[i] = P.d[i] + P.dl[i]
end
for i in 2:(length(D) - 1)
D[i] = P.du[i - 1] + P.d[i] + P.dl[i]
end
let i = lastindex(D)
D[i] = P.du[i - 1] + P.d[i]
end

return D
return nothing
end

#####################################################################
Expand Down Expand Up @@ -291,6 +207,8 @@ function alg_cache(alg::MPE, u, rate_prototype, ::Type{uEltypeNoUnits},
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)

# We use P to store the evaluation of the PDS
# as well as to store the system matrix of the linear system
linprob = LinearProblem(P, _vec(linsolve_tmp); u0 = _vec(tmp))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))
Expand Down Expand Up @@ -339,7 +257,7 @@ function perform_step!(integrator, cache::MPEConstantCache, repeat_step = false)
@unpack small_constant = cache

# Attention: Implementation assumes that the pds is conservative,
# i.e., P[i, i] == 0 for all i
# i.e., P[i, i] == 0 for all i

# evaluate production matrix
P = f.p(uprev, p, t)
Expand Down Expand Up @@ -375,14 +293,17 @@ function perform_step!(integrator, cache::MPECache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack P, D, weight = cache

# We use P to store the last evaluation of the PDS
# as well as to store the system matrix of the linear system

# TODO: Shall we require the users to set unused entries to zero?
fill!(P, zero(eltype(P)))

f.p(P, uprev, p, t) # evaluate production terms
sum_destruction_terms!(D, P) # store destruction terms in D
integrator.stats.nf += 1

build_mprk_matrix!(P, 1, P, D, uprev, dt)
build_mprk_matrix!(P, P, uprev, dt)

# Same as linres = P \ uprev
linres = dolinsolve(integrator, cache.linsolve;
A = P, b = _vec(uprev),
Expand Down Expand Up @@ -508,7 +429,6 @@ struct MPRK22Cache{uType, rateType, PType, tabType, Thread, F, uNoUnitsType} <:
P2::PType
D::uType
D2::uType
M::PType
σ::uType
tab::tabType
thread::Thread
Expand All @@ -526,12 +446,14 @@ function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits},

tmp = zero(u)

M = p_prototype(u, f)
P2 = p_prototype(u, f)
linsolve_tmp = zero(u)
weight = similar(u, uEltypeNoUnits)
recursivefill!(weight, false)

linprob = LinearProblem(M, _vec(linsolve_tmp); u0 = _vec(tmp))
# We use P2 to store the last evaluation of the PDS
# as well as to store the system matrix of the linear system
linprob = LinearProblem(P2, _vec(linsolve_tmp); u0 = _vec(tmp))
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true))

Expand All @@ -540,10 +462,9 @@ function alg_cache(alg::MPRK22, u, rate_prototype, ::Type{uEltypeNoUnits},
zero(rate_prototype), # k
zero(rate_prototype), #fsalfirst
p_prototype(u, f), # P
p_prototype(u, f), # P2
P2, # P2
zero(u), # D
zero(u), # D2
M,
zero(u), # σ
tab, alg.thread,
linsolve_tmp, linsolve, weight)
Expand Down Expand Up @@ -661,42 +582,52 @@ end

function perform_step!(integrator, cache::MPRK22Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack tmp, atmp, P, P2, D, D2, M, σ, thread, weight = cache
@unpack tmp, atmp, P, P2, D, D2, σ, thread, weight = cache
@unpack a21, b1, b2, c2, small_constant = cache.tab

uprev .= uprev .+ small_constant
# We use P2 to store the last evaluation of the PDS
# as well as to store the system matrix of the linear system

f.p(P, uprev, p, t) # evaluate production terms
sum_destruction_terms!(D, P) # store destruction terms in D
integrator.stats.nf += 1
@.. broadcast=false P2=a21 * P

# avoid division by zero due to zero Patankar weights
@.. broadcast=false σ=uprev + small_constant

build_mprk_matrix!(P2, P2, σ, dt)

build_mprk_matrix!(M, a21, P, D, uprev, dt)
# Same as linres = M \ uprev
# Same as linres = P2 \ uprev
linres = dolinsolve(integrator, cache.linsolve;
A = M, b = _vec(uprev),
A = P2, b = _vec(uprev),
du = integrator.fsalfirst, u = u, p = p, t = t,
weight = weight)
u .= linres
integrator.stats.nsolve += 1

u .= u .+ small_constant

σ .= uprev .* (u ./ uprev) .^ (1 / a21) .+ small_constant
if isone(a21)
σ .= u
else
@.. broadcast=false σ=σ^(1 - 1 / a21) * u^(1 / a21)
end
@.. broadcast=false σ=σ + small_constant

f.p(P2, u, p, t + a21 * dt) # evaluate production terms
sum_destruction_terms!(D, P) # store destruction terms in D
sum_destruction_terms!(D2, P2) # store destruction terms in D2
integrator.stats.nf += 1

@.. broadcast=false P2=b1 * P + b2 * P2

build_mprk_matrix!(P2, P2, σ, dt)

build_mprk_matrix!(M, b1, P, D, b2, P2, D2, σ, dt)
# Same as linres = M \ uprev
# Same as linres = P2 \ uprev
linres = dolinsolve(integrator, cache.linsolve;
A = M, b = _vec(uprev),
A = P2, b = _vec(uprev),
du = integrator.fsalfirst, u = u, p = p, t = t,
weight = weight)
u .= linres
integrator.stats.nsolve += 1

tmp .= u .- σ
@.. broadcast=false tmp=u - σ
calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t,
thread)
Expand Down

0 comments on commit 902e118

Please sign in to comment.