Skip to content

Commit

Permalink
Forest cover for specifying edge update order. Better specification o…
Browse files Browse the repository at this point in the history
…f parallel vs sequential via edge kwarg. Further examples in BPSequences
  • Loading branch information
JoeyT1994 committed Nov 2, 2023
1 parent 4c352ab commit 3ad3cbc
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 68 deletions.
75 changes: 72 additions & 3 deletions examples/belief_propagation/bpsequences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Metis
using ITensorNetworks
using Random
using SplitApplyCombine
using Graphs
using NamedGraphs

using ITensorNetworks:
Expand Down Expand Up @@ -39,6 +40,7 @@ function main()
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(
Expand All @@ -48,10 +50,63 @@ function main()
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
verbose=true,
)

g = NamedGraph(Graphs.random_regular_graph(100, 3))
s = siteinds("S=1/2", g)
χ = 4

Random.seed!(5467)

ψ = randomITensorNetwork(s; link_space=χ)
ψψ = ψ prime(dag(ψ); sites=[])

#Initial message tensors for BP
mts_init = message_tensors(
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

println("\nNow testing out a z = 3 random regular graph. Random network with bond dim ")

#Now test out various sequences
print("Parallel updates (sequence is irrelevant): ")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
verbose=true,
)

g = named_grid((6, 6))
Expand Down Expand Up @@ -79,6 +134,7 @@ function main()
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(
Expand All @@ -88,10 +144,16 @@ function main()
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
verbose=true,
)

g = NamedGraphs.hexagonal_lattice_graph(4, 4)
Expand Down Expand Up @@ -119,6 +181,7 @@ function main()
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(
Expand All @@ -128,10 +191,16 @@ function main()
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is our custom sequence finder): ")
return belief_propagation(
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
verbose=true,
)
end

Expand Down
22 changes: 15 additions & 7 deletions examples/gauging/gauging_itns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function benchmark_state_gauging(
ψ::ITensorNetwork;
mode="BeliefPropagation",
no_iterations=50,
BP_update_order::String="parallel",
BP_update_order::String="sequential",
)
s = siteinds(ψ)

Expand All @@ -69,9 +69,15 @@ function benchmark_state_gauging(
println("On Iteration " * string(i))

if mode == "BeliefPropagation"
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
ψψ, mts; contract_kwargs=(; alg="exact"), update_sequence=BP_update_order
)
if BP_update_order != "parallel"
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
ψψ, mts; contract_kwargs=(; alg="exact")
)
else
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
ψψ, mts; contract_kwargs=(; alg="exact"), edges=[[e] for e in edges(mts)]
)
end

times_gauging[i] = @elapsed ψ, bond_tensors = vidal_gauge(ψinit, mts)
elseif mode == "Eager"
Expand All @@ -98,14 +104,16 @@ s = siteinds("S=1/2", g)
ψ = randomITensorNetwork(s; link_space=χ)
no_iterations = 30

BPG_simulation_times, BPG_Cs = benchmark_state_gauging(ψ; no_iterations)
BPG_simulation_times, BPG_Cs = benchmark_state_gauging(
ψ; no_iterations, BP_update_order="parallel"
)
BPG_sequential_simulation_times, BPG_sequential_Cs = benchmark_state_gauging(
ψ; no_iterations, BP_update_order="sequential"
ψ; no_iterations
)
Eager_simulation_times, Eager_Cs = benchmark_state_gauging(ψ; mode="Eager", no_iterations)
SU_simulation_times, SU_Cs = benchmark_state_gauging(ψ; mode="SU", no_iterations)

epsilon = 1e-6
epsilon = 1e-10

println(
"Time for BPG (with parallel updates) to reach C < epsilon was " *
Expand Down
72 changes: 29 additions & 43 deletions src/beliefpropagation/beliefpropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,67 +135,31 @@ function belief_propagation_iteration(
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
undirected_graph(underlying_graph(mts))
),
) where {E<:NamedEdge}
return belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
end

# """
# Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
# """
# function belief_propagation_iteration(
# tn::ITensorNetwork,
# mts::DataGraph;
# contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
# compute_norm=false,
# update_sequence::String="sequential",
# edges::Vector{Vector{}} = edge_update_order(undirected_graph(underlying_graph(mts))),
# )
# new_mts = copy(mts)
# if update_sequence != "parallel" && update_sequence != "sequential"
# error(
# "Specified update order is not currently implemented. Choose parallel or sequential."
# )
# end
# incoming_mts = update_sequence == "parallel" ? mts : new_mts
# c = 0
# for e in edges
# environment_tensornetworks = ITensorNetwork[
# incoming_mts[e_in] for
# e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)])
# ]
# new_mts[src(e) => dst(e)] = update_message_tensor(
# tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs
# )

