Skip to content

Commit

Permalink
reapplying formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnoStrouwen committed Feb 22, 2024
1 parent 763ad4f commit c8104cf
Show file tree
Hide file tree
Showing 38 changed files with 718 additions and 649 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
style = "sciml"
format_markdown = true
format_docstrings = true
annotate_untyped_fields_with_any = false
2 changes: 1 addition & 1 deletion benchmarks/applelu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ algs = [
GenericLUFactorization(),
RFLUFactorization(),
AppleAccelerateLUFactorization(),
MetalLUFactorization(),
MetalLUFactorization()
]
res = [Float32[] for i in 1:length(algs)]

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ algs = [
RFLUFactorization(),
MKLLUFactorization(),
FastLUFactorization(),
SimpleLUFactorization(),
SimpleLUFactorization()
]
res = [Float64[] for i in 1:length(algs)]

Expand Down
5 changes: 3 additions & 2 deletions benchmarks/sparselu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ algs = [
UMFPACKFactorization(),
KLUFactorization(),
MKLPardisoFactorize(),
SparspakFactorization(),
SparspakFactorization()
]
cols = [:red, :blue, :green, :magenta, :turqoise] # one color per alg
lst = [:dash, :solid, :dashdot] # one line style per dim
Expand Down Expand Up @@ -65,7 +65,8 @@ function run_and_plot(; dims = [1, 2, 3], kmax = 12)
u0 = rand(rng, n)

for j in 1:length(algs)
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy($A),
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(
copy($A),
copy($b);
u0 = copy($u0),
alias_A = true,
Expand Down
6 changes: 3 additions & 3 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

pages = ["index.md",
"Tutorials" => Any["tutorials/linear.md"
"tutorials/caching_interface.md"],
"tutorials/caching_interface.md"],
"Basics" => Any["basics/LinearProblem.md",
"basics/common_solver_opts.md",
"basics/OperatorAssumptions.md",
"basics/Preconditioners.md",
"basics/FAQ.md"],
"Solvers" => Any["solvers/solvers.md"],
"Advanced" => Any["advanced/developing.md"
"advanced/custom.md"],
"Release Notes" => "release_notes.md",
"advanced/custom.md"],
"Release Notes" => "release_notes.md"
]
2 changes: 1 addition & 1 deletion docs/src/advanced/developing.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ basic machinery. A simplified version is:
struct MyLUFactorization{P} <: SciMLBase.AbstractLinearAlgorithm end

function init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
verbose)
verbose)
lu!(convert(AbstractMatrix, A))
end

Expand Down
11 changes: 6 additions & 5 deletions ext/LinearSolveBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module LinearSolveBandedMatricesExt

using BandedMatrices, LinearAlgebra, LinearSolve
import LinearSolve: defaultalg,
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
do_factorization, init_cacheval, DefaultLinearSolver,
DefaultAlgorithmChoice

# Defaults for BandedMatrices
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions{Bool})
Expand Down Expand Up @@ -35,14 +36,14 @@ for alg in (:SVDFactorization, :MKLLUFactorization, :DiagonalFactorization,
:AppleAccelerateLUFactorization, :CholeskyFactorization)
@eval begin
function init_cacheval(::$(alg), ::BandedMatrix, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
return nothing
end
end
end

function init_cacheval(::LUFactorization, A::BandedMatrix, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
return lu(similar(A, 0, 0))
end

Expand All @@ -54,8 +55,8 @@ for alg in (:SVDFactorization, :MKLLUFactorization, :DiagonalFactorization,
:AppleAccelerateLUFactorization, :QRFactorization, :LUFactorization)
@eval begin
function init_cacheval(::$(alg), ::Symmetric{<:Number, <:BandedMatrix}, b, u, Pl,
Pr, maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Pr, maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return nothing
end
end
Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolveBlockDiagonalsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module LinearSolveBlockDiagonalsExt
using LinearSolve, BlockDiagonals

function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...;
kwargs...)
kwargs...)
@assert ndims(A)==2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."
# We need to perform this check even when `zeroinit == true`, since the type of the
# cache is dependent on whether we are able to use the specialized dispatch.
Expand Down
6 changes: 3 additions & 3 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterfa
using SciMLBase: AbstractSciMLOperator

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
kwargs...)
if cache.isfresh
fact = qr(CUDA.CuArray(cache.A))
cache.cacheval = fact
Expand All @@ -18,8 +18,8 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
end

