Skip to content

Commit

Permalink
Actually works now
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 28, 2023
1 parent 3edfce6 commit a52eb0f
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 93 deletions.
65 changes: 41 additions & 24 deletions lib/OptimizationSolvers/src/bfgs.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

struct BFGS
ϵ::Float64
m::Int
@kwdef struct BFGS
ϵ::Float64=1e-6
end

SciMLBase.supports_opt_cache_interface(opt::BFGS) = true
Expand Down Expand Up @@ -50,49 +49,68 @@ function SciMLBase.__solve(cache::OptimizationCache{
end
opt = cache.opt
θ = copy(cache.u0)
G = zeros(length(θ))
g₀ = zeros(length(θ))
f = cache.f

_f = (θ) -> first(f.f(θ, cache.p))

ϕ(α) = _f.+ α.*s)
function (α)
f.grad(G, θ .+ α.*s)
return dot(G, s)
function ϕ(u, du)
function ϕ_internal(α)
u_ = u - α * du
_fu = _f(u_)
return dot(_fu, _fu) / 2
end
return ϕ_internal
end

function (u, du)
function dϕ_internal(α)
u_ = u - α * du
_fu = _f(u_)
f.grad(g₀, u_)
return dot(g₀, -du)
end
return dϕ_internal
end
function ϕdϕ(α)
phi = _f.+ α.*s)
f.grad(G, θ .+ α.*s)
dphi = dot(G, s)
return (phi, dphi)

function ϕdϕ(u, du)
function ϕdϕ_internal(α)
u_ = u - α * du
_fu = _f(u_)
f.grad(g₀, u_)
return dot(_fu, _fu) / 2, dot(g₀, -du)
end
return ϕdϕ_internal
end

Hₖ⁻¹= zeros(length(θ), length(θ))
f.hess(Hₖ⁻¹, θ)
println(Hₖ⁻¹)
Hₖ⁻¹ = inv(I(length(θ)) .+ Hₖ⁻¹)
G = zeros(length(θ))
f.grad(G, θ)
s = -1 * Hₖ⁻¹ * G

t0 = time()
for i in 1:maxiters
println(i, " ", θ, " Objective: ", f(θ, cache.p))
q = copy(G)
@show q
pₖ = -Hₖ⁻¹* G
fx = _f(θ)
dir = G' * pₖ
println(dir)
dir = -pₖ

if isnan(dir) || dir > 0
if all(isnan.(dir)) || all(dir .> 0)
pₖ = -G
dir = -G'*G
dir = G
end

αₖ = let
try
(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)[1]
= ϕ(θ, dir)
_dϕ = (θ, dir)
_ϕdϕ = ϕdϕ(θ, dir)

ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ)))
(HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1]
catch err
println(err)
1.0
end
end
Expand All @@ -103,15 +121,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
G = zeros(length(θ))
f.grad(G, θ)
zₖ = G - q
@show G
ρₖ = 1/dot(zₖ, s)
Hₖ⁻¹ = (I - ρₖ*s*zₖ')*Hₖ⁻¹*(I - ρₖ*zₖ*s') + ρₖ*(s*s')
if norm(G) <= opt.ϵ
println(i)
break
end
end


t1 = time()

SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0)
Expand Down
155 changes: 90 additions & 65 deletions lib/OptimizationSolvers/src/lbfgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,38 @@ function SciMLBase.__solve(cache::OptimizationCache{
end
opt = cache.opt
θ = copy(cache.u0)
G = zeros(length(θ))
g₀ = zeros(length(θ))
f = cache.f

_f = (θ) -> first(f.f(θ, cache.p))

ϕ(α) = _f.+ α.*s)
function (α)
f.grad(G, θ .+ α.*s)
return dot(G, s)
function ϕ(u, du)
function ϕ_internal(α)
u_ = u - α * du
_fu = _f(u_)
return dot(_fu, _fu) / 2
end
return ϕ_internal
end
function ϕdϕ(α)
phi = _f.+ α.*s)
f.grad(G, θ .+ α.*s)
dphi = dot(G, s)
return (phi, dphi)

function (u, du)
function dϕ_internal(α)
u_ = u - α * du
_fu = _f(u_)
f.grad(g₀, u_)
return dot(g₀, -du)
end
return dϕ_internal
end

function ϕdϕ(u, du)
function ϕdϕ_internal(α)
u_ = u - α * du
_fu = _f(u_)
f.grad(g₀, u_)
return dot(_fu, _fu) / 2, dot(g₀, -du)
end
return ϕdϕ_internal
end

