From 38b32a695ff8664b3840385dc63b65fa734a3c40 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 31 Oct 2023 12:03:49 -0400 Subject: [PATCH] Fix SVD --- NDTensors/src/linearalgebra/linearalgebra.jl | 44 ++++++-------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/NDTensors/src/linearalgebra/linearalgebra.jl b/NDTensors/src/linearalgebra/linearalgebra.jl index 20dec6b3a9..ae11b1eaff 100644 --- a/NDTensors/src/linearalgebra/linearalgebra.jl +++ b/NDTensors/src/linearalgebra/linearalgebra.jl @@ -93,36 +93,18 @@ end svd of an order-2 DenseTensor """ -function svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT} - truncate = haskey(kwargs, :maxdim) || haskey(kwargs, :cutoff) - - # - # Keyword argument deprecations - # - use_absolute_cutoff = false - if haskey(kwargs, :absoluteCutoff) - @warn "In svd, keyword argument absoluteCutoff is deprecated in favor of use_absolute_cutoff" - use_absolute_cutoff = get(kwargs, :absoluteCutoff, use_absolute_cutoff) - end - - use_relative_cutoff = true - if haskey(kwargs, :doRelCutoff) - @warn "In svd, keyword argument doRelCutoff is deprecated in favor of use_relative_cutoff" - use_relative_cutoff = get(kwargs, :doRelCutoff, use_relative_cutoff) - end - - if haskey(kwargs, :fastsvd) || haskey(kwargs, :fastSVD) - error( - "In svd, fastsvd/fastSVD keyword arguments are removed in favor of alg, see documentation for more details.", - ) - end - - maxdim::Int = get(kwargs, :maxdim, minimum(dims(T))) - mindim::Int = get(kwargs, :mindim, 1) - cutoff = get(kwargs, :cutoff, 0.0) - use_absolute_cutoff::Bool = get(kwargs, :use_absolute_cutoff, use_absolute_cutoff) - use_relative_cutoff::Bool = get(kwargs, :use_relative_cutoff, use_relative_cutoff) - alg::String = get(kwargs, :alg, "divide_and_conquer") +function svd( + T::DenseTensor{ElT,2,IndsT}; + mindim=default_mindim(T), + maxdim=nothing, + cutoff=nothing, + alg=default_svd_alg(T), + use_absolute_cutoff=default_use_absolute_cutoff(T), + use_relative_cutoff=default_use_relative_cutoff(T), +) where {ElT,IndsT} + truncate = !isnothing(maxdim) || !isnothing(cutoff) + maxdim = isnothing(maxdim) ? default_maxdim(T) : maxdim + cutoff = isnothing(cutoff) ? default_cutoff(T) : cutoff #@timeit_debug timer "dense svd" begin if alg == "divide_and_conquer" @@ -168,7 +150,7 @@ function svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,IndsT} P = MS .^ 2 if truncate P, truncerr, _ = truncate!!( - P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff, kwargs... + P; mindim, maxdim, cutoff, use_absolute_cutoff, use_relative_cutoff ) else truncerr = 0.0