Skip to content

Commit

Permalink
Removed support for Vector{ITensor} in contract.jl functions for now
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Apr 10, 2024
1 parent 6c91629 commit fa0f8a7
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 132 deletions.
82 changes: 82 additions & 0 deletions src/ITensorsExtensions/ITensorsExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module ITensorsExtensions
using LinearAlgebra: LinearAlgebra, eigen, pinv
using ITensors:
ITensor,
Index,
commonind,
dag,
hasqns,
inds,
isdiag,
itensor,
map_diag,
noncommonind,
noprime,
replaceinds,
space,
sqrt_decomp
using ITensors.NDTensors:
NDTensors,
Block,
Tensor,
blockdim,
blockoffsets,
denseblocks,
diaglength,
getdiagindex,
nzblocks,
setdiagindex!,
svd,
tensor,
DiagBlockSparseTensor,
DenseTensor,
BlockOffsets
using Observers: update!, insert_function!

function NDTensors.blockoffsets(dense::DenseTensor)
return BlockOffsets{ndims(dense)}([Block(ntuple(Returns(1), ndims(dense)))], [0])
end
function NDTensors.nzblocks(dense::DenseTensor)
return nzblocks(blockoffsets(dense))
end
NDTensors.blockdim(ind::Int, ::Block{1}) = ind
NDTensors.blockdim(i::Index{Int}, b::Integer) = blockdim(i, Block(b))
NDTensors.blockdim(i::Index{Int}, b::Block) = blockdim(space(i), b)

LinearAlgebra.isdiag(it::ITensor) = isdiag(tensor(it))

# Convenience functions
sqrt_diag(it::ITensor) = map_diag(sqrt, it)
inv_diag(it::ITensor) = map_diag(inv, it)
invsqrt_diag(it::ITensor) = map_diag(inv sqrt, it)
pinv_diag(it::ITensor) = map_diag(pinv, it)
pinvsqrt_diag(it::ITensor) = map_diag(pinv sqrt, it)

function map_itensor(
f::Function, A::ITensor, lind=first(inds(A)); regularization=nothing, kwargs...
)
USV = svd(A, lind; kwargs...)
U, S, V, spec, u, v = USV
S = map_diag(s -> f(s + regularization), S)
sqrtDL, δᵤᵥ, sqrtDR = sqrt_decomp(S, u, v)
sqrtDR = denseblocks(sqrtDR) * denseblocks(δᵤᵥ)
L, R = U * sqrtDL, V * sqrtDR
return L * R
end

# Analagous to `denseblocks`.
# Extract the diagonal entries into a diagonal tensor.
function diagblocks(D::Tensor)
nzblocksD = nzblocks(D)
T = DiagBlockSparseTensor(eltype(D), nzblocksD, inds(D))
for b in nzblocksD
for n in 1:diaglength(D)
setdiagindex!(T, getdiagindex(D, n), n)
end
end
return T
end

diagblocks(it::ITensor) = itensor(diagblocks(tensor(it)))

end
3 changes: 2 additions & 1 deletion src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,10 @@ function inner_network(
return BilinearFormNetwork(A, x, y; kwargs...)
end

# TODO: We should make this pass to inner_network and then to BiLinearForm.
# TODO: We should make this use the QuadraticFormNetwork constructor here.
# Parts of the code (tests relying on norm_sqr being two layer and the gauging code
# which relies on specific message tensors) currently would break in that case so we need to resolve
# We could have the option in the Form constructors to pre-contract the operator into the bra or ket
function norm_sqr_network::AbstractITensorNetwork)
return disjoint_union("bra" => dag(prime(ψ; sites=[])), "ket" => ψ)
end
Expand Down
4 changes: 2 additions & 2 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ end
function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
incoming_mts = environment(bp_cache, [pv])
local_state = factor(bp_cache, pv)
return scalar(vcat(incoming_mts, local_state))
return contract(vcat(incoming_mts, local_state))[]
end

function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
return scalar(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))
return contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
end

function vertex_scalars(
Expand Down
33 changes: 4 additions & 29 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ function NDTensors.contract(
return contract(Vector{ITensor}(tn); sequence=sequence_linear_index, kwargs...)
end

function NDTensors.contract(alg::Algorithm"exact", tensors::Vector{ITensor}; kwargs...)
return contract(tensors; kwargs...)
end

function NDTensors.contract(
alg::Union{Algorithm"density_matrix",Algorithm"ttn_svd"},
tn::AbstractITensorNetwork;
Expand All @@ -28,40 +24,19 @@ function NDTensors.contract(
return approx_tensornetwork(alg, tn, output_structure; kwargs...)
end

function contract_density_matrix(
contract_list::Vector{ITensor}; normalize=true, contractor_kwargs...
)
tn, _ = contract(
ITensorNetwork(contract_list); alg="density_matrix", contractor_kwargs...
)
out = Vector{ITensor}(tn)
if normalize
out .= normalize!.(copy.(out))
end
return out
end

function ITensors.scalar(
alg::Algorithm, tn::Union{AbstractITensorNetwork,Vector{ITensor}}; kwargs...
)
function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...)
return contract(alg, tn; kwargs...)[]
end

function ITensors.scalar(
tn::Union{AbstractITensorNetwork,Vector{ITensor}}; alg="exact", kwargs...
)
function ITensors.scalar(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return scalar(Algorithm(alg), tn; kwargs...)
end

function logscalar(
tn::Union{AbstractITensorNetwork,Vector{ITensor}}; alg="exact", kwargs...
)
function logscalar(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return logscalar(Algorithm(alg), tn; kwargs...)
end

function logscalar(
alg::Algorithm"exact", tn::Union{AbstractITensorNetwork,Vector{ITensor}}; kwargs...
)
function logscalar(alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...)
s = scalar(alg, tn; kwargs...)
s = real(s) < 0 ? complex(s) : s
return log(s)
Expand Down
Loading

0 comments on commit fa0f8a7

Please sign in to comment.