diff --git a/src/jacobian.jl b/src/jacobian.jl index bd4575fcc..fe12d7a86 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -119,8 +119,10 @@ __init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}( ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p, - ::Val{false}; kwargs...) + ::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false), + kwargs...) where {needsJᵀJ} # NOTE: Scalar `u` assumes scalar output from `f` uf = JacobianWrapper{false}(f, p) + needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u return uf, nothing, u, nothing, nothing, u end diff --git a/test/basictests.jl b/test/basictests.jl index acfeff6c8..1e1ded563 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -50,8 +50,8 @@ end end precs = [ - NonlinearSolve.DEFAULT_PRECS, - (args...) -> (Diagonal(rand!(similar(u0))), nothing), + (u0) -> NonlinearSolve.DEFAULT_PRECS, + u0 -> ((args...) -> (Diagonal(rand!(similar(u0))), nothing)), ] @testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([ @@ -60,13 +60,13 @@ end if prec === :Random prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing) end - sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec, + sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec(u0), linesearch) @test SciMLBase.successful_retcode(sol) @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0), - NewtonRaphson(; linsolve, precs = prec), abstol = 1e-9) + NewtonRaphson(; linsolve, precs = prec(u0)), abstol = 1e-9) @test (@ballocated solve!($cache)) ≤ 64 end end @@ -91,8 +91,7 @@ end res.u ≈ res_true end @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, - p) ≈ - 1 / (2 * sqrt(p)) + p) ≈ 1 / (2 * sqrt(p)) end end @@ -189,8 +188,7 @@ end res.u ≈ res_true end @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - oftype(p, 1.0), - p; radius_update_scheme).u, p) ≈ 1 / (2 * sqrt(p)) + oftype(p, 1.0), p; radius_update_scheme).u, p) ≈ 1 / (2 * sqrt(p)) end end