From dfe11ac93272d89809ac783095a5896125bbfe82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Fuhrmann?= Date: Mon, 4 Nov 2024 16:10:21 +0100 Subject: [PATCH] At once, format everything --- docs/src/basics/FAQ.md | 2 +- ext/LinearSolvePardisoExt.jl | 7 +++- src/LinearSolve.jl | 75 ++++++++++++++++++------------------ src/common.jl | 25 ++++++------ src/default.jl | 8 ++-- src/extension_algs.jl | 2 +- test/basictests.jl | 33 ++++++++-------- test/enzyme.jl | 6 +-- test/resolve.jl | 3 +- test/retcodes.jl | 2 +- 10 files changed, 84 insertions(+), 79 deletions(-) diff --git a/docs/src/basics/FAQ.md b/docs/src/basics/FAQ.md index 8e07a995..293468c6 100644 --- a/docs/src/basics/FAQ.md +++ b/docs/src/basics/FAQ.md @@ -83,5 +83,5 @@ Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(we Pr = Diagonal(weights) prob = LinearProblem(A, b) -sol = solve(prob, KrylovJL_GMRES(precs=Returns((Pl,Pr)))) +sol = solve(prob, KrylovJL_GMRES(precs = Returns((Pl, Pr)))) ``` diff --git a/ext/LinearSolvePardisoExt.jl b/ext/LinearSolvePardisoExt.jl index 0318bb8a..3f7db0c9 100644 --- a/ext/LinearSolvePardisoExt.jl +++ b/ext/LinearSolvePardisoExt.jl @@ -134,11 +134,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs if cache.isfresh phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT Pardiso.set_phase!(cache.cacheval, phase) - Pardiso.pardiso(cache.cacheval, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), eltype(A)[]) + Pardiso.pardiso(cache.cacheval, + SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), + eltype(A)[]) cache.isfresh = false end Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE) - Pardiso.pardiso(cache.cacheval, u, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b) + Pardiso.pardiso(cache.cacheval, u, + SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b) return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 4e6c9e25..ef86452b 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -5,47 +5,48 @@ if isdefined(Base, :Experimental) && end import PrecompileTools - using ArrayInterface - using RecursiveFactorization - using Base: cache_dependencies, Bool - using LinearAlgebra - using SparseArrays - using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr - using LazyArrays: @~, BroadcastArray - using SciMLBase: AbstractLinearAlgorithm - using SciMLOperators - using SciMLOperators: AbstractSciMLOperator, IdentityOperator - using Setfield - using UnPack - using KLU - using Sparspak - using FastLapackInterface - using DocStringExtensions - using EnumX - using Markdown - using ChainRulesCore - import InteractiveUtils - - import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix - - using LinearAlgebra: BlasInt, LU - using LinearAlgebra.LAPACK: require_one_based_indexing, - chkfinite, chkstride1, - @blasfunc, chkargsok - - import GPUArraysCore - import Preferences - import ConcreteStructs: @concrete - - # wrap - import Krylov - using SciMLBase - import Preferences +using ArrayInterface +using RecursiveFactorization +using Base: cache_dependencies, Bool +using LinearAlgebra +using SparseArrays +using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, getcolptr +using LazyArrays: @~, BroadcastArray +using SciMLBase: AbstractLinearAlgorithm +using SciMLOperators +using SciMLOperators: AbstractSciMLOperator, IdentityOperator +using Setfield +using UnPack +using KLU +using Sparspak +using FastLapackInterface +using DocStringExtensions +using EnumX +using Markdown +using ChainRulesCore +import InteractiveUtils + +import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix + +using LinearAlgebra: BlasInt, LU +using LinearAlgebra.LAPACK: require_one_based_indexing, + chkfinite, chkstride1, + @blasfunc, chkargsok + +import GPUArraysCore +import Preferences +import ConcreteStructs: @concrete + +# wrap +import Krylov +using SciMLBase +import Preferences const CRC = ChainRulesCore @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 - if Preferences.@load_preference("LoadMKL_JLL", !occursin("EPYC", Sys.cpu_info()[1].model)) + if Preferences.@load_preference("LoadMKL_JLL", + !occursin("EPYC", Sys.cpu_info()[1].model)) using MKL_jll const usemkl = MKL_jll.is_available() else diff --git a/src/common.jl b/src/common.jl index 3f145fc6..bcaaab1f 100644 --- a/src/common.jl +++ b/src/common.jl @@ -150,7 +150,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, assumptions = OperatorAssumptions(issquare(prob.A)), sensealg = LinearSolveAdjoint(), kwargs...) - (;A, b, u0, p) = prob + (; A, b, u0, p) = prob A = if alias_A || A isa SMatrix A @@ -206,22 +206,21 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, + typeof(sensealg)}( + A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache end - function SciMLBase.reinit!(cache::LinearCache; - A = nothing, - b = cache.b, - u = cache.u, - p = nothing, - reinit_cache = false, - reuse_precs = false) + A = nothing, + b = cache.b, + u = cache.u, + p = nothing, + reinit_cache = false, + reuse_precs = false) (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache - isfresh = !isnothing(A) precsisfresh = !reuse_precs && (isfresh || !isnothing(p)) isfresh |= cache.isfresh @@ -234,9 +233,11 @@ function SciMLBase.reinit!(cache::LinearCache; Pl = cache.Pl Pr = cache.Pr if reinit_cache - return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval), + return LinearCache{ + typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval), typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol, + typeof(sensealg)}( + A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) else cache.A = A diff --git a/src/default.jl b/src/default.jl index 525c5360..a0ebb705 100644 --- a/src/default.jl +++ b/src/default.jl @@ -179,8 +179,8 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool}) __conditioning(assump) === OperatorCondition.WellConditioned) if length(b) <= 10 DefaultAlgorithmChoice.RFLUFactorization - elseif appleaccelerate_isavailable() && b isa Array && - eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64} + elseif appleaccelerate_isavailable() && b isa Array && + eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64} DefaultAlgorithmChoice.AppleAccelerateLUFactorization elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) || (usemkl && length(b) <= 200)) && @@ -189,8 +189,8 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool}) DefaultAlgorithmChoice.RFLUFactorization #elseif A === nothing || A isa Matrix # alg = FastLUFactorization() - elseif usemkl && b isa Array && - eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64} + elseif usemkl && b isa Array && + eltype(b) <: Union{Float32, Float64, ComplexF32, ComplexF64} DefaultAlgorithmChoice.MKLLUFactorization else DefaultAlgorithmChoice.LUFactorization diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 2559a210..da444912 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -217,7 +217,7 @@ All values default to `nothing` and the solver internally determines the values given the input types, and these keyword arguments are only for overriding the default handling process. This should not be required by most users. """ -struct PardisoJL{T1, T2} <: AbstractSparseFactorization +struct PardisoJL{T1, T2} <: AbstractSparseFactorization nprocs::Union{Int, Nothing} solver_type::T1 matrix_type::T2 diff --git a/test/basictests.jl b/test/basictests.jl index e9492e4d..f9ddd302 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -267,12 +267,12 @@ end @testset "KrylovJL" begin kwargs = (; gmres_restart = 5) - precs = (A,p=nothing) -> (BlockJacobiPreconditioner(A, 2), I) + precs = (A, p = nothing) -> (BlockJacobiPreconditioner(A, 2), I) algorithms = ( ("Default", KrylovJL(kwargs...)), ("CG", KrylovJL_CG(kwargs...)), ("GMRES", KrylovJL_GMRES(kwargs...)), - ("GMRES_prec", KrylovJL_GMRES(;precs, ldiv=false, kwargs...)), + ("GMRES_prec", KrylovJL_GMRES(; precs, ldiv = false, kwargs...)), # ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)), ("MINRES", KrylovJL_MINRES(kwargs...)) ) @@ -579,28 +579,27 @@ end # test default algorithn @time "solve MySparseMatrixCSC" u=solve(pr) @test norm(u - u0, Inf) < 1.0e-13 - + # test Krylov algorithm with reinit! pr = LinearProblem(B, b) - solver=KrylovJL_CG() - cache=init(pr,solver,maxiters=1000,reltol=1.0e-10) - u=solve!(cache) + solver = KrylovJL_CG() + cache = init(pr, solver, maxiters = 1000, reltol = 1.0e-10) + u = solve!(cache) A1 = spdiagm(1 => -ones(N - 1), 0 => fill(100.0, N), -1 => -ones(N - 1)) - b1=A1*u0 - B1= MySparseMatrixCSC(A1) + b1 = A1 * u0 + B1 = MySparseMatrixCSC(A1) @test norm(u - u0, Inf) < 1.0e-8 - reinit!(cache; A=B1, b=b1) - u=solve!(cache) + reinit!(cache; A = B1, b = b1) + u = solve!(cache) @test norm(u - u0, Inf) < 1.0e-8 - + # test factorization with reinit! pr = LinearProblem(B, b) - solver=SparspakFactorization() - cache=init(pr,solver) - u=solve!(cache) + solver = SparspakFactorization() + cache = init(pr, solver) + u = solve!(cache) @test norm(u - u0, Inf) < 1.0e-8 - reinit!(cache; A=B1, b=b1) - u=solve!(cache) + reinit!(cache; A = B1, b = b1) + u = solve!(cache) @test norm(u - u0, Inf) < 1.0e-8 - end diff --git a/test/enzyme.jl b/test/enzyme.jl index b09c0de5..1650c453 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -207,9 +207,9 @@ end @show fd_jac en_jac = map(onehot(A)) do dA - return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, - Duplicated(A, dA), Const(b1), Const(alg))) - end |> collect + return only(Enzyme.autodiff(set_runtime_activity(Forward), fnice, + Duplicated(A, dA), Const(b1), Const(alg))) + end |> collect @show en_jac @test en_jac≈fd_jac rtol=1e-4 diff --git a/test/resolve.jl b/test/resolve.jl index a567a90b..9143aecb 100644 --- a/test/resolve.jl +++ b/test/resolve.jl @@ -1,7 +1,8 @@ using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization -for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization),InteractiveUtils.subtypes(AbstractSparseFactorization)) +for alg in vcat(InteractiveUtils.subtypes(AbstractDenseFactorization), + InteractiveUtils.subtypes(AbstractSparseFactorization)) @show alg if !(alg in [ DiagonalFactorization, diff --git a/test/retcodes.jl b/test/retcodes.jl index 61b4ecaa..c75442d3 100644 --- a/test/retcodes.jl +++ b/test/retcodes.jl @@ -20,7 +20,7 @@ alglist = ( AppleAccelerateLUFactorization, MKLLUFactorization, KrylovJL_CRAIGMR, - KrylovJL_LSMR, + KrylovJL_LSMR ) @testset "Success" begin