Skip to content

Commit

Permalink
Use lbfgsb as the default solver
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Mar 27, 2024
1 parent efe6038 commit d080318
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand Down
1 change: 1 addition & 0 deletions src/Optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export ObjSense, MaxSense, MinSense

include("utils.jl")
include("state.jl")
include("lbfgsb.jl")

export solve

Expand Down
79 changes: 79 additions & 0 deletions src/lbfgsb.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Optimization.SciMLBase, LBFGSB

@kwdef struct LBFGS
m::Int=10
end

SciMLBase.supports_opt_cache_interface(::LBFGS) = true
SciMLBase.allowsbounds(::LBFGS) = true
# SciMLBase.requiresgradient(::LBFGS) = true

function SciMLBase.__init(prob::SciMLBase.OptimizationProblem,
opt::LBFGS,
data = Optimization.DEFAULT_DATA; save_best = true,
callback = (args...) -> (false),
progress = false, kwargs...)
return OptimizationCache(prob, opt, data; save_best, callback, progress,
kwargs...)
end

function SciMLBase.__solve(cache::OptimizationCache{
F,
RC,
LB,
UB,
LC,
UC,
S,
O,
D,
P,
C
}) where {
F,
RC,
LB,
UB,
LC,
UC,
S,
O <:
LBFGS,
D,
P,
C
}
if cache.data != Optimization.DEFAULT_DATA
maxiters = length(cache.data)
data = cache.data
else
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
data = Optimization.take(cache.data, maxiters)
end

local x

_loss = function (θ)
x = cache.f(θ, cache.p)
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
return x[1]
end

t0 = time()
if cache.lb !== nothing && cache.ub !== nothing
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, maxiter = maxiters,
lb = cache.lb, ub = cache.ub)
else
res = lbfgsb(_loss, cache.f.grad, cache.u0; m = cache.opt.m, maxiter = maxiters)
end

t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)

return SciMLBase.build_solution(cache, cache.opt, res[2], res[1], stats = stats)
end

19 changes: 19 additions & 0 deletions test/lbfgsb.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Optimization
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
using ModelingToolkit, Enzyme, Random

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
l1 = rosenbrock(x0)

optf = OptimizationFunction(rosenbrock, AutoForwardDiff())
prob = OptimizationProblem(optf, x0)
res = solve(prob, Optimization.LBFGS(), maxiters = 100)

@test res.u [1.0, 1.0] atol=1e-3

optf = OptimizationFunction(rosenbrock, AutoZygote())
prob = OptimizationProblem(optf, x0, lb = [0.0, 0.0], ub = [0.3, 0.3])
res = solve(prob, Optimization.LBFGS(), maxiters = 100)

@test res.u [0.3, 0.09] atol=1e-3

0 comments on commit d080318

Please sign in to comment.