Skip to content

Commit

Permalink
Optimize ttn_svd (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss authored Apr 16, 2024
1 parent ce7b3e4 commit e920800
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ITensorNetworksEinExprsExt = "EinExprs"
AbstractTrees = "0.4.4"
Combinatorics = "1"
Compat = "3, 4"
DataGraphs = "0.1.7"
DataGraphs = "0.1.13"
DataStructures = "0.18"
Dictionaries = "0.4"
Distributions = "0.25.86"
Expand Down
107 changes: 71 additions & 36 deletions src/treetensornetworks/opsum_to_ttn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,25 @@ using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
using ITensors.Ops: Op, OpSum
using NamedGraphs: degrees, is_leaf, vertex_path
using StaticArrays: MVector

using NamedGraphs: boundary_edges
# convert ITensors.OpSum to TreeTensorNetwork

#
# Utility methods
#

# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
function find_index_in_tree(site, g::AbstractGraph, root_vertex)
ordering = reverse(post_order_dfs_vertices(g, root_vertex))
return findfirst(x -> x == site, ordering)
end
function find_index_in_tree(o::Op, g::AbstractGraph, root_vertex)
return find_index_in_tree(ITensors.site(o), g, root_vertex)
function align_edges(edges, reference_edges)
return intersect(Iterators.flatten(zip(edges, reverse.(edges))), reference_edges)
end

# determine 'support' of product operator on tree graph
function span(t::Scaled{C,Prod{Op}}, g::AbstractGraph) where {C}
spn = eltype(g)[]
nterms = length(t)
for i in 1:nterms, j in i:nterms
path = vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j]))
spn = union(spn, path)
end
return spn
function align_and_reorder_edges(edges, reference_edges)
return intersect(reference_edges, align_edges(edges, reference_edges))
end

# determine whether an operator string crosses a given graph vertex
function crosses_vertex(t::Scaled{C,Prod{Op}}, g::AbstractGraph, v) where {C}
return v span(t, g)
function split_at_vertex(g::AbstractGraph, v)
g = copy(g)
rem_vertex!(g, v)
return Set.(connected_components(g))
end

