Skip to content

Commit

Permalink
ALM and BCFW update (#450)
Browse files Browse the repository at this point in the history
* Update plot_sparsity

* Remove redundant CallbackStateBlockCoordinateMetho

* Isolate update step from BCFW

* SpectraplexLMO add field maxiters

* BCFW allow different stepsize strategies

* Adapt example

* ALM and BCFW for blockvectors

* Small changes

* Small progress on BPCG steps

* FIX BPCG step

* ALM inital point consistency

* Adapt tests

* Add collect for BlockVector

* Remove old CallbackState from Alternating project.

* Format

* Add tests for ALM and BCFW

* Update documentation

* BCFW example

* Add example for BCFW

* Update documentation

* Small documentation fix

* Documentation update
  • Loading branch information
JannisHal authored Jan 15, 2024
1 parent c8b81ad commit 01b9349
Show file tree
Hide file tree
Showing 10 changed files with 571 additions and 218 deletions.
11 changes: 11 additions & 0 deletions docs/src/reference/3_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ FrankWolfe.CyclicUpdate
FrankWolfe.StochasticUpdate
```

## Update step for block-coordinate Frank-Wolfe

Block-coordinate Frank-Wolfe (BCFW) can run different FW algorithms on different blocks. All update steps are subtypes of [`FrankWolfe.UpdateStep`](@ref) and implement [`FrankWolfe.update_iterate`](@ref) which defines one iteration of the corresponding method.

```@docs
FrankWolfe.UpdateStep
FrankWolfe.update_iterate
FrankWolfe.FrankWolfeStep
FrankWolfe.BPCGStep
```

## Index

```@index
Expand Down
5 changes: 3 additions & 2 deletions examples/alm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using FrankWolfe

n = Int(1e4)
n = Int(1e1)

xpi = rand(1:100, n)
total = sum(xpi)
Expand Down Expand Up @@ -29,7 +29,8 @@ for pair in lmo_pairs
grad!,
pair,
zeros(n);
lambda=1.0,
update_order=FrankWolfe.FullUpdate(),
verbose=true,
update_step=FrankWolfe.BPCGStep(),
)
end
13 changes: 6 additions & 7 deletions examples/alm_sdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ include("../examples/plot_utils.jl")
f(x) = 0.0

function grad!(storage, x)
@. storage = 0
@. storage = zero(x)
end

dim = 30
dim = 10

m = JuMP.Model(GLPK.Optimizer)
@variable(m, x[1:dim, 1:dim])
Expand All @@ -22,23 +22,22 @@ m = JuMP.Model(GLPK.Optimizer)
@constraint(m, x .>= 0)


lmos = (FrankWolfe.SpectraplexLMO(1.0, dim, true), FrankWolfe.MathOptLMO(m.moi_backend))
x0 = rand(dim, dim)
lmos = (FrankWolfe.SpectraplexLMO(1.0, dim), FrankWolfe.MathOptLMO(m.moi_backend))
x0 = (zeros(dim, dim), Matrix(I(dim) ./ dim))

trajectories = []

for order in [FrankWolfe.FullUpdate(), FrankWolfe.CyclicUpdate(), FrankWolfe.StochasticUpdate()]

_, _, _, _, _, traj_data = FrankWolfe.alternating_linear_minimization(
FrankWolfe.block_coordinate_frank_wolfe,
f,
grad!,
lmos,
x0;
update_order=order,
line_search=FrankWolfe.Adaptive(relaxed_smoothness=true),
verbose=true,
trajectory=true,
update_step=FrankWolfe.BPCGStep(),
)
push!(trajectories, traj_data)
end
Expand All @@ -51,7 +50,7 @@ fp = plot_trajectories(
legend_position=:best,
xscalelog=true,
reduce_size=true,
marker_shapes=[:dtriangle, :rect, :circle],
marker_shapes=[:dtriangle, :rect, :circle, :dtriangle, :rect, :circle],
)

display(fp)
28 changes: 24 additions & 4 deletions examples/block_coordinate_frank_wolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ using LinearAlgebra

include("plot_utils.jl")

f(x) = dot(x[:, 1] - x[:, 2], x[:, 1] - x[:, 2])
f(x) = dot(x.blocks[1] - x.blocks[2], x.blocks[1] - x.blocks[2])

function grad!(storage, x)
g = 2 * hcat(x[:, 1] - x[:, 2], x[:, 2] - x[:, 1])
g = copy(x)
g.blocks = [x.blocks[1] - x.blocks[2], x.blocks[2] - x.blocks[1]]
@. storage = g
end

Expand All @@ -16,7 +17,7 @@ lmo1 = FrankWolfe.ScaledBoundLInfNormBall(-ones(n), zeros(n))
lmo2 = FrankWolfe.ProbabilitySimplexOracle(1.0)
prod_lmo = FrankWolfe.ProductLMO((lmo1, lmo2))

x0 = compute_extreme_point(prod_lmo, ones(n, 2))
x0 = FrankWolfe.BlockVector([-ones(n), [i == 1 ? 1 : 0 for i in 1:n]], [(n,), (n,)], 2 * n)

trajectories = []

