diff --git a/Project.toml b/Project.toml index a52aa610..ff911a1e 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ ITensorNetworksEinExprsExt = "EinExprs" AbstractTrees = "0.4.4" Combinatorics = "1" Compat = "3, 4" -DataGraphs = "0.1.7" +DataGraphs = "0.1.13" DataStructures = "0.18" Dictionaries = "0.4" Distributions = "0.25.86" diff --git a/src/treetensornetworks/opsum_to_ttn.jl b/src/treetensornetworks/opsum_to_ttn.jl index 9b555547..a3923dd8 100644 --- a/src/treetensornetworks/opsum_to_ttn.jl +++ b/src/treetensornetworks/opsum_to_ttn.jl @@ -6,36 +6,25 @@ using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks using ITensors.Ops: Op, OpSum using NamedGraphs: degrees, is_leaf, vertex_path using StaticArrays: MVector - +using NamedGraphs: boundary_edges # convert ITensors.OpSum to TreeTensorNetwork # # Utility methods # -# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root -function find_index_in_tree(site, g::AbstractGraph, root_vertex) - ordering = reverse(post_order_dfs_vertices(g, root_vertex)) - return findfirst(x -> x == site, ordering) -end -function find_index_in_tree(o::Op, g::AbstractGraph, root_vertex) - return find_index_in_tree(ITensors.site(o), g, root_vertex) +function align_edges(edges, reference_edges) + return intersect(Iterators.flatten(zip(edges, reverse.(edges))), reference_edges) end -# determine 'support' of product operator on tree graph -function span(t::Scaled{C,Prod{Op}}, g::AbstractGraph) where {C} - spn = eltype(g)[] - nterms = length(t) - for i in 1:nterms, j in i:nterms - path = vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j])) - spn = union(spn, path) - end - return spn +function align_and_reorder_edges(edges, reference_edges) + return intersect(reference_edges, align_edges(edges, reference_edges)) end -# determine whether an operator string crosses a given graph vertex -function crosses_vertex(t::Scaled{C,Prod{Op}}, g::AbstractGraph, v) where {C} - return v ∈ span(t, g) +function split_at_vertex(g::AbstractGraph, v) + g = copy(g) + rem_vertex!(g, v) + return Set.(connected_components(g)) end # @@ -45,7 +34,7 @@ end """ ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex, kwargs...) -Construct a dense TreeTensorNetwork from a symbolic OpSum representation of a +Construct a TreeTensorNetwork from a symbolic OpSum representation of a Hamiltonian, compressing shared interaction channels. """ function ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex; kwargs...) @@ -71,9 +60,9 @@ function ttn_svd( thishasqns = any(v -> hasqns(sites[v]), vertices(sites)) # traverse tree outwards from root vertex - vs = reverse(post_order_dfs_vertices(sites, root_vertex)) # store vertices in fixed ordering relative to root + vs = _default_vertex_ordering(sites, root_vertex) # ToDo: Add check in ttn_svd that the ordering matches that of find_index_in_tree, which is used in sorteachterm #fermion-sign! - es = reverse(reverse.(post_order_dfs_edges(sites, root_vertex))) # store edges in fixed ordering relative to root + es = _default_edge_ordering(sites, root_vertex) # store edges in fixed ordering relative to root # some things to keep track of degrees = Dict(v => degree(sites, v) for v in vs) # rank of every TTN tensor in network Vs = Dict(e => Dict{QN,Matrix{coefficient_type}}() for e in es) # link isometries for SVD compression of TTN @@ -105,6 +94,8 @@ function ttn_svd( for v in vs is_internal[v] = isempty(sites[v]) if isempty(sites[v]) + # FIXME: This logic only works for trivial flux, breaks for nonzero flux + # ToDo: add assert or fix and add test! sites[v] = [Index(Hflux => 1)] end end @@ -118,35 +109,65 @@ function ttn_svd( # build compressed finite state machine representation for v in vs # for every vertex, find all edges that contain this vertex - edges = filter(e -> dst(e) == v || src(e) == v, es) + edges = align_and_reorder_edges(incident_edges(sites, v), es) + # use the corresponding ordering as index order for tensor elements at this site dim_in = findfirst(e -> dst(e) == v, edges) edge_in = (isnothing(dim_in) ? [] : edges[dim_in]) dims_out = findall(e -> src(e) == v, edges) edges_out = edges[dims_out] + # for every site w except v, determine the incident edge to v that lies + # in the edge_path(w,v) + subgraphs = split_at_vertex(sites, v) + _boundary_edges = align_edges( + [only(boundary_edges(underlying_graph(sites), subgraph)) for subgraph in subgraphs], + edges, + ) + which_incident_edge = Dict( + Iterators.flatten([ + subgraphs[i] .=> ((_boundary_edges[i]),) for i in eachindex(subgraphs) + ]), + ) + # sanity check, leaves only have single incoming or outgoing edge @assert !isempty(dims_out) || !isnothing(dim_in) (isempty(dims_out) || isnothing(dim_in)) && @assert is_leaf(sites, v) for term in os # loop over OpSum and pick out terms that act on current vertex - crosses_vertex(term, sites, v) || continue + + factors = ITensors.terms(term) + if v in ITensors.site.(factors) + crosses_vertex = true + else + crosses_vertex = + !isone( + length(Set([which_incident_edge[site] for site in ITensors.site.(factors)])) + ) + end + #if term doesn't cross vertex, skip it + crosses_vertex || continue + + # filter out factor that acts on current vertex + onsite = filter(t -> (ITensors.site(t) == v), factors) + not_onsite_factors = setdiff(factors, onsite) # filter out factors that come in from the direction of the incoming edge incoming = filter( - t -> edge_in ∈ edge_path(sites, ITensors.site(t), v), ITensors.terms(term) + t -> which_incident_edge[ITensors.site(t)] == edge_in, not_onsite_factors ) + # also store all non-incoming factors in standard order, used for channel merging not_incoming = filter( - t -> edge_in ∉ edge_path(sites, ITensors.site(t), v), ITensors.terms(term) + t -> (ITensors.site(t) == v) || which_incident_edge[ITensors.site(t)] != edge_in, + factors, ) - # filter out factor that acts on current vertex - onsite = filter(t -> (ITensors.site(t) == v), ITensors.terms(term)) + # for every outgoing edge, filter out factors that go out along that edge outgoing = Dict( - e => filter(t -> e ∈ edge_path(sites, v, ITensors.site(t)), ITensors.terms(term)) - for e in edges_out + e => filter(t -> which_incident_edge[ITensors.site(t)] == e, not_onsite_factors) for + e in edges_out ) # compute QNs @@ -246,7 +267,8 @@ function ttn_svd( for v in vs # redo the whole thing like before - edges = filter(e -> dst(e) == v || src(e) == v, es) + # ToDo: use neighborhood instead of going through all edges, see above + edges = align_and_reorder_edges(incident_edges(sites, v), es) dim_in = findfirst(e -> dst(e) == v, edges) dims_out = findall(e -> src(e) == v, edges) # slice isometries at this vertex @@ -340,9 +362,10 @@ function ttn_svd( if is_internal[v] H[v] += iT else - if hasqns(iT) - @assert flux(iT * Op) == Hflux - end + #ToDo: Remove this assert since it seems to be costly + #if hasqns(iT) + # @assert flux(iT * Op) == Hflux + #end H[v] += (iT * Op) end end @@ -409,12 +432,24 @@ function computeSiteProd(sites::IndsNetwork{V,<:Index}, ops::Prod{Op})::ITensor return T end +function _default_vertex_ordering(g::AbstractGraph, root_vertex) + return reverse(post_order_dfs_vertices(g, root_vertex)) +end + +function _default_edge_ordering(g::AbstractGraph, root_vertex) + return reverse(reverse.(post_order_dfs_edges(g, root_vertex))) +end + # This is almost an exact copy of ITensors/src/opsum_to_mpo_generic:sorteachterm except for the site ordering being # given via find_index_in_tree # changed `isless_site` to use tree vertex ordering relative to root function sorteachterm(os::OpSum, sites::IndsNetwork{V,<:Index}, root_vertex::V) where {V} os = copy(os) - findpos(op::Op) = find_index_in_tree(op, sites, root_vertex) + + # linear ordering of vertices in tree graph relative to chosen root, chosen outward from root + ordering = _default_vertex_ordering(sites, root_vertex) + site_positions = Dict(zip(ordering, 1:length(ordering))) + findpos(op::Op) = site_positions[ITensors.site(op)] isless_site(o1::Op, o2::Op) = findpos(o1) < findpos(o2) N = nv(sites) for n in eachindex(os)