From b25c8bc728b6981eb20777d363f18882c3c765b5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Oct 2023 12:50:56 -0400 Subject: [PATCH] Default to QR for GaussNewton --- Project.toml | 2 +- src/gaussnewton.jl | 60 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 244b21ad2..9fb70f62b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "2.4.0" +version = "2.5.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 6dcb0f835..c30ca6b36 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -46,7 +46,11 @@ function set_ad(alg::GaussNewton{CJ}, ad) where {CJ} return GaussNewton{CJ}(ad, alg.linsolve, alg.precs) end -function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(), +function set_linsolve(alg::GaussNewton{CJ}, linsolve) where {CJ} + return GaussNewton{CJ}(alg.ad, linsolve, alg.precs) +end + +function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) @@ -81,6 +85,15 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob + + # Use QR if the user did not specify a linear solver + if alg.linsolve === nothing + alg = set_linsolve(alg, QRFactorization(ColumnNorm(), 16, true)) + linsolve_with_JᵀJ = Val(false) + else + linsolve_with_JᵀJ = Val(true) + end + u = alias_u0 ? u0 : deepcopy(u0) if iip fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype @@ -88,8 +101,15 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: else fu1 = f(u, p) end - uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip); - linsolve_with_JᵀJ = Val(true)) + + if SciMLBase._unwrap_val(linsolve_with_JᵀJ) + uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, + Val(iip); linsolve_with_JᵀJ) + else + uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, + Val(iip); linsolve_with_JᵀJ) + JᵀJ, Jᵀf = nothing, nothing + end return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, @@ -99,12 +119,20 @@ end function perform_step!(cache::GaussNewtonCache{true}) @unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache jacobian!!(J, cache) - __matmul!(JᵀJ, J', J) - __matmul!(Jᵀf, J', fu1) - # u = u - J \ fu - linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf), - linu = _vec(du), p, reltol = cache.abstol) + if JᵀJ !== nothing + __matmul!(JᵀJ, J', J) + __matmul!(Jᵀf, J', fu1) + end + + # u = u - JᵀJ \ Jᵀfu + if cache.JᵀJ === nothing + linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du), + p, reltol = cache.abstol) + else + linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf), + linu = _vec(du), p, reltol = cache.abstol) + end cache.linsolve = linres.cache @. u = u - du f(cache.fu_new, u, p) @@ -125,14 +153,22 @@ function perform_step!(cache::GaussNewtonCache{false}) cache.J = jacobian!!(cache.J, cache) - cache.JᵀJ = cache.J' * cache.J - cache.Jᵀf = cache.J' * fu1 + if cache.JᵀJ !== nothing + cache.JᵀJ = cache.J' * cache.J + cache.Jᵀf = cache.J' * fu1 + end + # u = u - J \ fu if linsolve === nothing cache.du = fu1 / cache.J else - linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ), - b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol) + if cache.JᵀJ === nothing + linres = dolinsolve(alg.precs, linsolve; A = cache.J, b = _vec(fu1), + linu = _vec(cache.du), p, reltol = cache.abstol) + else + linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ), + b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol) + end cache.linsolve = linres.cache end cache.u = @. u - cache.du # `u` might not support mutation