Skip to content

Commit

Permalink
Rename checkSVDDone
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 23, 2023
1 parent f56a3b3 commit db828c7
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions NDTensors/src/linearalgebra/svd.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
function checkSVDDone(S::AbstractArray, thresh::Float64)
return checkSVDDone(leaf_parenttype(S), S, thresh)
# The state of the `svd_recursive` algorithm.
function svd_recursive_state(S::AbstractArray, thresh::Float64)
return svd_recursive_state(leaf_parenttype(S), S, thresh)
end

# CPU version.
function checkSVDDone(::Type{<:Array}, S::AbstractArray, thresh::Float64)
function svd_recursive_state(::Type{<:Array}, S::AbstractArray, thresh::Float64)
N = length(S)
(N <= 1 || thresh < 0.0) && return (true, 1)
S1t = S[1] * thresh
Expand All @@ -20,8 +21,8 @@ end

# Convert to CPU to avoid slow scalar indexing
# on GPU.
function checkSVDDone(::Type{<:AbstractArray}, S::AbstractArray, thresh::Float64)
return checkSVDDone(Array, cpu(S), thresh)
function svd_recursive_state(::Type{<:AbstractArray}, S::AbstractArray, thresh::Float64)
return svd_recursive_state(Array, cpu(S), thresh)
end

function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=2)
Expand All @@ -44,7 +45,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=
V, R = qr_positive(V)
D[1:Nd] = diag(R)[1:Nd]

(done, start) = checkSVDDone(D, thresh)
(done, start) = svd_recursive_state(D, thresh)

done && return U, D, V

Expand Down

0 comments on commit db828c7

Please sign in to comment.