Skip to content

Commit

Permalink
More work on kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 31, 2023
1 parent d714089 commit f2d65d2
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 35 deletions.
6 changes: 1 addition & 5 deletions NDTensors/src/DiagonalArrays/examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
#
# A Julia `DiagonalArray` type.

using NDTensors.DiagonalArrays:
DiagonalArray,
DiagIndex,
DiagIndices,
densearray
using NDTensors.DiagonalArrays: DiagonalArray, DiagIndex, DiagIndices, densearray

d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
@show d[1, 1, 1] == 1
Expand Down
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include("exports.jl")
#####################################
# General functionality
#
include("default_kwargs.jl")
include("algorithm.jl")
include("aliasstyle.jl")
include("abstractarray/set_types.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function contract(
array2::MatrixOrArrayStorage,
labels2,
labelsR=contract_labels(labels1, labels2);
kwargs...
kwargs...,
)
output_array = contraction_output(array1, labels1, array2, labels2, labelsR)
contract!(output_array, labelsR, array1, labels1, array2, labels2; kwargs...)
Expand Down
14 changes: 0 additions & 14 deletions NDTensors/src/arraystorage/arraystorage/tensor/svd.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
default_maxdim(a) = minimum(size(a))
default_mindim(a) = true
default_cutoff(a) = zero(eltype(a))
default_svd_alg(a) = "divide_and_conquer"
default_use_absolute_cutoff(a) = false
default_use_relative_cutoff(a) = true

# TODO: Rewrite this function to be more modern:
# 1. Output `Spectrum` as a keyword argument that gets overwritten.
# 2. Dispatch on `alg`.
Expand All @@ -22,13 +15,6 @@ function svd(
alg=default_svd_alg(T),
use_absolute_cutoff=default_use_absolute_cutoff(T),
use_relative_cutoff=default_use_relative_cutoff(T),
# These are getting passed erroneously.
# TODO: Make sure they don't get passed down
# to here.
## which_decomp=nothing,
## tags=nothing,
## eigen_perturbation=nothing,
## normalize=nothing,
)
truncate = !isnothing(maxdim) || !isnothing(cutoff)
maxdim = isnothing(maxdim) ? default_maxdim(T) : maxdim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function contract!(
for i in 1:min_dim
c₁ += A[DiagIndex(i)] * B[DiagIndex(i)]
end
C[DiagIndex(1)] = α * c₁ + β * C[DiagIndex(1)]
C[DiagIndex(1)] = α * c₁ + β * C[DiagIndex(1)]
else
# not all indices are summed over, set the diagonals of the result
# to the product of the diagonals of A and B
Expand Down
6 changes: 6 additions & 0 deletions NDTensors/src/default_kwargs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
default_maxdim(a) = minimum(size(a))
default_mindim(a) = true
default_cutoff(a) = zero(eltype(a))
default_svd_alg(a) = "divide_and_conquer"
default_use_absolute_cutoff(a) = false
default_use_relative_cutoff(a) = true
9 changes: 6 additions & 3 deletions NDTensors/test/arraytensor/diagonalarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ using NDTensors: storage, storagetype
A = randn(3, 4, 5)

for convert_to_dense in (true, false)
@test contract(D, (-1, -2, -3), A, (-1, -2, -3); convert_to_dense) contract(Dᵈ, (-1, -2, -3), A, (-1, -2, -3))
@test contract(D, (-1, -2, 1), A, (-1, -2, 2); convert_to_dense) contract(Dᵈ, (-1, -2, 1), A, (-1, -2, 2))
@test contract(D, (-1, -2, -3), A, (-1, -2, -3); convert_to_dense)
contract(Dᵈ, (-1, -2, -3), A, (-1, -2, -3))
@test contract(D, (-1, -2, 1), A, (-1, -2, 2); convert_to_dense)
contract(Dᵈ, (-1, -2, 1), A, (-1, -2, 2))
end

# Tensor tests
Dᵗ = tensor(D, size(D))
Dᵈᵗ = tensor(Dᵈ, size(D))
Aᵗ = tensor(A, size(A))
@test contract(Dᵗ, (-1, -2, -3), Aᵗ, (-1, -2, -3)) contract(Dᵈᵗ, (-1, -2, -3), Aᵗ, (-1, -2, -3))
@test contract(Dᵗ, (-1, -2, -3), Aᵗ, (-1, -2, -3))
contract(Dᵈᵗ, (-1, -2, -3), Aᵗ, (-1, -2, -3))
end
43 changes: 32 additions & 11 deletions src/tensor_operations/matrix_decomposition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,23 @@ Utrunc2, Strunc2, Vtrunc2 = svd(A, i, k; cutoff=1e-10);
See also: [`factorize`](@ref), [`eigen`](@ref)
"""
function svd(A::ITensor, Linds...; leftdir=nothing, rightdir=nothing, kwargs...)
utags::TagSet = get(kwargs, :lefttags, get(kwargs, :utags, "Link,u"))
vtags::TagSet = get(kwargs, :righttags, get(kwargs, :vtags, "Link,v"))

# Keyword argument deprecations
#if haskey(kwargs, :utags) || haskey(kwargs, :vtags)
# @warn "Keyword arguments `utags` and `vtags` are deprecated in favor of `leftags` and `righttags`."
#end

function svd(
A::ITensor,
Linds...;
leftdir=nothing,
rightdir=nothing,
lefttags="Link,u",
righttags="Link,v",
mindim=NDTensors.default_mindim(A),
maxdim=nothing,
cutoff=nothing,
alg=NDTensors.default_svd_alg(A),
use_absolute_cutoff=NDTensors.default_use_absolute_cutoff(A),
use_relative_cutoff=NDTensors.default_use_relative_cutoff(A),
# Deprecated
utags=lefttags,
vtags=righttags,
)
Lis = commoninds(A, indices(Linds...))
Ris = uniqueinds(A, Lis)

Expand Down Expand Up @@ -142,7 +150,9 @@ function svd(A::ITensor, Linds...; leftdir=nothing, rightdir=nothing, kwargs...)
AC = permute(AC, cL, cR)
end

USVT = svd(tensor(AC); kwargs...)
USVT = svd(
tensor(AC); mindim, maxdim, cutoff, alg, use_absolute_cutoff, use_relative_cutoff
)
if isnothing(USVT)
return nothing
end
Expand Down Expand Up @@ -564,7 +574,18 @@ function factorize_svd(
normalize=nothing,
)
leftdir, rightdir = -dir, -dir
USV = svd(A, Linds...; leftdir, rightdir, alg=svd_alg, mindim, maxdim, cutoff, tags)
USV = svd(
A,
Linds...;
leftdir,
rightdir,
alg=svd_alg,
mindim,
maxdim,
cutoff,
lefttags=tags,
righttags=tags,
)
if isnothing(USV)
return nothing
end
Expand Down

0 comments on commit f2d65d2

Please sign in to comment.