Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 11, 2024
1 parent 951d661 commit 03c2708
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 16 deletions.
4 changes: 0 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,16 @@ versioninfo() # hide
```@raw html
</details>
```

```@raw html
<details><summary>A more complete overview of all dependencies and their versions is also provided.</summary>
```

```@example
using Pkg # hide
Pkg.status(; mode = PKGMODE_MANIFEST) # hide
```

```@raw html
</details>
```

```@eval
using TOML
using Markdown
Expand Down
12 changes: 9 additions & 3 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ internal state.
abstract type AbstractManoptOptimizer end

SciMLBase.supports_opt_cache_interface(opt::AbstractManoptOptimizer) = true
SciMLBase.requiresgradient(opt::Union{GradientDescentOptimizer, ConjugateGradientDescentOptimizer, QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer}) = true
SciMLBase.requireshessian(opt::Union{AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer}) = true


function __map_optimizer_args!(cache::OptimizationCache,
opt::AbstractManoptOptimizer;
Expand Down Expand Up @@ -329,6 +326,15 @@ function call_manopt_optimizer(M::ManifoldsBase.AbstractManifold,
end

## Optimization.jl stuff
function SciMLBase.requiresgradient(opt::Union{
GradientDescentOptimizer, ConjugateGradientDescentOptimizer,
QuasiNewtonOptimizer, ConvexBundleOptimizer, FrankWolfeOptimizer})
true
end
function SciMLBase.requireshessian(opt::Union{
AdaptiveRegularizationCubicOptimizer, TrustRegionsOptimizer})
true
end

function build_loss(f::OptimizationFunction, prob, cb)
function (::AbstractManifold, θ)
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimJL/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using OptimizationOptimJL,
OptimizationOptimJL.Optim, Optimization, ForwardDiff, Zygote, ReverseDiff.
Random, ModelingToolkit, Optimization.OptimizationBase.DifferentiationInterface
Random, ModelingToolkit, Optimization.OptimizationBase.DifferentiationInterface
using Test

struct CallbackTester
Expand Down
13 changes: 9 additions & 4 deletions lib/OptimizationPRIMA/src/OptimizationPRIMA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ SciMLBase.supports_opt_cache_interface(::PRIMASolvers) = true
SciMLBase.allowsconstraints(::Union{LINCOA, COBYLA}) = true
SciMLBase.allowsbounds(opt::Union{BOBYQA, LINCOA, COBYLA}) = true
SciMLBase.requiresconstraints(opt::COBYLA) = true
SciMLBase.requiresgradient(opt::Union{BOBYQA, LINCOA, COBYLA}) = true
SciMLBase.requiresconsjac(opt::Union{LINCOA, COBYLA}) = true
SciMLBase.requiresconsjac(opt::COBYLA) = true
SciMLBase.requiresconshess(opt::COBYLA) = true

function Optimization.OptimizationCache(prob::SciMLBase.OptimizationProblem,
Expand All @@ -34,8 +33,14 @@ function Optimization.OptimizationCache(prob::SciMLBase.OptimizationProblem,
throw("We evaluate the jacobian and hessian of the constraints once to automatically detect
linear and nonlinear constraints, please provide a valid AD backend for using COBYLA.")
else
f = Optimization.instantiate_function(
prob.f, reinit_cache.u0, prob.f.adtype, reinit_cache.p, num_cons)
if opt isa COBYLA
f = Optimization.instantiate_function(
prob.f, reinit_cache.u0, prob.f.adtype, reinit_cache.p, num_cons,
cons_j = true, cons_h = true)
else
f = Optimization.instantiate_function(
prob.f, reinit_cache.u0, prob.f.adtype, reinit_cache.p, num_cons)
end
end

return Optimization.OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,
Expand Down
2 changes: 1 addition & 1 deletion src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
for _ in 1:maxiters
for (i, d) in enumerate(data)
f.grad(gₜ, θ, d)
x = cache.f(θ, cache.p, d...)
x = cache.f(θ, d)
opt_state = Optimization.OptimizationState(; iter = i,
u = θ,
objective = first(x),
Expand Down
6 changes: 3 additions & 3 deletions test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ optfun = OptimizationFunction(loss_adjoint,
optprob = OptimizationProblem(optfun, pp, train_loader)

sol = Optimization.solve(optprob,
Optimization.Sophia(; η = 0.5,
λ = 0.0),
maxiters = 1000)
Optimization.Sophia(; η = 0.5,
λ = 0.0),
maxiters = 1000)
@test 10res1.objective < l1

optfun = OptimizationFunction(loss_adjoint,
Expand Down

0 comments on commit 03c2708

Please sign in to comment.