Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ttn_svd #157

Merged
merged 16 commits into from
Apr 16, 2024
115 changes: 86 additions & 29 deletions src/treetensornetworks/opsum_to_ttn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,14 @@ using StaticArrays: MVector
# Utility methods
#

# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
function find_index_in_tree(site, g::AbstractGraph, root_vertex)
ordering = reverse(post_order_dfs_vertices(g, root_vertex))
return findfirst(x -> x == site, ordering)
end
function find_index_in_tree(o::Op, g::AbstractGraph, root_vertex)
return find_index_in_tree(ITensors.site(o), g, root_vertex)
end

# determine 'support' of product operator on tree graph
function span(t::Scaled{C,Prod{Op}}, g::AbstractGraph) where {C}
spn = eltype(g)[]
spn = Set{eltype(g)}()
nterms = length(t)
for i in 1:nterms, j in i:nterms
path = vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j]))
spn = union(spn, path)
nterms == 1 && return Set([ITensors.site(t[1])])
for i in 1:nterms, j in (i + 1):nterms
path = Set(vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j])))
spn = union!(spn, path)
end
return spn
end
Expand All @@ -38,6 +30,41 @@ function crosses_vertex(t::Scaled{C,Prod{Op}}, g::AbstractGraph, v) where {C}
return v ∈ span(t, g)
end

function align_edges(edges, reference_edges)
return intersect(reference_edges, Iterators.flatten((edges, reverse.(edges))))
end

# return a dict from vertices `w` of `g`, except for `v`, to the incident edge of `v`
# which lies in edge_path(g,w,v)
function vertices_to_incident_edges_dict(g::AbstractGraph, v, incident_edges)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
#split graph into subtrees by removing vertex v
_g = copy(underlying_graph(g))
rem_vertex!(_g, v)
subgraphs = Set.(connected_components(_g))

#for each incident edge, store the vertex that's not `v`
vs = vertextype(g)[]
for e in incident_edges
push!(vs, only(setdiff([src(e), dst(e)], [v])))
end

#return a Dictionary from vertices to incident_edges
_vs = vertextype(g)[]
_es = edgetype(g)[]
for (e, v) in zip(incident_edges, vs)
for i in eachindex(subgraphs)
if v in subgraphs[i]
append!(_vs, subgraphs[i])
append!(_es, fill(e, length(subgraphs[i])))
deleteat!(subgraphs, i)
break
end
end
end
@assert isempty(subgraphs)
return Dict(zip(_vs, _es))
end

