Skip to content

Commit

Permalink
Handle non-vector inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 20, 2023
1 parent 2fd9480 commit 1d0c424
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
SciMLBase = "2.4"
SimpleNonlinearSolve = "0.1.22"
SimpleNonlinearSolve = "0.1.23"
SparseDiffTools = "2.6"
StaticArraysCore = "1.4"
UnPack = "1.0"
Expand Down
1 change: 1 addition & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m
end

using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
import ArrayInterface: restructure
import ForwardDiff

import ADTypes: AbstractFiniteDifferencesMode
Expand Down
2 changes: 1 addition & 1 deletion src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
if needsJᵀJ
JᵀJ = __init_JᵀJ(J)
# FIXME: This needs to be handled better for JacVec Operator
Jᵀfu = J' * fu
Jᵀfu = J' * _vec(fu)

Check warning on line 95 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L95

Added line #L95 was not covered by tests
end

linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
Expand Down
6 changes: 3 additions & 3 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
else
d = similar(u)
d .= min_damping_D
DᵀD = Diagonal(d)
DᵀD = Diagonal(_vec(d))

Check warning on line 175 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L175

Added line #L175 was not covered by tests
end

loss = internalnorm(fu1)
Expand Down Expand Up @@ -289,7 +289,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end

Expand All @@ -301,7 +301,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
else
linres = dolinsolve(alg.precs, linsolve;
b = _mutable(_vec(J' *
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
_vec(((2 / h) .* (_vec(f(u .+ h .* _restructure(u,v), p)) .- _vec(fu1) ./ h .- J * _vec(v)))))),
linu = _vec(cache.a), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
Expand Down
12 changes: 6 additions & 6 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ function perform_step!(cache::TrustRegionCache{true})
if cache.make_new_J
jacobian!!(J, cache)
mul!(cache.H, J', J)
mul!(cache.g, J', fu)
mul!(_vec(cache.g), J', _vec(fu))

Check warning on line 350 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L350

Added line #L350 was not covered by tests
cache.stats.njacs += 1

# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
Expand Down Expand Up @@ -378,9 +378,9 @@ function perform_step!(cache::TrustRegionCache{false})
if make_new_J
J = jacobian!!(cache.J, cache)
cache.H = J' * J
cache.g = J' * fu
cache.g = _restructure(fu, J' * _vec(fu))
cache.stats.njacs += 1
cache.u_gauss_newton = -1 .* (cache.H \ cache.g)
cache.u_gauss_newton = -1 .* _restructure(cache.g, cache.H \ _vec(cache.g))
end

# Compute the Newton step.
Expand Down Expand Up @@ -419,7 +419,7 @@ function trust_region_step!(cache::TrustRegionCache)
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (dot(du, g) + dot(du, H, du) / 2)
cache.r = -(loss - cache.loss_new) / (dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2)
@unpack r = cache

if radius_update_scheme === RadiusUpdateSchemes.Simple
Expand Down Expand Up @@ -597,7 +597,7 @@ function dogleg!(cache::TrustRegionCache{true})

# Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
l_grad = norm(cache.g) # length of the gradient
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate

Check warning on line 600 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L600

Added line #L600 was not covered by tests
if d_cauchy >= trust_r
@. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
return
Expand Down Expand Up @@ -627,7 +627,7 @@ function dogleg!(cache::TrustRegionCache{false})

## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
l_grad = norm(cache.g)
d_cauchy = l_grad^3 / dot(cache.g, cache.H, cache.g) # distance of the cauchy point from the current iterate
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
if d_cauchy > trust_r # cauchy point lies outside of trust region
cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
return
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ end
@inline _vec(v::Number) = v
@inline _vec(v::AbstractVector) = v

@inline _restructure(y,x) = restructure(y,x)
@inline _restructure(y::Number,x::Number) = x

Check warning on line 78 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L78

Added line #L78 was not covered by tests

DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing

function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
Expand Down
21 changes: 21 additions & 0 deletions test/matrix_resizing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using NonlinearSolve, Test

ff(u, p) = u .* u .- p
u0 = rand(2,2)
p = 2.0
vecprob = NonlinearProblem(ff, vec(u0), p)
prob = NonlinearProblem(ff, u0, p)

for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end

fiip(du, u, p) = (du .= u .* u .- p)
u0 = rand(2,2)
p = 2.0
vecprob = NonlinearProblem(fiip, vec(u0), p)
prob = NonlinearProblem(fiip, u0, p)

for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(), RobustMultiNewton(), FastShortcutNonlinearPolyalg())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end
@time @safetestset "Basic Tests + Some AD" include("basictests.jl")
@time @safetestset "Sparsity Tests" include("sparse.jl")
@time @safetestset "Polyalgs" include("polyalgs.jl")
@time @safetestset "Matrix Resizing" include("matrix_resizing.jl")
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
end

Expand Down

0 comments on commit 1d0c424

Please sign in to comment.