Skip to content

Commit

Permalink
Fix SVD
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 31, 2023
1 parent f2d65d2 commit 38b32a6
Showing 1 changed file with 13 additions and 31 deletions.
44 changes: 13 additions & 31 deletions NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,36 +93,18 @@ end
svd of an order-2 DenseTensor
"""
function svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT}
truncate = haskey(kwargs, :maxdim) || haskey(kwargs, :cutoff)

#
# Keyword argument deprecations
#
use_absolute_cutoff = false
if haskey(kwargs, :absoluteCutoff)
@warn "In svd, keyword argument absoluteCutoff is deprecated in favor of use_absolute_cutoff"
use_absolute_cutoff = get(kwargs, :absoluteCutoff, use_absolute_cutoff)
end

use_relative_cutoff = true
if haskey(kwargs, :doRelCutoff)
@warn "In svd, keyword argument doRelCutoff is deprecated in favor of use_relative_cutoff"
use_relative_cutoff = get(kwargs, :doRelCutoff, use_relative_cutoff)
end

if haskey(kwargs, :fastsvd) || haskey(kwargs, :fastSVD)
error(
"In svd, fastsvd/fastSVD keyword arguments are removed in favor of alg, see documentation for more details.",
)
end

maxdim::Int = get(kwargs, :maxdim, minimum(dims(T)))
mindim::Int = get(kwargs, :mindim, 1)
cutoff = get(kwargs, :cutoff, 0.0)
use_absolute_cutoff::Bool = get(kwargs, :use_absolute_cutoff, use_absolute_cutoff)
use_relative_cutoff::Bool = get(kwargs, :use_relative_cutoff, use_relative_cutoff)
alg::String = get(kwargs, :alg, "divide_and_conquer")
function svd(
T::DenseTensor{ElT,2,IndsT};
mindim=default_mindim(T),
maxdim=nothing,
cutoff=nothing,
alg=default_svd_alg(T),
use_absolute_cutoff=default_use_absolute_cutoff(T),
use_relative_cutoff=default_use_relative_cutoff(T),
) where {ElT,IndsT}
truncate = !isnothing(maxdim) || !isnothing(cutoff)
maxdim = isnothing(maxdim) ? default_maxdim(T) : maxdim
cutoff = isnothing(cutoff) ? default_cutoff(T) : cutoff

#@timeit_debug timer "dense svd" begin
if alg == "divide_and_conquer"
Expand Down Expand Up @@ -168,7 +150,7 @@ function svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT}
P = MS .^ 2
if truncate
P, truncerr, _ = truncate!!(
P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs...
P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff
)
else
truncerr = 0.0
Expand Down

0 comments on commit 38b32a6

Please sign in to comment.