Expand Down Expand Up @@ -47,4 +48,23 @@ for order in [
end

labels = ["Full update", "Cyclic order", "Stochstic order", "Custom order"]
plot_trajectories(trajectories, labels, xscalelog=true)
display(plot_trajectories(trajectories, labels, xscalelog=true))

# Example for running BCFW with different update methods
trajectories = []

for us in [(FrankWolfe.BPCGStep(), FrankWolfe.FrankWolfeStep()), (FrankWolfe.FrankWolfeStep(), FrankWolfe.BPCGStep()), FrankWolfe.BPCGStep(), FrankWolfe.FrankWolfeStep()]

_, _, _, _, traj_data = FrankWolfe.block_coordinate_frank_wolfe(
f,
grad!,
prod_lmo,
x0;
verbose=true,
trajectory=true,
update_step=us,
)
push!(trajectories, traj_data)
end

display(plot_trajectories(trajectories, ["BPCG FW", "FW BPCG", "BPCG", "FW"], xscalelog=true))
146 changes: 64 additions & 82 deletions examples/plot_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,95 +334,77 @@ function plot_trajectories(
end

function plot_sparsity(
data, label;
filename=nothing, xscalelog=false,
legend_position=:topright, yscalelog=true,
lstyle=fill(:solid, length(data)),
marker_shapes=nothing,
n_markers=10,
empty_marker=false,
)
data,
label;
filename=nothing,
xscalelog=false,
legend_position=:topright,
yscalelog=true,
lstyle=fill(:solid, length(data)),
marker_shapes=nothing,
n_markers=10,
empty_marker=false,
reduce_size=false,
)
Plots.gr()

x = []
y = []
ps = nothing
ds = nothing
offset = 2
xscale = xscalelog ? :log : :identity
yscale = yscalelog ? :log : :identity
for i in eachindex(data)
trajectory = data[i]
x = [trajectory[j][6] for j in offset:length(trajectory)]
y = [trajectory[j][2] for j in offset:length(trajectory)]
if marker_shapes !== nothing && n_markers >= 2
marker_args = Dict(
:st => :samplemarkers,
:n_markers => n_markers,
:shape => marker_shapes[i],
:log => xscalelog,
:startmark => 5+20*(i-1),
:markercolor => empty_marker ? :white : :match,
:markerstrokecolor => empty_marker ? i : :match,
)
else
marker_args = Dict()
end
if i == 1
ps = plot(
x,
y;
label=label[i],
xaxis=xscale,
yaxis=yscale,
ylabel="Primal",
legend=legend_position,
yguidefontsize=8,
xguidefontsize=8,
legendfontsize=8,
linestyle=lstyle[i],
marker_args...
)
else
plot!(x, y; label=label[i], linestyle=lstyle[i], marker_args...)
end
end
for i in eachindex(data)
trajectory = data[i]
x = [trajectory[j][6] for j in offset:length(trajectory)]
y = [trajectory[j][4] for j in offset:length(trajectory)]
if marker_shapes !== nothing && n_markers >= 2
marker_args = Dict(
:st => :samplemarkers,
:n_markers => n_markers,
:shape => marker_shapes[i],
:log => xscalelog,
:startmark => 5+20*(i-1),
:markercolor => empty_marker ? :white : :match,
:markerstrokecolor => empty_marker ? i : :match,
)
else
marker_args = Dict()
end
if i == 1
ds = plot(
x,
y;
label=label[i],
legend=false,
xaxis=xscale,
yaxis=yscale,
ylabel="FW gap",
yguidefontsize=8,
xguidefontsize=8,
linestyle=lstyle[i],
marker_args...
)
else
plot!(x, y; label=label[i], linestyle=lstyle[i], marker_args...)
offset = 2

function subplot(idx_x, idx_y, ylabel)

fig = nothing
for (i, trajectory) in enumerate(data)

l = length(trajectory)
if reduce_size && l > 1000
indices = Int.(round.(collect(1:l/1000:l)))
trajectory = trajectory[indices]
end


x = [trajectory[j][idx_x] for j in offset:length(trajectory)]
y = [trajectory[j][idx_y] for j in offset:length(trajectory)]
if marker_shapes !== nothing && n_markers >= 2
marker_args = Dict(
:st => :samplemarkers,
:n_markers => n_markers,
:shape => marker_shapes[i],
:log => xscalelog,
:startmark => 5 + 20 * (i - 1),
:markercolor => empty_marker ? :white : :match,
:markerstrokecolor => empty_marker ? i : :match,
)
else
marker_args = Dict()
end
if i == 1
fig = plot(
x,
y;
label=label[i],
xaxis=xscale,
yaxis=yscale,
ylabel=ylabel,
legend=legend_position,
yguidefontsize=8,
xguidefontsize=8,
legendfontsize=8,
linestyle=lstyle[i],
marker_args...,
)
else
plot!(x, y; label=label[i], linestyle=lstyle[i], marker_args...)
end
end

return fig
end

ps = subplot(6, 2, "Primal")
ds = subplot(6, 4, "FW gap")

fp = plot(ps, ds, layout=(1, 2)) # layout = @layout([A{0.01h}; [B C; D E]]))
plot!(size=(600, 200))
if filename !== nothing
Expand Down
Loading

0 comments on commit 01b9349

Please sign in to comment.