diff --git a/NDTensors/src/linearalgebra/svd.jl b/NDTensors/src/linearalgebra/svd.jl index 873aca779c..59b5a795fc 100644 --- a/NDTensors/src/linearalgebra/svd.jl +++ b/NDTensors/src/linearalgebra/svd.jl @@ -1,9 +1,10 @@ -function checkSVDDone(S::AbstractArray, thresh::Float64) - return checkSVDDone(leaf_parenttype(S), S, thresh) +# The state of the `svd_recursive` algorithm. +function svd_recursive_state(S::AbstractArray, thresh::Float64) + return svd_recursive_state(leaf_parenttype(S), S, thresh) end # CPU version. -function checkSVDDone(::Type{<:Array}, S::AbstractArray, thresh::Float64) +function svd_recursive_state(::Type{<:Array}, S::AbstractArray, thresh::Float64) N = length(S) (N <= 1 || thresh < 0.0) && return (true, 1) S1t = S[1] * thresh @@ -20,8 +21,8 @@ end # Convert to CPU to avoid slow scalar indexing # on GPU. -function checkSVDDone(::Type{<:AbstractArray}, S::AbstractArray, thresh::Float64) - return checkSVDDone(Array, cpu(S), thresh) +function svd_recursive_state(::Type{<:AbstractArray}, S::AbstractArray, thresh::Float64) + return svd_recursive_state(Array, cpu(S), thresh) end function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=2) @@ -44,7 +45,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int= V, R = qr_positive(V) D[1:Nd] = diag(R)[1:Nd] - (done, start) = checkSVDDone(D, thresh) + (done, start) = svd_recursive_state(D, thresh) done && return U, D, V