#
# Tree adaptations of functionalities in ITensors.jl/src/physics/autompo/opsum_to_mpo.jl
#
Expand Down Expand Up @@ -71,9 +98,9 @@ function ttn_svd(
thishasqns = any(v -> hasqns(sites[v]), vertices(sites))

# traverse tree outwards from root vertex
vs = reverse(post_order_dfs_vertices(sites, root_vertex)) # store vertices in fixed ordering relative to root
vs = _default_vertex_ordering(sites, root_vertex)
# ToDo: Add check in ttn_svd that the ordering matches that of find_index_in_tree, which is used in sorteachterm #fermion-sign!
es = reverse(reverse.(post_order_dfs_edges(sites, root_vertex))) # store edges in fixed ordering relative to root
es = _default_edge_ordering(sites, root_vertex) # store edges in fixed ordering relative to root
# some things to keep track of
degrees = Dict(v => degree(sites, v) for v in vs) # rank of every TTN tensor in network
Vs = Dict(e => Dict{QN,Matrix{coefficient_type}}() for e in es) # link isometries for SVD compression of TTN
Expand Down Expand Up @@ -105,6 +132,8 @@ function ttn_svd(
for v in vs
is_internal[v] = isempty(sites[v])
if isempty(sites[v])
# FIXME: This logic only works for trivial flux, breaks for nonzero flux
# ToDo: add assert or fix and add test!
sites[v] = [Index(Hflux => 1)]
end
end
Expand All @@ -114,39 +143,53 @@ function ttn_svd(
site_coef_done = Prod{Op}[] # list of terms for which the coefficient has been added to a site factor
# temporary symbolic representation of TTN Hamiltonian
tempTTN = Dict(v => QNArrElem{Scaled{coefficient_type,Prod{Op}},degrees[v]}[] for v in vs)

#ToDo: precompute span of each term and store
# compute span of each term
spans = Dict{eltype(os),Set{vertextype_sites}}()
for term in os
spans[term] = span(term, sites)
end
# build compressed finite state machine representation
for v in vs
# for every vertex, find all edges that contain this vertex
edges = filter(e -> dst(e) == v || src(e) == v, es)
edges = align_edges(incident_edges(sites, v), es)

# use the corresponding ordering as index order for tensor elements at this site
dim_in = findfirst(e -> dst(e) == v, edges)
edge_in = (isnothing(dim_in) ? [] : edges[dim_in])
dims_out = findall(e -> src(e) == v, edges)
edges_out = edges[dims_out]

which_incident_edge = vertices_to_incident_edges_dict(sites, v, edges)
# sanity check, leaves only have single incoming or outgoing edge
@assert !isempty(dims_out) || !isnothing(dim_in)
(isempty(dims_out) || isnothing(dim_in)) && @assert is_leaf(sites, v)

for term in os
# loop over OpSum and pick out terms that act on current vertex
crosses_vertex(term, sites, v) || continue

v in spans[term] || continue
factors = ITensors.terms(term)

# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), factors)
not_onsite_factors = setdiff(factors, onsite)

# filter out factors that come in from the direction of the incoming edge
incoming = filter(
t -> edge_in ∈ edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
t -> which_incident_edge[ITensors.site(t)] == edge_in, not_onsite_factors
)

# also store all non-incoming factors in standard order, used for channel merging
not_incoming = filter(
t -> edge_in ∉ edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
t -> (ITensors.site(t) == v) || which_incident_edge[ITensors.site(t)] != edge_in,
factors,
)
# filter out factor that acts on current vertex
onsite = filter(t -> (ITensors.site(t) == v), ITensors.terms(term))

# for every outgoing edge, filter out factors that go out along that edge
outgoing = Dict(
e => filter(t -> e ∈ edge_path(sites, v, ITensors.site(t)), ITensors.terms(term))
for e in edges_out
e => filter(t -> which_incident_edge[ITensors.site(t)] == e, not_onsite_factors) for
e in edges_out
)

# compute QNs
Expand Down Expand Up @@ -246,7 +289,8 @@ function ttn_svd(

for v in vs
# redo the whole thing like before
edges = filter(e -> dst(e) == v || src(e) == v, es)
# ToDo: use neighborhood instead of going through all edges, see above
edges = align_edges(incident_edges(sites, v), es)
dim_in = findfirst(e -> dst(e) == v, edges)
dims_out = findall(e -> src(e) == v, edges)
# slice isometries at this vertex
Expand Down Expand Up @@ -340,9 +384,10 @@ function ttn_svd(
if is_internal[v]
H[v] += iT
else
if hasqns(iT)
@assert flux(iT * Op) == Hflux
end
#ToDo: Remove this assert since it seems to be costly
#if hasqns(iT)
# @assert flux(iT * Op) == Hflux
#end
H[v] += (iT * Op)
end
end
Expand Down Expand Up @@ -409,12 +454,24 @@ function computeSiteProd(sites::IndsNetwork{V,<:Index}, ops::Prod{Op})::ITensor
return T
end

function _default_vertex_ordering(g::AbstractGraph, root_vertex)
return reverse(post_order_dfs_vertices(g, root_vertex))
end

function _default_edge_ordering(g::AbstractGraph, root_vertex)
return reverse(reverse.(post_order_dfs_edges(g, root_vertex)))
end

# This is almost an exact copy of ITensors/src/opsum_to_mpo_generic:sorteachterm except for the site ordering being
# given via find_index_in_tree
# changed `isless_site` to use tree vertex ordering relative to root
function sorteachterm(os::OpSum, sites::IndsNetwork{V,<:Index}, root_vertex::V) where {V}
os = copy(os)
findpos(op::Op) = find_index_in_tree(op, sites, root_vertex)

# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
ordering = _default_vertex_ordering(sites, root_vertex)
site_positions = Dict(zip(ordering, 1:length(ordering)))
findpos(op::Op) = site_positions[ITensors.site(op)]
isless_site(o1::Op, o2::Op) = findpos(o1) < findpos(o2)
N = nv(sites)
for n in eachindex(os)
Expand Down
Loading