Skip to content

Commit

Permalink
try simpler gpu code
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Oct 4, 2023
1 parent 9ea818d commit ee0ecf1
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@ using SciMLBase: AbstractSciMLOperator
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
kwargs...)
if cache.isfresh
fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u)
fact = lu(CUDA.CuArray(cache.A))
cache.cacheval = fact
cache.isfresh = false
end

copyto!(cache.u, cache.b)
y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u)))
y = Array(ldiv!(cache.u, cache.cacheval, CUDA.CuArray(cache.u)))
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
fact = lu(CUDA.CuArray(A))
return fact
function init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(CUDA.CuArray(A))
end

end

0 comments on commit ee0ecf1

Please sign in to comment.