Skip to content

Commit

Permalink
added test
Browse files Browse the repository at this point in the history
  • Loading branch information
matbesancon committed Sep 29, 2023
1 parent 738eedf commit 975f5ab
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ build_linesearch_workspace(::LineSearchMethod, x, gradient) = nothing
"""
Computes step size: `l/(l + t)` at iteration `t`, given `l > 0`.
See:
Using `l ≥ 4` is advised only for strongly convex sets, see:
> Acceleration of Frank-Wolfe Algorithms with Open-Loop Step-Sizes, Wirth, Kerdreux, Pokutta, 2023.
"""
struct Agnostic{T<:Real} <: LineSearchMethod
l::Int
end

Agnostic() = Agnostic{Float64}(2)
Agnostic(l::Int) = Agnostic{Float64}(l)

Agnostic{T}() where {T} = Agnostic{T}(2)

Expand Down
91 changes: 91 additions & 0 deletions test/trajectory_tests/open_loop_parametric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using FrankWolfe

using Test
using LinearAlgebra

@testset "Open-loop FW on polytope" begin
n = Int(1e2)
k = Int(1e4)

xp = ones(n)
f(x) = norm(x - xp)^2
function grad!(storage, x)
@. storage = 2 * (x - xp)
end

lmo = FrankWolfe.KSparseLMO(40, 1.0)

x0 = FrankWolfe.compute_extreme_point(lmo, zeros(n))

res_2 = FrankWolfe.frank_wolfe(
f,
grad!,
lmo,
copy(x0),
max_iteration=k,
line_search=FrankWolfe.Agnostic(2),
print_iter=k / 10,
epsilon=1e-5,
verbose=true,
trajectory=true,
)

res_10 = FrankWolfe.frank_wolfe(
f,
grad!,
lmo,
copy(x0),
max_iteration=k,
line_search=FrankWolfe.Agnostic(10),
print_iter=k / 10,
epsilon=1e-5,
verbose=true,
trajectory=true,
)

@test res_2[4] 0.004799839951985518
@test res_10[4] 0.02399919272834694

# strongly convex set
xp2 = 10 * ones(n)
diag_term = 100 * rand(n)
covariance_matrix = LinearAlgebra.Diagonal(diag_term)
lmo2 = FrankWolfe.EllipsoidLMO(covariance_matrix)

f2(x) = norm(x - xp2)^2
function grad2!(storage, x)
@. storage = 2 * (x - xp2)
end

x0 = FrankWolfe.compute_extreme_point(lmo2, randn(n))

res_2 = FrankWolfe.frank_wolfe(
f2,
grad2!,
lmo2,
copy(x0),
max_iteration=k,
line_search=FrankWolfe.Agnostic(2),
print_iter=k / 10,
epsilon=1e-5,
verbose=true,
trajectory=true,
)

res_10 = FrankWolfe.frank_wolfe(
f2,
grad2!,
lmo2,
copy(x0),
max_iteration=k,
line_search=FrankWolfe.Agnostic(10),
print_iter=k / 10,
epsilon=1e-5,
verbose=true,
trajectory=true,
)

@test length(res_10[end]) <= 8
@test length(res_2[end]) <= 71

end

0 comments on commit 975f5ab

Please sign in to comment.