From f5d3aa42b7211d453beb7d18e47e7bea22d1b186 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 7 Nov 2024 16:50:40 -0500 Subject: [PATCH] Clean up orthogonalize code --- src/abstractitensornetwork.jl | 31 +++++-------------- src/edge_sequences.jl | 5 --- .../abstracttreetensornetwork.jl | 12 ++++--- 3 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index 09d1eab9..7f771c8e 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -40,7 +40,7 @@ using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rename_vertices, vertextype using NDTensors: NDTensors, dim @@ -617,30 +617,15 @@ end # Orthogonalize an ITensorNetwork towards a region, treating # the network as a tree spanned by a spanning tree. -# TODO: Rename `tree_orthogonalize`. -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, region::Vector) - spanning_tree_edges = post_order_dfs_edges_region(bfs_tree(ψ, first(region)), region) - return orthogonalize_path(ψ, spanning_tree_edges) +function tree_orthogonalize(ψ::AbstractITensorNetwork, region::Vector) + region = collect(vertices(steiner_tree(underlying_graph(ψ), region))) + path = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region)) + path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) + return orthogonalize_path(ψ, path) end -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, region) - return orthogonalize(ψ, [region]) -end - -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}) - return orthogonalize(ψ, unique(vcat([src(e) for e in edges], [dst(e) for e in edges]))) -end - -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edges::Vector{<:Pair}) - return orthogonalize(ψ, edgetype(ψ).(edges)) -end - -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edge::AbstractEdge) - return orthogonalize(ψ, [edge]) -end - -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edge::Pair) - return orthogonalize(ψ, edgetype(ψ)(edge)) +function tree_orthogonalize(ψ::AbstractITensorNetwork, region) + return tree_orthogonalize(ψ, [region]) end # TODO: decide whether to use graph mutating methods when resulting graph is unchanged? diff --git a/src/edge_sequences.jl b/src/edge_sequences.jl index 9b385c68..9dab9fff 100644 --- a/src/edge_sequences.jl +++ b/src/edge_sequences.jl @@ -44,8 +44,3 @@ end @traitfn function edge_sequence(::Algorithm"parallel", g::::(!IsDirected)) return [[e] for e in vcat(edges(g), reverse.(edges(g)))] end - -function post_order_dfs_edges_region(g::AbstractGraph, region) - es = post_order_dfs_edges(g, first(region)) - return filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), es) -end diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index 0c424cd6..5f079299 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -6,7 +6,7 @@ using NamedGraphs.GraphsExtensions: post_order_dfs_edges, post_order_dfs_vertices, a_star -using NamedGraphs: namedgraph_a_star +using NamedGraphs: namedgraph_a_star, steiner_tree using IsApprox: IsApprox, Approx using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize @@ -36,15 +36,19 @@ function set_ortho_region(tn::AbstractTTN, new_region) end function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...) - new_path = post_order_dfs_edges_region(ttn, region) - existing_path = post_order_dfs_edges_region(ttn, ortho_region(ttn)) - path = setdiff(new_path, existing_path) + st = steiner_tree(ttn, union(region, ortho_region(ttn))) + path = post_order_dfs_edges(st, first(region)) + path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) if !isempty(path) ttn = typeof(ttn)(orthogonalize_path(ITensorNetwork(ttn), path; kwargs...)) end return set_ortho_region(ttn, region) end +function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...) + return orthogonalize(ttn, [region]; kwargs...) +end + # # Truncation #