Skip to content

Commit

Permalink
refactor: move JacobianCache into NonlinearSolveBase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 26, 2024
1 parent 5549c42 commit 685319b
Show file tree
Hide file tree
Showing 23 changed files with 374 additions and 317 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/CI_NonlinearSolveBase.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ jobs:
run: |
import Pkg
Pkg.Registry.update()
# Install packages present in subdirectories
dev_pks = Pkg.PackageSpec[]
for path in ("lib/SciMLJacobianOperators",)
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks)
Pkg.instantiate()
Pkg.test(; coverage=true)
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/NonlinearSolveBase {0}
Expand Down
5 changes: 5 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

Expand All @@ -27,12 +28,14 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[extensions]
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"

[compat]
ADTypes = "1.9"
Expand All @@ -56,8 +59,10 @@ Markdown = "1.10"
MaybeInplace = "0.1.4"
RecursiveArrayTools = "3"
SciMLBase = "2.50"
SciMLJacobianOperators = "0.1.1"
SciMLOperators = "0.3.10"
SparseArrays = "1.10"
SparseMatrixColorings = "0.4.8"
StaticArraysCore = "1.4"
Test = "1.10"
julia = "1.10"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module NonlinearSolveBaseLinearSolveExt
using ArrayInterface: ArrayInterface
using CommonSolve: CommonSolve, init, solve!
using LinearAlgebra: ColumnNorm
using LinearSolve: LinearSolve, QRFactorization
using LinearSolve: LinearSolve, QRFactorization, SciMLLinearSolveAlgorithm
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
using SciMLBase: ReturnCode, LinearProblem

Expand Down Expand Up @@ -81,7 +81,12 @@ function (cache::LinearSolveJLCache)(;
return LinearSolveResult(; linres.u)
end

NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve)
function NonlinearSolveBase.needs_square_A(linsolve::SciMLLinearSolveAlgorithm, ::Any)
return LinearSolve.needs_square_A(linsolve)
end
function NonlinearSolveBase.needs_concrete_A(linsolve::SciMLLinearSolveAlgorithm)
return LinearSolve.needs_concrete_A(linsolve)
end

update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache
function update_A!(cache::LinearSolveJLCache, A, reuse)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module NonlinearSolveBaseSparseArraysExt

using NonlinearSolveBase: NonlinearSolveBase
using SparseArrays: AbstractSparseMatrixCSC, nonzeros
using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros

function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC)
return any(NonlinearSolveBase.NAN_CHECK, nonzeros(x))
end

NonlinearSolveBase.sparse_or_structured_prototype(::AbstractSparseMatrix) = true

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module NonlinearSolveBaseSparseMatrixColoringsExt

using ADTypes: ADTypes, AbstractADType
using NonlinearSolveBase: NonlinearSolveBase, Utils
using SciMLBase: SciMLBase, NonlinearFunction
using SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm,
LargestFirst

Utils.is_extension_loaded(::Val{:SparseMatrixColorings}) = true

function NonlinearSolveBase.select_fastest_coloring_algorithm(::Val{:SparseMatrixColorings},
prototype, f::NonlinearFunction, ad::AbstractADType)
prototype === nothing && return GreedyColoringAlgorithm(LargestFirst())
if SciMLBase.has_colorvec(f)
return ConstantColoringAlgorithm{ifelse(
ADTypes.mode(ad) isa ADTypes.ReverseMode, :row, :column)}(
prototype, f.colorvec)
end
return GreedyColoringAlgorithm(LargestFirst())
end

end
14 changes: 8 additions & 6 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module NonlinearSolveBase

