From 4f632dd083eaa3b50f27803b093fd3801cbcd212 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Oct 2023 18:57:27 -0400 Subject: [PATCH] Add more solvers to GPU testing --- src/broyden.jl | 4 ++-- src/gaussnewton.jl | 2 +- test/gpu.jl | 19 ++++++++++++++++++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/broyden.jl b/src/broyden.jl index 3232a2d9f..1b391d9a9 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -66,8 +66,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde alg.reset_tolerance reset_check = x -> abs(x) ≤ reset_tolerance return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu), - zero(fu), p, J⁻¹, zero(_vec(fu)'), _mutable_zero(u), false, 0, alg.max_resets, - maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance, + zero(fu), p, J⁻¹, zero(reshape(fu, 1, :)), _mutable_zero(u), false, 0, + alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) end diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index bce757e3a..11a82008b 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -82,7 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray) + if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u0 isa StaticArray) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/test/gpu.jl b/test/gpu.jl index 465cd2141..6a3ae386d 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -1,5 +1,7 @@ using CUDA, NonlinearSolve +CUDA.allowscalar(false) + A = cu(rand(4, 4)) u0 = cu(rand(4)) b = cu(rand(4)) @@ -9,4 +11,19 @@ function f(du, u, p) end prob = NonlinearProblem(f, u0) -sol = solve(prob, NewtonRaphson()) + +# TrustRegion is broken +for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()), + PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden()) + @test_nowarn sol = solve(prob, alg; abstol = 1.0f-8, reltol = 1.0f-8) +end + +f(u, p) = A * u .+ b + +prob = NonlinearProblem{false}(f, u0) + +# TrustRegion is broken +for alg in (NewtonRaphson(), LevenbergMarquardt(; linsolve = QRFactorization()), + PseudoTransient(; alpha_initial = 10.0f0), GeneralKlement(), GeneralBroyden()) + @test_nowarn sol = solve(prob, alg; abstol = 1.0f-8, reltol = 1.0f-8) +end