From fcf5b989a7089e8007805ddc4e83713be9397e3f Mon Sep 17 00:00:00 2001 From: Joey Date: Fri, 8 Nov 2024 08:48:51 -0500 Subject: [PATCH] Improve orthogonalize method efficiency --- src/abstractitensornetwork.jl | 6 ++++-- src/solvers/extract/extract.jl | 3 ++- src/treetensornetworks/abstracttreetensornetwork.jl | 5 +++++ test/test_itensornetwork.jl | 5 +++-- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index 7f771c8e..f7c148d0 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -7,6 +7,7 @@ using Graphs: add_edge!, add_vertex!, bfs_tree, + center, dst, edges, edgetype, @@ -618,8 +619,9 @@ end # Orthogonalize an ITensorNetwork towards a region, treating # the network as a tree spanned by a spanning tree. function tree_orthogonalize(ψ::AbstractITensorNetwork, region::Vector) - region = collect(vertices(steiner_tree(underlying_graph(ψ), region))) - path = post_order_dfs_edges(bfs_tree(ψ, first(region)), first(region)) + region_centre = + length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region) + path = post_order_dfs_edges(bfs_tree(ψ, region_centre), region_centre) path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) return orthogonalize_path(ψ, path) end diff --git a/src/solvers/extract/extract.jl b/src/solvers/extract/extract.jl index 431394ea..1013d1bd 100644 --- a/src/solvers/extract/extract.jl +++ b/src/solvers/extract/extract.jl @@ -7,12 +7,13 @@ # insert_local_tensors takes that tensor and factorizes it back # apart and puts it back into the network. # + function default_extracter(state, projected_operator, region; internal_kwargs) if isa(region, AbstractEdge) + # TODO: add functionality for orthogonalizing onto a bond so that can be called instead vsrc, vdst = src(region), dst(region) state = orthogonalize(state, vsrc) left_inds = uniqueinds(state[vsrc], state[vdst]) - #ToDo: replace with call to factorize U, S, V = svd( state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region) ) diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index 5f079299..d3f608b0 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -36,6 +36,11 @@ function set_ortho_region(tn::AbstractTTN, new_region) end function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...) + return orthogonalize_ttn(ttn, region; kwargs...) +end + +function orthogonalize_ttn(ttn::AbstractTTN, region::Vector; kwargs...) + issetequal(region, ortho_region(ttn)) && return ttn st = steiner_tree(ttn, union(region, ortho_region(ttn))) path = post_order_dfs_edges(st, first(region)) path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) diff --git a/test/test_itensornetwork.jl b/test/test_itensornetwork.jl index ba3caa01..53e2928f 100644 --- a/test/test_itensornetwork.jl +++ b/test/test_itensornetwork.jl @@ -51,6 +51,7 @@ using ITensorNetworks: orthogonalize, random_tensornetwork, siteinds, + tree_orthogonalize, ttn using LinearAlgebra: factorize using NamedGraphs: NamedEdge @@ -287,13 +288,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test nv(tn_ortho) == 5 @test nv(tn) == 4 @test Z ≈ Z̃ - tn_ortho = orthogonalize(tn, 4 => 3) + tn_ortho = tree_orthogonalize(tn, [3, 4]) Z̃ = norm_sqr(tn_ortho) @test nv(tn_ortho) == 4 @test nv(tn) == 4 @test Z ≈ Z̃ - tn_ortho = orthogonalize(tn, 1) + tn_ortho = tree_orthogonalize(tn, 1) Z̃ = norm_sqr(tn_ortho) @test Z ≈ Z̃ Z̃ = inner(tn_ortho, tn)