Skip to content

Commit

Permalink
refactor: delete more code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent 5ed4fe4 commit 8d40ab4
Show file tree
Hide file tree
Showing 18 changed files with 359 additions and 512 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
NonlinearSolveQuasiNewton = "9a2c21bd-3a47-402d-9113-8faf9a0ee114"
NonlinearSolveSpectralMethods = "26075421-4e9a-44e1-8bd1-420ed7ad02b2"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down Expand Up @@ -81,7 +81,6 @@ FixedPointAcceleration = "0.3"
ForwardDiff = "0.10.36"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LazyArrays = "1.8.2, 2"
LeastSquaresOptim = "0.8.5"
LineSearch = "0.1.4"
LineSearches = "7.3"
Expand Down
6 changes: 4 additions & 2 deletions lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
end

"""
step!(cache::AbstractNonlinearSolveCache;
recompute_jacobian::Union{Nothing, Bool} = nothing)
step!(
cache::AbstractNonlinearSolveCache;
recompute_jacobian::Union{Nothing, Bool} = nothing
)
Performs one step of the nonlinear solver.
Expand Down
4 changes: 2 additions & 2 deletions lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ function clean_sprint_struct(x)
name = nameof(typeof(x))
for field in fieldnames(typeof(x))
val = getfield(x, field)
if field === :name
if field === :name && val isa Symbol && val !== :unknown
name = val
continue
end
Expand All @@ -268,7 +268,7 @@ function clean_sprint_struct(x, indent::Int)
name = nameof(typeof(x))
for field in fieldnames(typeof(x))
val = getfield(x, field)
if field === :name
if field === :name && val isa Symbol && val !== :unknown
name = val
continue
end
Expand Down
37 changes: 37 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,40 @@
module NonlinearSolveFirstOrder

using Reexport: @reexport
using PrecompileTools: @compile_workload, @setup_workload

using ArrayInterface: ArrayInterface
using CommonSolve: CommonSolve
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
using LinearAlgebra: LinearAlgebra, Diagonal, dot, inv, diag
using LinearSolve: LinearSolve # Trigger Linear Solve extension in NonlinearSolveBase
using MaybeInplace: @bb
using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
AbstractNonlinearSolveCache, AbstractResetCondition,
AbstractResetConditionCache, AbstractApproximateJacobianStructure,
AbstractJacobianCache, AbstractJacobianInitialization,
AbstractApproximateJacobianUpdateRule, AbstractDescentDirection,
AbstractApproximateJacobianUpdateRuleCache,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM, NewtonDescent
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode
using SciMLOperators: AbstractSciMLOperator
using StaticArraysCore: StaticArray, Size, MArray

include("raphson.jl")
include("gauss_newton.jl")
include("levenberg_marquardt.jl")
include("trust_region.jl")
include("pseudo_transient.jl")

include("solve.jl")

@reexport using SciMLBase, NonlinearSolveBase

export NewtonRaphson, PseudoTransient
export GaussNewton, LevenbergMarquardt, TrustRegion

export GeneralizedFirstOrderAlgorithm

end
Empty file.
Empty file.
Empty file.
Empty file.
303 changes: 303 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/solve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
"""
GeneralizedFirstOrderAlgorithm(;
descent, linesearch = missing,
trustregion = missing, autodiff = nothing, vjp_autodiff = nothing,
jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int),
concrete_jac = Val(false), name::Symbol = :unknown
)
This is a Generalization of First-Order (uses Jacobian) Nonlinear Solve Algorithms. The most
common example of this is Newton-Raphson Method.
First Order here refers to the order of differentiation, and should not be confused with the
order of convergence.
### Keyword Arguments
- `trustregion`: Globalization using a Trust Region Method. This needs to follow the
[`NonlinearSolve.AbstractTrustRegionMethod`](@ref) interface.
- `descent`: The descent method to use to compute the step. This needs to follow the
[`NonlinearSolve.AbstractDescentAlgorithm`](@ref) interface.
- `max_shrink_times`: The maximum number of times the trust region radius can be shrunk
before the algorithm terminates.
"""
@concrete struct GeneralizedFirstOrderAlgorithm <: AbstractNonlinearSolveAlgorithm
linesearch
trustregion
descent
max_shrink_times::Int

autodiff
vjp_autodiff
jvp_autodiff

concrete_jac <: Union{Val{false}, Val{true}}
name::Symbol
end

function GeneralizedFirstOrderAlgorithm(;
descent, linesearch = missing, trustregion = missing, autodiff = nothing,
vjp_autodiff = nothing, jvp_autodiff = nothing, max_shrink_times::Int = typemax(Int),
concrete_jac = Val(false), name::Symbol = :unknown)
return GeneralizedFirstOrderAlgorithm(
linesearch, trustregion, descent, max_shrink_times,
autodiff, vjp_autodiff, jvp_autodiff,
concrete_jac, name
)
end

@concrete mutable struct GeneralizedFirstOrderAlgorithmCache <: AbstractNonlinearSolveCache
# Basic Requirements
fu
u
u_cache
p
du # Aliased to `get_du(descent_cache)`
J # Aliased to `jac_cache.J`
alg <: GeneralizedFirstOrderAlgorithm
prob <: AbstractNonlinearProblem
globalization <: Union{Val{:LineSearch}, Val{:TrustRegion}, Val{:None}}

# Internal Caches
jac_cache
descent_cache
linesearch_cache
trustregion_cache

# Counters
stats::NLStats
nsteps::Int
maxiters::Int
maxtime
max_shrink_times::Int

# Timer
timer
total_time::Float64

# State Affect
make_new_jacobian::Bool

# Termination & Tracking
termination_cache
trace
retcode::ReturnCode.T
force_stop::Bool
kwargs
end

# XXX: Implement
# function __reinit_internal!(
# cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u,
# alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip}
# if iip
# recursivecopy!(cache.u, u0)
# cache.prob.f(cache.fu, cache.u, p)
# else
# cache.u = __maybe_unaliased(u0, alias_u0)
# set_fu!(cache, cache.prob.f(cache.u, p))
# end
# cache.p = p

# __reinit_internal!(cache.stats)
# cache.nsteps = 0
# cache.maxiters = maxiters
# cache.maxtime = maxtime
# cache.total_time = 0.0
# cache.force_stop = false
# cache.retcode = ReturnCode.Default
# cache.make_new_jacobian = true

# reset!(cache.trace)
# reinit!(cache.termination_cache, get_fu(cache), get_u(cache); kwargs...)
# reset_timer!(cache.timer)
# end

NonlinearSolveBase.@internal_caches(GeneralizedFirstOrderAlgorithmCache,
:jac_cache, :descent_cache, :linesearch_cache, :trustregion_cache)

# function SciMLBase.__init(
# prob::AbstractNonlinearProblem{uType, iip}, alg::GeneralizedFirstOrderAlgorithm,
# args...; stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
# abstol = nothing, reltol = nothing, maxtime = nothing,
# termination_condition = nothing, internalnorm = L2_NORM,
# linsolve_kwargs = (;), kwargs...) where {uType, iip}
# autodiff = select_jacobian_autodiff(prob, alg.autodiff)
# jvp_autodiff = if alg.jvp_autodiff === nothing && alg.autodiff !== nothing &&
# (ADTypes.mode(alg.autodiff) isa ADTypes.ForwardMode ||
# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode)
# select_forward_mode_autodiff(prob, alg.autodiff)
# else
# select_forward_mode_autodiff(prob, alg.jvp_autodiff)
# end
# vjp_autodiff = if alg.vjp_autodiff === nothing && alg.autodiff !== nothing &&
# (ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode ||
# ADTypes.mode(alg.autodiff) isa ADTypes.ForwardOrReverseMode)
# select_reverse_mode_autodiff(prob, alg.autodiff)
# else
# select_reverse_mode_autodiff(prob, alg.vjp_autodiff)
# end

# timer = get_timer_output()
# @static_timeit timer "cache construction" begin
# (; f, u0, p) = prob
# u = __maybe_unaliased(u0, alias_u0)
# fu = evaluate_f(prob, u)
# @bb u_cache = copy(u)

# linsolve = get_linear_solver(alg.descent)

# abstol, reltol, termination_cache = NonlinearSolveBase.init_termination_cache(
# prob, abstol, reltol, fu, u, termination_condition, Val(:regular))
# linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)

# jac_cache = construct_jacobian_cache(
# prob, alg, f, fu, u, p; stats, autodiff, linsolve, jvp_autodiff, vjp_autodiff)
# J = jac_cache(nothing)

# descent_cache = __internal_init(prob, alg.descent, J, fu, u; stats, abstol,
# reltol, internalnorm, linsolve_kwargs, timer)
# du = get_du(descent_cache)

# has_linesearch = alg.linesearch !== missing && alg.linesearch !== nothing
# has_trustregion = alg.trustregion !== missing && alg.trustregion !== nothing

# if has_trustregion && has_linesearch
# error("TrustRegion and LineSearch methods are algorithmically incompatible.")
# end

# GB = :None
# linesearch_cache = nothing
# trustregion_cache = nothing

# if has_trustregion
# supports_trust_region(alg.descent) || error("Trust Region not supported by \
# $(alg.descent).")
# trustregion_cache = __internal_init(
# prob, alg.trustregion, f, fu, u, p; stats, internalnorm, kwargs...,
# autodiff, jvp_autodiff, vjp_autodiff)
# GB = :TrustRegion
# end

# if has_linesearch
# supports_line_search(alg.descent) || error("Line Search not supported by \
# $(alg.descent).")
# linesearch_cache = init(
# prob, alg.linesearch, fu, u; stats, autodiff = jvp_autodiff, kwargs...)
# GB = :LineSearch
# end

# trace = init_nonlinearsolve_trace(
# prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...)

# return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
# fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
# trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
# 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs)
# end
# end

# function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
# recompute_jacobian::Union{Nothing, Bool} = nothing, kwargs...) where {iip, GB}
# @static_timeit cache.timer "jacobian" begin
# if (recompute_jacobian === nothing || recompute_jacobian) && cache.make_new_jacobian
# J = cache.jac_cache(cache.u)
# new_jacobian = true
# else
# J = cache.jac_cache(nothing)
# new_jacobian = false
# end
# end

# @static_timeit cache.timer "descent" begin
# if cache.trustregion_cache !== nothing &&
# hasfield(typeof(cache.trustregion_cache), :trust_region)
# descent_result = __internal_solve!(
# cache.descent_cache, J, cache.fu, cache.u; new_jacobian,
# trust_region = cache.trustregion_cache.trust_region, cache.kwargs...)
# else
# descent_result = __internal_solve!(
# cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...)
# end
# end

# if !descent_result.linsolve_success
# if new_jacobian
# # Jacobian Information is current and linear solve failed terminate the solve
# cache.retcode = ReturnCode.InternalLinearSolveFailed
# cache.force_stop = true
# return
# else
# # Jacobian Information is not current and linear solve failed, recompute
# # Jacobian
# if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose]
# @warn "Linear Solve Failed but Jacobian Information is not current. \
# Retrying with updated Jacobian."
# end
# # In the 2nd call the `new_jacobian` is guaranteed to be `true`.
# cache.make_new_jacobian = true
# __step!(cache; recompute_jacobian = true, kwargs...)
# return
# end
# end

# δu, descent_intermediates = descent_result.δu, descent_result.extras

# if descent_result.success
# cache.make_new_jacobian = true
# if GB === :LineSearch
# @static_timeit cache.timer "linesearch" begin
# linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
# linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
# α = linesearch_sol.step_size
# end
# if linesearch_failed
# cache.retcode = ReturnCode.InternalLineSearchFailed
# cache.force_stop = true
# end
# @static_timeit cache.timer "step" begin
# @bb axpy!(α, δu, cache.u)
# evaluate_f!(cache, cache.u, cache.p)
# end
# elseif GB === :TrustRegion
# @static_timeit cache.timer "trustregion" begin
# tr_accepted, u_new, fu_new = __internal_solve!(
# cache.trustregion_cache, J, cache.fu,
# cache.u, δu, descent_intermediates)
# if tr_accepted
# @bb copyto!(cache.u, u_new)
# @bb copyto!(cache.fu, fu_new)
# α = true
# else
# α = false
# cache.make_new_jacobian = false
# end
# if hasfield(typeof(cache.trustregion_cache), :shrink_counter) &&
# cache.trustregion_cache.shrink_counter > cache.max_shrink_times
# cache.retcode = ReturnCode.ShrinkThresholdExceeded
# cache.force_stop = true
# end
# end
# elseif GB === :None
# @static_timeit cache.timer "step" begin
# @bb axpy!(1, δu, cache.u)
# evaluate_f!(cache, cache.u, cache.p)
# end
# α = true
# else
# error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \
# :TrustRegion, :None)")
# end
# check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
# else
# α = false
# cache.make_new_jacobian = false
# end

# update_trace!(cache, α)
# @bb copyto!(cache.u_cache, cache.u)

# callback_into_cache!(cache)

# return nothing
# end
Empty file.
Loading

0 comments on commit 8d40ab4

Please sign in to comment.