Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ttn_svd #157

Merged
merged 16 commits into from
Apr 16, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ITensorNetworksEinExprsExt = "EinExprs"
AbstractTrees = "0.4.4"
Combinatorics = "1"
Compat = "3, 4"
DataGraphs = "0.1.7"
DataGraphs = "0.1.13"
DataStructures = "0.18"
Dictionaries = "0.4"
Distributions = "0.25.86"
Expand Down
107 changes: 71 additions & 36 deletions src/treetensornetworks/opsum_to_ttn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,25 @@ using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
using ITensors.Ops: Op, OpSum
using NamedGraphs: degrees, is_leaf, vertex_path
using StaticArrays: MVector

using NamedGraphs: boundary_edges
# convert ITensors.OpSum to TreeTensorNetwork

#
# Utility methods
#

# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
function find_index_in_tree(site, g::AbstractGraph, root_vertex)
ordering = reverse(post_order_dfs_vertices(g, root_vertex))
return findfirst(x -> x == site, ordering)
end
function find_index_in_tree(o::Op, g::AbstractGraph, root_vertex)
return find_index_in_tree(ITensors.site(o), g, root_vertex)
function align_edges(edges, reference_edges)
return intersect(Iterators.flatten(zip(edges, reverse.(edges))), reference_edges)
end

# determine 'support' of product operator on tree graph
function span(t::Scaled{C,Prod{Op}}, g::AbstractGraph) where {C}
spn = eltype(g)[]
nterms = length(t)
for i in 1:nterms, j in i:nterms
path = vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j]))
spn = union(spn, path)
end
return spn
function align_and_reorder_edges(edges, reference_edges)
return intersect(reference_edges, align_edges(edges, reference_edges))
end

# determine whether an operator string crosses a given graph vertex
function crosses_vertex(t::Scaled{C,Prod{Op}}, g::AbstractGraph, v) where {C}
return v ∈ span(t, g)
function split_at_vertex(g::AbstractGraph, v)
g = copy(g)
rem_vertex!(g, v)
return Set.(connected_components(g))
end