using ADTypes: ADTypes, AbstractADType
using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector,
KnownJacobianSparsityDetector
using Adapt: WrappedArray
using ArrayInterface: ArrayInterface
using CommonSolve: CommonSolve, init
using Compat: @compat
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using DifferentiationInterface: DifferentiationInterface, Constant
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using FunctionProperties: hasbranching
Expand All @@ -17,8 +18,8 @@ using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
NonlinearProblem, NonlinearLeastSquaresProblem, StandardNonlinearProblem,
NullParameters, NLStats, LinearProblem, isinplace, warn_paramtype,
@add_kwonly
NonlinearFunction, NullParameters, NLStats, LinearProblem
using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using StaticArraysCore: StaticArray, SMatrix, SArray, MArray

Expand All @@ -44,9 +45,10 @@ include("linear_solve.jl")
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
select_jacobian_autodiff))

# public for NonlinearSolve.jl to use
# public for NonlinearSolve.jl and subpackages to use
@compat(public, (InternalAPI, supports_line_search, supports_trust_region, set_du!))
@compat(public, (construct_linear_solver, needs_square_A))
@compat(public, (construct_linear_solver, needs_square_A, needs_concrete_A))
@compat(public, (construct_jacobian_cache,))

export RelTerminationMode, AbsTerminationMode,
NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode,
Expand Down
23 changes: 22 additions & 1 deletion lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module InternalAPI

function init end
function solve! end
function reinit! end

end

Expand Down Expand Up @@ -32,14 +33,34 @@ abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name))

"""
concrete_jac(alg::AbstractNonlinearSolveAlgorithm)::Bool
Whether the algorithm uses a concrete Jacobian.
"""
function concrete_jac(alg::AbstractNonlinearSolveAlgorithm)
return concrete_jac(Utils.safe_getproperty(alg, Val(:concrete_jac)))
end
concrete_jac(::Missing) = missing
concrete_jac(::Missing) = false
concrete_jac(::Nothing) = false
concrete_jac(v::Bool) = v
concrete_jac(::Val{false}) = false
concrete_jac(::Val{true}) = true

abstract type AbstractNonlinearSolveCache end

"""
AbstractLinearSolverCache
Abstract Type for all Linear Solvers used in NonlinearSolve. Subtypes of these are
meant to be constructured via [`construct_linear_solver`](@ref).
"""
abstract type AbstractLinearSolverCache end

"""
AbstractJacobianCache
Abstract Type for all Jacobian Caches used in NonlinearSolve. Subtypes of these are
meant to be constructured via [`construct_jacobian_cache`](@ref).
"""
abstract type AbstractJacobianCache end
15 changes: 7 additions & 8 deletions lib/NonlinearSolveBase/src/immutable_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
problem_type::PT
kwargs::K

@add_kwonly function ImmutableNonlinearProblem{iip}(
SciMLBase.@add_kwonly function ImmutableNonlinearProblem{iip}(
f::AbstractNonlinearFunction{iip}, u0, p = NullParameters(),
problem_type = StandardNonlinearProblem(); kwargs...) where {iip}
if haskey(kwargs, :p)
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to \
`NonlinearProblem`. This is not supported.")
end
warn_paramtype(p)
SciMLBase.warn_paramtype(p)
return new{
typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs), typeof(problem_type)}(
f, u0, p, problem_type, kwargs)
Expand All @@ -31,27 +31,26 @@ struct ImmutableNonlinearProblem{uType, iip, P, F, K, PT} <:
end

"""
Define a nonlinear problem using an instance of
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
Define a nonlinear problem using an instance of [`AbstractNonlinearFunction`](@ref).
"""
function ImmutableNonlinearProblem(
f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
return ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
return ImmutableNonlinearProblem{SciMLBase.isinplace(f)}(f, u0, p; kwargs...)
end

function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
return ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
end

"""
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
Define a ImmutableNonlinearProblem problem from SteadyStateProblem.
"""
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
return ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(prob.f, prob.u0, prob.p)
end

function Base.convert(
::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
return ImmutableNonlinearProblem{isinplace(prob)}(
return ImmutableNonlinearProblem{SciMLBase.isinplace(prob)}(
prob.f, prob.u0, prob.p, prob.problem_type; prob.kwargs...)
end
Loading

0 comments on commit 685319b

Please sign in to comment.