From e089c9eafc52de7db1b240b68b3dbda164cac58a Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 21 Dec 2023 07:18:44 -0500 Subject: [PATCH 1/6] chunk getcoeff --- src/systems/alias_elimination.jl | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 3ddb88f77b..38de4fbb31 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -383,8 +383,31 @@ end swap!(v, i, j) = v[i], v[j] = v[j], v[i] function getcoeff(vars, coeffs, var) - for (vj, v) in enumerate(vars) - v == var && return coeffs[vj] + Nvars = length(vars) + i = 0 + chunk_size = 8 + @inbounds while i < Nvars - chunk_size + 1 + btup = let vars = vars, var = var, i = i + ntuple(Val(chunk_size)) do j + @inbounds vars[i + j] == var + end + end + inds = ntuple(Base.Fix2(-, 1), Val(8)) + eights = ntuple(Returns(8), Val(8)) + inds = map(ifelse, btup, inds, eights) + inds4 = (min(inds[1], inds[5]), + min(inds[2], inds[6]), + min(inds[3], inds[7]), + min(inds[4], inds[8])) + inds2 = (min(inds4[1], inds4[3]), min(inds4[2], inds4[4])) + ind = min(inds2[1], inds2[2]) + if ind != 8 + return coeffs[i + ind + 1] + end + i += chunk_size + end + @inbounds for vj in (i + 1):Nvars + vars[vj] == var && return coeffs[vj] end return 0 end From 04962e526e93009acf226d3454e34a96e3480749 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Fri, 22 Dec 2023 09:07:58 -0500 Subject: [PATCH 2/6] optimize `bareiss_update_virtual_colswap_mtk!` --- src/systems/alias_elimination.jl | 30 -------- src/systems/sparsematrixclil.jl | 120 +++++++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 43 deletions(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 38de4fbb31..6af95b5902 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -382,36 +382,6 @@ end swap!(v, i, j) = v[i], v[j] = v[j], v[i] -function getcoeff(vars, coeffs, var) - Nvars = length(vars) - i = 0 - chunk_size = 8 - @inbounds while i < Nvars - chunk_size + 1 - btup = let vars = vars, var = var, i = i - ntuple(Val(chunk_size)) do j - @inbounds vars[i + j] == var - end - end - inds = ntuple(Base.Fix2(-, 1), Val(8)) - eights = ntuple(Returns(8), Val(8)) - inds = map(ifelse, btup, inds, eights) - inds4 = (min(inds[1], inds[5]), - min(inds[2], inds[6]), - min(inds[3], inds[7]), - min(inds[4], inds[8])) - inds2 = (min(inds4[1], inds4[3]), min(inds4[2], inds4[4])) - ind = min(inds2[1], inds2[2]) - if ind != 8 - return coeffs[i + ind + 1] - end - i += chunk_size - end - @inbounds for vj in (i + 1):Nvars - vars[vj] == var && return coeffs[vj] - end - return 0 -end - """ $(SIGNATURES) diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index dca48973c4..92f12ca926 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -169,7 +169,7 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap # case for MTK (where most pivots are `1` or `-1`). pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot) - for ei in (k + 1):size(M, 1) + @inbounds for ei in (k + 1):size(M, 1) # eliminate `v` coeff = 0 ivars = eadj[ei] @@ -193,18 +193,112 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap tmp_coeffs = similar(old_cadj[ei], 0) # TODO: We know both ivars and kvars are sorted, we could just write # a quick iterator here that does this without allocation/faster. - vars = sort(union(ivars, kvars)) - - for v in vars - v == vpivot && continue - ck = getcoeff(kvars, kcoeffs, v) - ci = getcoeff(ivars, icoeffs, v) - p1 = Base.Checked.checked_mul(pivot, ci) - p2 = Base.Checked.checked_mul(coeff, ck) - ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) - if !iszero(ci) - push!(tmp_incidence, v) - push!(tmp_coeffs, ci) + numkvars = length(kvars) + numivars = length(ivars) + kvind = ivind = 0 + if _debug_mode + # in debug mode, we at least check to confirm we're iterating over + # `v`s in the correct order + vars = sort(union(ivars, kvars)) + vi = 0 + end + if numivars > 0 && numkvars > 0 + kvv = kvars[kvind += 1] + ivv = ivars[ivind += 1] + dobreak = false + while true + if kvv == ivv + v = kvv + ck = kcoeffs[kvind] + ci = icoeffs[ivind] + kvind += 1 + ivind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + elseif kvv < ivv + v = kvv + ck = kcoeffs[kvind] + ci = zero(eltype(icoeffs)) + kvind += 1 + if kvind > numkvars + dobreak = true + else + kvv = kvars[kvind] + end + else # kvv > ivv + v = ivv + ck = zero(eltype(kcoeffs)) + ci = icoeffs[ivind] + ivind += 1 + if ivind > numivars + dobreak = true + else + ivv = ivars[ivind] + end + end + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + p1 = Base.Checked.checked_mul(pivot, ci) + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) + if !iszero(ci) + push!(tmp_incidence, v) + push!(tmp_coeffs, ci) + end + end + dobreak && break + end + elseif numivars == 0 + ivind = 1 + kvv = kvars[kvind += 1] + else # numkvars == 0 + kvind = 1 + ivv = ivars[ivind += 1] + end + if kvind <= numkvars + v = kvv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + ck = kcoeffs[kvind] + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(0, p2), last_pivot) + if !iszero(ci) + push!(tmp_incidence, v) + push!(tmp_coeffs, ci) + end + end + (kvind == numkvars) && break + v = kvars[kvind += 1] + end + elseif ivind <= numivars + v = ivv + while true + if _debug_mode + @assert v == vars[vi += 1] + end + if v != vpivot + p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind]) + ci = exactdiv(p1, last_pivot) + if !iszero(ci) + push!(tmp_incidence, v) + push!(tmp_coeffs, ci) + end + end + (ivind == numivars) && break + v = ivars[ivind += 1] end end eadj[ei] = tmp_incidence From a06fe2b3da3ec25712d4df5668df267ed21aba58 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 27 Dec 2023 07:28:38 -0500 Subject: [PATCH 3/6] minor microoptimization --- src/systems/alias_elimination.jl | 4 +- src/systems/sparsematrixclil.jl | 109 +++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 23 deletions(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 6af95b5902..efb79aae50 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -153,7 +153,7 @@ the `constraint`. mask, constraint) eadj = M.row_cols - for i in range + @inbounds for i in range vertices = eadj[i] if constraint(length(vertices)) for (j, v) in enumerate(vertices) @@ -170,7 +170,7 @@ end range, mask, constraint) - for i in range + @inbounds for i in range row = @view M[i, :] if constraint(count(!iszero, row)) for (v, val) in enumerate(row) diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index 92f12ca926..af25615dad 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -129,6 +129,74 @@ end # build something that works for us here and worry about it later. nonzerosmap(a::CLILVector) = NonZeros(a) +findfirstequal(vpivot, ivars) = findfirst(isequal(vpivot), ivars) +function findfirstequal(vpivot::Int64, ivars::AbstractVector{Int64}) + GC.@preserve ivars begin + ret = Base.llvmcall((""" + declare i8 @llvm.cttz.i8(i8, i1); + define i64 @entry(i64 %0, i64 %1, i64 %2) #0 { + top: + %ivars = inttoptr i64 %1 to i64* + %btmp = insertelement <8 x i64> undef, i64 %0, i64 0 + %var = shufflevector <8 x i64> %btmp, <8 x i64> undef, <8 x i32> zeroinitializer + %lenm7 = add nsw i64 %2, -7 + %dosimditer = icmp ugt i64 %2, 7 + br i1 %dosimditer, label %L9.lr.ph, label %L32 + + L9.lr.ph: + %len8 = and i64 %2, 9223372036854775800 + br label %L9 + + L9: + %i = phi i64 [ 0, %L9.lr.ph ], [ %vinc, %L30 ] + %ivarsi = getelementptr inbounds i64, i64* %ivars, i64 %i + %vpvi = bitcast i64* %ivarsi to <8 x i64>* + %v = load <8 x i64>, <8 x i64>* %vpvi, align 8 + %m = icmp eq <8 x i64> %v, %var + %mu = bitcast <8 x i1> %m to i8 + %matchnotfound = icmp eq i8 %mu, 0 + br i1 %matchnotfound, label %L30, label %L17 + + L17: + %tz8 = call i8 @llvm.cttz.i8(i8 %mu, i1 true) + %tz64 = zext i8 %tz8 to i64 + %vis = add nuw i64 %i, %tz64 + br label %common.ret + + common.ret: + %retval = phi i64 [ %vis, %L17 ], [ -1, %L32 ], [ %si, %L51 ], [ -1, %L67 ] + ret i64 %retval + + L30: + %vinc = add nuw nsw i64 %i, 8 + %continue = icmp slt i64 %vinc, %lenm7 + br i1 %continue, label %L9, label %L32 + + L32: + %cumi = phi i64 [ 0, %top ], [ %len8, %L30 ] + %done = icmp eq i64 %cumi, %2 + br i1 %done, label %common.ret, label %L51 + + L51: + %si = phi i64 [ %inc, %L67 ], [ %cumi, %L32 ] + %spi = getelementptr inbounds i64, i64* %ivars, i64 %si + %svi = load i64, i64* %spi, align 8 + %match = icmp eq i64 %svi, %0 + br i1 %match, label %common.ret, label %L67 + + L67: + %inc = add i64 %si, 1 + %dobreak = icmp eq i64 %inc, %2 + br i1 %dobreak, label %common.ret, label %L51 + + } + attributes #0 = { alwaysinline } + """, "entry"), Int64, Tuple{Int64, Ptr{Int64}, Int64}, vpivot, pointer(ivars), + length(ivars)) + end + ret < 0 ? nothing : ret + 1 +end + function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot, last_pivot; pivot_equal_optimization = true) # for ei in nzrows(>= k) @@ -168,12 +236,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap # conservative, we leave it at this, as this captures the most important # case for MTK (where most pivots are `1` or `-1`). pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot) - @inbounds for ei in (k + 1):size(M, 1) # eliminate `v` coeff = 0 ivars = eadj[ei] - vj = findfirst(isequal(vpivot), ivars) + vj = findfirstequal(vpivot, ivars) if vj !== nothing coeff = old_cadj[ei][vj] deleteat!(old_cadj[ei], vj) @@ -189,12 +256,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap ivars = eadj[ei] icoeffs = old_cadj[ei] - tmp_incidence = similar(eadj[ei], 0) - tmp_coeffs = similar(old_cadj[ei], 0) - # TODO: We know both ivars and kvars are sorted, we could just write - # a quick iterator here that does this without allocation/faster. numkvars = length(kvars) numivars = length(ivars) + tmp_incidence = similar(eadj[ei], numkvars + numivars) + tmp_coeffs = similar(old_cadj[ei], numkvars + numivars) + tmp_len = 0 kvind = ivind = 0 if _debug_mode # in debug mode, we at least check to confirm we're iterating over @@ -223,19 +289,22 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap else ivv = ivars[ivind] end + p1 = Base.Checked.checked_mul(pivot, ci) + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) elseif kvv < ivv v = kvv ck = kcoeffs[kvind] - ci = zero(eltype(icoeffs)) kvind += 1 if kvind > numkvars dobreak = true else kvv = kvars[kvind] end + p2 = Base.Checked.checked_mul(coeff, ck) + ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot) else # kvv > ivv v = ivv - ck = zero(eltype(kcoeffs)) ci = icoeffs[ivind] ivind += 1 if ivind > numivars @@ -243,18 +312,14 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap else ivv = ivars[ivind] end + ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot) end if _debug_mode @assert v == vars[vi += 1] end - if v != vpivot - p1 = Base.Checked.checked_mul(pivot, ci) - p2 = Base.Checked.checked_mul(coeff, ck) - ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot) - if !iszero(ci) - push!(tmp_incidence, v) - push!(tmp_coeffs, ci) - end + if v != vpivot && !iszero(ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci end dobreak && break end @@ -274,10 +339,10 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap if v != vpivot ck = kcoeffs[kvind] p2 = Base.Checked.checked_mul(coeff, ck) - ci = exactdiv(Base.Checked.checked_sub(0, p2), last_pivot) + ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot) if !iszero(ci) - push!(tmp_incidence, v) - push!(tmp_coeffs, ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci end end (kvind == numkvars) && break @@ -293,14 +358,16 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind]) ci = exactdiv(p1, last_pivot) if !iszero(ci) - push!(tmp_incidence, v) - push!(tmp_coeffs, ci) + tmp_incidence[tmp_len += 1] = v + tmp_coeffs[tmp_len] = ci end end (ivind == numivars) && break v = ivars[ivind += 1] end end + resize!(tmp_incidence, tmp_len) + resize!(tmp_coeffs, tmp_len) eadj[ei] = tmp_incidence old_cadj[ei] = tmp_coeffs end From b435f941ba15fc3d1d3656041e36d64a9ab48fcb Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 27 Dec 2023 07:54:21 -0500 Subject: [PATCH 4/6] fix a couple branches --- src/systems/sparsematrixclil.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index af25615dad..66400e4fd5 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -323,10 +323,10 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap end dobreak && break end - elseif numivars == 0 + elseif numkvars > 0 ivind = 1 kvv = kvars[kvind += 1] - else # numkvars == 0 + elseif numivars > 0 kvind = 1 ivv = ivars[ivind += 1] end From c31d58cafea675766fdb19d88fef5f8560c9aa0a Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Tue, 2 Jan 2024 09:20:37 -0500 Subject: [PATCH 5/6] Update src/systems/sparsematrixclil.jl Co-authored-by: Yingbo Ma --- src/systems/sparsematrixclil.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index 66400e4fd5..2137480093 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -130,7 +130,7 @@ end nonzerosmap(a::CLILVector) = NonZeros(a) findfirstequal(vpivot, ivars) = findfirst(isequal(vpivot), ivars) -function findfirstequal(vpivot::Int64, ivars::AbstractVector{Int64}) +function findfirstequal(vpivot::Int64, ivars::Vector{Int64}) GC.@preserve ivars begin ret = Base.llvmcall((""" declare i8 @llvm.cttz.i8(i8, i1); From 6000ef6f0f0335875e3b64c1fbb5ce826603be66 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 24 Jan 2024 09:47:31 -0500 Subject: [PATCH 6/6] FindFirstFunctions --- Project.toml | 2 + src/systems/sparsematrixclil.jl | 68 +-------------------------------- 2 files changed, 3 insertions(+), 67 deletions(-) diff --git a/Project.toml b/Project.toml index 5b9a9bf0ac..9511605162 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" +FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -74,6 +75,7 @@ Distributions = "0.23, 0.24, 0.25" DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6" DynamicQuantities = "0.8, 0.9, 0.10" +FindFirstFunctions = "1" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" Graphs = "1.5.2" diff --git a/src/systems/sparsematrixclil.jl b/src/systems/sparsematrixclil.jl index 2137480093..cddf316084 100644 --- a/src/systems/sparsematrixclil.jl +++ b/src/systems/sparsematrixclil.jl @@ -129,73 +129,7 @@ end # build something that works for us here and worry about it later. nonzerosmap(a::CLILVector) = NonZeros(a) -findfirstequal(vpivot, ivars) = findfirst(isequal(vpivot), ivars) -function findfirstequal(vpivot::Int64, ivars::Vector{Int64}) - GC.@preserve ivars begin - ret = Base.llvmcall((""" - declare i8 @llvm.cttz.i8(i8, i1); - define i64 @entry(i64 %0, i64 %1, i64 %2) #0 { - top: - %ivars = inttoptr i64 %1 to i64* - %btmp = insertelement <8 x i64> undef, i64 %0, i64 0 - %var = shufflevector <8 x i64> %btmp, <8 x i64> undef, <8 x i32> zeroinitializer - %lenm7 = add nsw i64 %2, -7 - %dosimditer = icmp ugt i64 %2, 7 - br i1 %dosimditer, label %L9.lr.ph, label %L32 - - L9.lr.ph: - %len8 = and i64 %2, 9223372036854775800 - br label %L9 - - L9: - %i = phi i64 [ 0, %L9.lr.ph ], [ %vinc, %L30 ] - %ivarsi = getelementptr inbounds i64, i64* %ivars, i64 %i - %vpvi = bitcast i64* %ivarsi to <8 x i64>* - %v = load <8 x i64>, <8 x i64>* %vpvi, align 8 - %m = icmp eq <8 x i64> %v, %var - %mu = bitcast <8 x i1> %m to i8 - %matchnotfound = icmp eq i8 %mu, 0 - br i1 %matchnotfound, label %L30, label %L17 - - L17: - %tz8 = call i8 @llvm.cttz.i8(i8 %mu, i1 true) - %tz64 = zext i8 %tz8 to i64 - %vis = add nuw i64 %i, %tz64 - br label %common.ret - - common.ret: - %retval = phi i64 [ %vis, %L17 ], [ -1, %L32 ], [ %si, %L51 ], [ -1, %L67 ] - ret i64 %retval - - L30: - %vinc = add nuw nsw i64 %i, 8 - %continue = icmp slt i64 %vinc, %lenm7 - br i1 %continue, label %L9, label %L32 - - L32: - %cumi = phi i64 [ 0, %top ], [ %len8, %L30 ] - %done = icmp eq i64 %cumi, %2 - br i1 %done, label %common.ret, label %L51 - - L51: - %si = phi i64 [ %inc, %L67 ], [ %cumi, %L32 ] - %spi = getelementptr inbounds i64, i64* %ivars, i64 %si - %svi = load i64, i64* %spi, align 8 - %match = icmp eq i64 %svi, %0 - br i1 %match, label %common.ret, label %L67 - - L67: - %inc = add i64 %si, 1 - %dobreak = icmp eq i64 %inc, %2 - br i1 %dobreak, label %common.ret, label %L51 - - } - attributes #0 = { alwaysinline } - """, "entry"), Int64, Tuple{Int64, Ptr{Int64}, Int64}, vpivot, pointer(ivars), - length(ivars)) - end - ret < 0 ? nothing : ret + 1 -end +using FindFirstFunctions: findfirstequal function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot, last_pivot; pivot_equal_optimization = true)