Skip to content

Commit

Permalink
Merge pull request #112 from vpuri3/vp-ldiv
Browse files Browse the repository at this point in the history
custom linear solve function
  • Loading branch information
ChrisRackauckas authored Apr 7, 2022
2 parents 5a75de5 + b692761 commit c960f8f
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ makedocs(
],
"Advanced" => Any[
"advanced/developing.md"
"advanced/custom.md"
]
]
)
Expand Down
55 changes: 55 additions & 0 deletions docs/src/advanced/custom.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Passing in a Custom Linear Solver
Julia users are building a wide variety of applications in the SciML ecosystem,
often requiring problem-specific handling of their linear solves. As existing solvers in `LinearSolve.jl` may not
be optimally suited for novel applications, it is essential for the linear solve
interface to be easily extendable by users. To that end, the linear solve algorithm
`LinearSolveFunction()` accepts a user-defined function for handling the solve. A
user can pass in their custom linear solve function, say `my_linsolve`, to
`LinearSolveFunction()`. A contrived example of solving a linear system with a custom solver is below.
```julia
using LinearSolve, LinearAlgebra

function my_linsolve(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
if verbose == true
println("solving Ax=b")
end
u = A \ b
return u
end

prob = LinearProblem(Diagonal(rand(4)), rand(4))
alg = LinearSolveFunction(my_linsolve),
sol = solve(prob, alg)
```
The inputs to the function are as follows:
- `A`, the linear operator
- `b`, the right-hand-side
- `u`, the solution initialized as `zero(b)`,
- `p`, a set of parameters
- `newA`, a `Bool` which is `true` if `A` has been modified since last solve
- `Pl`, left-preconditioner
- `Pr`, right-preconditioner
- `solverdata`, solver cache set to `nothing` if solver hasn't been initialized
- `kwargs`, standard SciML keyword arguments such as `verbose`, `maxiters`,
`abstol`, `reltol`
The function `my_linsolve` must accept the above specified arguments, and return
the solution, `u`. As memory for `u` is already allocated, the user may choose
to modify `u` in place as follows:
```julia
function my_linsolve!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
if verbose == true
println("solving Ax=b")
end
u .= A \ b # in place
return u
end

alg = LinearSolveFunction(my_linsolve!)
sol = solve(prob, alg)
```
Finally, note that `LinearSolveFunction()` dispatches to the default linear solve
algorithm handling if no arguments are passed in.
```julia
alg = LinearSolveFunction()
sol = solve(prob, alg) # same as solve(prob, nothing)
```
10 changes: 9 additions & 1 deletion docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@ Pardiso.jl's methods are also known to be very efficient sparse linear solvers.

As sparse matrices get larger, iterative solvers tend to get more efficient than
factorization methods if a lower tolerance of the solution is required.
Krylov.jl works with CPUs and GPUs and tends to be more efficient than other

IterativeSolvers.jl uses a low-rank Q update in its GMRES so it tends to be
faster than Krylov.jl for CPU-based arrays, but it's only compatible with
CPU-based arrays while Krylov.jl is more general and will support accelerators
like CUDA. Krylov.jl works with CPUs and GPUs and tends to be more efficient than other
Krylov-based methods.

Finally, a user can pass a custom function ofr the linear solve using
`LinearSolveFunction()` if existing solvers are not optimal for their application.
The interface is detailed [here](#passing-in-a-custom-linear-solver)

## Full List of Methods

### RecursiveFactorization.jl
Expand Down
6 changes: 6 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ using Reexport
abstract type SciMLLinearSolveAlgorithm <: SciMLBase.AbstractLinearAlgorithm end
abstract type AbstractFactorization <: SciMLLinearSolveAlgorithm end
abstract type AbstractKrylovSubspaceMethod <: SciMLLinearSolveAlgorithm end
abstract type AbstractSolveFunction <: SciMLLinearSolveAlgorithm end

# Traits

needs_concrete_A(alg::AbstractFactorization) = true
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Code

Expand All @@ -39,6 +41,7 @@ include("factorization.jl")
include("simplelu.jl")
include("iterative_wrappers.jl")
include("preconditioners.jl")
include("solve_function.jl")
include("default.jl")
include("init.jl")

Expand All @@ -48,6 +51,9 @@ isopenblas() = IS_OPENBLAS[]
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization,
UMFPACKFactorization, KLUFactorization

export LinearSolveFunction

export KrylovJL, KrylovJL_CG, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_MINRES,
IterativeSolversJL, IterativeSolversJL_CG, IterativeSolversJL_GMRES,
IterativeSolversJL_BICGSTAB, IterativeSolversJL_MINRES,
Expand Down
2 changes: 1 addition & 1 deletion src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function set_cacheval(cache::LinearCache, alg_cache)
return cache
end

init_cacheval(alg::SciMLLinearSolveAlgorithm, A, b, u) = nothing
init_cacheval(alg::SciMLLinearSolveAlgorithm, args...) = nothing

SciMLBase.init(prob::LinearProblem, args...; kwargs...) = SciMLBase.init(prob,nothing,args...;kwargs...)

Expand Down
15 changes: 15 additions & 0 deletions src/solve_function.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
struct LinearSolveFunction{F} <: AbstractSolveFunction
solve_func::F
end

function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,
args...; kwargs...)
@unpack A,b,u,p,isfresh,Pl,Pr,cacheval = cache
@unpack solve_func = alg

u = solve_func(A,b,u,p,isfresh,Pl,Pr,cacheval;kwargs...)
cache = set_u(cache, u)

return SciMLBase.build_linear_solution(alg,cache.u,nothing,cache)
end
31 changes: 31 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,35 @@ end
@test sol13.u sol33.u
end

@testset "Solve Function" begin

A1 = rand(n) |> Diagonal; b1 = rand(n); x1 = zero(b1)
A2 = rand(n) |> Diagonal; b2 = rand(n); x2 = zero(b1)

function sol_func(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
if verbose == true
println("out-of-place solve")
end
u = A \ b
end

function sol_func!(A,b,u,p,newA,Pl,Pr,solverdata;verbose=true, kwargs...)
if verbose == true
println("in-place solve")
end
ldiv!(u,A,b)
end

prob1 = LinearProblem(A1, b1; u0=x1)
prob2 = LinearProblem(A1, b1; u0=x1)

for alg in (
LinearSolveFunction(sol_func),
LinearSolveFunction(sol_func!),
)

test_interface(alg, prob1, prob2)
end
end

end # testset

0 comments on commit c960f8f

Please sign in to comment.