Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weak separation oracle #406

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
9 changes: 8 additions & 1 deletion docs/src/reference/2_lmo.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The Linear Minimization Oracle (LMO) is a key component called at each iteration
v\in \argmin_{x\in \mathcal{C}} \langle d,x \rangle.
```

See [Combettes, Pokutta 2021](https://arxiv.org/abs/2101.10040) for references on most LMOs
See [Combettes, Pokutta 2021](https://arxiv.org/abs/2101.10040) for references on essential LMOs
implemented in the package and their comparison with projection operators.

## Interface and wrappers
Expand All @@ -19,6 +19,13 @@ All of them are subtypes of [`FrankWolfe.LinearMinimizationOracle`](@ref) and im
compute_extreme_point
```

Optionally, an LMO can implement a weak separation procedure based either on a heuristic or on an approximation algorithm:
```@docs
compute_weak_separation_point
```

Weak separation procedures will be used in the methods using an active set and lazified variants only.

We also provide some meta-LMOs wrapping another one with extended behavior:
```@docs
FrankWolfe.CachedLinearMinimizationOracle
Expand Down
113 changes: 108 additions & 5 deletions src/abstract_oracles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@
"""
function compute_extreme_point end

"""
compute_weak_separation_point(lmo, direction, max_value) -> (vertex, gap)

Weak separation algorithm for a given oracle.
Unlike `compute_extreme_point`, `compute_weak_separation_point` may provide a suboptimal `vertex` with the following conditions:
- `vertex` is still a valid extreme point of the polytope.
IF an inexact vertex is computed:
- `⟨v, d⟩ ≤ max_value`, the pre-specified required improvement.
- `⟨v, d⟩ ≤ ⟨v_opt, d⟩ + gap`, with `v_opt` a vertex computed with the exact oracle.
- If the algorithm used to compute the inexact vertex provides a bound on optimality, the `gap` value must be valid.
- Otherwise, e.g. if the vertex is computed with a heuristic, `gap = ∞`.
ELSE (the oracle computes an optimal vertex):
- `⟨v, d⟩` may be greater than `max_value`, `gap` must be 0.
"""
function compute_weak_separation_point(lmo, direction, max_value; kwargs...) end

Check warning on line 34 in src/abstract_oracles.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_oracles.jl#L34

Added line #L34 was not covered by tests

# default to computing an exact vertex.
function compute_weak_separation_point(lmo::LinearMinimizationOracle, direction, max_value; kwargs...)
v = compute_extreme_point(lmo, direction; kwargs...)
gap = zero(eltype(v)) * zero(eltype(direction))
return v, gap
end

"""
CachedLinearMinimizationOracle{LMO}

Expand All @@ -43,11 +66,29 @@
mutable struct SingleLastCachedLMO{LMO,A} <: CachedLinearMinimizationOracle{LMO}
last_vertex::Union{Nothing,A}
inner::LMO
store_cache::Bool
end

# initializes with no cache by default
SingleLastCachedLMO(lmo::LMO) where {LMO<:LinearMinimizationOracle} =
SingleLastCachedLMO{LMO,AbstractVector}(nothing, lmo)
SingleLastCachedLMO{LMO,AbstractVector}(nothing, lmo, true)

# gap is 0 if exact, ∞ if cached point
function compute_weak_separation_point(lmo::SingleLastCachedLMO, direction, max_value; kwargs...)
if lmo.last_vertex !== nothing && isfinite(max_value)

Check warning on line 78 in src/abstract_oracles.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_oracles.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
# cache is a sufficiently-decreasing direction
if fast_dot(lmo.last_vertex, direction) ≤ max_value
T = promote_type(eltype(lmo.last_vertex), eltype(direction))
return lmo.last_vertex, T(Inf)

Check warning on line 82 in src/abstract_oracles.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_oracles.jl#L80-L82

Added lines #L80 - L82 were not covered by tests
end
end
v = compute_extreme_point(lmo.inner, direction, kwargs...)
if lmo.store_cache
lmo.last_vertex = v

Check warning on line 87 in src/abstract_oracles.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_oracles.jl#L85-L87

Added lines #L85 - L87 were not covered by tests
end
T = promote_type(eltype(v), eltype(direction))
return v, zero(T)

Check warning on line 90 in src/abstract_oracles.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_oracles.jl#L89-L90

Added lines #L89 - L90 were not covered by tests
end

function compute_extreme_point(
lmo::SingleLastCachedLMO,
Expand All @@ -62,7 +103,7 @@
return lmo.last_vertex
end
end
v = compute_extreme_point(lmo.inner, direction, kwargs...)
v = compute_extreme_point(lmo.inner, direction, v=v, kwargs...)
if store_cache
lmo.last_vertex = v
end
Expand Down Expand Up @@ -188,14 +229,17 @@
CachedLinearMinimizationOracle{LMO}
vertices::Vector{VT}
inner::LMO
store_cache::Bool
greedy::Bool
weak_separation::Bool
end

function VectorCacheLMO{LMO,VT}(lmo::LMO) where {VT,LMO<:LinearMinimizationOracle}
return VectorCacheLMO{LMO,VT}(VT[], lmo)
return VectorCacheLMO{LMO,VT}(VT[], lmo, true, false, false)
end

function VectorCacheLMO(lmo::LMO) where {LMO<:LinearMinimizationOracle}
return VectorCacheLMO{LMO,Vector{Float64}}(AbstractVector[], lmo)
return VectorCacheLMO{LMO,Vector{Float64}}(AbstractVector[], lmo, true, false, false)
end

function Base.empty!(lmo::VectorCacheLMO)
Expand All @@ -205,6 +249,65 @@

Base.length(lmo::VectorCacheLMO) = length(lmo.vertices)

function compute_weak_separation_point(lmo::VectorCacheLMO, direction, max_value; kwargs...)
if isempty(lmo.vertices)
v, gap = if lmo.weak_separation
compute_weak_separation_point(lmo.inner, direction, max_value; kwargs...)
else
v = compute_extreme_point(lmo.inner, direction; kwargs...)
v, zero(eltype(v))
end
T = promote_type(eltype(v), eltype(direction))
if lmo.store_cache
push!(lmo.vertices, v)
end
return v, T(gap)
end
best_idx = -1
best_val = Inf
best_v = nothing
for idx in reverse(eachindex(lmo.vertices))
@inbounds v = lmo.vertices[idx]
new_val = fast_dot(v, direction)
if new_val ≤ max_value
T = promote_type(eltype(v), eltype(direction))
# stop and return
if lmo.greedy
return v, T(Inf)

Check warning on line 276 in src/abstract_oracles.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_oracles.jl#L276

Added line #L276 was not covered by tests
end
# otherwise, compare to incumbent
if new_val < best_val
best_v = v
best_val = new_val
best_idx = idx
end
end
end
if best_idx > 0
T = promote_type(eltype(best_v), eltype(direction))
return best_v, T(Inf)
end
# no satisfactory vertex found, call oracle
v, gap = if lmo.weak_separation
compute_weak_separation_point(lmo.inner, direction, max_value; kwargs...)
else
v = compute_extreme_point(lmo.inner, direction; kwargs...)
v, zero(eltype(v))
end
if lmo.store_cache
# note: we do not check for duplicates. hence you might end up with more vertices,
# in fact up to number of dual steps many, that might be already in the cache
# in order to reach this point, if v was already in the cache is must not meet the threshold (otherwise we would have returned it)
# and it is the best possible, hence we will perform a dual step on the outside.
#
# note: another possibility could be to test against that in the if statement but then you might end you recalculating the same vertex a few times.
# as such this might be a better tradeoff, i.e., to not check the set for duplicates and potentially accept #dual_steps many duplicates.
push!(lmo.vertices, v)
end
T = promote_type(eltype(v), eltype(direction))
return v, T(gap)
end

function compute_extreme_point(
lmo::VectorCacheLMO,
direction;
Expand All @@ -228,7 +331,7 @@
@inbounds v = lmo.vertices[idx]
new_val = fast_dot(v, direction)
if new_val ≤ threshold
# stop, store and return
# stop and return
if greedy
return v
end
Expand Down
64 changes: 46 additions & 18 deletions src/afw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function away_frank_wolfe(
use_extra_vertex_storage=false,
linesearch_workspace=nothing,
recompute_last_vertex=true,
weak_separation=false,
)
# add the first vertex to active set from initialization
active_set = ActiveSet([(1.0, x0)])
Expand Down Expand Up @@ -64,6 +65,7 @@ function away_frank_wolfe(
use_extra_vertex_storage=use_extra_vertex_storage,
linesearch_workspace=linesearch_workspace,
recompute_last_vertex=recompute_last_vertex,
weak_separation=weak_separation,
)
end

Expand Down Expand Up @@ -95,6 +97,7 @@ function away_frank_wolfe(
use_extra_vertex_storage=false,
linesearch_workspace=nothing,
recompute_last_vertex=true,
weak_separation=false,
)
# format string for output of the algorithm
format_string = "%6s %13s %14e %14e %14e %14e %14e %14i\n"
Expand Down Expand Up @@ -153,7 +156,7 @@ function away_frank_wolfe(
)
grad_type = typeof(gradient)
println(
"GRADIENTTYPE: $grad_type LAZY: $lazy lazy_tolerance: $lazy_tolerance MOMENTUM: $momentum AWAYSTEPS: $away_steps",
"GRADIENT TYPE: $grad_type LAZY: $lazy LAZY_TOLERANCE: $lazy_tolerance WEAK_SEPARATION: $weak_separation MOMENTUM: $momentum AWAYSTEPS: $away_steps",
)
println("Linear Minimization Oracle: $(typeof(lmo))")
if (use_extra_vertex_storage || add_dropped_vertices) && extra_vertex_storage === nothing
Expand Down Expand Up @@ -223,10 +226,16 @@ function away_frank_wolfe(
extra_vertex_storage=extra_vertex_storage,
lazy_tolerance=lazy_tolerance,
memory_mode=memory_mode,
weak_separation=weak_separation,
)
else
d, vertex, index, gamma_max, phi_value, away_step_taken, fw_step_taken, tt =
afw_step(x, gradient, lmo, active_set, epsilon, d, memory_mode=memory_mode)
afw_step(
x, gradient, lmo, active_set, epsilon, d,
memory_mode=memory_mode,
weak_separation=weak_separation,
lazy_tolerance=lazy_tolerance,
)
end
else
d, vertex, index, gamma_max, phi_value, away_step_taken, fw_step_taken, tt =
Expand All @@ -248,7 +257,7 @@ function away_frank_wolfe(
memory_mode,
)

gamma = min(gamma_max, gamma)
gamma = min(gamma_max, gamma)
# cleanup and renormalize every x iterations. Only for the fw steps.
renorm = mod(t, renorm_interval) == 0
if away_step_taken
Expand Down Expand Up @@ -368,9 +377,9 @@ function away_frank_wolfe(
return x, v, primal, dual_gap, traj_data, active_set
end

function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_vertex_storage=false, extra_vertex_storage=nothing, lazy_tolerance=2.0, memory_mode::MemoryEmphasis=InplaceEmphasis())
function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_vertex_storage=false, extra_vertex_storage=nothing, lazy_tolerance=2.0, memory_mode::MemoryEmphasis=InplaceEmphasis(), weak_separation::Bool=true)
_, v, v_loc, _, a_lambda, a, a_loc, _, _ = active_set_argminmax(active_set, gradient)
#Do lazy FW step
# do lazy FW step
grad_dot_lazy_fw_vertex = fast_dot(v, gradient)
grad_dot_x = fast_dot(x, gradient)
grad_dot_a = fast_dot(a, gradient)
Expand All @@ -385,7 +394,7 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
fw_step_taken = true
index = v_loc
else
#Do away step, as it promises enough progress.
# do away step, as it promises enough progress.
if grad_dot_a - grad_dot_x > grad_dot_x - grad_dot_lazy_fw_vertex &&
grad_dot_a - grad_dot_x >= phi / lazy_tolerance
tt = away
Expand All @@ -395,7 +404,7 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
away_step_taken = true
fw_step_taken = false
index = a_loc
#Resort to calling the LMO
# resort to calling the LMO
else
# optionally: try vertex storage
if use_extra_vertex_storage
Expand All @@ -406,26 +415,34 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
@debug("Found acceptable lazy vertex in storage")
v = new_forward_vertex
tt = lazylazy
end
else
found_better_vertex = false
end
if !found_better_vertex
# compute new vertex with normal or weak oracle
if weak_separation
lazy_threshold = fast_dot(gradient, x) - phi / lazy_tolerance
(v, gap) = compute_weak_separation_point(lmo, gradient, lazy_threshold)
tt = gap == 0.0 ? regular : weaksep
else
v = compute_extreme_point(lmo, gradient)
gap = zero(eltype(v))
tt = regular
end
else
v = compute_extreme_point(lmo, gradient)
tt = regular
end
# Real dual gap promises enough progress.
grad_dot_fw_vertex = fast_dot(v, gradient)
dual_gap = grad_dot_x - grad_dot_fw_vertex
# Real dual gap promises enough progress.
if dual_gap >= phi / lazy_tolerance
gamma_max = one(a_lambda)
d = muladd_memory_mode(memory_mode, d, x, v)
vertex = v
away_step_taken = false
fw_step_taken = true
index = -1
#Lower our expectation for progress.
else
else # lower our expectation for progress.
@assert tt != weaksep
tt = dualstep
phi = min(dual_gap, phi / 2.0)
gamma_max = zero(a_lambda)
Expand All @@ -439,14 +456,25 @@ function lazy_afw_step(x, gradient, lmo, active_set, phi, epsilon, d; use_extra_
return d, vertex, index, gamma_max, phi, away_step_taken, fw_step_taken, tt
end

function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryEmphasis=InplaceEmphasis())
function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryEmphasis=InplaceEmphasis(), weak_separation::Bool=false, lazy_tolerance=2.0)
_, _, _, _, a_lambda, a, a_loc = active_set_argminmax(active_set, gradient)
v = compute_extreme_point(lmo, gradient)
grad_dot_x = fast_dot(x, gradient)
away_gap = fast_dot(a, gradient) - grad_dot_x
(v, gap) = if weak_separation
# Condition for taking a FW step
# ⟨∇f, x-v⟩ ≥ gₐ <=>
# ⟨∇f, v⟩ ≤ ⟨∇f, x⟩ - gₐ
# We ask for a bit more progress on the FW step
# to promote away steps when we can (and therefore sparsity)
# ⟨∇f, v⟩ ≤ ⟨∇f, x⟩ - K gₐ
lazy_threshold = grad_dot_x - lazy_tolerance * away_gap
compute_weak_separation_point(lmo, gradient, lazy_threshold)
else
(compute_extreme_point(lmo, gradient), zero(away_gap))
end
dual_gap = grad_dot_x - fast_dot(v, gradient)
if dual_gap >= away_gap && dual_gap >= epsilon
tt = regular
if dual_gap > away_gap && dual_gap >= epsilon
tt = gap == 0.0 ? regular : weaksep
gamma_max = one(a_lambda)
d = muladd_memory_mode(memory_mode, d, x, v)
vertex = v
Expand All @@ -469,7 +497,7 @@ function afw_step(x, gradient, lmo, active_set, epsilon, d; memory_mode::MemoryE
fw_step_taken = false
index = a_loc
end
return d, vertex, index, gamma_max, dual_gap, away_step_taken, fw_step_taken, tt
return d, vertex, index, gamma_max, dual_gap + gap, away_step_taken, fw_step_taken, tt
end

function fw_step(x, gradient, lmo, d; memory_mode::MemoryEmphasis = InplaceEmphasis())
Expand Down
2 changes: 2 additions & 0 deletions src/defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct OutplaceEmphasis <: MemoryEmphasis end
away = 6
pairwise = 7
drop = 8
weaksep = 9
simplex_descent = 101
gap_step = 102
last = 1000
Expand All @@ -34,6 +35,7 @@ const st = (
away="A",
pairwise="P",
drop="D",
weaksep="W",
simplex_descent="SD",
gap_step="GS",
last="Last",
Expand Down
Loading