#
Expand All @@ -45,7 +34,7 @@ end
"""
ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex, kwargs...)
Construct a dense TreeTensorNetwork from a symbolic OpSum representation of a
Construct a TreeTensorNetwork from a symbolic OpSum representation of a
Hamiltonian, compressing shared interaction channels.
"""
function ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex; kwargs...)
Expand All @@ -71,9 +60,9 @@ function ttn_svd(
thishasqns = any(v -> hasqns(sites[v]), vertices(sites))

# traverse tree outwards from root vertex
vs = reverse(post_order_dfs_vertices(sites, root_vertex)) # store vertices in fixed ordering relative to root
vs = _default_vertex_ordering(sites, root_vertex)
# ToDo: Add check in ttn_svd that the ordering matches that of find_index_in_tree, which is used in sorteachterm #fermion-sign!
es = reverse(reverse.(post_order_dfs_edges(sites, root_vertex))) # store edges in fixed ordering relative to root
es = _default_edge_ordering(sites, root_vertex) # store edges in fixed ordering relative to root
# some things to keep track of
degrees = Dict(v => degree(sites, v) for v in vs) # rank of every TTN tensor in network
Vs = Dict(e => Dict{QN,Matrix{coefficient_type}}() for e in es) # link isometries for SVD compression of TTN
Expand Down Expand Up @@ -105,6 +94,8 @@ function ttn_svd(
for v in vs
is_internal[v] = isempty(sites[v])
if isempty(sites[v])
# FIXME: This logic only works for trivial flux, breaks for nonzero flux
# ToDo: add assert or fix and add test!
sites[v] = [Index(Hflux => 1)]
end
end
Expand All @@ -118,35 +109,65 @@ function ttn_svd(
# build compressed finite state machine representation
for v in vs
# for every vertex, find all edges that contain this vertex
edges = filter(e -> dst(e) == v || src(e) == v, es)
edges = align_and_reorder_edges(incident_edges(sites, v), es)

# use the corresponding ordering as index order for tensor elements at this site
dim_in = findfirst(e -> dst(e) == v, edges)
edge_in = (isnothing(dim_in) ? [] : edges[dim_in])
dims_out = findall(e -> src(e) == v, edges)
edges_out = edges[dims_out]

# for every site w except v, determine the incident edge to v that lies
# in the edge_path(w,v)
subgraphs = split_at_vertex(sites, v)
_boundary_edges = align_edges(
[only(boundary_edges(underlying_graph(sites), subgraph)) for subgraph in subgraphs],
edges,
)
which_incident_edge = Dict(
Iterators.flatten([
subgraphs[i] .=> ((_boundary_edges[i]),) for i in eachindex(subgraphs)
]),
)

# sanity check, leaves only have single incoming or outgoing edge
@assert !isempty(dims_out) || !isnothing(dim_in)
(isempty(dims_out) || isnothing(dim_in)) && @assert is_leaf(sites, v)

for term in os
# loop over OpSum and pick out terms that act on current vertex
crosses_vertex(term, sites, v) || continue

factors = ITensors.terms(term)
if v in ITensors.site.(factors)
crosses_vertex = true
else
crosses_vertex =
!isone(
length(Set([which_incident_edge[site] for site in ITensors.site.(factors)]))
)
end
#if term doesn't cross vertex, skip it
crosses_vertex || continue

# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), factors)
not_onsite_factors = setdiff(factors, onsite)

# filter out factors that come in from the direction of the incoming edge
incoming = filter(
t -> edge_in edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
t -> which_incident_edge[ITensors.site(t)] == edge_in, not_onsite_factors
)

# also store all non-incoming factors in standard order, used for channel merging
not_incoming = filter(
t -> edge_in edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
t -> (ITensors.site(t) == v) || which_incident_edge[ITensors.site(t)] != edge_in,
factors,
)
# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), ITensors.terms(term))

# for every outgoing edge, filter out factors that go out along that edge
outgoing = Dict(
e => filter(t -> e edge_path(sites, v, ITensors.site(t)), ITensors.terms(term))
for e in edges_out
e => filter(t -> which_incident_edge[ITensors.site(t)] == e, not_onsite_factors) for
e in edges_out
)

# compute QNs
Expand Down Expand Up @@ -246,7 +267,8 @@ function ttn_svd(

for v in vs
# redo the whole thing like before
edges = filter(e -> dst(e) == v || src(e) == v, es)
# ToDo: use neighborhood instead of going through all edges, see above
edges = align_and_reorder_edges(incident_edges(sites, v), es)
dim_in = findfirst(e -> dst(e) == v, edges)
dims_out = findall(e -> src(e) == v, edges)
# slice isometries at this vertex
Expand Down Expand Up @@ -340,9 +362,10 @@ function ttn_svd(
if is_internal[v]
H[v] += iT
else
if hasqns(iT)
@assert flux(iT * Op) == Hflux
end
#ToDo: Remove this assert since it seems to be costly
#if hasqns(iT)
# @assert flux(iT * Op) == Hflux
#end
H[v] += (iT * Op)
end
end
Expand Down Expand Up @@ -409,12 +432,24 @@ function computeSiteProd(sites::IndsNetwork{V,<:Index}, ops::Prod{Op})::ITensor
return T
end

function _default_vertex_ordering(g::AbstractGraph, root_vertex)
return reverse(post_order_dfs_vertices(g, root_vertex))
end

function _default_edge_ordering(g::AbstractGraph, root_vertex)
return reverse(reverse.(post_order_dfs_edges(g, root_vertex)))
end

# This is almost an exact copy of ITensors/src/opsum_to_mpo_generic:sorteachterm except for the site ordering being
# given via find_index_in_tree
# changed `isless_site` to use tree vertex ordering relative to root
function sorteachterm(os::OpSum, sites::IndsNetwork{V,<:Index}, root_vertex::V) where {V}
os = copy(os)
findpos(op::Op) = find_index_in_tree(op, sites, root_vertex)

# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
ordering = _default_vertex_ordering(sites, root_vertex)
site_positions = Dict(zip(ordering, 1:length(ordering)))
findpos(op::Op) = site_positions[ITensors.site(op)]
isless_site(o1::Op, o2::Op) = findpos(o1) < findpos(o2)
N = nv(sites)
for n in eachindex(os)
Expand Down

0 comments on commit e920800

Please sign in to comment.