Skip to content

Commit

Permalink
some moi and optimisers updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Aug 25, 2024
1 parent 66d9577 commit 9f36c85
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
19 changes: 19 additions & 0 deletions lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,25 @@ function MOI.eval_constraint_jacobian(evaluator::MOIOptimizationNLPEvaluator, j,
return
end

function MOI.eval_constraint_jacobian_product(evaluator::Evaluator, y, x, w)
start = time()
MOI.eval_constraint_jacobian_product(evaluator.backend, y, x, w)
evaluator.eval_constraint_jacobian_timer += time() - start
return
end

function MOI.eval_constraint_jacobian_transpose_product(
evaluator::Evaluator,
y,
x,
w,
)
start = time()
MOI.eval_constraint_jacobian_transpose_product(evaluator.backend, y, x, w)
evaluator.eval_constraint_jacobian_timer += time() - start
return
end

function MOI.hessian_lagrangian_structure(evaluator::MOIOptimizationNLPEvaluator)
lagh = evaluator.f.lag_h !== nothing
if evaluator.f.lag_hess_prototype !== nothing
Expand Down
6 changes: 3 additions & 3 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Optimization.SciMLBase

SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
SciMLBase.requiresgradient(opt::AbstractRule) = true
SciMLBase.allowsfg(opt::AbstractRule) = true

function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, opt::AbstractRule,
data = Optimization.DEFAULT_DATA; save_best = true,
Expand Down Expand Up @@ -55,7 +56,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
else
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
if maxiters === nothing
throw(ArgumentError("The number of iterations must be specified as the maxiters kwarg."))
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end
data = Optimization.take(cache.data, maxiters)
end
Expand All @@ -74,8 +75,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
Optimization.@withprogress cache.progress name="Training" begin
for _ in 1:maxiters
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
x = cache.f.fg(G, θ, d...)
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
objective = x[1],
Expand Down

0 comments on commit 9f36c85

Please sign in to comment.