Skip to content

Commit

Permalink
Fix tessts
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 9, 2023
1 parent 2edd92f commit 37b883a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ __init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; kwargs...)
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
kwargs...) where {needsJᵀJ}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end
14 changes: 6 additions & 8 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ end
end

precs = [
NonlinearSolve.DEFAULT_PRECS,
(args...) -> (Diagonal(rand!(similar(u0))), nothing),
(u0) -> NonlinearSolve.DEFAULT_PRECS,
u0 -> ((args...) -> (Diagonal(rand!(similar(u0))), nothing)),
]

@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
Expand All @@ -60,13 +60,13 @@ end
if prec === :Random
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
end
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec,
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec(u0),
linesearch)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
NewtonRaphson(; linsolve, precs = prec), abstol = 1e-9)
NewtonRaphson(; linsolve, precs = prec(u0)), abstol = 1e-9)
@test (@ballocated solve!($cache)) 64
end
end
Expand All @@ -91,8 +91,7 @@ end
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p)
1 / (2 * sqrt(p))
p) 1 / (2 * sqrt(p))
end
end

Expand Down Expand Up @@ -189,8 +188,7 @@ end
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
oftype(p, 1.0),
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
oftype(p, 1.0), p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
end
end

Expand Down

0 comments on commit 37b883a

Please sign in to comment.