Skip to content

Commit 3ad3cbc

Browse files
committed
Forest cover for specifying edge update order. Better specification of parallel vs sequential via edge kwarg. Further examples in BPSequences
1 parent 4c352ab commit 3ad3cbc

File tree

6 files changed

+131
-68
lines changed

6 files changed

+131
-68
lines changed

examples/belief_propagation/bpsequences.jl

+72-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Metis
44
using ITensorNetworks
55
using Random
66
using SplitApplyCombine
7+
using Graphs
78
using NamedGraphs
89

910
using ITensorNetworks:
@@ -39,6 +40,7 @@ function main()
3940
target_precision=1e-10,
4041
niters=100,
4142
edges=[[e] for e in edges(mts_init)],
43+
verbose=true,
4244
)
4345
print("Sequential updates (sequence is default edge list of the message tensors): ")
4446
belief_propagation(
@@ -48,10 +50,63 @@ function main()
4850
target_precision=1e-10,
4951
niters=100,
5052
edges=[e for e in edges(mts_init)],
53+
verbose=true,
5154
)
5255
print("Sequential updates (sequence is our custom sequence finder): ")
5356
belief_propagation(
54-
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
57+
ψψ,
58+
mts_init;
59+
contract_kwargs=(; alg="exact"),
60+
target_precision=1e-10,
61+
niters=100,
62+
verbose=true,
63+
)
64+
65+
g = NamedGraph(Graphs.random_regular_graph(100, 3))
66+
s = siteinds("S=1/2", g)
67+
χ = 4
68+
69+
Random.seed!(5467)
70+
71+
ψ = randomITensorNetwork(s; link_space=χ)
72+
ψψ = ψ prime(dag(ψ); sites=[])
73+
74+
#Initial message tensors for BP
75+
mts_init = message_tensors(
76+
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
77+
)
78+
79+
println("\nNow testing out a z = 3 random regular graph. Random network with bond dim ")
80+
81+
#Now test out various sequences
82+
print("Parallel updates (sequence is irrelevant): ")
83+
belief_propagation(
84+
ψψ,
85+
mts_init;
86+
contract_kwargs=(; alg="exact"),
87+
target_precision=1e-10,
88+
niters=100,
89+
edges=[[e] for e in edges(mts_init)],
90+
verbose=true,
91+
)
92+
print("Sequential updates (sequence is default edge list of the message tensors): ")
93+
belief_propagation(
94+
ψψ,
95+
mts_init;
96+
contract_kwargs=(; alg="exact"),
97+
target_precision=1e-10,
98+
niters=100,
99+
edges=[e for e in edges(mts_init)],
100+
verbose=true,
101+
)
102+
print("Sequential updates (sequence is our custom sequence finder): ")
103+
belief_propagation(
104+
ψψ,
105+
mts_init;
106+
contract_kwargs=(; alg="exact"),
107+
target_precision=1e-10,
108+
niters=100,
109+
verbose=true,
55110
)
56111

57112
g = named_grid((6, 6))
@@ -79,6 +134,7 @@ function main()
79134
target_precision=1e-10,
80135
niters=100,
81136
edges=[[e] for e in edges(mts_init)],
137+
verbose=true,
82138
)
83139
print("Sequential updates (sequence is default edge list of the message tensors): ")
84140
belief_propagation(
@@ -88,10 +144,16 @@ function main()
88144
target_precision=1e-10,
89145
niters=100,
90146
edges=[e for e in edges(mts_init)],
147+
verbose=true,
91148
)
92149
print("Sequential updates (sequence is our custom sequence finder): ")
93150
belief_propagation(
94-
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
151+
ψψ,
152+
mts_init;
153+
contract_kwargs=(; alg="exact"),
154+
target_precision=1e-10,
155+
niters=100,
156+
verbose=true,
95157
)
96158

97159
g = NamedGraphs.hexagonal_lattice_graph(4, 4)
@@ -119,6 +181,7 @@ function main()
119181
target_precision=1e-10,
120182
niters=100,
121183
edges=[[e] for e in edges(mts_init)],
184+
verbose=true,
122185
)
123186
print("Sequential updates (sequence is default edge list of the message tensors): ")
124187
belief_propagation(
@@ -128,10 +191,16 @@ function main()
128191
target_precision=1e-10,
129192
niters=100,
130193
edges=[e for e in edges(mts_init)],
194+
verbose=true,
131195
)
132196
print("Sequential updates (sequence is our custom sequence finder): ")
133197
return belief_propagation(
134-
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
198+
ψψ,
199+
mts_init;
200+
contract_kwargs=(; alg="exact"),
201+
target_precision=1e-10,
202+
niters=100,
203+
verbose=true,
135204
)
136205
end
137206