function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
qr(CUDA.CuArray(A))
end

Expand Down
35 changes: 24 additions & 11 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
function EnzymeCore.EnzymeRules.forward(

Check warning on line 11 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L11

Added line #L11 was not covered by tests
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
res = func.val(prob.val, alg.val; kwargs...)
if RT <: Const
Expand All @@ -26,11 +27,13 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, :
error("Unsupported return type $RT")
end

function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},

Check warning on line 30 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L30

Added line #L30 was not covered by tests
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)

res = func.val(linsolve.val; kwargs...)

if RT <: Const
return res
end
Expand All @@ -56,7 +59,10 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
return Duplicated(res, dres)
end

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
function EnzymeCore.EnzymeRules.augmented_primal(

Check warning on line 62 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L62

Added line #L62 was not covered by tests
config, func::Const{typeof(LinearSolve.init)},
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)
Expand All @@ -77,7 +83,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
(dval.b for dval in dres)
end


prob_d_A = if EnzymeRules.width(config) == 1
prob.dval.A
else
Expand All @@ -92,7 +97,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
function EnzymeCore.EnzymeRules.reverse(

Check warning on line 100 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L100

Added line #L100 was not covered by tests
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b, prob_d_A, prob_d_b = cache

if EnzymeRules.width(config) == 1
Expand All @@ -105,7 +113,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
d_b .= 0
end
else
for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
for (_prob_d_A, _d_A, _prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)

Check warning on line 116 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L116

Added line #L116 was not covered by tests
if _d_A !== _prob_d_A
_prob_d_A .+= _d_A
_d_A .= 0
Expand All @@ -123,7 +131,10 @@ end
# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
function EnzymeCore.EnzymeRules.augmented_primal(

Check warning on line 134 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L134

Added line #L134 was not covered by tests
config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)

dres = if EnzymeRules.width(config) == 1
Expand Down Expand Up @@ -176,7 +187,9 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},

Check warning on line 190 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L190

Added line #L190 was not covered by tests
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache

@assert !(linsolve isa Const)
Expand All @@ -202,7 +215,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
end
end

dA .-= z * transpose(y)
db .+= z
Expand Down
3 changes: 2 additions & 1 deletion ext/LinearSolveFastAlmostBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module LinearSolveFastAlmostBandedMatricesExt

using FastAlmostBandedMatrices, LinearAlgebra, LinearSolve
import LinearSolve: defaultalg,
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
do_factorization, init_cacheval, DefaultLinearSolver,
DefaultAlgorithmChoice

function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions{Bool})
if oa.issq
Expand Down
44 changes: 22 additions & 22 deletions ext/LinearSolveHYPREExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using LinearAlgebra
using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning
OperatorAssumptions, default_tol, init_cacheval, __issquare,
__conditioning
using SciMLBase: LinearProblem, SciMLBase
using UnPack: @unpack
using Setfield: @set!
Expand All @@ -21,8 +21,8 @@ mutable struct HYPRECache
end

