Skip to content

Commit

Permalink
Fix ITensor to TTN conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 5, 2024
1 parent 99f61de commit 33f15fb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 22 deletions.
6 changes: 0 additions & 6 deletions src/solvers/alternating_update/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,13 @@ function alternating_update(
end

function alternating_update(operator::AbstractTTN, init_state::AbstractTTN; kwargs...)
# Permute the indices to have a better memory layout
# and minimize permutations
operator = ITensors.permute(operator, (linkind, siteinds, linkind))
projected_operator = ProjTTN(operator)
return alternating_update(projected_operator, init_state; kwargs...)
end

function alternating_update(
operator::AbstractTTN, init_state::AbstractTTN, sweep_plans; kwargs...
)
# Permute the indices to have a better memory layout
# and minimize permutations
operator = ITensors.permute(operator, (linkind, siteinds, linkind))
projected_operator = ProjTTN(operator)
return alternating_update(projected_operator, init_state, sweep_plans; kwargs...)
end
Expand Down
19 changes: 3 additions & 16 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ end
#

function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...)
ortho_center ortho_region(tn) && return tn
if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn))
return tn
end
# TODO: Rewrite this in a more general way.
if isone(length(ortho_region(tn)))
edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center)
Expand Down Expand Up @@ -279,21 +281,6 @@ function ITensors.add(tn1::AbstractTTN, tn2::AbstractTTN; kwargs...)
return +(tn1, tn2; kwargs...)
end

# TODO: Delete this
function ITensors.permute(
tn::AbstractTTN, ::Tuple{typeof(linkind),typeof(siteinds),typeof(linkind)}
)
tn = copy(tn)
for v in vertices(tn)
ls = [only(linkinds(tn, n => v)) for n in neighbors(tn, v)] # TODO: won't work for multiple indices per link...
ss = TupleTools.sort(Tuple(siteinds(tn, v)); by=plev)
setindex_preserve_graph!(
tn, permute(tn[v], filter(!isnothing, (ls[1], ss..., ls[2:end]...))), v
)
end
return set_ortho_region(tn, ortho_region(tn))
end

function Base.isapprox(
x::AbstractTTN,
y::AbstractTTN;
Expand Down

0 comments on commit 33f15fb

Please sign in to comment.