Skip to content

Commit

Permalink
Use symmetric linear solve if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2023
1 parent 8f68ef1 commit f43a52d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ function perform_step!(cache::GaussNewtonCache{true})
__matmul!(Jᵀf, J', fu1)

Check warning on line 97 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L96-L97

Added lines #L96 - L97 were not covered by tests

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),

Check warning on line 100 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L100

Added line #L100 was not covered by tests
linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)
Expand All @@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),

Check warning on line 128 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L128

Added line #L128 was not covered by tests
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
Expand Down
12 changes: 8 additions & 4 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
Jᵀfu = J' * fu
end

linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
u0 = _vec(du))
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

weight = similar(u)
recursivefill!(weight, true)

Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
nothing, nothing, nothing, nothing, nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

Expand All @@ -119,6 +119,10 @@ __init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J

Check warning on line 119 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L119

Added line #L119 was not covered by tests
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x

Check warning on line 124 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L122-L124

Added lines #L122 - L124 were not covered by tests

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
Expand Down
8 changes: 4 additions & 4 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
mul!(cache.u_tmp, J', fu1)
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),

Check warning on line 206 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L206

Added line #L206 was not covered by tests
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.v = -cache.du

Expand Down Expand Up @@ -280,8 +280,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1),
linu = _vec(cache.v), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),

Check warning on line 283 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L283

Added line #L283 was not covered by tests
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end

Expand Down

0 comments on commit f43a52d

Please sign in to comment.