Skip to content

Commit

Permalink
Merge pull request SciML#2393 from chriselrod/getcoeffchunk
Browse files Browse the repository at this point in the history
Getcoeffchunk
  • Loading branch information
YingboMa authored Jan 24, 2024
2 parents 1b11b47 + 0b1092d commit 7685996
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 28 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand Down Expand Up @@ -74,6 +75,7 @@ Distributions = "0.23, 0.24, 0.25"
DocStringExtensions = "0.7, 0.8, 0.9"
DomainSets = "0.6"
DynamicQuantities = "^0.11.2"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
Expand Down
11 changes: 2 additions & 9 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ the `constraint`.
mask,
constraint)
eadj = M.row_cols
for i in range
@inbounds for i in range
vertices = eadj[i]
if constraint(length(vertices))
for (j, v) in enumerate(vertices)
Expand All @@ -170,7 +170,7 @@ end
range,
mask,
constraint)
for i in range
@inbounds for i in range
row = @view M[i, :]
if constraint(count(!iszero, row))
for (v, val) in enumerate(row)
Expand Down Expand Up @@ -382,13 +382,6 @@ end

swap!(v, i, j) = v[i], v[j] = v[j], v[i]

function getcoeff(vars, coeffs, var)
for (vj, v) in enumerate(vars)
v == var && return coeffs[vj]
end
return 0
end

"""
$(SIGNATURES)
Expand Down
133 changes: 114 additions & 19 deletions src/systems/sparsematrixclil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ end
# build something that works for us here and worry about it later.
nonzerosmap(a::CLILVector) = NonZeros(a)

using FindFirstFunctions: findfirstequal

function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot,
last_pivot; pivot_equal_optimization = true)
# for ei in nzrows(>= k)
Expand Down Expand Up @@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
# conservative, we leave it at this, as this captures the most important
# case for MTK (where most pivots are `1` or `-1`).
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)

for ei in (k + 1):size(M, 1)
@inbounds for ei in (k + 1):size(M, 1)
# eliminate `v`
coeff = 0
ivars = eadj[ei]
vj = findfirst(isequal(vpivot), ivars)
vj = findfirstequal(vpivot, ivars)
if vj !== nothing
coeff = old_cadj[ei][vj]
deleteat!(old_cadj[ei], vj)
Expand All @@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
ivars = eadj[ei]
icoeffs = old_cadj[ei]

tmp_incidence = similar(eadj[ei], 0)
tmp_coeffs = similar(old_cadj[ei], 0)
# TODO: We know both ivars and kvars are sorted, we could just write
# a quick iterator here that does this without allocation/faster.
vars = sort(union(ivars, kvars))

for v in vars
v == vpivot && continue
ck = getcoeff(kvars, kcoeffs, v)
ci = getcoeff(ivars, icoeffs, v)
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
if !iszero(ci)
push!(tmp_incidence, v)
push!(tmp_coeffs, ci)
numkvars = length(kvars)
numivars = length(ivars)
tmp_incidence = similar(eadj[ei], numkvars + numivars)
tmp_coeffs = similar(old_cadj[ei], numkvars + numivars)
tmp_len = 0
kvind = ivind = 0
if _debug_mode
# in debug mode, we at least check to confirm we're iterating over
# `v`s in the correct order
vars = sort(union(ivars, kvars))
vi = 0
end
if numivars > 0 && numkvars > 0
kvv = kvars[kvind += 1]
ivv = ivars[ivind += 1]
dobreak = false
while true
if kvv == ivv
v = kvv
ck = kcoeffs[kvind]
ci = icoeffs[ivind]
kvind += 1
ivind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
p1 = Base.Checked.checked_mul(pivot, ci)
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
elseif kvv < ivv
v = kvv
ck = kcoeffs[kvind]
kvind += 1
if kvind > numkvars
dobreak = true
else
kvv = kvars[kvind]
end
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
else # kvv > ivv
v = ivv
ci = icoeffs[ivind]
ivind += 1
if ivind > numivars
dobreak = true
else
ivv = ivars[ivind]
end
ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot)
end
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot && !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
dobreak && break
end
elseif numkvars > 0
ivind = 1
kvv = kvars[kvind += 1]
elseif numivars > 0
kvind = 1
ivv = ivars[ivind += 1]
end
if kvind <= numkvars
v = kvv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
ck = kcoeffs[kvind]
p2 = Base.Checked.checked_mul(coeff, ck)
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
if !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
end
(kvind == numkvars) && break
v = kvars[kvind += 1]
end
elseif ivind <= numivars
v = ivv
while true
if _debug_mode
@assert v == vars[vi += 1]
end
if v != vpivot
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
ci = exactdiv(p1, last_pivot)
if !iszero(ci)
tmp_incidence[tmp_len += 1] = v
tmp_coeffs[tmp_len] = ci
end
end
(ivind == numivars) && break
v = ivars[ivind += 1]
end
end
resize!(tmp_incidence, tmp_len)
resize!(tmp_coeffs, tmp_len)
eadj[ei] = tmp_incidence
old_cadj[ei] = tmp_coeffs
end
Expand Down

0 comments on commit 7685996

Please sign in to comment.