Skip to content

Commit

Permalink
KaHyPar selector and greedy merge (#46)
Browse files Browse the repository at this point in the history
* greedymerge

* size_reduction

* size_reduction

* spaces

* sapces

* add test on coverd_by

* fix test

* fix test

* update

* update-greedy-implementation

* merge and fix bugs

* fix doc

* save

* kahypar selector

* add more test

* rm simulated annealing

* speedup greedy method (#47)

* save

* update

* update

* update

* fix test

* use priority queue

* update

* add inbounds

* update

* ugly fix

* fix tests

* update

* new visualization

* update

* change name

---------

Co-authored-by: nzy <[email protected]>

---------

Co-authored-by: GiggleLiu <[email protected]>
Co-authored-by: nzy <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2025
1 parent 9f425a3 commit dd34198
Show file tree
Hide file tree
Showing 24 changed files with 625 additions and 53 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ lib/OptimalBranchingMIS/docs/build/
lib/OptimalBranchingCore/docs/build/

docs/src/generated/

report.typ
29 changes: 28 additions & 1 deletion examples/rule_discovery.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,31 @@ branching_region = SimpleGraph(Graphs.SimpleEdge.(edges))
# Generate the tree-like N3 neighborhood of R
graph = tree_like_N3_neighborhood(branching_region)

solve_opt_rule(branching_region, graph, vs)
solve_opt_rule(branching_region, graph, vs)


# ## Generating rules for large scale problems
# For large scale problems, we can use the greedy merge rule to generate rules, which avoids generating all candidate clauses.
function solve_greedy_rule(branching_region, graph, vs)
## Use default solver and measure
m = D3Measure()
table_solver = TensorNetworkSolver(; prune_by_env=true)

## Pruning irrelevant entries
ovs = OptimalBranchingMIS.open_vertices(graph, vs)
subg, vmap = induced_subgraph(graph, vs)
@info "solving the branching table..."
tbl = OptimalBranchingMIS.reduced_alpha_configs(table_solver, subg, Int[findfirst(==(v), vs) for v in ovs])
@info "the length of the truth_table after pruning irrelevant entries: $(length(tbl.table))"

@info "generating the optimal branching rule via greedy merge..."
candidates = OptimalBranchingCore.bit_clauses(tbl)
result = OptimalBranchingMIS.OptimalBranchingCore.greedymerge(candidates, MISProblem(graph), vs, m)
return result
@info "the greedily minimized gamma: $(result.γ)"

@info "the branching rule on R:"
viz_dnf(result.optimal_rule, vs)
end

result = solve_greedy_rule(branching_region, graph, vs)
2 changes: 2 additions & 0 deletions lib/OptimalBranchingCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ version = "0.1.1"

[deps]
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"

[compat]
BitBasis = "0.9"
DataStructures = "0.18.20"
HiGHS = "1.12"
JuMP = "1.23"
julia = "1.10"
Expand Down
5 changes: 4 additions & 1 deletion lib/OptimalBranchingCore/src/OptimalBranchingCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module OptimalBranchingCore

using JuMP, HiGHS
using BitBasis
using DataStructures

# logic expressions
export Clause, BranchingTable, DNF, booleans, , , ¬, covered_by, literals, is_true_literal, is_false_literal
Expand All @@ -16,7 +17,7 @@ export AbstractProblem, branch_and_reduce, BranchingStrategy
# variable selector interface
export select_variable, AbstractSelector
# branching table solver interface
export branching_table, AbstractTableSolver
export branching_table, AbstractTableSolver, NaiveBranch, GreedyMerge
# measure interface
export measure, AbstractMeasure
# reducer interface
Expand All @@ -30,5 +31,7 @@ include("interfaces.jl")
include("branching_table.jl")
include("setcovering.jl")
include("branch.jl")
include("greedymerge.jl")
include("mockproblem.jl")

end
54 changes: 37 additions & 17 deletions lib/OptimalBranchingCore/src/branch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ A [`OptimalBranchingResult`](@ref) object representing the optimal branching rul
"""
function optimal_branching_rule(table::BranchingTable, variables::Vector, problem::AbstractProblem, m::AbstractMeasure, solver::AbstractSetCoverSolver)
candidates = candidate_clauses(table)
size_reductions = [measure(problem, m) - measure(first(apply_branch(problem, candidate, variables)), m) for candidate in candidates]
return minimize_γ(table, candidates, size_reductions, solver; γ0=2.0)
size_reductions = [size_reduction(problem, m, candidate, variables) for candidate in candidates]
return minimize_γ(table, candidates, size_reductions, solver; γ0 = 2.0)
end

function size_reduction(p::AbstractProblem, m::AbstractMeasure, cl::Clause{INT}, variables::Vector) where {INT}
return measure(p, m) - measure(first(apply_branch(p, cl, variables)), m)
end


"""
BranchingStrategy
BranchingStrategy(; kwargs...)
Expand All @@ -31,23 +36,23 @@ A struct representing the configuration for a solver, including the reducer and
- `selector::AbstractSelector`: The selector to select the next branching variable or decision.
- `m::AbstractMeasure`: The measure to evaluate the performance of the branching strategy.
"""
@kwdef struct BranchingStrategy{TS<:AbstractTableSolver, SCS<:AbstractSetCoverSolver, SL<:AbstractSelector, M<:AbstractMeasure}
@kwdef struct BranchingStrategy{TS <: AbstractTableSolver, SCS <: AbstractSetCoverSolver, SL <: AbstractSelector, M <: AbstractMeasure}
set_cover_solver::SCS = IPSolver()
table_solver::TS
selector::SL
measure::M
end
Base.show(io::IO, config::BranchingStrategy) = print(io,
"""
BranchingStrategy
├── table_solver - $(config.table_solver)
├── set_cover_solver - $(config.set_cover_solver)
├── selector - $(config.selector)
└── measure - $(config.measure)
""")
Base.show(io::IO, config::BranchingStrategy) = print(io,
"""
BranchingStrategy
├── table_solver - $(config.table_solver)
├── set_cover_solver - $(config.set_cover_solver)
├── selector - $(config.selector)
└── measure - $(config.measure)
""")

"""
branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy; reducer::AbstractReducer=NoReducer(), result_type=Int)
branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy; reducer::AbstractReducer=NoReducer(), result_type=Int, show_progress=false)
Branch the given problem using the specified solver configuration.
Expand All @@ -62,19 +67,34 @@ Branch the given problem using the specified solver configuration.
### Returns
The resulting value, which may have different type depending on the `result_type`.
"""
function branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy, reducer::AbstractReducer, result_type)
function branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy, reducer::AbstractReducer, result_type; show_progress=false, tag=Tuple{Int,Int}[])
@debug "Branching and reducing problem" problem
isempty(problem) && return zero(result_type)
has_zero_size(problem) && return zero(result_type)
# reduce the problem
rp, reducedvalue = reduce_problem(result_type, problem, reducer)
rp !== problem && return branch_and_reduce(rp, config, reducer, result_type) * reducedvalue
rp !== problem && return branch_and_reduce(rp, config, reducer, result_type; tag) * reducedvalue

# branch the problem
variables = select_variables(rp, config.measure, config.selector) # select a subset of variables
tbl = branching_table(rp, config.table_solver, variables) # compute the BranchingTable
result = optimal_branching_rule(tbl, variables, rp, config.measure, config.set_cover_solver) # compute the optimal branching rule
return sum(result.optimal_rule.clauses) do branch # branch and recurse
return sum(enumerate(get_clauses(result))) do (i, branch) # branch and recurse
show_progress && (print_sequence(stdout, tag); println(stdout))
subproblem, localvalue = apply_branch(rp, branch, variables)
branch_and_reduce(subproblem, config, reducer, result_type) * result_type(localvalue) * reducedvalue
branch_and_reduce(subproblem, config, reducer, result_type;
tag=(show_progress ? [tag..., (i, length(get_clauses(result)))] : tag),
show_progress) * result_type(localvalue) * reducedvalue
end
end

function print_sequence(io::IO, sequence::Vector{Tuple{Int,Int}})
for (i, n) in sequence
if i == n
print(io, "")
elseif i == 1
print(io, "")
else
print(io, "")
end
end
end
78 changes: 78 additions & 0 deletions lib/OptimalBranchingCore/src/greedymerge.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
struct GreedyMerge <: AbstractSetCoverSolver end
struct NaiveBranch <: AbstractSetCoverSolver end
function optimal_branching_rule(table::BranchingTable, variables::Vector, problem::AbstractProblem, m::AbstractMeasure, solver::GreedyMerge)
candidates = bit_clauses(table)
return greedymerge(candidates, problem, variables, m)
end

function optimal_branching_rule(table::BranchingTable, variables::Vector, problem::AbstractProblem, m::AbstractMeasure, solver::NaiveBranch)
candidates = bit_clauses(table)
size_reductions = [Float64(size_reduction(problem, m, first(candidate), variables)) for candidate in candidates]
γ = complexity_bv(size_reductions)
return OptimalBranchingResult(DNF(first.(candidates)), size_reductions, γ)
end

function bit_clauses(tbl::BranchingTable{INT}) where {INT}
n, bss = tbl.bit_length, tbl.table
temp_clauses = [[Clause(bmask(INT, 1:n), bs) for bs in bss1] for bss1 in bss]
return temp_clauses
end

function greedymerge(cls::Vector{Vector{Clause{INT}}}, problem::AbstractProblem, variables::Vector, m::AbstractMeasure) where {INT}
function reduction_merge(cli, clj)
clmax, iimax, jjmax, reductionmax = Clause(zero(INT), zero(INT)), -1, -1, 0.0
@inbounds for ii = 1:length(cli), jj = 1:length(clj)
cl12 = gather2(length(variables), cli[ii], clj[jj])
iszero(cl12.mask) && continue
reduction = Float64(size_reduction(problem, m, cl12, variables))
if reduction > reductionmax
clmax, iimax, jjmax, reductionmax = cl12, ii, jj, reduction
end
end
return clmax, iimax, jjmax, reductionmax
end
cls = copy(cls)
size_reductions = [Float64(size_reduction(problem, m, first(candidate), variables)) for candidate in cls]
k = 0
@inbounds while true
nc = length(cls)
mask = trues(nc)
γ = complexity_bv(size_reductions)
weights = map(s -> γ^(-s), size_reductions)
queue = PriorityQueue{NTuple{2, Int}, Float64}() # from small to large
for i 1:nc, j i+1:nc
_, _, _, reduction = reduction_merge(cls[i], cls[j])
dE = γ^(-reduction) - weights[i] - weights[j]
dE <= -1e-12 && enqueue!(queue, (i, j), dE - 1e-12 * (k += 1; k))
end
isempty(queue) && return OptimalBranchingResult(DNF(first.(cls)), size_reductions, γ)
while !isempty(queue)
(i, j) = dequeue!(queue)
# remove i, j-th row
for rowid in (i, j)
mask[rowid] = false
for k = 1:nc
if mask[k]
a, b = minmax(rowid, k)
haskey(queue, (a, b)) && delete!(queue, (a, b))
end
end
end
# add i-th row
mask[i] = true
clij, _, _, size_reductions[i] = reduction_merge(cls[i], cls[j])
cls[i] = [clij]
weights[i] = γ^(-size_reductions[i])
for k = 1:nc
if i !== k && mask[k]
a, b = minmax(i, k)
_, _, _, reduction = reduction_merge(cls[a], cls[b])
dE = γ^(-reduction) - weights[a] - weights[b]

dE <= -1e-12 && enqueue!(queue, (a, b), dE - 1e-12 * (k += 1; k))
end
end
end
cls, size_reductions = cls[mask], size_reductions[mask]
end
end
76 changes: 76 additions & 0 deletions lib/OptimalBranchingCore/src/mockproblem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
struct MockProblem <: AbstractProblem
optimal::BitVector
end

"""
NumOfVariables
A struct representing a measure that counts the number of variables in a problem.
Each variable is counted as 1.
"""
struct NumOfVariables <: AbstractMeasure end
measure(p::MockProblem, ::NumOfVariables) = length(p.optimal)


"""
struct RandomSelector <: AbstractSelector
The `RandomSelector` struct represents a strategy for selecting a subset of variables randomly.
# Fields
- `n::Int`: The number of variables to select.
"""
struct RandomSelector <: AbstractSelector
n::Int
end
function select_variables(p::MockProblem, ::NumOfVariables, selector::RandomSelector)
nv = min(length(p.optimal), selector.n)
return sortperm(rand(length(p.optimal)))[1:nv]
end

"""
struct MockTableSolver <: AbstractTableSolver
The `MockTableSolver` randomly generates a branching table with a given number of rows.
Each row must have at least one variable to be covered by the branching rule.
### Fields
- `n::Int`: The number of rows in the branching table.
- `p::Float64 = 0.0`: The probability of generating more than one variables in a row, following the Poisson distribution.
"""
struct MockTableSolver <: AbstractTableSolver
n::Int
p::Float64
end
MockTableSolver(n::Int) = MockTableSolver(n, 0.0)
function branching_table(p::MockProblem, table_solver::MockTableSolver, variables)
function rand_fib() # random independent set on 1D chain
bs = falses(length(variables))
for i=1:length(variables)
if rand() < min(0.5, i == 1 ? 1.0 : 1 - bs[i-1])
bs[i] = true
end
end
return bs
end
rows = unique!([[rand_fib()] for _ in 1:table_solver.n] [[p.optimal[variables]]])
for i in 1:length(rows)
for _ = 1:100
if rand() < table_solver.p
push!(rows[i], rand_fib())
else
break
end
end
end
return BranchingTable(length(variables), unique!.(rows))
end

function apply_branch(p::MockProblem, clause::Clause{INT}, variables::Vector{T}) where {INT<:Integer, T<:Integer}
remain_mask = trues(length(p.optimal))
for i in 1:length(variables)
isone(readbit(clause.mask, i)) && (remain_mask[variables[i]] = false)
end
return MockProblem(p.optimal[remain_mask]), count(i -> isone(readbit(clause.mask, i)) && (readbit(clause.val, i) == p.optimal[variables[i]]), 1:length(variables))
end
has_zero_size(p::MockProblem) = measure(p, NumOfVariables()) == 0
14 changes: 7 additions & 7 deletions lib/OptimalBranchingCore/src/setcovering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ end
The result type for the optimal branching rule.
### Fields
- `selected_ids::Vector{Int}`: The indices of the selected rows in the branching table.
- `optimal_rule::DNF{INT}`: The optimal branching rule.
- `branching_vector::Vector{T<:Real}`: The branching vector that records the size reduction in each subproblem.
- `γ::Float64`: The optimal γ value (the complexity of the branching rule).
"""
struct OptimalBranchingResult{INT <: Integer, T <: Real}
selected_ids::Vector{Int}
optimal_rule::DNF{INT}
branching_vector::Vector{T}
γ::Float64
end
Base.show(io::IO, results::OptimalBranchingResult{INT, T}) where {INT, T} = print(io, "OptimalBranchingResult{$INT, $T}:\n selected_ids: $(results.selected_ids)\n optimal_rule: $(results.optimal_rule)\n branching_vector: $(results.branching_vector)\n γ: $(results.γ)")
Base.show(io::IO, results::OptimalBranchingResult{INT, T}) where {INT, T} = print(io, "OptimalBranchingResult{$INT, $T}:\n optimal_rule: $(results.optimal_rule)\n branching_vector: $(results.branching_vector)\n γ: $(results.γ)")
get_clauses(results::OptimalBranchingResult) = results.optimal_rule.clauses
get_clauses(res::AbstractArray) = res

