Skip to content

Commit

Permalink
refactor: move tracing functionality to NonlinearSolveBase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent 5ca1074 commit a30f80b
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 233 deletions.
4 changes: 2 additions & 2 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ module NonlinearSolveLeastSquaresOptimExt

using ConcreteStructs: @concrete
using LeastSquaresOptim: LeastSquaresOptim
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL, TraceMinimal
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance
using NonlinearSolve: NonlinearSolve, LeastSquaresOptimJL
using SciMLBase: SciMLBase, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode

const LSO = LeastSquaresOptim
Expand Down
6 changes: 3 additions & 3 deletions ext/NonlinearSolveNLsolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module NonlinearSolveNLsolveExt

using LineSearches: Static
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using NonlinearSolve: NonlinearSolve, NLsolveJL, TraceMinimal
using NonlinearSolveBase: NonlinearSolveBase, TraceMinimal, get_tolerance
using NonlinearSolve: NonlinearSolve, NLsolveJL
using NLsolve: NLsolve, OnceDifferentiable, nlsolve
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode

Expand Down Expand Up @@ -32,7 +32,7 @@ function SciMLBase.__solve(
abstol = get_tolerance(abstol, eltype(u0))
show_trace = ShT
store_trace = StT
extended_trace = !(trace_level isa TraceMinimal)
extended_trace = !(trace_level.trace_mode isa Val{:minimal})

linesearch = alg.linesearch === missing ? Static() : alg.linesearch

Expand Down
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
Expand Down Expand Up @@ -64,6 +65,7 @@ LinearSolve = "2.36.1"
Markdown = "1.10"
MaybeInplace = "0.1.4"
Preferences = "1.4"
Printf = "1.10"
RecursiveArrayTools = "3"
SciMLBase = "2.50"
SciMLJacobianOperators = "0.1.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ Utils.maybe_symmetric(x::AbstractSparseMatrix) = x

Utils.make_sparse(x) = sparse(x)

Utils.condition_number(J::AbstractSparseMatrix) = Utils.condition_number(Matrix(J))

Utils.maybe_pinv!!_workspace(A::AbstractSparseMatrix) = Matrix(A)

end
4 changes: 4 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
using Markdown: @doc_str
using MaybeInplace: @bb
using Preferences: @load_preference
using Printf: @printf
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
Expand All @@ -39,6 +40,7 @@ include("autodiff.jl")
include("jacobian.jl")
include("linear_solve.jl")
include("timer_outputs.jl")
include("tracing.jl")

include("descent/common.jl")
include("descent/newton.jl")
Expand All @@ -59,6 +61,8 @@ include("descent/geodesic_acceleration.jl")
@compat(public, (construct_linear_solver, needs_square_A, needs_concrete_A))
@compat(public, (construct_jacobian_cache,))

export TraceMinimal, TraceWithJacobianConditionNumber, TraceAll

export RelTerminationMode, AbsTerminationMode,
NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode,
RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
Expand Down
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ concrete_jac(::Val{true}) = true

abstract type AbstractNonlinearSolveCache <: AbstractNonlinearSolveBaseAPI end

function get_u end
function get_fu end

"""
AbstractLinearSolverCache
Expand Down
224 changes: 224 additions & 0 deletions lib/NonlinearSolveBase/src/tracing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
@concrete struct NonlinearSolveTracing
trace_mode <: Union{Val{:minimal}, Val{:condition_number}, Val{:all}}
print_frequency::Int
store_frequency::Int
end

"""
TraceMinimal(freq)
TraceMinimal(; print_frequency = 1, store_frequency::Int = 1)
Trace Minimal Information
1. Iteration Number
2. f(u) inf-norm
3. Step 2-norm
See also [`TraceWithJacobianConditionNumber`](@ref) and [`TraceAll`](@ref).
"""
function TraceMinimal(; print_frequency = 1, store_frequency::Int = 1)
return NonlinearSolveTracing(Val(:minimal), print_frequency, store_frequency)
end

"""
TraceWithJacobianConditionNumber(freq)
TraceWithJacobianConditionNumber(; print_frequency = 1, store_frequency::Int = 1)
[`TraceMinimal`](@ref) + Print the Condition Number of the Jacobian.
See also [`TraceMinimal`](@ref) and [`TraceAll`](@ref).
"""
function TraceWithJacobianConditionNumber(; print_frequency = 1, store_frequency::Int = 1)
return NonlinearSolveTracing(Val(:condition_number), print_frequency, store_frequency)
end

"""
TraceAll(freq)
TraceAll(; print_frequency = 1, store_frequency::Int = 1)
[`TraceWithJacobianConditionNumber`](@ref) + Store the Jacobian, u, f(u), and δu.
!!! warning
This is very expensive and makes copyies of the Jacobian, u, f(u), and δu.
See also [`TraceMinimal`](@ref) and [`TraceWithJacobianConditionNumber`](@ref).
"""
function TraceAll(; print_frequency = 1, store_frequency::Int = 1)
return NonlinearSolveTracing(Val(:all), print_frequency, store_frequency)
end

for Tr in (:TraceMinimal, :TraceWithJacobianConditionNumber, :TraceAll)
@eval $(Tr)(freq) = $(Tr)(; print_frequency = freq, store_frequency = freq)
end

# NonlinearSolve Tracing Utilities
@concrete struct NonlinearSolveTraceEntry
iteration::Int
fnorm
stepnorm
condJ
storage
norm_type::Symbol
end

function Base.getproperty(entry::NonlinearSolveTraceEntry, sym::Symbol)
hasfield(typeof(entry), sym) && return getfield(entry, sym)
return getproperty(entry.storage, sym)
end

function print_top_level(io::IO, entry::NonlinearSolveTraceEntry)
if entry.condJ === nothing
@printf io "%-8s\t%-20s\t%-20s\n" "----" "-------------" "-----------"
if entry.norm_type === :L2
@printf io "%-8s\t%-20s\t%-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm"
else
@printf io "%-8s\t%-20s\t%-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
end
@printf io "%-8s\t%-20s\t%-20s\n" "----" "-------------" "-----------"
else
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "----" "-------------" "-----------" "-------"
if entry.norm_type === :L2
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm" "cond(J)"
else
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
end
@printf io "%-8s\t%-20s\t%-20s\t%-20s\n" "----" "-------------" "-----------" "-------"
end
end

function Base.show(io::IO, ::MIME"text/plain", entry::NonlinearSolveTraceEntry)
entry.iteration == 0 && print_top_level(io, entry)
if entry.iteration < 0 # Special case for final entry
@printf io "%-8s\t%-20.8e\n" "Final" entry.fnorm
@printf io "%-28s\n" "----------------------"
elseif entry.condJ === nothing
@printf io "%-8d\t%-20.8e\t%-20.8e\n" entry.iteration entry.fnorm entry.stepnorm
else
@printf io "%-8d\t%-20.8e\t%-20.8e\t%-20.8e\n" entry.iteration entry.fnorm entry.stepnorm entry.condJ
end
end

function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J, u)
norm_type = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu)
condJ = J !== missing ? Utils.condition_number(J) : nothing
storage = u === missing ? nothing :
(; u = copy(u), fu = copy(fu), δu = copy(δu), J = copy(J))
return NonlinearSolveTraceEntry(
iteration, fnorm, L2_NORM(δu), condJ, storage, norm_type
)
end

@concrete struct NonlinearSolveTrace
show_trace <: Union{Val{false}, Val{true}}
store_trace <: Union{Val{false}, Val{true}}
history
trace_level <: NonlinearSolveTracing
prob
end

reset!(trace::NonlinearSolveTrace) = reset!(trace.history)
reset!(::Nothing) = nothing
reset!(history::Vector) = empty!(history)

function Base.show(io::IO, ::MIME"text/plain", trace::NonlinearSolveTrace)
if trace.history !== nothing
foreach(trace.history) do entry
show(io, MIME"text/plain"(), entry)
end
else
print(io, "Tracing Disabled")
end
end

function init_nonlinearsolve_trace(
prob, alg, u, fu, J, δu; show_trace::Val = Val(false),
trace_level::NonlinearSolveTracing = TraceMinimal(), store_trace::Val = Val(false),
uses_jac_inverse = Val(false), kwargs...
)
return init_nonlinearsolve_trace(
prob, alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse
)
end

function init_nonlinearsolve_trace(
prob::AbstractNonlinearProblem, alg, show_trace::Val,
trace_level::NonlinearSolveTracing, store_trace::Val, u, fu, J, δu,
uses_jac_inverse::Val
)
if show_trace isa Val{true}
print("\nAlgorithm: ")
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
end
J = uses_jac_inverse isa Val{true} ?
(trace_level.trace_mode isa Val{:minimal} ? J : LinearAlgebra.pinv(J)) : J
history = init_trace_history(prob, show_trace, trace_level, store_trace, u, fu, J, δu)
return NonlinearSolveTrace(show_trace, store_trace, history, trace_level, prob)
end

function init_trace_history(
prob::AbstractNonlinearProblem, show_trace::Val, trace_level,
store_trace::Val, u, fu, J, δu
)
store_trace isa Val{false} && show_trace isa Val{false} && return nothing
entry = if trace_level.trace_mode isa Val{:minimal}
NonlinearSolveTraceEntry(prob, 0, fu, δu, missing, missing)
elseif trace_level.trace_mode isa Val{:condition_number}
NonlinearSolveTraceEntry(prob, 0, fu, δu, J, missing)
else
NonlinearSolveTraceEntry(prob, 0, fu, δu, J, u)
end
show_trace isa Val{true} && show(stdout, MIME"text/plain"(), entry)
store_trace isa Val{true} && return NonlinearSolveTraceEntry[entry]
return nothing
end

function update_trace!(
trace::NonlinearSolveTrace, iter, u, fu, J, δu, α = true; last::Val = Val(false)
)
trace.store_trace isa Val{false} && trace.show_trace isa Val{false} && return nothing

if last isa Val{true}
norm_type = ifelse(trace.prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = trace.prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu)
entry = NonlinearSolveTraceEntry(-1, fnorm, NaN32, nothing, nothing, norm_type)
trace.show_trace isa Val{true} && show(stdout, MIME"text/plain"(), entry)
return trace
end

show_now = trace.show_trace isa Val{true} &&
(mod1(iter, trace.trace_level.print_frequency) == 1)
store_now = trace.store_trace isa Val{true} &&
(mod1(iter, trace.trace_level.store_frequency) == 1)
if show_now || store_now
entry = if trace.trace_level.trace_mode isa Val{:minimal}
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, missing, missing)
elseif trace.trace_level.trace_mode isa Val{:condition_number}
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, missing)
else
NonlinearSolveTraceEntry(trace.prob, iter, fu, δu .* α, J, u)
end
show_now && show(stdout, MIME"text/plain"(), entry)
store_now && push!(trace.history, entry)
end
return trace
end

function update_trace!(cache, α = true)
trace = Utils.safe_getproperty(cache, Val(:trace))
trace === missing && return nothing

J = Utils.safe_getproperty(cache, Val(:J))
if J === missing
update_trace!(
trace, cache.nsteps + 1, get_u(cache), get_fu(cache), nothing, cache.du, α
)
# XXX: Implement
# elseif cache isa ApproximateJacobianSolveCache && store_inverse_jacobian(cache)
# update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache),
# ApplyArray(__safe_inv, J), cache.du, α)
else
update_trace!(trace, cache.nsteps + 1, get_u(cache), get_fu(cache), J, cache.du, α)
end
end
45 changes: 44 additions & 1 deletion lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module Utils

using ArrayInterface: ArrayInterface
using FastClosures: @closure
using LinearAlgebra: Symmetric, norm, dot
using LinearAlgebra: LinearAlgebra, Diagonal, Symmetric, norm, dot, cond, diagind, pinv
using MaybeInplace: @bb
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLOperators: AbstractSciMLOperator
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NonlinearFunction
Expand Down Expand Up @@ -146,4 +147,46 @@ end

function make_sparse end

condition_number(J::AbstractMatrix) = cond(J)
function condition_number(J::AbstractVector)
if !ArrayInterface.can_setindex(J)
J′ = similar(J)
copyto!(J′, J)
J = J′
end
return cond(Diagonal(J))
end
condition_number(::Any) = -1

# XXX: Move to NonlinearSolveQuasiNewton
# compute `pinv` if `inv` won't work
maybe_pinv!!_workspace(A) = nothing

maybe_pinv!!(workspace, A::Union{Number, AbstractMatrix}) = pinv(A)
function maybe_pinv!!(workspace, A::Diagonal)
D = A.diag
@bb @. D = pinv(D)
return Diagonal(D)
end
maybe_pinv!!(workspace, A::AbstractVector) = maybe_pinv!!(workspace, Diagonal(A))
function maybe_pinv!!(workspace, A::StridedMatrix)
LinearAlgebra.checksquare(A)
if LinearAlgebra.istriu(A)
issingular = any(iszero, @view(A[diagind(A)]))
A_ = UpperTriangular(A)
!issingular && return triu!(parent(inv(A_)))
elseif LinearAlgebra.istril(A)
A_ = LowerTriangular(A)
issingular = any(iszero, @view(A_[diagind(A_)]))
!issingular && return tril!(parent(inv(A_)))
else
F = LinearAlgebra.lu(A; check = false)
if issuccess(F)
Ai = LinearAlgebra.inv!(F)
return convert(typeof(parent(Ai)), Ai)
end
end
return pinv(A)
end

end
6 changes: 3 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ using NonlinearSolveBase: NonlinearSolveBase,
DescentResult,
SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
GeodesicAcceleration,
reset_timer!, @static_timeit
reset_timer!, @static_timeit,
init_nonlinearsolve_trace, update_trace!, reset!

# XXX: Remove
import NonlinearSolveBase: InternalAPI, concrete_jac, supports_line_search,
supports_trust_region, last_step_accepted, get_linear_solver,
AbstractDampingFunction, AbstractDampingFunctionCache,
requires_normal_form_jacobian, requires_normal_form_rhs,
returns_norm_form_damping, get_timer_output
returns_norm_form_damping, get_timer_output, get_u, get_fu

using Printf: @printf
using Preferences: Preferences, set_preferences!
Expand Down Expand Up @@ -74,7 +75,6 @@ include("timer_outputs.jl")
include("internal/helpers.jl")

include("internal/termination.jl")
include("internal/tracing.jl")
include("internal/approximate_initialization.jl")

include("globalization/line_search.jl")
Expand Down
Loading

0 comments on commit a30f80b

Please sign in to comment.