Skip to content

Commit

Permalink
Use symmetric linear solve if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2023
1 parent 8f68ef1 commit 564950a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 25 deletions.
22 changes: 11 additions & 11 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
Expand All @@ -22,9 +22,9 @@ for large-scale and numerically-difficult nonlinear least squares problems.
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
the construction of the Jacobian.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
linear solves within the Newton method. Defaults to `CholeskyFactorization`. For more
information on available algorithm choices, see the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
Expand All @@ -41,8 +41,8 @@ for large-scale and numerically-difficult nonlinear least squares problems.
precs
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),

Check warning on line 44 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L44

Added line #L44 was not covered by tests
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
end
Expand Down Expand Up @@ -97,8 +97,8 @@ function perform_step!(cache::GaussNewtonCache{true})
__matmul!(Jᵀf, J', fu1)

Check warning on line 97 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L96-L97

Added lines #L96 - L97 were not covered by tests

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),

Check warning on line 100 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L100

Added line #L100 was not covered by tests
linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)
Expand All @@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),

Check warning on line 128 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L128

Added line #L128 was not covered by tests
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
Expand Down
12 changes: 8 additions & 4 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
Jᵀfu = J' * fu
end

linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
u0 = _vec(du))
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

weight = similar(u)
recursivefill!(weight, true)

Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
nothing, nothing, nothing, nothing, nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

Expand All @@ -119,6 +119,10 @@ __init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J

Check warning on line 119 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L119

Added line #L119 was not covered by tests
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x

Check warning on line 124 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L122-L124

Added lines #L122 - L124 were not covered by tests

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
Expand Down
20 changes: 10 additions & 10 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
LevenbergMarquardt(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
precs = DEFAULT_PRECS, damping_initial::Real = 1.0,
damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0,
finite_diff_step_geodesic::Real = 0.1, α_geodesic::Real = 0.75,
Expand All @@ -22,9 +22,9 @@ numerically-difficult nonlinear systems.
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
the construction of the Jacobian.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
linear solves within the Newton method. Defaults to `CholeskyFactorization`. For more
information on available algorithm choices, see the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
Expand Down Expand Up @@ -86,7 +86,7 @@ numerically-difficult nonlinear systems.
min_damping_D::T
end

function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
function LevenbergMarquardt(; concrete_jac = nothing, linsolve = CholeskyFactorization(),

Check warning on line 89 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L89

Added line #L89 was not covered by tests
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
Expand Down Expand Up @@ -203,8 +203,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
mul!(cache.u_tmp, J', fu1)
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),

Check warning on line 206 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L206

Added line #L206 was not covered by tests
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.v = -cache.du

Expand Down Expand Up @@ -280,8 +280,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1),
linu = _vec(cache.v), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),

Check warning on line 283 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L283

Added line #L283 was not covered by tests
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end

Expand All @@ -291,7 +291,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.a = -cache.mat_tmp \
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp,
linres = dolinsolve(alg.precs, linsolve;

Check warning on line 294 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L294

Added line #L294 was not covered by tests
b = _mutable(_vec(J' *
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
linu = _vec(cache.a), p, reltol = cache.abstol)
Expand Down

0 comments on commit 564950a

Please sign in to comment.