Sₖ = zeros(length(θ), opt.m)
Expand All @@ -71,27 +88,29 @@ function SciMLBase.__solve(cache::OptimizationCache{
Dₖ = zeros(opt.m)

Hₖ⁻¹= zeros(length(θ), length(θ))
println(Hₖ⁻¹)
Hₖ⁻¹ = I(length(θ))
G = zeros(length(θ))
f.grad(G, θ)
s = -1 * Hₖ⁻¹ * G
t0 = time()
conv = false
for k in 1:opt.m
# println(k, " ", θ, " Objective: ", f(θ, cache.p))
q = copy(G)
pₖ = -Hₖ⁻¹* G
fx = _f(θ)
dir = dot(G, pₖ)
# println(fx, " ", dir)
if !isnan(dir) && dir > 0
dir = -pₖ
if all(isnan.(dir)) || all(dir .> 0)
pₖ = -G
dir = dot(G, pₖ)
else
dir = -G
dir = G
end
αₖ = let
try
[(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...]
= ϕ(θ, dir)
_dϕ = (θ, dir)
_ϕdϕ = ϕdϕ(θ, dir)

ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ)))
(HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1]
catch err
αₖ = [1.0]
end
Expand All @@ -104,64 +123,70 @@ function SciMLBase.__solve(cache::OptimizationCache{
Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s)
Sₖ[:, k] .= s
Yₖ[:, k] .= zₖ
if norm(G) < 1e-6
if norm(G) < opt.ϵ
conv = true
break
end
end

for j in 1:opt.m
for i in 1:j
Rₖ[i, j] = dot(Sₖ[:, i], Yₖ[:, j])
if !conv
for j in 1:opt.m
for i in 1:j
Rₖ[i, j] = dot(Sₖ[:, i], Yₖ[:, j])
end
Dₖ[j] = dot(Sₖ[:, j], Yₖ[:, j])
end
Dₖ[j] = dot(Sₖ[:, j], Yₖ[:, j])
end

m = opt.m
for i in opt.m+1:maxiters
_G = copy(G)
fx = _f(θ)
println(i, " ", θ, " Objective: ", fx)
γₖ = dot(Sₖ[:, m], Yₖ[:, m])/dot(Yₖ[:, m], Yₖ[:, m])
Rinv = let
try
inv(Rₖ)
catch
println(i, " ", Rₖ)
println("Inversion failed")
break
m = opt.m
for i in opt.m+1:maxiters
_G = copy(G)
fx = _f(θ)
γₖ = dot(Sₖ[:, m], Yₖ[:, m])/dot(Yₖ[:, m], Yₖ[:, m])
Rinv = let
try
inv(Rₖ)
catch
println("Inversion failed")
break
end
end
end

p = [Rinv'*(diagm(Dₖ) + γₖ * Yₖ'*Yₖ)*Rinv*(Sₖ'*G) - γₖ * Rinv*(Yₖ'*G); -Rinv*(Sₖ'*G)]
p = -1 .* (γₖ * G + hcat(Sₖ, γₖ*Yₖ)*p)
p = dot(p, G)
αₖ = let
try
[(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, p)...]
catch err
println(err)
break
αₖ = [1.0]
p = [Rinv'*(diagm(Dₖ) + γₖ * Yₖ'*Yₖ)*Rinv*(Sₖ'*G) - γₖ * Rinv*(Yₖ'*G); -Rinv*(Sₖ'*G)]
p = -1 .* (γₖ * G + hcat(Sₖ, γₖ*Yₖ)*p)
dir = -p
αₖ = let
try
= ϕ(θ, dir)
_dϕ = (θ, dir)
_ϕdϕ = ϕdϕ(θ, dir)

ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ)))
(HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1]
catch err
αₖ = 1.0
end
end
end

θ = θ .+ αₖ.*p
s = αₖ.*p
Sₖ[:, 1:end-1] .= Sₖ[:, 2:end]
Yₖ[:, 1:end-1] .= Yₖ[:, 2:end]
Dₖ[1:end-1] .= Dₖ[2:end]
Rₖ[1:end-1, 1:end-1] .= Rₖ[2:end, 2:end]
Sₖ[:, end] .= s
G = zeros(length(θ))
f.grad(G, θ)
zₖ = G - _G
Yₖ[:, end] .= zₖ
for i in 1:m
Rₖ[i, m] = dot(Sₖ[:, i], Yₖ[:, m])
θ = θ .+ αₖ.*p
s = αₖ.*p
Sₖ[:, 1:end-1] .= Sₖ[:, 2:end]
Yₖ[:, 1:end-1] .= Yₖ[:, 2:end]
Dₖ[1:end-1] .= Dₖ[2:end]
Rₖ[1:end-1, 1:end-1] .= Rₖ[2:end, 2:end]
Sₖ[:, end] .= s
G = zeros(length(θ))
f.grad(G, θ)
zₖ = G - _G
Yₖ[:, end] .= zₖ
for i in 1:m
Rₖ[i, m] = dot(Sₖ[:, i], Yₖ[:, m])
end
Dₖ[m] = dot(Sₖ[:, m], Yₖ[:, m])
if norm(G) < opt.ϵ
break
end
end
Dₖ[m] = dot(Sₖ[:, m], Yₖ[:, m])
end

t1 = time()

SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0)
Expand Down
12 changes: 8 additions & 4 deletions lib/OptimizationSolvers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@ using Zygote

sol = Optimization.solve(prob,
OptimizationSolvers.LBFGS(1e-3, 10),
maxiters = 10)
maxiters = 50)

x0 = zeros(2)
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
l1 = rosenbrock(x0)
optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote())
prob = OptimizationProblem(optf, x0)
prob = OptimizationProblem(optf, [-1.2, 1.0])
sol = Optimization.solve(prob,
OptimizationSolvers.BFGS(1e-3, 5),
maxiters = 1000)
OptimizationSolvers.BFGS(1e-5),
maxiters = 100)
@test 10 * sol.objective < l1

sol = Optimization.solve(prob,
OptimizationSolvers.LBFGS(1e-3, 10),
maxiters = 50)


end

0 comments on commit a52eb0f

Please sign in to comment.