function LinearSolve.init_cacheval(alg::HYPREAlgorithm, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
return HYPRECache(nothing, nothing, nothing, nothing, true, true, true)
end

Expand Down Expand Up @@ -54,21 +54,21 @@ end
# fill!(similar(b, size(A, 2)), false) since HYPREArrays are not AbstractArrays.

function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
args...;
alias_A = false, alias_b = false,
# TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful
# even if it is not AbstractArray.
abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
reltol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
# TODO: Implement length() for HYPREVector in HYPRE.jl?
maxiters::Int = prob.b isa HYPREVector ? 1000 : length(prob.b),
verbose::Bool = false,
Pl = LinearAlgebra.I,
Pr = LinearAlgebra.I,
assumptions = OperatorAssumptions(),
kwargs...)
args...;
alias_A = false, alias_b = false,
# TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful
# even if it is not AbstractArray.
abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
reltol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex :
eltype(prob.A)),
# TODO: Implement length() for HYPREVector in HYPRE.jl?
maxiters::Int = prob.b isa HYPREVector ? 1000 : length(prob.b),
verbose::Bool = false,
Pl = LinearAlgebra.I,
Pr = LinearAlgebra.I,
assumptions = OperatorAssumptions(),
kwargs...)
@unpack A, b, u0, p = prob

A = A isa HYPREMatrix ? A : HYPREMatrix(A)
Expand All @@ -89,7 +89,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,
cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol),
typeof(__issquare(assumptions)),
typeof(__issquare(assumptions))
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
Expand Down Expand Up @@ -219,8 +219,8 @@ end

# HYPREArrays are not AbstractArrays so perform some type-piracy
function SciMLBase.LinearProblem(A::HYPREMatrix, b::HYPREVector,
p = SciMLBase.NullParameters();
u0::Union{HYPREVector, Nothing} = nothing, kwargs...)
p = SciMLBase.NullParameters();
u0::Union{HYPREVector, Nothing} = nothing, kwargs...)
return LinearProblem{true}(A, b, p; u0 = u0, kwargs)
end

Expand Down
17 changes: 9 additions & 8 deletions ext/LinearSolveIterativeSolversExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ else
end

function LinearSolve.IterativeSolversJL(args...;
generate_iterator = IterativeSolvers.gmres_iterable!,
gmres_restart = 0, kwargs...)
generate_iterator = IterativeSolvers.gmres_iterable!,
gmres_restart = 0, kwargs...)
return IterativeSolversJL(generate_iterator, gmres_restart,
args, kwargs)
end
Expand Down Expand Up @@ -49,9 +49,9 @@ LinearSolve.default_alias_A(::IterativeSolversJL, ::Any, ::Any) = true
LinearSolve.default_alias_b(::IterativeSolversJL, ::Any, ::Any) = true

function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
restart = (alg.gmres_restart == 0) ? min(20, size(A, 1)) : alg.gmres_restart
s = :idrs_s in keys(alg.kwargs) ? alg.kwargs.idrs_s : 4 # shadow space

Expand All @@ -69,10 +69,10 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max
elseif alg.generate_iterator === IterativeSolvers.idrs_iterable!
!!LinearSolve._isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
history = IterativeSolvers.ConvergenceHistory(partial=true)
history = IterativeSolvers.ConvergenceHistory(partial = true)
history[:abstol] = abstol
history[:reltol] = reltol
IterativeSolvers.idrs_iterable!(history, u, A, b, s, Pl, abstol, reltol, maxiters;
IterativeSolvers.idrs_iterable!(history, u, A, b, s, Pl, abstol, reltol, maxiters;
alg.kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
!!LinearSolve._isidentity_struct(Pr) &&
Expand Down Expand Up @@ -110,7 +110,8 @@ function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...
end
cache.verbose && println()

resid = cache.cacheval isa IterativeSolvers.IDRSIterable ? cache.cacheval.R : cache.cacheval.residual
resid = cache.cacheval isa IterativeSolvers.IDRSIterable ? cache.cacheval.R :
cache.cacheval.residual
if resid isa IterativeSolvers.Residual
resid = resid.current
end
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveKrylovKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using LinearSolve, KrylovKit, LinearAlgebra
using LinearSolve: LinearCache

function LinearSolve.KrylovKitJL(args...;
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
kwargs...)
KrylovAlg = KrylovKit.GMRES, gmres_restart = 0,
kwargs...)
return KrylovKitJL(KrylovAlg, gmres_restart, args, kwargs)
end

Expand Down
Loading

0 comments on commit c8104cf

Please sign in to comment.