# if compute_norm
# LHS, RHS = ITensors.contract(ITensor(mts[src(e) => dst(e)])),
# ITensors.contract(ITensor(new_mts[src(e) => dst(e)]))
# LHS /= sum(diag(LHS))
# RHS /= sum(diag(RHS))
# c += 0.5 * norm(denseblocks(LHS) - denseblocks(RHS))
# end
# end
# return new_mts, c / (length(edges))
# end

function belief_propagation(
tn::ITensorNetwork,
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
niters=20,
target_precision::Union{Float64,Nothing}=nothing,
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
undirected_graph(underlying_graph(mts))
),
verbose=false,
) where {E<:NamedEdge}
compute_norm = target_precision == nothing ? false : true
for i in 1:niters
mts, c = belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
if compute_norm && c <= target_precision
println("BP converged to desired precision after $i iterations.")
if verbose
println("BP converged to desired precision after $i iterations.")
end
break
end
end
Expand All @@ -210,9 +174,10 @@ function belief_propagation(
subgraph_vertices=nothing,
niters=20,
target_precision::Union{Float64,Nothing}=nothing,
verbose=false,
)
mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices)
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision)
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision, verbose)
end

"""
Expand Down Expand Up @@ -247,3 +212,24 @@ function approx_network_region(

return environment_tn verts_tn
end

"""
Return a custom edge order for how how to update all BP message tensors on a general undirected graph.
On a tree this will yield a sequence which only needs to be performed once. Based on forest covers and depth first searches amongst the forests.
"""
function belief_propagation_edge_sequence(
g::NamedGraph; root_vertex=NamedGraphs.default_root_vertex
)
@assert !is_directed(g)
forests = NamedGraphs.forest_cover(g)
edges = NamedEdge[]
for forest in forests
trees = NamedGraph[forest[vs] for vs in connected_components(forest)]
for tree in trees
tree_edges = post_order_dfs_edges(tree, root_vertex(tree))
push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...)
end
end

return edges
end
4 changes: 2 additions & 2 deletions src/beliefpropagation/sqrt_beliefpropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
function sqrt_belief_propagation_iteration(
tn::ITensorNetwork,
sqrt_mts::DataGraph;
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
undirected_graph(underlying_graph(mts))
),
) where {E<:NamedEdge}
Expand All @@ -56,7 +56,7 @@ function sqrt_belief_propagation(
tn::ITensorNetwork,
mts::DataGraph;
niters=20,
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
undirected_graph(underlying_graph(mts))
),
# target_precision::Union{Float64,Nothing}=nothing,
Expand Down
8 changes: 7 additions & 1 deletion src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,20 @@ function vidal_gauge(
regularization=10 * eps(real(scalartype(ψ))),
niters=30,
target_canonicalness::Union{Nothing,Float64}=nothing,
verbose=false,
svd_kwargs...,
)
ψψ = norm_network(ψ)
Z = partition(ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ)))))
mts = message_tensors(Z)

mts = belief_propagation(
ψψ, mts; contract_kwargs=(; alg="exact"), niters, target_precision=target_canonicalness
ψψ,
mts;
contract_kwargs=(; alg="exact"),
niters,
target_precision=target_canonicalness,
verbose,
)
return vidal_gauge(
ψ, mts; eigen_message_tensor_cutoff, regularization, niters, svd_kwargs...
Expand Down
18 changes: 6 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,18 @@ function line_to_tree(line::Vector)
return [line_to_tree(line[1:(end - 1)]), line[end]]
end

function edge_update_order(g)
forests = NamedGraphs.build_forest_cover(g)
#Custom edge order for updating all BP message tensors on a general undirected graph. On a tree this will yield a sequence which only needs to be performed once.
function BP_edge_update_order(g::NamedGraph; root_vertex=NamedGraphs.default_root_vertex)
@assert !is_directed(g)
forests = NamedGraphs.forest_cover(g)
edges = NamedEdge[]
for forest in forests
trees = NamedGraph[forest[vs] for vs in connected_components(forest)]
for tree in trees
push!(edges, tree_edge_update_order(tree)...)
tree_edges = post_order_dfs_edges(tree, root_vertex(tree))
push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...)
end
end

return edges
end

#Find an optimal ordering of the edges in a tree
function tree_edge_update_order(
g::AbstractNamedGraph; root_vertex=first(keys(sort(eccentricities(g); rev=true)))
)
@assert is_tree(g)
es = post_order_dfs_edges(g, root_vertex)
return vcat(es, reverse(reverse.(es)))
end

0 comments on commit 3ad3cbc

Please sign in to comment.