"""
minimize_γ(table::BranchingTable, candidates::Vector{Clause}, Δρ::Vector, solver)
Expand Down Expand Up @@ -140,7 +140,7 @@ function minimize_γ(table::BranchingTable, candidates::Vector{Clause{INT}}, Δ

# Note: the following instance is captured for time saving, and also for it may cause IP solver to fail
for (k, subset) in enumerate(subsets)
(length(subset) == num_items) && return OptimalBranchingResult([k], DNF([candidates[k]]), [Δρ[k]], 1.0)
(length(subset) == num_items) && return OptimalBranchingResult(DNF([candidates[k]]), [Δρ[k]], 1.0)
end

cx_old = cx = γ0
Expand All @@ -153,7 +153,7 @@ function minimize_γ(table::BranchingTable, candidates::Vector{Clause{INT}}, Δ
cx cx_old && break # convergence
cx_old = cx
end
return OptimalBranchingResult(picked_scs, DNF([candidates[i] for i in picked_scs]), Δρ[picked_scs], cx)
return OptimalBranchingResult(DNF([candidates[i] for i in picked_scs]), Δρ[picked_scs], cx)
end

# TODO: we need to extend this function to trim the candidate clauses
Expand Down Expand Up @@ -204,7 +204,7 @@ function gather2(n::Int, c1::Clause{INT}, c2::Clause{INT}) where INT
return Clause(mask, val)
end

function is_solved(xs::Vector{T}, sets_id::Vector{Vector{Int}}, num_items::Int) where{T}
function is_solved_by(xs::Vector{T}, sets_id::Vector{Vector{Int}}, num_items::Int) where{T}
for i in 1:num_items
flag = sum(xs[j] for j in sets_id[i])
((flag < 1) && !(flag 1)) && return false
Expand Down Expand Up @@ -247,7 +247,7 @@ function weighted_minimum_set_cover(solver::LPSolver, weights::AbstractVector, s

optimize!(model)
xs = value.(x)
@assert is_solved(xs, sets_id, num_items)
@assert is_solved_by(xs, sets_id, num_items)
return pick_sets(xs, subsets, num_items)
end

Expand Down
2 changes: 1 addition & 1 deletion lib/OptimalBranchingCore/test/branching_table.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OptimalBranchingCore, GenericTensorNetworks
using BitBasis
using OptimalBranchingCore.BitBasis
using Test

@testset "branching table" begin
Expand Down
Loading

0 comments on commit dd34198

Please sign in to comment.