From e6ff3aa80301f4ea1b84f3e632d2deeba8da9939 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Mar 2024 15:23:52 -0400 Subject: [PATCH] Make __findmin type stable --- Project.toml | 2 +- src/utils.jl | 27 ++++++++++++++++++++------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 51aa68f54..2bfe146e1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.8.1" +version = "3.8.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/utils.jl b/src/utils.jl index d649d334c..73f134c0e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -94,16 +94,29 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x) @inline __is_complex(::Type{Complex}) = true @inline __is_complex(::Type{T}) where {T} = false -@inline __findmin_caches(f, caches) = __findmin(f ∘ get_fu, caches) +@inline __findmin_caches(f::F, caches) where {F} = __findmin(f ∘ get_fu, caches) # FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`) -@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x) -@inline function __findmin(f, x) +@generated function __findmin(f::F, x) where {F} + # JET shows dynamic dispatch if this is not written as a generated function + if F === typeof(DEFAULT_NORM) + return :(return __findmin_impl(Base.Fix1(maximum, abs), x)) + end + return :(return __findmin_impl(f, x)) +end +@inline @views function __findmin_impl(f::F, x) where {F} + idx = findfirst(Base.Fix2(!==, nothing), x) + # This is an internal function so we assume that inputs are consistent and there is + # atleast one non-`nothing` value + fx_idx = f(x[idx]) + idx == length(x) && return fx_idx, idx fmin = @closure xᵢ -> begin - xᵢ === nothing && return Inf + xᵢ === nothing && return oftype(fx_idx, Inf) fx = f(xᵢ) - return ifelse(isnan(fx), Inf, fx) + return ifelse(isnan(fx), oftype(fx, Inf), fx) end - return findmin(fmin, x) + x_min, x_min_idx = findmin(fmin, x[(idx + 1):length(x)]) + x_min < fx_idx && return x_min, x_min_idx + idx + return fx_idx, idx end @inline __can_setindex(x) = can_setindex(x) @@ -130,7 +143,7 @@ Statistics from the nonlinear equation solver about the solution process. - nf: Number of function evaluations. - njacs: Number of Jacobians created during the solve. - nfactors: Number of factorzations of the jacobian required for the solve. - - nsolve: Number of linear solves `W\b` required for the solve. + - nsolve: Number of linear solves `W \\ b` required for the solve. - nsteps: Total number of iterations for the nonlinear solver. """ struct ImmutableNLStats