examples/gauging/gauging_itns.jl

+15-7
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function benchmark_state_gauging(
4848
ψ::ITensorNetwork;
4949
mode="BeliefPropagation",
5050
no_iterations=50,
51-
BP_update_order::String="parallel",
51+
BP_update_order::String="sequential",
5252
)
5353
s = siteinds(ψ)
5454

@@ -69,9 +69,15 @@ function benchmark_state_gauging(
6969
println("On Iteration " * string(i))
7070

7171
if mode == "BeliefPropagation"
72-
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
73-
ψψ, mts; contract_kwargs=(; alg="exact"), update_sequence=BP_update_order
74-
)
72+
if BP_update_order != "parallel"
73+
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
74+
ψψ, mts; contract_kwargs=(; alg="exact")
75+
)
76+
else
77+
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
78+
ψψ, mts; contract_kwargs=(; alg="exact"), edges=[[e] for e in edges(mts)]
79+
)
80+
end
7581

7682
times_gauging[i] = @elapsed ψ, bond_tensors = vidal_gauge(ψinit, mts)
7783
elseif mode == "Eager"
@@ -98,14 +104,16 @@ s = siteinds("S=1/2", g)
98104
ψ = randomITensorNetwork(s; link_space=χ)
99105
no_iterations = 30
100106

101-
BPG_simulation_times, BPG_Cs = benchmark_state_gauging(ψ; no_iterations)
107+
BPG_simulation_times, BPG_Cs = benchmark_state_gauging(
108+
ψ; no_iterations, BP_update_order="parallel"
109+
)
102110
BPG_sequential_simulation_times, BPG_sequential_Cs = benchmark_state_gauging(
103-
ψ; no_iterations, BP_update_order="sequential"
111+
ψ; no_iterations
104112
)
105113
Eager_simulation_times, Eager_Cs = benchmark_state_gauging(ψ; mode="Eager", no_iterations)
106114
SU_simulation_times, SU_Cs = benchmark_state_gauging(ψ; mode="SU", no_iterations)
107115

108-
epsilon = 1e-6
116+
epsilon = 1e-10
109117

110118
println(
111119
"Time for BPG (with parallel updates) to reach C < epsilon was " *

src/beliefpropagation/beliefpropagation.jl

+29-43
Original file line numberDiff line numberDiff line change
@@ -135,67 +135,31 @@ function belief_propagation_iteration(
135135
mts::DataGraph;
136136
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
137137
compute_norm=false,
138-
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
138+
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
139139
undirected_graph(underlying_graph(mts))
140140
),
141141
) where {E<:NamedEdge}
142142
return belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
143143
end
144144

145-
# """
146-
# Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
147-
# """
148-
# function belief_propagation_iteration(
149-
# tn::ITensorNetwork,
150-
# mts::DataGraph;
151-
# contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
152-
# compute_norm=false,
153-
# update_sequence::String="sequential",
154-
# edges::Vector{Vector{}} = edge_update_order(undirected_graph(underlying_graph(mts))),
155-
# )
156-
# new_mts = copy(mts)
157-
# if update_sequence != "parallel" && update_sequence != "sequential"
158-
# error(
159-
# "Specified update order is not currently implemented. Choose parallel or sequential."
160-
# )
161-
# end
162-
# incoming_mts = update_sequence == "parallel" ? mts : new_mts
163-
# c = 0
164-
# for e in edges
165-
# environment_tensornetworks = ITensorNetwork[
166-
# incoming_mts[e_in] for
167-
# e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)])
168-
# ]
169-
# new_mts[src(e) => dst(e)] = update_message_tensor(
170-
# tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs
171-
# )
172-
173-
# if compute_norm
174-
# LHS, RHS = ITensors.contract(ITensor(mts[src(e) => dst(e)])),
175-
# ITensors.contract(ITensor(new_mts[src(e) => dst(e)]))
176-
# LHS /= sum(diag(LHS))
177-
# RHS /= sum(diag(RHS))
178-
# c += 0.5 * norm(denseblocks(LHS) - denseblocks(RHS))
179-
# end
180-
# end
181-
# return new_mts, c / (length(edges))
182-
# end
183-
184145
function belief_propagation(
185146
tn::ITensorNetwork,
186147
mts::DataGraph;
187148
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
188149
niters=20,
189150
target_precision::Union{Float64,Nothing}=nothing,
190-
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
151+
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
191152
undirected_graph(underlying_graph(mts))
192153
),
154+
verbose=false,
193155
) where {E<:NamedEdge}
194156
compute_norm = target_precision == nothing ? false : true
195157
for i in 1:niters
196158
mts, c = belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
197159
if compute_norm && c <= target_precision
198-
println("BP converged to desired precision after $i iterations.")
160+
if verbose
161+
println("BP converged to desired precision after $i iterations.")
162+
end
199163
break
200164
end
201165
end
@@ -210,9 +174,10 @@ function belief_propagation(
210174
subgraph_vertices=nothing,
211175
niters=20,
212176
target_precision::Union{Float64,Nothing}=nothing,
177+
verbose=false,
213178
)
214179
mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices)
215-
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision)
180+
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision, verbose)
216181
end
217182

