Skip to content

Commit

Permalink
Clean up orthogonalize code
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Nov 7, 2024
1 parent f0e8bf5 commit f5d3aa4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 32 deletions.
31 changes: 8 additions & 23 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
using LinearAlgebra: LinearAlgebra, factorize
using MacroTools: @capture
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
using NamedGraphs.GraphsExtensions:
, directed_graph, incident_edges, rename_vertices, vertextype
using NDTensors: NDTensors, dim
Expand Down Expand Up @@ -617,30 +617,15 @@ end

# Orthogonalize an ITensorNetwork towards a region, treating
# the network as a tree spanned by a spanning tree.
# TODO: Rename `tree_orthogonalize`.
function ITensorMPS.orthogonalize::AbstractITensorNetwork, region::Vector)
spanning_tree_edges = post_order_dfs_edges_region(bfs_tree(ψ, first(region)), region)
return orthogonalize_path(ψ, spanning_tree_edges)
function tree_orthogonalize::AbstractITensorNetwork, region::Vector)
region = collect(vertices(steiner_tree(underlying_graph(ψ), region)))
path = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region))
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
return orthogonalize_path(ψ, path)
end

function ITensorMPS.orthogonalize::AbstractITensorNetwork, region)
return orthogonalize(ψ, [region])
end

function ITensorMPS.orthogonalize::AbstractITensorNetwork, edges::Vector{<:AbstractEdge})
return orthogonalize(ψ, unique(vcat([src(e) for e in edges], [dst(e) for e in edges])))
end

function ITensorMPS.orthogonalize::AbstractITensorNetwork, edges::Vector{<:Pair})
return orthogonalize(ψ, edgetype(ψ).(edges))
end

function ITensorMPS.orthogonalize::AbstractITensorNetwork, edge::AbstractEdge)
return orthogonalize(ψ, [edge])
end

function ITensorMPS.orthogonalize::AbstractITensorNetwork, edge::Pair)
return orthogonalize(ψ, edgetype(ψ)(edge))
function tree_orthogonalize::AbstractITensorNetwork, region)
return tree_orthogonalize(ψ, [region])
end

# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
Expand Down
5 changes: 0 additions & 5 deletions src/edge_sequences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,3 @@ end
@traitfn function edge_sequence(::Algorithm"parallel", g::::(!IsDirected))
return [[e] for e in vcat(edges(g), reverse.(edges(g)))]
end

function post_order_dfs_edges_region(g::AbstractGraph, region)
es = post_order_dfs_edges(g, first(region))
return filter(e -> !((src(e) region) && (dst(e) region)), es)
end
12 changes: 8 additions & 4 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using NamedGraphs.GraphsExtensions:
post_order_dfs_edges,
post_order_dfs_vertices,
a_star
using NamedGraphs: namedgraph_a_star
using NamedGraphs: namedgraph_a_star, steiner_tree
using IsApprox: IsApprox, Approx
using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev
using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize
Expand Down Expand Up @@ -36,15 +36,19 @@ function set_ortho_region(tn::AbstractTTN, new_region)
end

function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
new_path = post_order_dfs_edges_region(ttn, region)
existing_path = post_order_dfs_edges_region(ttn, ortho_region(ttn))
path = setdiff(new_path, existing_path)
st = steiner_tree(ttn, union(region, ortho_region(ttn)))
path = post_order_dfs_edges(st, first(region))
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
if !isempty(path)
ttn = typeof(ttn)(orthogonalize_path(ITensorNetwork(ttn), path; kwargs...))
end
return set_ortho_region(ttn, region)
end

function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...)
return orthogonalize(ttn, [region]; kwargs...)
end

#
# Truncation
#
Expand Down

0 comments on commit f5d3aa4

Please sign in to comment.