diff --git a/Project.toml b/Project.toml index 8299dbe1aa..079f1e9ac0 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" +FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -74,6 +75,7 @@ Distributions = "0.23, 0.24, 0.25" DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6" DynamicQuantities = "^0.11.2" +FindFirstFunctions = "1" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" Graphs = "1.5.2" diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 3ddb88f77b..efb79aae50 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -153,7 +153,7 @@ the `constraint`. mask, constraint) eadj = M.row_cols - for i in range + @inbounds for i in range vertices = eadj[i] if constraint(length(vertices)) for (j, v) in enumerate(vertices) @@ -170,7 +170,7 @@ end range, mask, constraint) - for i in range + @inbounds for i in range row = @view M[i, :] if constraint(count(!iszero, row)) for (v, val) in enumerate(row) @@ -382,13 +382,6 @@ end swap!(v, i, j) = v[i], v[j] = v[j], v[i] -function getcoeff(vars, coeffs, var) - for (vj, v) in enumerate(vars) - v == var && return coeffs[vj] - end - return 0 -end - """ $(SIGNATURES) diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index dca48973c4..cddf316084 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -129,6 +129,8 @@ end # build something that works for us here and worry about it later. nonzerosmap(a::CLILVector) = NonZeros(a) +using FindFirstFunctions: findfirstequal + function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot, last_pivot; pivot_equal_optimization = true) # for ei in nzrows(>= k) @@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap # conservative, we leave it at this, as this captures the most important # case for MTK (where most pivots are `1` or `-1`). pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot) - - for ei in (k + 1):size(M, 1) + @inbounds for ei in (k + 1):size(M, 1) # eliminate `v` coeff = 0 ivars = eadj[ei] - vj = findfirst(isequal(vpivot), ivars) + vj = findfirstequal(vpivot, ivars) if vj !== nothing coeff = old_cadj[ei][vj] deleteat!(old_cadj[ei], vj) @@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap ivars = eadj[ei] icoeffs = old_cadj[ei] - tmp_incidence = similar(eadj[ei], 0) - tmp_coeffs = similar(old_cadj[ei], 0) - # TODO: We know both ivars and kvars are sorted, we could just write - # a quick iterator here that does this without allocation/faster. - vars = sort(union(ivars, kvars)) - - for v in vars - v == vpivot && continue - ck = getcoeff(kvars, kcoeffs, v) - ci = getcoeff(ivars, icoeffs, v) - p1 = Base.Checked.checked_mul(pivot, ci) - p2 = Base.Checked.checked_mul(coeff, ck) - ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) - if !iszero(ci) - push!(tmp_incidence, v) - push!(tmp_coeffs, ci) + numkvars = length(kvars) + numivars = length(ivars) + tmp_incidence = similar(eadj[ei], numkvars + numivars) + tmp_coeffs = similar(old_cadj[ei], numkvars + numivars) + tmp_len = 0 + kvind = ivind = 0 + if _debug_mode + # in debug mode, we at least check to confirm we're iterating over + # `v`s in the correct order + vars = sort(union(ivars, kvars)) + vi = 0 + end + if numivars > 0 && numkvars > 0 + kvv = kvars[kvind += 1] + ivv = ivars[ivind += 1] + dobreak = false + while true + if kvv == ivv + v = kvv + ck = kcoeffs[kvind] + ci = icoeffs[ivind] + kvind += 1 + ivind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + p1 = Base.Checked.checked_mul(pivot, ci) + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) + elseif kvv < ivv + v = kvv + ck = kcoeffs[kvind] + kvind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot) + else # kvv > ivv + v = ivv + ci = icoeffs[ivind] + ivind += 1 + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot) + end + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot && !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci + end + dobreak && break + end + elseif numkvars > 0 + ivind = 1 + kvv = kvars[kvind += 1] + elseif numivars > 0 + kvind = 1 + ivv = ivars[ivind += 1] + end + if kvind <= numkvars + v = kvv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + ck = kcoeffs[kvind] + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot) + if !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci + end + end + (kvind == numkvars) && break + v = kvars[kvind += 1] + end + elseif ivind <= numivars + v = ivv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind]) + ci = exactdiv(p1, last_pivot) + if !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci + end + end + (ivind == numivars) && break + v = ivars[ivind += 1] end end + resize!(tmp_incidence, tmp_len) + resize!(tmp_coeffs, tmp_len) eadj[ei] = tmp_incidence old_cadj[ei] = tmp_coeffs end