218183
"""
@@ -247,3 +212,24 @@ function approx_network_region(
247212

248213
return environment_tn verts_tn
249214
end
215+
216+
"""
217+
Return a custom edge order for how how to update all BP message tensors on a general undirected graph.
218+
On a tree this will yield a sequence which only needs to be performed once. Based on forest covers and depth first searches amongst the forests.
219+
"""
220+
function belief_propagation_edge_sequence(
221+
g::NamedGraph; root_vertex=NamedGraphs.default_root_vertex
222+
)
223+
@assert !is_directed(g)
224+
forests = NamedGraphs.forest_cover(g)
225+
edges = NamedEdge[]
226+
for forest in forests
227+
trees = NamedGraph[forest[vs] for vs in connected_components(forest)]
228+
for tree in trees
229+
tree_edges = post_order_dfs_edges(tree, root_vertex(tree))
230+
push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...)
231+
end
232+
end
233+
234+
return edges
235+
end

src/beliefpropagation/sqrt_beliefpropagation.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
function sqrt_belief_propagation_iteration(
4646
tn::ITensorNetwork,
4747
sqrt_mts::DataGraph;
48-
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
48+
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
4949
undirected_graph(underlying_graph(mts))
5050
),
5151
) where {E<:NamedEdge}
@@ -56,7 +56,7 @@ function sqrt_belief_propagation(
5656
tn::ITensorNetwork,
5757
mts::DataGraph;
5858
niters=20,
59-
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
59+
edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence(
6060
undirected_graph(underlying_graph(mts))
6161
),
6262
# target_precision::Union{Float64,Nothing}=nothing,

src/gauging.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,20 @@ function vidal_gauge(
9696
regularization=10 * eps(real(scalartype(ψ))),
9797
niters=30,
9898
target_canonicalness::Union{Nothing,Float64}=nothing,
99+
verbose=false,
99100
svd_kwargs...,
100101
)
101102
ψψ = norm_network(ψ)
102103
Z = partition(ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ)))))
103104
mts = message_tensors(Z)
104105

105106
mts = belief_propagation(
106-
ψψ, mts; contract_kwargs=(; alg="exact"), niters, target_precision=target_canonicalness
107+
ψψ,
108+
mts;
109+
contract_kwargs=(; alg="exact"),
110+
niters,
111+
target_precision=target_canonicalness,
112+
verbose,
107113
)
108114
return vidal_gauge(
109115
ψ, mts; eigen_message_tensor_cutoff, regularization, niters, svd_kwargs...

src/utils.jl

+6-12
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,18 @@ function line_to_tree(line::Vector)
2626
return [line_to_tree(line[1:(end - 1)]), line[end]]
2727
end
2828

29-
function edge_update_order(g)
30-
forests = NamedGraphs.build_forest_cover(g)
29+
#Custom edge order for updating all BP message tensors on a general undirected graph. On a tree this will yield a sequence which only needs to be performed once.
30+
function BP_edge_update_order(g::NamedGraph; root_vertex=NamedGraphs.default_root_vertex)
31+
@assert !is_directed(g)
32+
forests = NamedGraphs.forest_cover(g)
3133
edges = NamedEdge[]
3234
for forest in forests
3335
trees = NamedGraph[forest[vs] for vs in connected_components(forest)]
3436
for tree in trees
35-
push!(edges, tree_edge_update_order(tree)...)
37+
tree_edges = post_order_dfs_edges(tree, root_vertex(tree))
38+
push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...)
3639
end
3740
end
3841

3942
return edges
4043
end
41-
42-
#Find an optimal ordering of the edges in a tree
43-
function tree_edge_update_order(
44-
g::AbstractNamedGraph; root_vertex=first(keys(sort(eccentricities(g); rev=true)))
45-
)
46-
@assert is_tree(g)
47-
es = post_order_dfs_edges(g, root_vertex)
48-
return vcat(es, reverse(reverse.(es)))
49-
end

0 commit comments

Comments
 (0)