Skip to content

Commit

Permalink
Add example for BCFW
Browse files Browse the repository at this point in the history
  • Loading branch information
JannisHal committed Dec 4, 2023
1 parent 9e49a22 commit d34aeed
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
34 changes: 19 additions & 15 deletions examples/block_coordinate_frank_wolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ using LinearAlgebra

include("plot_utils.jl")

f(x) = dot(x[:, 1] - x[:, 2], x[:, 1] - x[:, 2])
f(x) = dot(x.blocks[1] - x.blocks[2], x.blocks[1] - x.blocks[2])

function grad!(storage, x)
g = 2 * hcat(x[:, 1] - x[:, 2], x[:, 2] - x[:, 1])
g = copy(x)
g.blocks = [x.blocks[1] - x.blocks[2], x.blocks[2] - x.blocks[1]]
@. storage = g
end

Expand All @@ -16,7 +17,7 @@ lmo1 = FrankWolfe.ScaledBoundLInfNormBall(-ones(n), zeros(n))
lmo2 = FrankWolfe.ProbabilitySimplexOracle(1.0)
prod_lmo = FrankWolfe.ProductLMO((lmo1, lmo2))

x0 = compute_extreme_point(prod_lmo, ones(n, 2))
x0 = FrankWolfe.BlockVector([-ones(n), [i == 1 ? 1 : 0 for i in 1:n]], [(n,), (n,)], 2 * n)

trajectories = []

Expand Down Expand Up @@ -49,18 +50,21 @@ end
labels = ["Full update", "Cyclic order", "Stochstic order", "Custom order"]
display(plot_trajectories(trajectories, labels, xscalelog=true))

# Example for running BCFW with different update methods
trajectories = []

# Example for running BCFW with different update methods and linesearches on different blocks
for us in [(FrankWolfe.BPCGStep(), FrankWolfe.FrankWolfeStep()), (FrankWolfe.FrankWolfeStep(), FrankWolfe.BPCGStep()), FrankWolfe.BPCGStep(), FrankWolfe.FrankWolfeStep()]

_, _, _, _, traj_data = FrankWolfe.block_coordinate_frank_wolfe(
f,
grad!,
prod_lmo,
x0;
verbose=true,
trajectory=true,
line_search = [FrankWolfe.Agnostic(), FrankWolfe.Adaptive()],
update_step = [FrankWolfe.BPCGStep(), FrankWolfe.FrankWolfeStep()]
)
_, _, _, _, traj_data = FrankWolfe.block_coordinate_frank_wolfe(
f,
grad!,
prod_lmo,
x0;
verbose=true,
trajectory=true,
update_step=us,
)
push!(trajectories, traj_data)
end

display(plot_trajectories([traj_Data], [""], xscalelog=true))
display(plot_trajectories(trajectories, ["BPCG FW", "FW BPCG", "BPCG", "FW"], xscalelog=true))
2 changes: 1 addition & 1 deletion examples/plot_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ end
markershape --> :auto
x := sx
y := sy
return z_order := 1
z_order := 1
end

function plot_trajectories(
Expand Down
5 changes: 3 additions & 2 deletions src/block_coordinate_algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ function block_coordinate_frank_wolfe(
"MEMORY_MODE: $memory_mode STEPSIZE: $line_search_type EPSILON: $epsilon MAXITERATION: $max_iteration TYPE: $num_type",
)
grad_type = typeof(gradient)
println("MOMENTUM: $momentum GRADIENTTYPE: $grad_type UPDATE_ORDER: $update_order")
update_step_type = [typeof(s) for s in update_step]
println("MOMENTUM: $momentum GRADIENTTYPE: $grad_type UPDATE_ORDER: $update_order UPDATE_STEP: $update_step_type")
if memory_mode isa InplaceEmphasis
@info("In memory_mode memory iterates are written back into x0!")
end
Expand Down Expand Up @@ -517,7 +518,7 @@ function block_coordinate_frank_wolfe(
v = compute_extreme_point(lmo, gradient)

primal = f(x)
dual_gap = fast_dot(x - v, gradient)
dual_gap = fast_dot(x, gradient) - fast_dot(v, gradient)

tot_time = (time_ns() - time_start) / 1.0e9

Expand Down

0 comments on commit d34aeed

Please sign in to comment.