diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index fdc479c7..a0efd7ff 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -42,7 +42,7 @@ include("solvers/local_solvers/dmrg_x.jl") include("solvers/local_solvers/contract.jl") include("solvers/local_solvers/linsolve.jl") include("treetensornetworks/abstracttreetensornetwork.jl") -include("treetensornetworks/ttn.jl") +include("treetensornetworks/treetensornetwork.jl") include("treetensornetworks/opsum_to_ttn.jl") include("treetensornetworks/projttns/abstractprojttn.jl") include("treetensornetworks/projttns/projttn.jl") diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index 4961cfe8..2dc983ee 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -617,7 +617,11 @@ function neighbor_vertices(ψ::AbstractITensorNetwork, T::ITensor) end function linkinds_combiners(tn::AbstractITensorNetwork; edges=edges(tn)) - combiners = DataGraph(directed_graph(underlying_graph(tn)); vertex_data_eltype=ITensor, edge_data_eltype=ITensor) + combiners = DataGraph( + directed_graph(underlying_graph(tn)); + vertex_data_eltype=ITensor, + edge_data_eltype=ITensor, + ) for e in edges C = combiner(linkinds(tn, e); tags=edge_tag(e)) combiners[e] = C diff --git a/src/edge_sequences.jl b/src/edge_sequences.jl index 65e8ac8f..e9086ce2 100644 --- a/src/edge_sequences.jl +++ b/src/edge_sequences.jl @@ -1,7 +1,7 @@ using Graphs: IsDirected, connected_components, edges, edgetype using ITensors.NDTensors: Algorithm, @Algorithm_str using NamedGraphs: NamedGraphs -using NamedGraphs.GraphsExtensions: undirected_graph +using NamedGraphs.GraphsExtensions: GraphsExtensions, undirected_graph using NamedGraphs.PartitionedGraphs: PartitionEdge, PartitionedGraph, partitioned_graph using SimpleTraits: SimpleTraits, @traitfn, Not using SimpleTraits @@ -26,7 +26,9 @@ end end @traitfn function edge_sequence( - ::Algorithm"forest_cover", g::::(!IsDirected); root_vertex=NamedGraphs.default_root_vertex + ::Algorithm"forest_cover", + g::::(!IsDirected); + root_vertex=GraphsExtensions.default_root_vertex, ) forests = NamedGraphs.forest_cover(g) edges = edgetype(g)[] diff --git a/src/solvers/alternating_update/alternating_update.jl b/src/solvers/alternating_update/alternating_update.jl index 8fae9532..3b0f6b77 100644 --- a/src/solvers/alternating_update/alternating_update.jl +++ b/src/solvers/alternating_update/alternating_update.jl @@ -1,5 +1,6 @@ using ITensors: state using ITensors.ITensorMPS: linkind +using NamedGraphs.GraphsExtensions: GraphsExtensions using Observers: Observers function alternating_update( @@ -13,7 +14,7 @@ function alternating_update( sweep_printer=nothing, (sweep_observer!)=nothing, (region_observer!)=nothing, - root_vertex=default_root_vertex(init_state), + root_vertex=GraphsExtensions.default_root_vertex(init_state), extracter_kwargs=(;), extracter=default_extracter(), updater_kwargs=(;), diff --git a/src/solvers/sweep_plans/sweep_plans.jl b/src/solvers/sweep_plans/sweep_plans.jl index 208f9bce..69221995 100644 --- a/src/solvers/sweep_plans/sweep_plans.jl +++ b/src/solvers/sweep_plans/sweep_plans.jl @@ -1,3 +1,6 @@ +using Graphs: AbstractEdge, dst, src +using NamedGraphs.GraphsExtensions: GraphsExtensions + direction(step_number) = isodd(step_number) ? Base.Forward : Base.Reverse function overlap(edge_a::AbstractEdge, edge_b::AbstractEdge) @@ -58,7 +61,7 @@ end function forward_sweep( dir::Base.ForwardOrdering, graph::AbstractGraph; - root_vertex=default_root_vertex(graph), + root_vertex=GraphsExtensions.default_root_vertex(graph), region_kwargs, reverse_kwargs=region_kwargs, reverse_step=false, @@ -141,7 +144,10 @@ function default_sweep_plans( end function default_sweep_plan( - graph::AbstractGraph; root_vertex=default_root_vertex(graph), region_kwargs, nsites::Int + graph::AbstractGraph; + root_vertex=GraphsExtensions.default_root_vertex(graph), + region_kwargs, + nsites::Int, ) return vcat( [ @@ -158,7 +164,7 @@ end function tdvp_sweep_plan( graph::AbstractGraph; - root_vertex=default_root_vertex(graph), + root_vertex=GraphsExtensions.default_root_vertex(graph), region_kwargs, reverse_step=true, order::Int, diff --git a/src/solvers/tdvp.jl b/src/solvers/tdvp.jl index 1b70015e..7a58fe1b 100644 --- a/src/solvers/tdvp.jl +++ b/src/solvers/tdvp.jl @@ -1,3 +1,5 @@ +using NamedGraphs.GraphsExtensions: GraphsExtensions + #ToDo: Cleanup _compute_nsweeps, maybe restrict flexibility to simplify code function _compute_nsweeps(nsweeps::Int, t::Number, time_step::Number) return error("Cannot specify both nsweeps and time_step in tdvp") @@ -101,7 +103,7 @@ function tdvp( sweep_printer=nothing, (sweep_observer!)=nothing, (region_observer!)=nothing, - root_vertex=default_root_vertex(init_state), + root_vertex=GraphsExtensions.default_root_vertex(init_state), reverse_step=true, extracter_kwargs=(;), extracter=default_extracter(), # ToDo: extracter could be inside extracter_kwargs, at the cost of having to extract it in region_update diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index d24c4e95..7b74f0d9 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -1,6 +1,6 @@ using Graphs: has_vertex using NamedGraphs.GraphsExtensions: - edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices + GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices using IsApprox: IsApprox, Approx using ITensors: @Algorithm_str, directsum, hasinds, permute, plev using ITensors.ITensorMPS: linkind, loginner, lognorm, orthogonalize @@ -21,11 +21,6 @@ end ITensorNetwork(tn::AbstractTTN) = error("Not implemented") ortho_region(tn::AbstractTTN) = error("Not implemented") -function default_root_vertex(gs::AbstractGraph...) - # @assert all(is_tree.(gs)) - return first(leaf_vertices(gs[end])) -end - # # Orthogonality center # @@ -64,7 +59,9 @@ end # Truncation # -function Base.truncate(tn::AbstractTTN; root_vertex=default_root_vertex(tn), kwargs...) +function Base.truncate( + tn::AbstractTTN; root_vertex=GraphsExtensions.default_root_vertex(tn), kwargs... +) for e in post_order_dfs_edges(tn, root_vertex) # always orthogonalize towards source first to make truncations controlled tn = orthogonalize(tn, src(e)) @@ -84,7 +81,9 @@ end # # TODO: decide on contraction order: reverse dfs vertices or forward dfs edges? -function NDTensors.contract(tn::AbstractTTN, root_vertex=default_root_vertex(tn); kwargs...) +function NDTensors.contract( + tn::AbstractTTN, root_vertex=GraphsExtensions.default_root_vertex(tn); kwargs... +) tn = copy(tn) # reverse post order vertices traversal_order = reverse(post_order_dfs_vertices(tn, root_vertex)) @@ -98,7 +97,7 @@ function NDTensors.contract(tn::AbstractTTN, root_vertex=default_root_vertex(tn) end function ITensors.inner( - x::AbstractTTN, y::AbstractTTN; root_vertex=default_root_vertex(x, y) + x::AbstractTTN, y::AbstractTTN; root_vertex=GraphsExtensions.default_root_vertex(x) ) xᴴ = sim(dag(x); sites=[]) y = sim(y; sites=[]) @@ -186,7 +185,7 @@ end # TODO: stick with this traversal or find optimal contraction sequence? function ITensorMPS.loginner( - tn1::AbstractTTN, tn2::AbstractTTN; root_vertex=default_root_vertex(tn1, tn2) + tn1::AbstractTTN, tn2::AbstractTTN; root_vertex=GraphsExtensions.default_root_vertex(tn1) ) N = nv(tn1) if nv(tn2) != N @@ -228,14 +227,16 @@ function Base.:+( ::Algorithm"densitymatrix", tns::AbstractTTN...; cutoff=1e-15, - root_vertex=default_root_vertex(tns...), + root_vertex=GraphsExtensions.default_root_vertex(first(tns)), kwargs..., ) return error("Not implemented (yet) for trees.") end function Base.:+( - ::Algorithm"directsum", tns::AbstractTTN...; root_vertex=default_root_vertex(tns...) + ::Algorithm"directsum", + tns::AbstractTTN...; + root_vertex=GraphsExtensions.default_root_vertex(first(tns)), ) @assert all(tn -> nv(first(tns)) == nv(tn), tns) @@ -302,7 +303,10 @@ end # TODO: implement using multi-graph disjoint union function ITensors.inner( - y::AbstractTTN, A::AbstractTTN, x::AbstractTTN; root_vertex=default_root_vertex(x, A, y) + y::AbstractTTN, + A::AbstractTTN, + x::AbstractTTN; + root_vertex=GraphsExtensions.default_root_vertex(x), ) traversal_order = reverse(post_order_dfs_vertices(x, root_vertex)) ydag = sim(dag(y); sites=[]) @@ -320,7 +324,7 @@ function ITensors.inner( y::AbstractTTN, A::AbstractTTN, x::AbstractTTN; - root_vertex=default_root_vertex(B, y, A, x), + root_vertex=GraphsExtensions.default_root_vertex(B), ) N = nv(B) if nv(y) != N || nv(x) != N || nv(A) != N @@ -349,8 +353,8 @@ function ITensorMPS.expect( operator::String, state::AbstractTTN; vertices=vertices(state), - # ToDo: verify that this is a sane default - root_vertex=default_root_vertex(siteinds(state)), + # TODO: verify that this is a sane default + root_vertex=GraphsExtensions.default_root_vertex(state), ) # TODO: Optimize this with proper caching. state /= norm(state) diff --git a/src/treetensornetworks/opsum_to_ttn.jl b/src/treetensornetworks/opsum_to_ttn.jl index b4bda869..e257c953 100644 --- a/src/treetensornetworks/opsum_to_ttn.jl +++ b/src/treetensornetworks/opsum_to_ttn.jl @@ -4,7 +4,8 @@ using ITensors.ITensorMPS: ITensorMPS, cutoff, linkdims, truncate! using ITensors.LazyApply: Prod, Sum, coefficient using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks using ITensors.Ops: Op, OpSum -using NamedGraphs.GraphsExtensions: boundary_edges, degrees, is_leaf, vertex_path +using NamedGraphs.GraphsExtensions: + GraphsExtensions, boundary_edges, degrees, is_leaf, vertex_path using StaticArrays: MVector # convert ITensors.OpSum to TreeTensorNetwork @@ -522,29 +523,17 @@ Convert an OpSum object `os` to a TreeTensorNetwork, with indices given by `site function ttn( os::OpSum, sites::IndsNetwork; - root_vertex=default_root_vertex(sites), - splitblocks=false, - algorithm="svd", + root_vertex=GraphsExtensions.default_root_vertex(sites), kwargs..., -)::TTN +) length(ITensors.terms(os)) == 0 && error("OpSum has no terms") is_tree(sites) || error("Site index graph must be a tree.") is_leaf(sites, root_vertex) || error("Tree root must be a leaf vertex.") os = deepcopy(os) os = sorteachterm(os, sites, root_vertex) - os = ITensorMPS.sortmergeterms(os) # not exported - if algorithm == "svd" - T = ttn_svd(os, sites, root_vertex; kwargs...) - else - error("Currently only SVD is supported as TTN constructor backend.") - end - - if splitblocks - error("splitblocks not yet implemented for AbstractTreeTensorNetwork.") - T = ITensors.splitblocks(linkinds, T) # TODO: make this work - end - return T + os = ITensorMPS.sortmergeterms(os) + return ttn_svd(os, sites, root_vertex; kwargs...) end function mpo(os::OpSum, external_inds::Vector; kwargs...) diff --git a/src/treetensornetworks/ttn.jl b/src/treetensornetworks/treetensornetwork.jl similarity index 94% rename from src/treetensornetworks/ttn.jl rename to src/treetensornetworks/treetensornetwork.jl index a8b3a301..4bca217d 100644 --- a/src/treetensornetworks/ttn.jl +++ b/src/treetensornetworks/treetensornetwork.jl @@ -1,7 +1,7 @@ using Graphs: path_graph using ITensors: ITensor using LinearAlgebra: factorize, normalize -using NamedGraphs: vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, vertextype """ TreeTensorNetwork{V} <: AbstractTreeTensorNetwork{V} @@ -76,7 +76,12 @@ function mps(f, is::Vector{<:Index}; kwargs...) end # Construct from dense ITensor, using IndsNetwork of site indices. -function ttn(a::ITensor, is::IndsNetwork; ortho_region=[default_root_vertex(is)], kwargs...) +function ttn( + a::ITensor, + is::IndsNetwork; + ortho_region=[GraphsExtensions.default_root_vertex(is)], + kwargs..., +) for v in vertices(is) @assert hasinds(a, is[v]) end