-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BPCG with direct solve extension (#507)
* added direct solve feature BPCG via LP solver * adjusted for arbitrary LP solver * fixed deps * cleanup and comment * minor * added reporting of direct solve step * chose highs as standard solver * added sparsification * added sparsification code * cleanup * minor cleanup * minor * added generalized direct_solve * clean up, docu, additional direct_solve * docstrings fixed? * sparsifier active set (#508) * sparsifier active set * fix typo * added sparsifying tests * generic tolerane * remove sparsification * format * HiGHS dep * Quadratic solve structure (#511) * sparsifier active set * start working on LP AS * first working quadratic * remove quadratic LP from current * cleanup * HiGHS in test deps * working reworked LP quadratic * working version generic quadratic * slow version generic quadratic * faster term manipulation * copy sufficient * remove comment * added test for quadratic * minor * simplify example * clean up code, verify error with ASQuad * Add update_weights! to fix direct solve with active_set_quadratic * remove direct solve from BPCG * rng changed --------- Co-authored-by: Sébastien Designolle <[email protected]> * update example * format * clean up example * fix callback --------- Co-authored-by: Mathieu Besançon <[email protected]> Co-authored-by: Sébastien Designolle <[email protected]>
- Loading branch information
1 parent
96e7ec9
commit 318d2b9
Showing
16 changed files
with
1,206 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
#= | ||
This example demonstrates the use of the Blended Pairwise Conditional Gradient algorithm | ||
with direct solve steps for a quadratic optimization problem over a sparse polytope. | ||
Note the special structure of f(x) = norm(x - x0)^2 that we assume here | ||
The example showcases how the algorithm balances between: | ||
- Pairwise steps for efficient optimization | ||
- Periodic direct solves for handling the quadratic objective | ||
- Lazy (approximate) linear minimization steps for improved iteration complexity | ||
It also demonstrates how to set up custom callbacks for tracking algorithm progress. | ||
=# | ||
|
||
using FrankWolfe | ||
using LinearAlgebra | ||
using Random | ||
|
||
import HiGHS | ||
import MathOptInterface as MOI | ||
|
||
include("../examples/plot_utils.jl") | ||
|
||
n = Int(1e4) | ||
k = 10_000 | ||
|
||
s = 10 | ||
@info "Seed $s" | ||
Random.seed!(s) | ||
|
||
xpi = rand(n); | ||
total = sum(xpi); | ||
|
||
const xp = xpi ./ total; | ||
|
||
f(x) = norm(x - xp)^2 | ||
function grad!(storage, x) | ||
@. storage = 2 * (x - xp) | ||
end | ||
|
||
lmo = FrankWolfe.KSparseLMO(5, 1.0) | ||
|
||
const x00 = FrankWolfe.compute_extreme_point(lmo, rand(n)) | ||
|
||
function build_callback(trajectory_arr) | ||
return function callback(state, active_set, args...) | ||
return push!(trajectory_arr, (FrankWolfe.callback_state(state)..., length(active_set))) | ||
end | ||
end | ||
|
||
|
||
trajectoryBPCG_standard = [] | ||
@time x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
copy(x00), | ||
max_iteration=k, | ||
line_search=FrankWolfe.Shortstep(2.0), | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_standard), | ||
); | ||
|
||
# Just projection quadratic | ||
trajectoryBPCG_quadratic = [] | ||
as_quad = FrankWolfe.ActiveSetQuadratic([(1.0, copy(x00))], 2 * LinearAlgebra.I, -2xp) | ||
@time x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
as_quad, | ||
max_iteration=k, | ||
line_search=FrankWolfe.Shortstep(2.0), | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic), | ||
); | ||
|
||
as_quad = FrankWolfe.ActiveSetQuadratic([(1.0, copy(x00))], 2 * LinearAlgebra.I, -2xp) | ||
|
||
# with quadratic active set | ||
trajectoryBPCG_quadratic_as = [] | ||
@time x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
as_quad, | ||
max_iteration=k, | ||
line_search=FrankWolfe.Shortstep(2.0), | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_as), | ||
); | ||
|
||
as_quad_direct = FrankWolfe.ActiveSetQuadraticLinearSolve( | ||
[(1.0, copy(x00))], | ||
2 * LinearAlgebra.I, | ||
-2xp, | ||
MOI.instantiate(MOI.OptimizerWithAttributes(HiGHS.Optimizer, MOI.Silent() => true)), | ||
) | ||
|
||
# with LP acceleration | ||
trajectoryBPCG_quadratic_direct = [] | ||
@time x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
as_quad_direct, | ||
max_iteration=k, | ||
line_search=FrankWolfe.Shortstep(2.0), | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_direct), | ||
); | ||
|
||
as_quad_direct_generic = FrankWolfe.ActiveSetQuadraticLinearSolve( | ||
[(1.0, copy(x00))], | ||
2 * Diagonal(ones(length(xp))), | ||
-2xp, | ||
MOI.instantiate(MOI.OptimizerWithAttributes(HiGHS.Optimizer, MOI.Silent() => true)), | ||
) | ||
|
||
# with LP acceleration | ||
trajectoryBPCG_quadratic_direct_generic = [] | ||
@time x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
as_quad_direct_generic, | ||
max_iteration=k, | ||
line_search=FrankWolfe.Shortstep(2.0), | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_direct_generic), | ||
); | ||
|
||
as_quad_direct_basic_as = FrankWolfe.ActiveSetQuadraticLinearSolve( | ||
FrankWolfe.ActiveSet([1.0], [copy(x00)], collect(x00)), | ||
2 * LinearAlgebra.I, | ||
-2xp, | ||
MOI.instantiate(MOI.OptimizerWithAttributes(HiGHS.Optimizer, MOI.Silent() => true)), | ||
) | ||
|
||
# with LP acceleration | ||
trajectoryBPCG_quadratic_noqas = [] | ||
|
||
@time x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
as_quad_direct_basic_as, | ||
max_iteration=k, | ||
line_search=FrankWolfe.Shortstep(2.0), | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_noqas), | ||
); | ||
|
||
|
||
# Update the data and labels for plotting | ||
data_trajectories = [ | ||
trajectoryBPCG_standard, | ||
trajectoryBPCG_quadratic, | ||
trajectoryBPCG_quadratic_as, | ||
trajectoryBPCG_quadratic_direct, | ||
trajectoryBPCG_quadratic_direct_generic, | ||
trajectoryBPCG_quadratic_noqas, | ||
] | ||
labels_trajectories = [ | ||
"BPCG (Standard)", | ||
"BPCG (Specific Direct)", | ||
"AS_Quad", | ||
"Reloaded", | ||
"Reloaded_generic", | ||
"Reloaded_noqas", | ||
] | ||
|
||
# Plot trajectories | ||
plot_trajectories(data_trajectories, labels_trajectories, xscalelog=false) |
146 changes: 146 additions & 0 deletions
146
examples/blended_pairwise_with_direct_non-standard-quadratic.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#= | ||
This example demonstrates the use of the Blended Pairwise Conditional Gradient algorithm | ||
with direct solve steps for a quadratic optimization problem over a sparse polytope which is not standard quadratic. | ||
The example showcases how the algorithm balances between: | ||
- Pairwise steps for efficient optimization | ||
- Periodic direct solves for handling the quadratic objective | ||
- Lazy (approximate) linear minimization steps for improved iteration complexity | ||
It also demonstrates how to set up custom callbacks for tracking algorithm progress. | ||
=# | ||
|
||
using FrankWolfe | ||
using LinearAlgebra | ||
using Random | ||
|
||
import HiGHS | ||
import MathOptInterface as MOI | ||
|
||
include("../examples/plot_utils.jl") | ||
|
||
n = Int(1e2) | ||
k = 10000 | ||
|
||
# s = rand(1:100) | ||
s = 10 | ||
@info "Seed $s" | ||
Random.seed!(s) | ||
|
||
A = let | ||
A = randn(n, n) | ||
A' * A | ||
end | ||
@assert isposdef(A) == true | ||
|
||
const y = Random.rand(Bool, n) * 0.6 .+ 0.3 | ||
|
||
function f(x) | ||
d = x - y | ||
return dot(d, A, d) | ||
end | ||
|
||
function grad!(storage, x) | ||
mul!(storage, A, x) | ||
return mul!(storage, A, y, -2, 2) | ||
end | ||
|
||
|
||
# lmo = FrankWolfe.KSparseLMO(5, 1000.0) | ||
|
||
## other LMOs to try | ||
# lmo_big = FrankWolfe.KSparseLMO(100, big"1.0") | ||
# lmo = FrankWolfe.LpNormLMO{Float64,5}(100.0) | ||
# lmo = FrankWolfe.ProbabilitySimplexOracle(100.0); | ||
lmo = FrankWolfe.UnitSimplexOracle(10000.0); | ||
|
||
x00 = FrankWolfe.compute_extreme_point(lmo, rand(n)) | ||
|
||
|
||
function build_callback(trajectory_arr) | ||
return function callback(state, active_set, args...) | ||
return push!(trajectory_arr, (FrankWolfe.callback_state(state)..., length(active_set))) | ||
end | ||
end | ||
|
||
|
||
trajectoryBPCG_standard = [] | ||
callback = build_callback(trajectoryBPCG_standard) | ||
|
||
x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
copy(x00), | ||
max_iteration=k, | ||
line_search=FrankWolfe.Adaptive(), | ||
print_iter=k / 10, | ||
memory_mode=FrankWolfe.InplaceEmphasis(), | ||
verbose=true, | ||
trajectory=true, | ||
callback=callback, | ||
); | ||
|
||
active_set_quadratic_automatic = FrankWolfe.ActiveSetQuadraticLinearSolve( | ||
[(1.0, copy(x00))], | ||
grad!, | ||
MOI.instantiate(MOI.OptimizerWithAttributes(HiGHS.Optimizer, MOI.Silent() => true)), | ||
scheduler=FrankWolfe.LogScheduler(start_time=100, scaling_factor=1.2, max_interval=100), | ||
) | ||
trajectoryBPCG_quadratic_automatic = [] | ||
x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
active_set_quadratic_automatic, | ||
max_iteration=k, | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_automatic), | ||
); | ||
|
||
active_set_quadratic_automatic2 = FrankWolfe.ActiveSetQuadraticLinearSolve( | ||
[(1.0, copy(x00))], | ||
grad!, | ||
MOI.instantiate(MOI.OptimizerWithAttributes(HiGHS.Optimizer, MOI.Silent() => true)), | ||
scheduler=FrankWolfe.LogScheduler(start_time=10, scaling_factor=2), | ||
) | ||
trajectoryBPCG_quadratic_automatic2 = [] | ||
x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
active_set_quadratic_automatic2, | ||
max_iteration=k, | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_automatic2), | ||
); | ||
|
||
|
||
active_set_quadratic_automatic_standard = FrankWolfe.ActiveSetQuadraticLinearSolve( | ||
FrankWolfe.ActiveSet([(1.0, copy(x00))]), | ||
grad!, | ||
MOI.instantiate(MOI.OptimizerWithAttributes(HiGHS.Optimizer, MOI.Silent() => true)), | ||
scheduler=FrankWolfe.LogScheduler(start_time=10, scaling_factor=2), | ||
) | ||
trajectoryBPCG_quadratic_automatic_standard = [] | ||
x, v, primal, dual_gap, _ = FrankWolfe.blended_pairwise_conditional_gradient( | ||
f, | ||
grad!, | ||
lmo, | ||
active_set_quadratic_automatic_standard, | ||
max_iteration=k, | ||
verbose=true, | ||
callback=build_callback(trajectoryBPCG_quadratic_automatic_standard), | ||
); | ||
|
||
|
||
dataSparsity = [ | ||
trajectoryBPCG_standard, | ||
trajectoryBPCG_quadratic_automatic, | ||
trajectoryBPCG_quadratic_automatic_standard, | ||
] | ||
labelSparsity = ["BPCG (Standard)", "AS_Quad", "AS_Standard"] | ||
|
||
|
||
# Plot trajectories | ||
plot_trajectories(dataSparsity, labelSparsity, xscalelog=false) |
Oops, something went wrong.