diff --git a/examples/belief_propagation/bpsequences.jl b/examples/belief_propagation/bpsequences.jl index 5691f6ae..fbcf0c14 100644 --- a/examples/belief_propagation/bpsequences.jl +++ b/examples/belief_propagation/bpsequences.jl @@ -4,6 +4,7 @@ using Metis using ITensorNetworks using Random using SplitApplyCombine +using Graphs using NamedGraphs using ITensorNetworks: @@ -39,6 +40,7 @@ function main() target_precision=1e-10, niters=100, edges=[[e] for e in edges(mts_init)], + verbose=true, ) print("Sequential updates (sequence is default edge list of the message tensors): ") belief_propagation( @@ -48,10 +50,63 @@ function main() target_precision=1e-10, niters=100, edges=[e for e in edges(mts_init)], + verbose=true, ) print("Sequential updates (sequence is our custom sequence finder): ") belief_propagation( - ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100 + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + verbose=true, + ) + + g = NamedGraph(Graphs.random_regular_graph(100, 3)) + s = siteinds("S=1/2", g) + χ = 4 + + Random.seed!(5467) + + ψ = randomITensorNetwork(s; link_space=χ) + ψψ = ψ ⊗ prime(dag(ψ); sites=[]) + + #Initial message tensors for BP + mts_init = message_tensors( + ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ)))) + ) + + println("\nNow testing out a z = 3 random regular graph. Random network with bond dim $χ") + + #Now test out various sequences + print("Parallel updates (sequence is irrelevant): ") + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[[e] for e in edges(mts_init)], + verbose=true, + ) + print("Sequential updates (sequence is default edge list of the message tensors): ") + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + edges=[e for e in edges(mts_init)], + verbose=true, + ) + print("Sequential updates (sequence is our custom sequence finder): ") + belief_propagation( + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + verbose=true, ) g = named_grid((6, 6)) @@ -79,6 +134,7 @@ function main() target_precision=1e-10, niters=100, edges=[[e] for e in edges(mts_init)], + verbose=true, ) print("Sequential updates (sequence is default edge list of the message tensors): ") belief_propagation( @@ -88,10 +144,16 @@ function main() target_precision=1e-10, niters=100, edges=[e for e in edges(mts_init)], + verbose=true, ) print("Sequential updates (sequence is our custom sequence finder): ") belief_propagation( - ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100 + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + verbose=true, ) g = NamedGraphs.hexagonal_lattice_graph(4, 4) @@ -119,6 +181,7 @@ function main() target_precision=1e-10, niters=100, edges=[[e] for e in edges(mts_init)], + verbose=true, ) print("Sequential updates (sequence is default edge list of the message tensors): ") belief_propagation( @@ -128,10 +191,16 @@ function main() target_precision=1e-10, niters=100, edges=[e for e in edges(mts_init)], + verbose=true, ) print("Sequential updates (sequence is our custom sequence finder): ") return belief_propagation( - ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100 + ψψ, + mts_init; + contract_kwargs=(; alg="exact"), + target_precision=1e-10, + niters=100, + verbose=true, ) end diff --git a/examples/gauging/gauging_itns.jl b/examples/gauging/gauging_itns.jl index 0f65dc8e..b420d45e 100644 --- a/examples/gauging/gauging_itns.jl +++ b/examples/gauging/gauging_itns.jl @@ -48,7 +48,7 @@ function benchmark_state_gauging( ψ::ITensorNetwork; mode="BeliefPropagation", no_iterations=50, - BP_update_order::String="parallel", + BP_update_order::String="sequential", ) s = siteinds(ψ) @@ -69,9 +69,15 @@ function benchmark_state_gauging( println("On Iteration " * string(i)) if mode == "BeliefPropagation" - times_iters[i] = @elapsed mts, _ = belief_propagation_iteration( - ψψ, mts; contract_kwargs=(; alg="exact"), update_sequence=BP_update_order - ) + if BP_update_order != "parallel" + times_iters[i] = @elapsed mts, _ = belief_propagation_iteration( + ψψ, mts; contract_kwargs=(; alg="exact") + ) + else + times_iters[i] = @elapsed mts, _ = belief_propagation_iteration( + ψψ, mts; contract_kwargs=(; alg="exact"), edges=[[e] for e in edges(mts)] + ) + end times_gauging[i] = @elapsed ψ, bond_tensors = vidal_gauge(ψinit, mts) elseif mode == "Eager" @@ -98,14 +104,16 @@ s = siteinds("S=1/2", g) ψ = randomITensorNetwork(s; link_space=χ) no_iterations = 30 -BPG_simulation_times, BPG_Cs = benchmark_state_gauging(ψ; no_iterations) +BPG_simulation_times, BPG_Cs = benchmark_state_gauging( + ψ; no_iterations, BP_update_order="parallel" +) BPG_sequential_simulation_times, BPG_sequential_Cs = benchmark_state_gauging( - ψ; no_iterations, BP_update_order="sequential" + ψ; no_iterations ) Eager_simulation_times, Eager_Cs = benchmark_state_gauging(ψ; mode="Eager", no_iterations) SU_simulation_times, SU_Cs = benchmark_state_gauging(ψ; mode="SU", no_iterations) -epsilon = 1e-6 +epsilon = 1e-10 println( "Time for BPG (with parallel updates) to reach C < epsilon was " * diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index b4558993..64581e55 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -135,67 +135,31 @@ function belief_propagation_iteration( mts::DataGraph; contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), compute_norm=false, - edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence( undirected_graph(underlying_graph(mts)) ), ) where {E<:NamedEdge} return belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm) end -# """ -# Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs -# """ -# function belief_propagation_iteration( -# tn::ITensorNetwork, -# mts::DataGraph; -# contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), -# compute_norm=false, -# update_sequence::String="sequential", -# edges::Vector{Vector{}} = edge_update_order(undirected_graph(underlying_graph(mts))), -# ) -# new_mts = copy(mts) -# if update_sequence != "parallel" && update_sequence != "sequential" -# error( -# "Specified update order is not currently implemented. Choose parallel or sequential." -# ) -# end -# incoming_mts = update_sequence == "parallel" ? mts : new_mts -# c = 0 -# for e in edges -# environment_tensornetworks = ITensorNetwork[ -# incoming_mts[e_in] for -# e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)]) -# ] -# new_mts[src(e) => dst(e)] = update_message_tensor( -# tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs -# ) - -# if compute_norm -# LHS, RHS = ITensors.contract(ITensor(mts[src(e) => dst(e)])), -# ITensors.contract(ITensor(new_mts[src(e) => dst(e)])) -# LHS /= sum(diag(LHS)) -# RHS /= sum(diag(RHS)) -# c += 0.5 * norm(denseblocks(LHS) - denseblocks(RHS)) -# end -# end -# return new_mts, c / (length(edges)) -# end - function belief_propagation( tn::ITensorNetwork, mts::DataGraph; contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1), niters=20, target_precision::Union{Float64,Nothing}=nothing, - edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence( undirected_graph(underlying_graph(mts)) ), + verbose=false, ) where {E<:NamedEdge} compute_norm = target_precision == nothing ? false : true for i in 1:niters mts, c = belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm) if compute_norm && c <= target_precision - println("BP converged to desired precision after $i iterations.") + if verbose + println("BP converged to desired precision after $i iterations.") + end break end end @@ -210,9 +174,10 @@ function belief_propagation( subgraph_vertices=nothing, niters=20, target_precision::Union{Float64,Nothing}=nothing, + verbose=false, ) mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices) - return belief_propagation(tn, mts; contract_kwargs, niters, target_precision) + return belief_propagation(tn, mts; contract_kwargs, niters, target_precision, verbose) end """ @@ -247,3 +212,24 @@ function approx_network_region( return environment_tn ⊗ verts_tn end + +""" +Return a custom edge order for how how to update all BP message tensors on a general undirected graph. +On a tree this will yield a sequence which only needs to be performed once. Based on forest covers and depth first searches amongst the forests. +""" +function belief_propagation_edge_sequence( + g::NamedGraph; root_vertex=NamedGraphs.default_root_vertex +) + @assert !is_directed(g) + forests = NamedGraphs.forest_cover(g) + edges = NamedEdge[] + for forest in forests + trees = NamedGraph[forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + + return edges +end diff --git a/src/beliefpropagation/sqrt_beliefpropagation.jl b/src/beliefpropagation/sqrt_beliefpropagation.jl index c4eff09e..f28ab9a0 100644 --- a/src/beliefpropagation/sqrt_beliefpropagation.jl +++ b/src/beliefpropagation/sqrt_beliefpropagation.jl @@ -45,7 +45,7 @@ end function sqrt_belief_propagation_iteration( tn::ITensorNetwork, sqrt_mts::DataGraph; - edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence( undirected_graph(underlying_graph(mts)) ), ) where {E<:NamedEdge} @@ -56,7 +56,7 @@ function sqrt_belief_propagation( tn::ITensorNetwork, mts::DataGraph; niters=20, - edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order( + edges::Union{Vector{Vector{E}},Vector{E}}=belief_propagation_edge_sequence( undirected_graph(underlying_graph(mts)) ), # target_precision::Union{Float64,Nothing}=nothing, diff --git a/src/gauging.jl b/src/gauging.jl index 4086198c..17b8f748 100644 --- a/src/gauging.jl +++ b/src/gauging.jl @@ -96,6 +96,7 @@ function vidal_gauge( regularization=10 * eps(real(scalartype(ψ))), niters=30, target_canonicalness::Union{Nothing,Float64}=nothing, + verbose=false, svd_kwargs..., ) ψψ = norm_network(ψ) @@ -103,7 +104,12 @@ function vidal_gauge( mts = message_tensors(Z) mts = belief_propagation( - ψψ, mts; contract_kwargs=(; alg="exact"), niters, target_precision=target_canonicalness + ψψ, + mts; + contract_kwargs=(; alg="exact"), + niters, + target_precision=target_canonicalness, + verbose, ) return vidal_gauge( ψ, mts; eigen_message_tensor_cutoff, regularization, niters, svd_kwargs... diff --git a/src/utils.jl b/src/utils.jl index e3e2bc82..057822cb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -26,24 +26,18 @@ function line_to_tree(line::Vector) return [line_to_tree(line[1:(end - 1)]), line[end]] end -function edge_update_order(g) - forests = NamedGraphs.build_forest_cover(g) +#Custom edge order for updating all BP message tensors on a general undirected graph. On a tree this will yield a sequence which only needs to be performed once. +function BP_edge_update_order(g::NamedGraph; root_vertex=NamedGraphs.default_root_vertex) + @assert !is_directed(g) + forests = NamedGraphs.forest_cover(g) edges = NamedEdge[] for forest in forests trees = NamedGraph[forest[vs] for vs in connected_components(forest)] for tree in trees - push!(edges, tree_edge_update_order(tree)...) + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) end end return edges end - -#Find an optimal ordering of the edges in a tree -function tree_edge_update_order( - g::AbstractNamedGraph; root_vertex=first(keys(sort(eccentricities(g); rev=true))) -) - @assert is_tree(g) - es = post_order_dfs_edges(g, root_vertex) - return vcat(es, reverse(reverse.(es))) -end