From cf6f468d75659a721f500e04607deb1e501efaa9 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 11 Oct 2023 11:17:33 -0400 Subject: [PATCH 1/8] init --- lib/OptimizationSolvers/LICENSE | 19 +++ lib/OptimizationSolvers/Project.toml | 26 ++++ .../src/OptimizationSolvers.jl | 72 +++++++++++ lib/OptimizationSolvers/src/sophia.jl | 116 ++++++++++++++++++ lib/OptimizationSolvers/test/runtests.jl | 53 ++++++++ 5 files changed, 286 insertions(+) create mode 100644 lib/OptimizationSolvers/LICENSE create mode 100644 lib/OptimizationSolvers/Project.toml create mode 100644 lib/OptimizationSolvers/src/OptimizationSolvers.jl create mode 100644 lib/OptimizationSolvers/src/sophia.jl create mode 100644 lib/OptimizationSolvers/test/runtests.jl diff --git a/lib/OptimizationSolvers/LICENSE b/lib/OptimizationSolvers/LICENSE new file mode 100644 index 000000000..9ab96cacd --- /dev/null +++ b/lib/OptimizationSolvers/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2022 Julia Computing + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/OptimizationSolvers/Project.toml b/lib/OptimizationSolvers/Project.toml new file mode 100644 index 000000000..ed00f04c4 --- /dev/null +++ b/lib/OptimizationSolvers/Project.toml @@ -0,0 +1,26 @@ +name = "OptimizationOptimisers" +uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +authors = ["Vaibhav Dixit and contributors"] +version = "0.1.5" + +[deps] +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" + +[compat] +julia = "1" +Optimisers = "0.2" +ProgressLogging = "0.1" +Reexport = "1.2" +Optimization = "3.15" + +[extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[targets] +test = ["ForwardDiff","Test","Zygote"] diff --git a/lib/OptimizationSolvers/src/OptimizationSolvers.jl b/lib/OptimizationSolvers/src/OptimizationSolvers.jl new file mode 100644 index 000000000..8eec6da06 --- /dev/null +++ b/lib/OptimizationSolvers/src/OptimizationSolvers.jl @@ -0,0 +1,72 @@ +module OptimizationSolvers + +using Reexport, Printf, ProgressLogging +@reexport using Optimization +using Optimization.SciMLBase + +SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true +include("sophia.jl") + +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::AbstractRule, + data = Optimization.DEFAULT_DATA; save_best = true, + callback = (args...) -> (false), + progress = false, kwargs...) + return OptimizationCache(prob, opt, data; save_best, callback, progress, + kwargs...) +end + +struct BFGS end + +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C, +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <:BFGS, + D, + P, + C, +} + if cache.data != Optimization.DEFAULT_DATA + maxiters = length(cache.data) + data = cache.data + else + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + data = Optimization.take(cache.data, maxiters) + end + opt = cache.opt + θ = copy(cache.u0) + G = copy(θ) + + H = I(length(θ)) * γ + + t0 = time() + Optimization.@withprogress cache.progress name="Training" begin + for (i, d) in enumerate(data) + #bfgs + + end + end + + t1 = time() + + SciMLBase.build_solution(cache, cache.opt, θ, x[1], solve_time = t1 - t0) + # here should be build_solution to create the output message +end + +end diff --git a/lib/OptimizationSolvers/src/sophia.jl b/lib/OptimizationSolvers/src/sophia.jl new file mode 100644 index 000000000..fa9fa827b --- /dev/null +++ b/lib/OptimizationSolvers/src/sophia.jl @@ -0,0 +1,116 @@ +using Optimization.LinearAlgebra + +struct Sophia + η::Float64 + βs::Tuple{Float64, Float64} + ϵ::Float64 + λ::Float64 + k::Integer + ρ::Float64 +end + +SciMLBase.supports_opt_cache_interface(opt::Sophia) = true + +function Sophia(; η = 1e-3, βs = (0.9, 0.999), ϵ = 1e-8, λ = 1e-1, k = 10, + ρ = 0.04) + Sophia(η, βs, ϵ, λ, k, ρ) +end + +clip(z, ρ) = max(min(z, ρ), -ρ) + +function SciMLBase.__init(prob::OptimizationProblem, opt::Sophia, + data = Optimization.DEFAULT_DATA; + maxiters::Number = 1000, callback = (args...) -> (false), + progress = false, save_best = true, kwargs...) + return OptimizationCache(prob, opt, data; maxiters, callback, progress, + save_best, kwargs...) +end + +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C, +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <: + Sophia, + D, + P, + C, +} + local x, cur, state + uType = eltype(cache.u0) + η = uType(cache.opt.η) + βs = uType.(cache.opt.βs) + ϵ = uType(cache.opt.ϵ) + λ = uType(cache.opt.λ) + ρ = uType(cache.opt.ρ) + + if cache.data != Optimization.DEFAULT_DATA + maxiters = length(cache.data) + data = cache.data + else + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + data = Optimization.take(cache.data, maxiters) + end + + maxiters = Optimization._check_and_convert_maxiters(maxiters) + + _loss = function (θ) + if cache.callback === Optimization.DEFAULT_CALLBACK && data === Optimization.DEFAULT_DATA + return first(cache.f(θ, cache.p)) + elseif cache.callback === Optimization.DEFAULT_CALLBACK + return first(cache.f(θ, cache.p, cur...)) + elseif data === Optimization.DEFAULT_DATA + x = cache.f(θ, cache.p) + return first(x) + else + x = cache.f(θ, cache.p, cur...) + return first(x) + end + end + f = cache.f + θ = copy(cache.u0) + gₜ = zero(θ) + mₜ = zero(θ) + hₜ = zero(θ) + for (i, d) in enumerate(data) + f.grad(gₜ, θ, d...) + x = cache.f(θ, cache.p, d...) + cb_call = cache.callback(θ, x...) + if !(typeof(cb_call) <: Bool) + error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") + elseif cb_call + break + end + mₜ = βs[1] .* mₜ + (1 - βs[1]) .* gₜ + + if i % cache.opt.k == 1 + hₜ₋₁ = copy(hₜ) + u = randn(uType, length(θ)) + f.hv(hₜ, θ, u, d...) + hₜ = βs[2] .* hₜ₋₁ + (1 - βs[2]) .* (u .* hₜ) + end + θ = θ .- η * λ .* θ + θ = θ .- + η .* clip.(mₜ ./ max.(hₜ, Ref(ϵ)), Ref(ρ)) + end + + return SciMLBase.build_solution(cache, cache.opt, + θ, + x) +end diff --git a/lib/OptimizationSolvers/test/runtests.jl b/lib/OptimizationSolvers/test/runtests.jl new file mode 100644 index 000000000..523309d60 --- /dev/null +++ b/lib/OptimizationSolvers/test/runtests.jl @@ -0,0 +1,53 @@ +using OptimizationOptimisers, ForwardDiff, Optimization +using Test +using Zygote + +@testset "OptimizationOptimisers.jl" begin + rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + x0 = zeros(2) + _p = [1.0, 100.0] + l1 = rosenbrock(x0, _p) + + optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote()) + + prob = OptimizationProblem(optprob, x0, _p) + + sol = Optimization.solve(prob, + OptimizationOptimisers.Sophia(; η = 0.5, + λ = 0.0), + maxiters = 1000) + @test 10 * sol.objective < l1 + + prob = OptimizationProblem(optprob, x0, _p) + sol = solve(prob, Optimisers.Adam(), maxiters = 1000, progress = false) + @test 10 * sol.objective < l1 + + x0 = 2 * ones(ComplexF64, 2) + _p = ones(2) + sumfunc(x0, _p) = sum(abs2, (x0 - _p)) + l1 = sumfunc(x0, _p) + + optprob = OptimizationFunction(sumfunc, Optimization.AutoZygote()) + + prob = OptimizationProblem(optprob, x0, _p) + + sol = solve(prob, Optimisers.Adam(), maxiters = 1000) + @test 10 * sol.objective < l1 + + @testset "cache" begin + objective(x, p) = (p[1] - x[1])^2 + x0 = zeros(1) + p = [1.0] + + prob = OptimizationProblem(OptimizationFunction(objective, + Optimization.AutoForwardDiff()), x0, + p) + cache = Optimization.init(prob, Optimisers.Adam(0.1), maxiters = 1000) + sol = Optimization.solve!(cache) + @test sol.u≈[1.0] atol=1e-3 + + cache = Optimization.reinit!(cache; p = [2.0]) + sol = Optimization.solve!(cache) + @test sol.u≈[2.0] atol=1e-3 + end +end From 26367a9040e8e1cbfa681e895142f358543fd3c8 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 13 Oct 2023 10:33:37 -0400 Subject: [PATCH 2/8] wip --- lib/OptimizationSolvers/Manifest.toml | 645 ++++++++++++++++++ lib/OptimizationSolvers/Project.toml | 15 +- .../src/OptimizationSolvers.jl | 101 ++- 3 files changed, 740 insertions(+), 21 deletions(-) create mode 100644 lib/OptimizationSolvers/Manifest.toml diff --git a/lib/OptimizationSolvers/Manifest.toml b/lib/OptimizationSolvers/Manifest.toml new file mode 100644 index 000000000..d6e8a5e29 --- /dev/null +++ b/lib/OptimizationSolvers/Manifest.toml @@ -0,0 +1,645 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.9.2" +manifest_format = "2.0" +project_hash = "a5e45d00dc655835e569dd2c44dec3a7d78a978f" + +[[deps.ADTypes]] +git-tree-sha1 = "5d2e21d7b0d8c22f67483ef95ebdc39c0e6b6003" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "0.2.4" + +[[deps.AbstractTrees]] +git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.4" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.6.2" + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + + [deps.Adapt.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.4.11" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.16.0" + +[[deps.CommonSolve]] +git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c" +uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +version = "0.2.4" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["UUIDs"] +git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.10.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.5+0" + +[[deps.ConsoleProgressMonitor]] +deps = ["Logging", "ProgressMeter"] +git-tree-sha1 = "3ab7b2136722890b9af903859afcf457fa3059e8" +uuid = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" +version = "0.1.2" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c53fc348ca4d40d7b371e71fd52251839080cbc9" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.4" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.DataAPI]] +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.15.0" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.EnumX]] +git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.4" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random"] +git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.6.1" +weakdeps = ["SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.21.1" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + + [deps.ForwardDiff.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.3" + +[[deps.FunctionWrappersWrappers]] +deps = ["FunctionWrappers"] +git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" +uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" +version = "0.1.3" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.5" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.Lazy]] +deps = ["MacroTools"] +git-tree-sha1 = "1370f8202dac30758f3c345f9909b97f53d87d3f" +uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" +version = "0.15.1" + +[[deps.LeftChildRightSiblingTrees]] +deps = ["AbstractTrees"] +git-tree-sha1 = "fb6803dafae4a5d62ea5cab204b1e657d9737e7f" +uuid = "1d6d02ad-be62-4b6b-8a6d-2f90e265016e" +version = "0.2.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] +git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.2.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.26" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.11" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+0" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.10.11" + +[[deps.NLSolversBase]] +deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] +git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.8.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.21+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimization]] +deps = ["ADTypes", "ArrayInterface", "ConsoleProgressMonitor", "DocStringExtensions", "LinearAlgebra", "Logging", "LoggingExtras", "Pkg", "Printf", "ProgressLogging", "Reexport", "Requires", "SciMLBase", "SparseArrays", "TerminalLoggers"] +git-tree-sha1 = "f59193f79f8310e5eaad309ccfa8f2fa06b6703b" +uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +version = "3.19.1" + + [deps.Optimization.extensions] + OptimizationEnzymeExt = "Enzyme" + OptimizationFiniteDiffExt = "FiniteDiff" + OptimizationForwardDiffExt = "ForwardDiff" + OptimizationMTKExt = "ModelingToolkit" + OptimizationReverseDiffExt = "ReverseDiff" + OptimizationSparseDiffExt = ["SparseDiffTools", "Symbolics", "ReverseDiff"] + OptimizationTrackerExt = "Tracker" + OptimizationZygoteExt = "Zygote" + + [deps.Optimization.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" + Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.OrderedCollections]] +git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.2" + +[[deps.Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.9.2" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.1" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "00099623ffee15972c16111bcf84c58a0051257c" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.9.0" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "fa453b42ba1623bd2e70260bf44dac850a3430a7" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "2.39.0" + + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "6aacc5eefe8415f47b3e34214c1d79d2674a0ba2" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.12" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.SciMLBase]] +deps = ["ADTypes", "ArrayInterface", "ChainRulesCore", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces", "ZygoteRules"] +git-tree-sha1 = "317f77cb31f7a0275cdd045aa7b3526ebc15c817" +uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +version = "2.4.0" + + [deps.SciMLBase.extensions] + SciMLBasePyCallExt = "PyCall" + SciMLBasePythonCallExt = "PythonCall" + SciMLBaseRCallExt = "RCall" + SciMLBaseZygoteExt = "Zygote" + + [deps.SciMLBase.weakdeps] + PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" + PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + RCall = "6f49c342-dc21-5d91-9882-a32aef131414" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.SciMLOperators]] +deps = ["ArrayInterface", "DocStringExtensions", "Lazy", "LinearAlgebra", "Setfield", "SparseArrays", "StaticArraysCore", "Tricks"] +git-tree-sha1 = "65c2e6ced6f62ea796af251eb292a0e131a3613b" +uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" +version = "0.3.6" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.3.1" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" + +[[deps.SymbolicIndexingInterface]] +deps = ["DocStringExtensions"] +git-tree-sha1 = "f8ab052bfcbdb9b48fad2c80c873aa0d0344dfe5" +uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +version = "0.2.2" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "a1f34829d5ac0ef499f6d84428bd6b4c71f02ead" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.11.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TerminalLoggers]] +deps = ["LeftChildRightSiblingTrees", "Logging", "Markdown", "Printf", "ProgressLogging", "UUIDs"] +git-tree-sha1 = "f133fab380933d042f6796eda4e130272ba520ca" +uuid = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +version = "0.1.7" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.Tricks]] +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.8" + +[[deps.TruncatedStacktraces]] +deps = ["InteractiveUtils", "MacroTools", "Preferences"] +git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1" +uuid = "781d530d-4396-4725-bb49-402e4bee1e77" +version = "1.4.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+0" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/lib/OptimizationSolvers/Project.toml b/lib/OptimizationSolvers/Project.toml index ed00f04c4..4946da535 100644 --- a/lib/OptimizationSolvers/Project.toml +++ b/lib/OptimizationSolvers/Project.toml @@ -1,21 +1,20 @@ -name = "OptimizationOptimisers" -uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +name = "OptimizationSolvers" +uuid = "1b6bb0ae-4728-4f07-8c53-ec2138650943" authors = ["Vaibhav Dixit and contributors"] -version = "0.1.5" +version = "0.0.1" [deps] -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] -julia = "1" -Optimisers = "0.2" +Optimization = "3.15" ProgressLogging = "0.1" Reexport = "1.2" -Optimization = "3.15" +julia = "1" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -23,4 +22,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ForwardDiff","Test","Zygote"] +test = ["ForwardDiff", "Test", "Zygote"] diff --git a/lib/OptimizationSolvers/src/OptimizationSolvers.jl b/lib/OptimizationSolvers/src/OptimizationSolvers.jl index 8eec6da06..c97288c9a 100644 --- a/lib/OptimizationSolvers/src/OptimizationSolvers.jl +++ b/lib/OptimizationSolvers/src/OptimizationSolvers.jl @@ -2,12 +2,17 @@ module OptimizationSolvers using Reexport, Printf, ProgressLogging @reexport using Optimization -using Optimization.SciMLBase +using Optimization.SciMLBase, LineSearches -SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true +struct BFGS + ϵ::Float64 + m::Int +end + +SciMLBase.supports_opt_cache_interface(opt::BFGS) = true include("sophia.jl") -function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::AbstractRule, +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::BFGS, data = Optimization.DEFAULT_DATA; save_best = true, callback = (args...) -> (false), progress = false, kwargs...) @@ -15,7 +20,6 @@ function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::AbstractRule kwargs...) end -struct BFGS end function SciMLBase.__solve(cache::OptimizationCache{ F, @@ -51,21 +55,92 @@ function SciMLBase.__solve(cache::OptimizationCache{ end opt = cache.opt θ = copy(cache.u0) - G = copy(θ) + G = zeros(length(θ)) + f = cache.f + + _f = (θ) -> first(f.f(θ, cache.p)) - H = I(length(θ)) * γ + ϕ(α) = _f(θ .+ α.*s) + function dϕ(α) + f.grad(G, θ .+ α.*s) + return dot(G, s) + end + function ϕdϕ(α) + phi = _f(θ .+ α.*s) + f.grad(G, θ .+ α.*s) + dphi = dot(G, s) + return (phi, dphi) + end + Hₖ⁻¹= zeros(length(θ), length(θ)) + f.hess(Hₖ⁻¹, θ) + Hₖ⁻¹ = inv(Hₖ⁻¹) + f.grad(G, θ) + s = -1 * Hₖ⁻¹ * G + # m = opt.m + # α = Vector{typeof(θ)}(undef, m) + # β = zeros(m) + # ss = Vector{typeof(θ)}(undef, m) + # y = Vector{typeof(θ)}(undef, m) + # ρ = Vector{Float64}(undef, m) + # ρ[1] = 1.0 + # k = 1 + # t0 = time() + # ss[1] = θ + # f.grad(G, θ) + # y[1] = G + # α[1] = α0 + # γ = dot(ss[1], y[1])/dot(y[1], y[1]) + # Hₖ = I(length(θ)) * γ + # ρ[1] = 1/dot(y[1], ss[1]) + + for i in 1:maxiters + println(i, " ", θ, " Objective: ", f(θ, cache.p)) + # println(ss, " ", y, " ", γ) + + q = copy(G) + # if k > 1 + # y[k-1] = q - y[k-1] + # γ = dot(ss[k-1], y[k-1])/dot(y[k-1], y[k-1]) + # Hₖ = I(length(θ)) * γ - t0 = time() - Optimization.@withprogress cache.progress name="Training" begin - for (i, d) in enumerate(data) - #bfgs - - end + # ρ[k] = 1/dot(y[k-1], ss[k-1]) + # end + + # for j in 1:min(m,i-1) + # α[j] = ρ[j]*dot(ss[j], G) + # G = G - α[j]*y[j] + # end + # r = Hₖ*G + # for j in min(m,i-1):1 + # β[j] = ρ[j]*dot(y[j], r) + # r = r + ss[j]*(α[j] - β[j]) + # end + pₖ = -Hₖ⁻¹* G + fx = _f(θ) + dir = dot(G, pₖ) + αₖ = [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] + # α[k] = αₖ + θ = θ .+ αₖ.*pₖ + s = αₖ.*pₖ + # if k > m + # ss[1:end-1] = ss[2:end] + # y[1:end-1] = y[2:end] + # k = m + # ss[k] = α[k-2]*pₖ + # y[k] = q + # α[1:end-1] = α[2:end] + # end + # k+=1 + G = zeros(length(θ)) + f.grad(G, θ) + zₖ = G - q + Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s) end + t1 = time() - SciMLBase.build_solution(cache, cache.opt, θ, x[1], solve_time = t1 - t0) + SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0) # here should be build_solution to create the output message end From 2b7375bcd46deed8f75eeebdd00f73bbe949c80f Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 13 Oct 2023 17:13:51 -0400 Subject: [PATCH 3/8] simple bfgs simple problem --- .../src/OptimizationSolvers.jl | 8 ++- lib/OptimizationSolvers/test/runtests.jl | 56 +++++-------------- 2 files changed, 20 insertions(+), 44 deletions(-) diff --git a/lib/OptimizationSolvers/src/OptimizationSolvers.jl b/lib/OptimizationSolvers/src/OptimizationSolvers.jl index c97288c9a..f380dc46f 100644 --- a/lib/OptimizationSolvers/src/OptimizationSolvers.jl +++ b/lib/OptimizationSolvers/src/OptimizationSolvers.jl @@ -73,7 +73,8 @@ function SciMLBase.__solve(cache::OptimizationCache{ end Hₖ⁻¹= zeros(length(θ), length(θ)) f.hess(Hₖ⁻¹, θ) - Hₖ⁻¹ = inv(Hₖ⁻¹) + println(Hₖ⁻¹) + Hₖ⁻¹ = inv(I(length(θ)) .+ Hₖ⁻¹) f.grad(G, θ) s = -1 * Hₖ⁻¹ * G # m = opt.m @@ -93,6 +94,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ # Hₖ = I(length(θ)) * γ # ρ[1] = 1/dot(y[1], ss[1]) + t0 = time() for i in 1:maxiters println(i, " ", θ, " Objective: ", f(θ, cache.p)) # println(ss, " ", y, " ", γ) @@ -118,6 +120,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ pₖ = -Hₖ⁻¹* G fx = _f(θ) dir = dot(G, pₖ) + println(fx, " ", dir) αₖ = [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] # α[k] = αₖ θ = θ .+ αₖ.*pₖ @@ -135,6 +138,9 @@ function SciMLBase.__solve(cache::OptimizationCache{ f.grad(G, θ) zₖ = G - q Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s) + if norm(G) < 1e-6 + break + end end diff --git a/lib/OptimizationSolvers/test/runtests.jl b/lib/OptimizationSolvers/test/runtests.jl index 523309d60..e5f9f87a8 100644 --- a/lib/OptimizationSolvers/test/runtests.jl +++ b/lib/OptimizationSolvers/test/runtests.jl @@ -1,53 +1,23 @@ -using OptimizationOptimisers, ForwardDiff, Optimization +using OptimizationSolvers, ForwardDiff, Optimization using Test using Zygote @testset "OptimizationOptimisers.jl" begin - rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 - x0 = zeros(2) - _p = [1.0, 100.0] - l1 = rosenbrock(x0, _p) - - optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote()) - - prob = OptimizationProblem(optprob, x0, _p) + + function objf(x, p) + return x[1]^2 + x[2]^2 + 2*x[1]* x[2] + end + optprob = OptimizationFunction(objf, Optimization.AutoZygote()) + x0 = zeros(2) .+ 1 + x0[1] = 0.5 + prob = OptimizationProblem(optprob, x0) + sol = Optimization.solve(prob, - OptimizationOptimisers.Sophia(; η = 0.5, - λ = 0.0), + OptimizationSolvers.BFGS(1e-3, 5), maxiters = 1000) @test 10 * sol.objective < l1 - prob = OptimizationProblem(optprob, x0, _p) - sol = solve(prob, Optimisers.Adam(), maxiters = 1000, progress = false) - @test 10 * sol.objective < l1 - - x0 = 2 * ones(ComplexF64, 2) - _p = ones(2) - sumfunc(x0, _p) = sum(abs2, (x0 - _p)) - l1 = sumfunc(x0, _p) - - optprob = OptimizationFunction(sumfunc, Optimization.AutoZygote()) - - prob = OptimizationProblem(optprob, x0, _p) - - sol = solve(prob, Optimisers.Adam(), maxiters = 1000) - @test 10 * sol.objective < l1 - - @testset "cache" begin - objective(x, p) = (p[1] - x[1])^2 - x0 = zeros(1) - p = [1.0] - - prob = OptimizationProblem(OptimizationFunction(objective, - Optimization.AutoForwardDiff()), x0, - p) - cache = Optimization.init(prob, Optimisers.Adam(0.1), maxiters = 1000) - sol = Optimization.solve!(cache) - @test sol.u≈[1.0] atol=1e-3 - - cache = Optimization.reinit!(cache; p = [2.0]) - sol = Optimization.solve!(cache) - @test sol.u≈[2.0] atol=1e-3 - end + prob = OptimizationProblem(optprob, x0) + end From 392152b5c864171c517f7f43cf069f03cfabd086 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 13 Oct 2023 18:35:14 -0400 Subject: [PATCH 4/8] hack lineseacrh --- .../src/OptimizationSolvers.jl | 16 +++++++++++++++- lib/OptimizationSolvers/test/runtests.jl | 3 +-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/lib/OptimizationSolvers/src/OptimizationSolvers.jl b/lib/OptimizationSolvers/src/OptimizationSolvers.jl index f380dc46f..0cfbfeae6 100644 --- a/lib/OptimizationSolvers/src/OptimizationSolvers.jl +++ b/lib/OptimizationSolvers/src/OptimizationSolvers.jl @@ -121,10 +121,24 @@ function SciMLBase.__solve(cache::OptimizationCache{ fx = _f(θ) dir = dot(G, pₖ) println(fx, " ", dir) - αₖ = [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] + + if dir > 0 + pₖ = -G + dir = dot(G, pₖ) + end + + αₖ = let + try + [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] + catch err + αₖ = [1.0] + end + end # α[k] = αₖ + θ = θ .+ αₖ.*pₖ s = αₖ.*pₖ + # if k > m # ss[1:end-1] = ss[2:end] # y[1:end-1] = y[2:end] diff --git a/lib/OptimizationSolvers/test/runtests.jl b/lib/OptimizationSolvers/test/runtests.jl index e5f9f87a8..2309443e0 100644 --- a/lib/OptimizationSolvers/test/runtests.jl +++ b/lib/OptimizationSolvers/test/runtests.jl @@ -5,12 +5,11 @@ using Zygote @testset "OptimizationOptimisers.jl" begin function objf(x, p) - return x[1]^2 + x[2]^2 + 2*x[1]* x[2] + return x[1]^2 + x[2]^2 + x[1]* x[2] end optprob = OptimizationFunction(objf, Optimization.AutoZygote()) x0 = zeros(2) .+ 1 - x0[1] = 0.5 prob = OptimizationProblem(optprob, x0) sol = Optimization.solve(prob, From 3edfce614ef641d1d93c9602f3eb457a0762e03a Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 26 Oct 2023 09:45:02 -0400 Subject: [PATCH 5/8] bugs --- .../src/OptimizationSolvers.jl | 161 +---------------- lib/OptimizationSolvers/src/bfgs.jl | 119 ++++++++++++ lib/OptimizationSolvers/src/lbfgs.jl | 169 ++++++++++++++++++ lib/OptimizationSolvers/test/runtests.jl | 25 ++- 4 files changed, 309 insertions(+), 165 deletions(-) create mode 100644 lib/OptimizationSolvers/src/bfgs.jl create mode 100644 lib/OptimizationSolvers/src/lbfgs.jl diff --git a/lib/OptimizationSolvers/src/OptimizationSolvers.jl b/lib/OptimizationSolvers/src/OptimizationSolvers.jl index 0cfbfeae6..0adbed0e1 100644 --- a/lib/OptimizationSolvers/src/OptimizationSolvers.jl +++ b/lib/OptimizationSolvers/src/OptimizationSolvers.jl @@ -4,164 +4,7 @@ using Reexport, Printf, ProgressLogging @reexport using Optimization using Optimization.SciMLBase, LineSearches -struct BFGS - ϵ::Float64 - m::Int -end - -SciMLBase.supports_opt_cache_interface(opt::BFGS) = true include("sophia.jl") - -function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::BFGS, - data = Optimization.DEFAULT_DATA; save_best = true, - callback = (args...) -> (false), - progress = false, kwargs...) - return OptimizationCache(prob, opt, data; save_best, callback, progress, - kwargs...) -end - - -function SciMLBase.__solve(cache::OptimizationCache{ - F, - RC, - LB, - UB, - LC, - UC, - S, - O, - D, - P, - C, -}) where { - F, - RC, - LB, - UB, - LC, - UC, - S, - O <:BFGS, - D, - P, - C, -} - if cache.data != Optimization.DEFAULT_DATA - maxiters = length(cache.data) - data = cache.data - else - maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) - data = Optimization.take(cache.data, maxiters) - end - opt = cache.opt - θ = copy(cache.u0) - G = zeros(length(θ)) - f = cache.f - - _f = (θ) -> first(f.f(θ, cache.p)) - - ϕ(α) = _f(θ .+ α.*s) - function dϕ(α) - f.grad(G, θ .+ α.*s) - return dot(G, s) - end - function ϕdϕ(α) - phi = _f(θ .+ α.*s) - f.grad(G, θ .+ α.*s) - dphi = dot(G, s) - return (phi, dphi) - end - Hₖ⁻¹= zeros(length(θ), length(θ)) - f.hess(Hₖ⁻¹, θ) - println(Hₖ⁻¹) - Hₖ⁻¹ = inv(I(length(θ)) .+ Hₖ⁻¹) - f.grad(G, θ) - s = -1 * Hₖ⁻¹ * G - # m = opt.m - # α = Vector{typeof(θ)}(undef, m) - # β = zeros(m) - # ss = Vector{typeof(θ)}(undef, m) - # y = Vector{typeof(θ)}(undef, m) - # ρ = Vector{Float64}(undef, m) - # ρ[1] = 1.0 - # k = 1 - # t0 = time() - # ss[1] = θ - # f.grad(G, θ) - # y[1] = G - # α[1] = α0 - # γ = dot(ss[1], y[1])/dot(y[1], y[1]) - # Hₖ = I(length(θ)) * γ - # ρ[1] = 1/dot(y[1], ss[1]) - - t0 = time() - for i in 1:maxiters - println(i, " ", θ, " Objective: ", f(θ, cache.p)) - # println(ss, " ", y, " ", γ) - - q = copy(G) - # if k > 1 - # y[k-1] = q - y[k-1] - # γ = dot(ss[k-1], y[k-1])/dot(y[k-1], y[k-1]) - # Hₖ = I(length(θ)) * γ - - # ρ[k] = 1/dot(y[k-1], ss[k-1]) - # end - - # for j in 1:min(m,i-1) - # α[j] = ρ[j]*dot(ss[j], G) - # G = G - α[j]*y[j] - # end - # r = Hₖ*G - # for j in min(m,i-1):1 - # β[j] = ρ[j]*dot(y[j], r) - # r = r + ss[j]*(α[j] - β[j]) - # end - pₖ = -Hₖ⁻¹* G - fx = _f(θ) - dir = dot(G, pₖ) - println(fx, " ", dir) - - if dir > 0 - pₖ = -G - dir = dot(G, pₖ) - end - - αₖ = let - try - [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] - catch err - αₖ = [1.0] - end - end - # α[k] = αₖ - - θ = θ .+ αₖ.*pₖ - s = αₖ.*pₖ - - # if k > m - # ss[1:end-1] = ss[2:end] - # y[1:end-1] = y[2:end] - # k = m - # ss[k] = α[k-2]*pₖ - # y[k] = q - # α[1:end-1] = α[2:end] - # end - # k+=1 - G = zeros(length(θ)) - f.grad(G, θ) - zₖ = G - q - Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s) - if norm(G) < 1e-6 - break - end - end - - - t1 = time() - - SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0) - # here should be build_solution to create the output message -end - +include("bfgs.jl") +include("lbfgs.jl") end diff --git a/lib/OptimizationSolvers/src/bfgs.jl b/lib/OptimizationSolvers/src/bfgs.jl new file mode 100644 index 000000000..009ad82f4 --- /dev/null +++ b/lib/OptimizationSolvers/src/bfgs.jl @@ -0,0 +1,119 @@ + +struct BFGS + ϵ::Float64 + m::Int +end + +SciMLBase.supports_opt_cache_interface(opt::BFGS) = true + + +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::BFGS, + data = Optimization.DEFAULT_DATA; save_best = true, + callback = (args...) -> (false), + progress = false, kwargs...) + return OptimizationCache(prob, opt, data; save_best, callback, progress, + kwargs...) +end + + +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C, +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <:BFGS, + D, + P, + C, +} + if cache.data != Optimization.DEFAULT_DATA + maxiters = length(cache.data) + data = cache.data + else + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + data = Optimization.take(cache.data, maxiters) + end + opt = cache.opt + θ = copy(cache.u0) + G = zeros(length(θ)) + f = cache.f + + _f = (θ) -> first(f.f(θ, cache.p)) + + ϕ(α) = _f(θ .+ α.*s) + function dϕ(α) + f.grad(G, θ .+ α.*s) + return dot(G, s) + end + function ϕdϕ(α) + phi = _f(θ .+ α.*s) + f.grad(G, θ .+ α.*s) + dphi = dot(G, s) + return (phi, dphi) + end + Hₖ⁻¹= zeros(length(θ), length(θ)) + f.hess(Hₖ⁻¹, θ) + println(Hₖ⁻¹) + Hₖ⁻¹ = inv(I(length(θ)) .+ Hₖ⁻¹) + f.grad(G, θ) + s = -1 * Hₖ⁻¹ * G + + t0 = time() + for i in 1:maxiters + println(i, " ", θ, " Objective: ", f(θ, cache.p)) + q = copy(G) + @show q + pₖ = -Hₖ⁻¹* G + fx = _f(θ) + dir = G' * pₖ + println(dir) + + if isnan(dir) || dir > 0 + pₖ = -G + dir = -G'*G + end + + αₖ = let + try + (HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)[1] + catch err + println(err) + 1.0 + end + end + + θ = θ .+ αₖ.*pₖ + s = αₖ.*pₖ + + G = zeros(length(θ)) + f.grad(G, θ) + zₖ = G - q + @show G + ρₖ = 1/dot(zₖ, s) + Hₖ⁻¹ = (I - ρₖ*s*zₖ')*Hₖ⁻¹*(I - ρₖ*zₖ*s') + ρₖ*(s*s') + if norm(G) <= opt.ϵ + break + end + end + + + t1 = time() + + SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0) + # here should be build_solution to create the output message +end diff --git a/lib/OptimizationSolvers/src/lbfgs.jl b/lib/OptimizationSolvers/src/lbfgs.jl new file mode 100644 index 000000000..2cba50627 --- /dev/null +++ b/lib/OptimizationSolvers/src/lbfgs.jl @@ -0,0 +1,169 @@ +struct LBFGS + ϵ::Float64 + m::Int +end + +SciMLBase.supports_opt_cache_interface(opt::LBFGS) = true + +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::LBFGS, + data = Optimization.DEFAULT_DATA; save_best = true, + callback = (args...) -> (false), + progress = false, kwargs...) + return OptimizationCache(prob, opt, data; save_best, callback, progress, + kwargs...) +end + + +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C, +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <:LBFGS, + D, + P, + C, +} + if cache.data != Optimization.DEFAULT_DATA + maxiters = length(cache.data) + data = cache.data + else + maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + data = Optimization.take(cache.data, maxiters) + end + opt = cache.opt + θ = copy(cache.u0) + G = zeros(length(θ)) + f = cache.f + + _f = (θ) -> first(f.f(θ, cache.p)) + + ϕ(α) = _f(θ .+ α.*s) + function dϕ(α) + f.grad(G, θ .+ α.*s) + return dot(G, s) + end + function ϕdϕ(α) + phi = _f(θ .+ α.*s) + f.grad(G, θ .+ α.*s) + dphi = dot(G, s) + return (phi, dphi) + end + + Sₖ = zeros(length(θ), opt.m) + Yₖ = zeros(length(θ), opt.m) + Rₖ = zeros(opt.m, opt.m) + Dₖ = zeros(opt.m) + + Hₖ⁻¹= zeros(length(θ), length(θ)) + println(Hₖ⁻¹) + Hₖ⁻¹ = I(length(θ)) + f.grad(G, θ) + s = -1 * Hₖ⁻¹ * G + t0 = time() + for k in 1:opt.m + # println(k, " ", θ, " Objective: ", f(θ, cache.p)) + q = copy(G) + pₖ = -Hₖ⁻¹* G + fx = _f(θ) + dir = dot(G, pₖ) + # println(fx, " ", dir) + if !isnan(dir) && dir > 0 + pₖ = -G + dir = dot(G, pₖ) + else + dir = -G + end + αₖ = let + try + [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] + catch err + αₖ = [1.0] + end + end + θ = θ .+ αₖ.*pₖ + s = αₖ.*pₖ + G = zeros(length(θ)) + f.grad(G, θ) + zₖ = G - q + Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s) + Sₖ[:, k] .= s + Yₖ[:, k] .= zₖ + if norm(G) < 1e-6 + break + end + end + + for j in 1:opt.m + for i in 1:j + Rₖ[i, j] = dot(Sₖ[:, i], Yₖ[:, j]) + end + Dₖ[j] = dot(Sₖ[:, j], Yₖ[:, j]) + end + + m = opt.m + for i in opt.m+1:maxiters + _G = copy(G) + fx = _f(θ) + println(i, " ", θ, " Objective: ", fx) + γₖ = dot(Sₖ[:, m], Yₖ[:, m])/dot(Yₖ[:, m], Yₖ[:, m]) + Rinv = let + try + inv(Rₖ) + catch + println(i, " ", Rₖ) + println("Inversion failed") + break + end + end + + p = [Rinv'*(diagm(Dₖ) + γₖ * Yₖ'*Yₖ)*Rinv*(Sₖ'*G) - γₖ * Rinv*(Yₖ'*G); -Rinv*(Sₖ'*G)] + p = -1 .* (γₖ * G + hcat(Sₖ, γₖ*Yₖ)*p) + p = dot(p, G) + αₖ = let + try + [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, p)...] + catch err + println(err) + break + αₖ = [1.0] + end + end + + θ = θ .+ αₖ.*p + s = αₖ.*p + Sₖ[:, 1:end-1] .= Sₖ[:, 2:end] + Yₖ[:, 1:end-1] .= Yₖ[:, 2:end] + Dₖ[1:end-1] .= Dₖ[2:end] + Rₖ[1:end-1, 1:end-1] .= Rₖ[2:end, 2:end] + Sₖ[:, end] .= s + G = zeros(length(θ)) + f.grad(G, θ) + zₖ = G - _G + Yₖ[:, end] .= zₖ + for i in 1:m + Rₖ[i, m] = dot(Sₖ[:, i], Yₖ[:, m]) + end + Dₖ[m] = dot(Sₖ[:, m], Yₖ[:, m]) + end + + t1 = time() + + SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0) + # here should be build_solution to create the output message +end \ No newline at end of file diff --git a/lib/OptimizationSolvers/test/runtests.jl b/lib/OptimizationSolvers/test/runtests.jl index 2309443e0..df6fddb92 100644 --- a/lib/OptimizationSolvers/test/runtests.jl +++ b/lib/OptimizationSolvers/test/runtests.jl @@ -2,21 +2,34 @@ using OptimizationSolvers, ForwardDiff, Optimization using Test using Zygote -@testset "OptimizationOptimisers.jl" begin - + +@testset "OptimizationSolvers.jl" begin function objf(x, p) - return x[1]^2 + x[2]^2 + x[1]* x[2] + return x[1]^2 + x[2]^2 end optprob = OptimizationFunction(objf, Optimization.AutoZygote()) x0 = zeros(2) .+ 1 prob = OptimizationProblem(optprob, x0) - + l1 = objf(x0, nothing) + sol = Optimization.solve(prob, + OptimizationSolvers.BFGS(1e-3, 10), + maxiters = 10) + @test 10 * sol.objective < l1 + + sol = Optimization.solve(prob, + OptimizationSolvers.LBFGS(1e-3, 10), + maxiters = 10) + + x0 = zeros(2) + rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2 + l1 = rosenbrock(x0) + optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote()) + prob = OptimizationProblem(optf, x0) sol = Optimization.solve(prob, OptimizationSolvers.BFGS(1e-3, 5), maxiters = 1000) @test 10 * sol.objective < l1 - prob = OptimizationProblem(optprob, x0) - + end From a52eb0f8bc06027f62b2ca9ccb783cbb1b18fc6b Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 27 Oct 2023 20:30:51 -0400 Subject: [PATCH 6/8] Actually works now --- lib/OptimizationSolvers/src/bfgs.jl | 65 ++++++---- lib/OptimizationSolvers/src/lbfgs.jl | 155 +++++++++++++---------- lib/OptimizationSolvers/test/runtests.jl | 12 +- 3 files changed, 139 insertions(+), 93 deletions(-) diff --git a/lib/OptimizationSolvers/src/bfgs.jl b/lib/OptimizationSolvers/src/bfgs.jl index 009ad82f4..9373ccf0b 100644 --- a/lib/OptimizationSolvers/src/bfgs.jl +++ b/lib/OptimizationSolvers/src/bfgs.jl @@ -1,7 +1,6 @@ -struct BFGS - ϵ::Float64 - m::Int +@kwdef struct BFGS + ϵ::Float64=1e-6 end SciMLBase.supports_opt_cache_interface(opt::BFGS) = true @@ -50,49 +49,68 @@ function SciMLBase.__solve(cache::OptimizationCache{ end opt = cache.opt θ = copy(cache.u0) - G = zeros(length(θ)) + g₀ = zeros(length(θ)) f = cache.f _f = (θ) -> first(f.f(θ, cache.p)) - ϕ(α) = _f(θ .+ α.*s) - function dϕ(α) - f.grad(G, θ .+ α.*s) - return dot(G, s) + function ϕ(u, du) + function ϕ_internal(α) + u_ = u - α * du + _fu = _f(u_) + return dot(_fu, _fu) / 2 + end + return ϕ_internal + end + + function dϕ(u, du) + function dϕ_internal(α) + u_ = u - α * du + _fu = _f(u_) + f.grad(g₀, u_) + return dot(g₀, -du) + end + return dϕ_internal end - function ϕdϕ(α) - phi = _f(θ .+ α.*s) - f.grad(G, θ .+ α.*s) - dphi = dot(G, s) - return (phi, dphi) + + function ϕdϕ(u, du) + function ϕdϕ_internal(α) + u_ = u - α * du + _fu = _f(u_) + f.grad(g₀, u_) + return dot(_fu, _fu) / 2, dot(g₀, -du) + end + return ϕdϕ_internal end + Hₖ⁻¹= zeros(length(θ), length(θ)) f.hess(Hₖ⁻¹, θ) - println(Hₖ⁻¹) Hₖ⁻¹ = inv(I(length(θ)) .+ Hₖ⁻¹) + G = zeros(length(θ)) f.grad(G, θ) s = -1 * Hₖ⁻¹ * G t0 = time() for i in 1:maxiters - println(i, " ", θ, " Objective: ", f(θ, cache.p)) q = copy(G) - @show q pₖ = -Hₖ⁻¹* G fx = _f(θ) - dir = G' * pₖ - println(dir) + dir = -pₖ - if isnan(dir) || dir > 0 + if all(isnan.(dir)) || all(dir .> 0) pₖ = -G - dir = -G'*G + dir = G end αₖ = let try - (HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)[1] + _ϕ = ϕ(θ, dir) + _dϕ = dϕ(θ, dir) + _ϕdϕ = ϕdϕ(θ, dir) + + ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ))) + (HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] catch err - println(err) 1.0 end end @@ -103,15 +121,14 @@ function SciMLBase.__solve(cache::OptimizationCache{ G = zeros(length(θ)) f.grad(G, θ) zₖ = G - q - @show G ρₖ = 1/dot(zₖ, s) Hₖ⁻¹ = (I - ρₖ*s*zₖ')*Hₖ⁻¹*(I - ρₖ*zₖ*s') + ρₖ*(s*s') if norm(G) <= opt.ϵ + println(i) break end end - t1 = time() SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0) diff --git a/lib/OptimizationSolvers/src/lbfgs.jl b/lib/OptimizationSolvers/src/lbfgs.jl index 2cba50627..0e640289a 100644 --- a/lib/OptimizationSolvers/src/lbfgs.jl +++ b/lib/OptimizationSolvers/src/lbfgs.jl @@ -48,21 +48,38 @@ function SciMLBase.__solve(cache::OptimizationCache{ end opt = cache.opt θ = copy(cache.u0) - G = zeros(length(θ)) + g₀ = zeros(length(θ)) f = cache.f _f = (θ) -> first(f.f(θ, cache.p)) - ϕ(α) = _f(θ .+ α.*s) - function dϕ(α) - f.grad(G, θ .+ α.*s) - return dot(G, s) + function ϕ(u, du) + function ϕ_internal(α) + u_ = u - α * du + _fu = _f(u_) + return dot(_fu, _fu) / 2 + end + return ϕ_internal end - function ϕdϕ(α) - phi = _f(θ .+ α.*s) - f.grad(G, θ .+ α.*s) - dphi = dot(G, s) - return (phi, dphi) + + function dϕ(u, du) + function dϕ_internal(α) + u_ = u - α * du + _fu = _f(u_) + f.grad(g₀, u_) + return dot(g₀, -du) + end + return dϕ_internal + end + + function ϕdϕ(u, du) + function ϕdϕ_internal(α) + u_ = u - α * du + _fu = _f(u_) + f.grad(g₀, u_) + return dot(_fu, _fu) / 2, dot(g₀, -du) + end + return ϕdϕ_internal end Sₖ = zeros(length(θ), opt.m) @@ -71,27 +88,29 @@ function SciMLBase.__solve(cache::OptimizationCache{ Dₖ = zeros(opt.m) Hₖ⁻¹= zeros(length(θ), length(θ)) - println(Hₖ⁻¹) Hₖ⁻¹ = I(length(θ)) + G = zeros(length(θ)) f.grad(G, θ) s = -1 * Hₖ⁻¹ * G t0 = time() + conv = false for k in 1:opt.m - # println(k, " ", θ, " Objective: ", f(θ, cache.p)) q = copy(G) pₖ = -Hₖ⁻¹* G fx = _f(θ) - dir = dot(G, pₖ) - # println(fx, " ", dir) - if !isnan(dir) && dir > 0 + dir = -pₖ + if all(isnan.(dir)) || all(dir .> 0) pₖ = -G - dir = dot(G, pₖ) - else - dir = -G + dir = G end αₖ = let try - [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, dir)...] + _ϕ = ϕ(θ, dir) + _dϕ = dϕ(θ, dir) + _ϕdϕ = ϕdϕ(θ, dir) + + ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ))) + (HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] catch err αₖ = [1.0] end @@ -104,64 +123,70 @@ function SciMLBase.__solve(cache::OptimizationCache{ Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s) Sₖ[:, k] .= s Yₖ[:, k] .= zₖ - if norm(G) < 1e-6 + if norm(G) < opt.ϵ + conv = true break end end - for j in 1:opt.m - for i in 1:j - Rₖ[i, j] = dot(Sₖ[:, i], Yₖ[:, j]) + if !conv + for j in 1:opt.m + for i in 1:j + Rₖ[i, j] = dot(Sₖ[:, i], Yₖ[:, j]) + end + Dₖ[j] = dot(Sₖ[:, j], Yₖ[:, j]) end - Dₖ[j] = dot(Sₖ[:, j], Yₖ[:, j]) - end - m = opt.m - for i in opt.m+1:maxiters - _G = copy(G) - fx = _f(θ) - println(i, " ", θ, " Objective: ", fx) - γₖ = dot(Sₖ[:, m], Yₖ[:, m])/dot(Yₖ[:, m], Yₖ[:, m]) - Rinv = let - try - inv(Rₖ) - catch - println(i, " ", Rₖ) - println("Inversion failed") - break + m = opt.m + for i in opt.m+1:maxiters + _G = copy(G) + fx = _f(θ) + γₖ = dot(Sₖ[:, m], Yₖ[:, m])/dot(Yₖ[:, m], Yₖ[:, m]) + Rinv = let + try + inv(Rₖ) + catch + println("Inversion failed") + break + end end - end - p = [Rinv'*(diagm(Dₖ) + γₖ * Yₖ'*Yₖ)*Rinv*(Sₖ'*G) - γₖ * Rinv*(Yₖ'*G); -Rinv*(Sₖ'*G)] - p = -1 .* (γₖ * G + hcat(Sₖ, γₖ*Yₖ)*p) - p = dot(p, G) - αₖ = let - try - [(HagerZhang())(ϕ, dϕ, ϕdϕ, 1.0, fx, p)...] - catch err - println(err) - break - αₖ = [1.0] + p = [Rinv'*(diagm(Dₖ) + γₖ * Yₖ'*Yₖ)*Rinv*(Sₖ'*G) - γₖ * Rinv*(Yₖ'*G); -Rinv*(Sₖ'*G)] + p = -1 .* (γₖ * G + hcat(Sₖ, γₖ*Yₖ)*p) + dir = -p + αₖ = let + try + _ϕ = ϕ(θ, dir) + _dϕ = dϕ(θ, dir) + _ϕdϕ = ϕdϕ(θ, dir) + + ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ))) + (HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] + catch err + αₖ = 1.0 + end end - end - θ = θ .+ αₖ.*p - s = αₖ.*p - Sₖ[:, 1:end-1] .= Sₖ[:, 2:end] - Yₖ[:, 1:end-1] .= Yₖ[:, 2:end] - Dₖ[1:end-1] .= Dₖ[2:end] - Rₖ[1:end-1, 1:end-1] .= Rₖ[2:end, 2:end] - Sₖ[:, end] .= s - G = zeros(length(θ)) - f.grad(G, θ) - zₖ = G - _G - Yₖ[:, end] .= zₖ - for i in 1:m - Rₖ[i, m] = dot(Sₖ[:, i], Yₖ[:, m]) + θ = θ .+ αₖ.*p + s = αₖ.*p + Sₖ[:, 1:end-1] .= Sₖ[:, 2:end] + Yₖ[:, 1:end-1] .= Yₖ[:, 2:end] + Dₖ[1:end-1] .= Dₖ[2:end] + Rₖ[1:end-1, 1:end-1] .= Rₖ[2:end, 2:end] + Sₖ[:, end] .= s + G = zeros(length(θ)) + f.grad(G, θ) + zₖ = G - _G + Yₖ[:, end] .= zₖ + for i in 1:m + Rₖ[i, m] = dot(Sₖ[:, i], Yₖ[:, m]) + end + Dₖ[m] = dot(Sₖ[:, m], Yₖ[:, m]) + if norm(G) < opt.ϵ + break + end end - Dₖ[m] = dot(Sₖ[:, m], Yₖ[:, m]) end - t1 = time() SciMLBase.build_solution(cache, cache.opt, θ, f(θ, cache.p), solve_time = t1 - t0) diff --git a/lib/OptimizationSolvers/test/runtests.jl b/lib/OptimizationSolvers/test/runtests.jl index df6fddb92..c655e07c8 100644 --- a/lib/OptimizationSolvers/test/runtests.jl +++ b/lib/OptimizationSolvers/test/runtests.jl @@ -19,17 +19,21 @@ using Zygote sol = Optimization.solve(prob, OptimizationSolvers.LBFGS(1e-3, 10), - maxiters = 10) + maxiters = 50) x0 = zeros(2) rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2 l1 = rosenbrock(x0) optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote()) - prob = OptimizationProblem(optf, x0) + prob = OptimizationProblem(optf, [-1.2, 1.0]) sol = Optimization.solve(prob, - OptimizationSolvers.BFGS(1e-3, 5), - maxiters = 1000) + OptimizationSolvers.BFGS(1e-5), + maxiters = 100) @test 10 * sol.objective < l1 + sol = Optimization.solve(prob, + OptimizationSolvers.LBFGS(1e-3, 10), + maxiters = 50) + end From ed6f4dfd751b70ff015aa20212782d4be30de3ab Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 30 Oct 2023 12:14:05 -0400 Subject: [PATCH 7/8] wip --- lib/OptimizationSolvers/src/bfgs.jl | 12 ++++++------ lib/OptimizationSolvers/src/lbfgs.jl | 13 ++++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/lib/OptimizationSolvers/src/bfgs.jl b/lib/OptimizationSolvers/src/bfgs.jl index 9373ccf0b..d5a9f05c8 100644 --- a/lib/OptimizationSolvers/src/bfgs.jl +++ b/lib/OptimizationSolvers/src/bfgs.jl @@ -51,8 +51,9 @@ function SciMLBase.__solve(cache::OptimizationCache{ θ = copy(cache.u0) g₀ = zeros(length(θ)) f = cache.f - - _f = (θ) -> first(f.f(θ, cache.p)) + p = cache.p + ls = HagerZhang() + _f = (θ) -> first(f.f(θ, p)) function ϕ(u, du) function ϕ_internal(α) @@ -105,11 +106,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ αₖ = let try _ϕ = ϕ(θ, dir) - _dϕ = dϕ(θ, dir) _ϕdϕ = ϕdϕ(θ, dir) ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ))) - (HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] + ls(_ϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] catch err 1.0 end @@ -122,8 +122,8 @@ function SciMLBase.__solve(cache::OptimizationCache{ f.grad(G, θ) zₖ = G - q ρₖ = 1/dot(zₖ, s) - Hₖ⁻¹ = (I - ρₖ*s*zₖ')*Hₖ⁻¹*(I - ρₖ*zₖ*s') + ρₖ*(s*s') - if norm(G) <= opt.ϵ + Hₖ⁻¹ = mul!((I - ρₖ*s*zₖ'), Hₖ⁻¹, (I - ρₖ*zₖ*s')) + ρₖ*(s*s') + if norm(G, Inf) <= opt.ϵ println(i) break end diff --git a/lib/OptimizationSolvers/src/lbfgs.jl b/lib/OptimizationSolvers/src/lbfgs.jl index 0e640289a..5559237f7 100644 --- a/lib/OptimizationSolvers/src/lbfgs.jl +++ b/lib/OptimizationSolvers/src/lbfgs.jl @@ -50,8 +50,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ θ = copy(cache.u0) g₀ = zeros(length(θ)) f = cache.f - - _f = (θ) -> first(f.f(θ, cache.p)) + p = cache.p + ls = HagerZhang() + + _f = (θ) -> first(f.f(θ, p)) function ϕ(u, du) function ϕ_internal(α) @@ -87,11 +89,9 @@ function SciMLBase.__solve(cache::OptimizationCache{ Rₖ = zeros(opt.m, opt.m) Dₖ = zeros(opt.m) - Hₖ⁻¹= zeros(length(θ), length(θ)) Hₖ⁻¹ = I(length(θ)) G = zeros(length(θ)) f.grad(G, θ) - s = -1 * Hₖ⁻¹ * G t0 = time() conv = false for k in 1:opt.m @@ -106,11 +106,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ αₖ = let try _ϕ = ϕ(θ, dir) - _dϕ = dϕ(θ, dir) _ϕdϕ = ϕdϕ(θ, dir) ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ))) - (HagerZhang())(_ϕ, _dϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] + ls(_ϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] catch err αₖ = [1.0] end @@ -123,7 +122,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ Hₖ⁻¹ = (I - (s*zₖ')/dot(zₖ, s))*Hₖ⁻¹*(I - (zₖ*s')/dot(zₖ, s)) + (s*s')/dot(zₖ, s) Sₖ[:, k] .= s Yₖ[:, k] .= zₖ - if norm(G) < opt.ϵ + if norm(G, Inf) < opt.ϵ conv = true break end From 0e2d62ec0b4643c7224f2cd87851afef73acc4d0 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 14 Dec 2023 18:04:22 -0500 Subject: [PATCH 8/8] misc --- lib/OptimizationSolvers/src/OptimizationSolvers.jl | 4 ++-- lib/OptimizationSolvers/src/bfgs.jl | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/OptimizationSolvers/src/OptimizationSolvers.jl b/lib/OptimizationSolvers/src/OptimizationSolvers.jl index 0adbed0e1..38970a58d 100644 --- a/lib/OptimizationSolvers/src/OptimizationSolvers.jl +++ b/lib/OptimizationSolvers/src/OptimizationSolvers.jl @@ -5,6 +5,6 @@ using Reexport, Printf, ProgressLogging using Optimization.SciMLBase, LineSearches include("sophia.jl") -include("bfgs.jl") -include("lbfgs.jl") +# include("bfgs.jl") +# include("lbfgs.jl") end diff --git a/lib/OptimizationSolvers/src/bfgs.jl b/lib/OptimizationSolvers/src/bfgs.jl index d5a9f05c8..1640b62fc 100644 --- a/lib/OptimizationSolvers/src/bfgs.jl +++ b/lib/OptimizationSolvers/src/bfgs.jl @@ -107,7 +107,6 @@ function SciMLBase.__solve(cache::OptimizationCache{ try _ϕ = ϕ(θ, dir) _ϕdϕ = ϕdϕ(θ, dir) - ϕ₀, dϕ₀ = _ϕdϕ(zero(eltype(θ))) ls(_ϕ, _ϕdϕ, 1.0, ϕ₀, dϕ₀)[1] catch err