#
Expand All @@ -45,7 +34,7 @@ end
"""
ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex, kwargs...)

Construct a dense TreeTensorNetwork from a symbolic OpSum representation of a
Construct a TreeTensorNetwork from a symbolic OpSum representation of a
Hamiltonian, compressing shared interaction channels.
"""
function ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex; kwargs...)
Expand All @@ -71,9 +60,9 @@ function ttn_svd(
thishasqns = any(v -> hasqns(sites[v]), vertices(sites))

# traverse tree outwards from root vertex
vs = reverse(post_order_dfs_vertices(sites, root_vertex)) # store vertices in fixed ordering relative to root
vs = _default_vertex_ordering(sites, root_vertex)
# ToDo: Add check in ttn_svd that the ordering matches that of find_index_in_tree, which is used in sorteachterm #fermion-sign!
es = reverse(reverse.(post_order_dfs_edges(sites, root_vertex))) # store edges in fixed ordering relative to root
es = _default_edge_ordering(sites, root_vertex) # store edges in fixed ordering relative to root
# some things to keep track of
degrees = Dict(v => degree(sites, v) for v in vs) # rank of every TTN tensor in network
Vs = Dict(e => Dict{QN,Matrix{coefficient_type}}() for e in es) # link isometries for SVD compression of TTN
Expand Down Expand Up @@ -105,6 +94,8 @@ function ttn_svd(
for v in vs
is_internal[v] = isempty(sites[v])
if isempty(sites[v])
# FIXME: This logic only works for trivial flux, breaks for nonzero flux
# ToDo: add assert or fix and add test!
sites[v] = [Index(Hflux => 1)]
end
end
Expand All @@ -118,35 +109,65 @@ function ttn_svd(
# build compressed finite state machine representation
for v in vs
# for every vertex, find all edges that contain this vertex
edges = filter(e -> dst(e) == v || src(e) == v, es)
edges = align_and_reorder_edges(incident_edges(sites, v), es)

# use the corresponding ordering as index order for tensor elements at this site
dim_in = findfirst(e -> dst(e) == v, edges)
edge_in = (isnothing(dim_in) ? [] : edges[dim_in])
dims_out = findall(e -> src(e) == v, edges)
edges_out = edges[dims_out]

# for every site w except v, determine the incident edge to v that lies
# in the edge_path(w,v)
subgraphs = split_at_vertex(sites, v)
_boundary_edges = align_edges(
[only(boundary_edges(underlying_graph(sites), subgraph)) for subgraph in subgraphs],
edges,
)
which_incident_edge = Dict(
Iterators.flatten([
subgraphs[i] .=> ((_boundary_edges[i]),) for i in eachindex(subgraphs)
]),
)

# sanity check, leaves only have single incoming or outgoing edge
@assert !isempty(dims_out) || !isnothing(dim_in)
(isempty(dims_out) || isnothing(dim_in)) && @assert is_leaf(sites, v)

for term in os
# loop over OpSum and pick out terms that act on current vertex
crosses_vertex(term, sites, v) || continue

factors = ITensors.terms(term)
if v in ITensors.site.(factors)
crosses_vertex = true
else
crosses_vertex =
!isone(
length(Set([which_incident_edge[site] for site in ITensors.site.(factors)]))
)
end
#if term doesn't cross vertex, skip it
crosses_vertex || continue

# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), factors)
not_onsite_factors = setdiff(factors, onsite)

# filter out factors that come in from the direction of the incoming edge
incoming = filter(
t -> edge_in ∈ edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
t -> which_incident_edge[ITensors.site(t)] == edge_in, not_onsite_factors
)

# also store all non-incoming factors in standard order, used for channel merging
not_incoming = filter(
t -> edge_in ∉ edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
t -> (ITensors.site(t) == v) || which_incident_edge[ITensors.site(t)] != edge_in,
factors,
)
# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), ITensors.terms(term))

# for every outgoing edge, filter out factors that go out along that edge
outgoing = Dict(
e => filter(t -> e ∈ edge_path(sites, v, ITensors.site(t)), ITensors.terms(term))
for e in edges_out
e => filter(t -> which_incident_edge[ITensors.site(t)] == e, not_onsite_factors) for
e in edges_out
)

# compute QNs
Expand Down Expand Up @@ -246,7 +267,8 @@ function ttn_svd(

for v in vs
# redo the whole thing like before
edges = filter(e -> dst(e) == v || src(e) == v, es)
# ToDo: use neighborhood instead of going through all edges, see above
edges = align_and_reorder_edges(incident_edges(sites, v), es)
dim_in = findfirst(e -> dst(e) == v, edges)
dims_out = findall(e -> src(e) == v, edges)
# slice isometries at this vertex
Expand Down Expand Up @@ -340,9 +362,10 @@ function ttn_svd(
if is_internal[v]
H[v] += iT
else
if hasqns(iT)
@assert flux(iT * Op) == Hflux
end
#ToDo: Remove this assert since it seems to be costly
#if hasqns(iT)
# @assert flux(iT * Op) == Hflux
#end
H[v] += (iT * Op)
end
end
Expand Down Expand Up @@ -409,12 +432,24 @@ function computeSiteProd(sites::IndsNetwork{V,<:Index}, ops::Prod{Op})::ITensor
return T
end

function _default_vertex_ordering(g::AbstractGraph, root_vertex)
return reverse(post_order_dfs_vertices(g, root_vertex))
end

function _default_edge_ordering(g::AbstractGraph, root_vertex)
return reverse(reverse.(post_order_dfs_edges(g, root_vertex)))
end

# This is almost an exact copy of ITensors/src/opsum_to_mpo_generic:sorteachterm except for the site ordering being
# given via find_index_in_tree
# changed `isless_site` to use tree vertex ordering relative to root
function sorteachterm(os::OpSum, sites::IndsNetwork{V,<:Index}, root_vertex::V) where {V}
os = copy(os)
findpos(op::Op) = find_index_in_tree(op, sites, root_vertex)

# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
ordering = _default_vertex_ordering(sites, root_vertex)
site_positions = Dict(zip(ordering, 1:length(ordering)))
findpos(op::Op) = site_positions[ITensors.site(op)]
isless_site(o1::Op, o2::Op) = findpos(o1) < findpos(o2)
N = nv(sites)
for n in eachindex(os)
Expand Down
Loading