Skip to content

Commit

Permalink
Belief propagation order flexibility (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Nov 30, 2023
1 parent 94eb10f commit 6c1a0f6
Show file tree
Hide file tree
Showing 14 changed files with 354 additions and 97 deletions.
6 changes: 4 additions & 2 deletions examples/belief_propagation/bpexample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Metis
using ITensorNetworks
using Random
using SplitApplyCombine
using NamedGraphs

using ITensorNetworks:
belief_propagation,
Expand Down Expand Up @@ -34,7 +35,8 @@ function main()
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"))
mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"), niters=20)

numerator_network = approx_network_region(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
Expand All @@ -52,7 +54,7 @@ function main()
)
Zpp = partition(ψψ; subgraph_vertices=nested_graph_leaf_vertices(Zp))
mts = message_tensors(Zpp)
mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"))
mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"), niters=20)
numerator_network = approx_network_region(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
Expand Down
81 changes: 81 additions & 0 deletions examples/belief_propagation/bpsequences.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using Compat
using ITensors
using Metis
using ITensorNetworks
using Random
using SplitApplyCombine
using Graphs
using NamedGraphs

using ITensorNetworks:
belief_propagation,
approx_network_region,
contract_inner,
message_tensors,
nested_graph_leaf_vertices,
edge_sequence

function main()
g_labels = [
"Comb Tree",
"100 Site Random Regular Graph z = 3",
"6x6 Square Grid",
"4x4 Hexagonal Lattice",
]
gs = [
named_comb_tree((6, 6)),
NamedGraph(Graphs.random_regular_graph(100, 3)),
named_grid((6, 6)),
NamedGraphs.hexagonal_lattice_graph(4, 4),
]
χs = [4, 4, 2, 3]

for (i, g) in enumerate(gs)
Random.seed!(5467)
g_label = g_labels[i]
χ = χs[i]
s = siteinds("S=1/2", g)
ψ = randomITensorNetwork(s; link_space=χ)
ψψ = ψ prime(dag(ψ); sites=[])

#Initial message tensors for BP
mts_init = message_tensors(
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

println("\nFirst testing out a $g_label. Random network with bond dim ")

#Now test out various sequences
print("Parallel updates (sequence is irrelevant): ")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=edge_sequence(mts_init; alg="parallel"),
verbose=true,
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
verbose=true,
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
verbose=true,
)
end
end

main()
6 changes: 3 additions & 3 deletions examples/dynamics/heavy_hex_ising_real_tebd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ end
function ibm_processor_graph(n::Int64, m::Int64)
g = heavy_hex_lattice_graph(n, m)
dims = maximum(vertices(hexagonal_lattice_graph(n, m)))
v1, v2 = (1, dims[2]), (dims[1], 1)
v1, v2 = (dims[1], 1), (1, dims[2])
add_vertices!(g, [v1, v2])
add_edge!(g, v1 => v1 .- (0, 1))
add_edge!(g, v2 => v2 .+ (0, 1))
add_edge!(g, v1 => v1 .- (1, 0))
add_edge!(g, v2 => v2 .+ (1, 0))

return g
end
Expand Down
64 changes: 46 additions & 18 deletions examples/gauging/gauging_itns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using Metis
using ITensorNetworks
using Random
using SplitApplyCombine
using ProfileView

using ITensorNetworks:
message_tensors,
Expand All @@ -18,7 +17,8 @@ using ITensorNetworks:
vidal_to_symmetric_gauge,
initialize_bond_tensors,
vidal_itn_isometries,
norm_network
norm_network,
edge_sequence

using NamedGraphs
using NamedGraphs: add_edges!, rem_vertex!, hexagonal_lattice_graph
Expand All @@ -45,7 +45,10 @@ end

"""Bring an ITN into the Vidal gauge, various methods possible. Result is timed"""
function benchmark_state_gauging(
ψ::ITensorNetwork; mode="BeliefPropagation", no_iterations=50
ψ::ITensorNetwork;
mode="belief_propagation",
no_iterations=50,
BP_update_order::String="sequential",
)
s = siteinds(ψ)

Expand All @@ -65,12 +68,19 @@ function benchmark_state_gauging(
for i in 1:no_iterations
println("On Iteration " * string(i))

if mode == "BeliefPropagation"
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
ψψ, mts; contract_kwargs=(; alg="exact")
)
if mode == "belief_propagation"
if BP_update_order != "parallel"
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
ψψ, mts; contract_kwargs=(; alg="exact")
)
else
times_iters[i] = @elapsed mts, _ = belief_propagation_iteration(
ψψ, mts; contract_kwargs=(; alg="exact"), edges=edge_sequence(mts; alg="parallel")
)
end

times_gauging[i] = @elapsed ψ, bond_tensors = vidal_gauge(ψinit, mts)
elseif mode == "Eager"
elseif mode == "eager"
times_iters[i] = @elapsed ψ, bond_tensors, mts = eager_gauging(ψ, bond_tensors, mts)
else
times_iters[i] = @elapsed begin
Expand All @@ -82,7 +92,7 @@ function benchmark_state_gauging(

C[i] = vidal_itn_canonicalness(ψ, bond_tensors)
end

@show times_iters, time
simulation_times = cumsum(times_iters) + times_gauging

return simulation_times, C
Expand All @@ -94,23 +104,41 @@ s = siteinds("S=1/2", g)
ψ = randomITensorNetwork(s; link_space=χ)
no_iterations = 30

BPG_simulation_times, BPG_Cs = benchmark_state_gauging(ψ; no_iterations)
Eager_simulation_times, Eager_Cs = benchmark_state_gauging(ψ; mode="Eager", no_iterations)
BPG_simulation_times, BPG_Cs = benchmark_state_gauging(
ψ; no_iterations, BP_update_order="parallel"
)
BPG_sequential_simulation_times, BPG_sequential_Cs = benchmark_state_gauging(
ψ; no_iterations
)
Eager_simulation_times, Eager_Cs = benchmark_state_gauging(ψ; mode="eager", no_iterations)
SU_simulation_times, SU_Cs = benchmark_state_gauging(ψ; mode="SU", no_iterations)

epsilon = 1e-6
epsilon = 1e-10

println(
"Time for BPG to reach C < epsilon was " *
"Time for BPG (with parallel updates) to reach C < epsilon was " *
string(BPG_simulation_times[findfirst(x -> x < 0, BPG_Cs .- epsilon)]) *
" seconds",
" seconds. No iters was " *
string(findfirst(x -> x < 0, BPG_Cs .- epsilon)),
)
println(
"Time for BPG (with sequential updates) to reach C < epsilon was " *
string(
BPG_sequential_simulation_times[findfirst(x -> x < 0, BPG_sequential_Cs .- epsilon)]
) *
" seconds. No iters was " *
string(findfirst(x -> x < 0, BPG_sequential_Cs .- epsilon)),
)

println(
"Time for Eager to reach C < epsilon was " *
"Time for Eager Gauging to reach C < epsilon was " *
string(Eager_simulation_times[findfirst(x -> x < 0, Eager_Cs .- epsilon)]) *
" seconds",
" seconds. No iters was " *
string(findfirst(x -> x < 0, Eager_Cs .- epsilon)),
)
println(
"Time for SU to reach C < epsilon was " *
"Time for SU Gauging (with sequential updates) to reach C < epsilon was " *
string(SU_simulation_times[findfirst(x -> x < 0, SU_Cs .- epsilon)]) *
" seconds",
" seconds. No iters was " *
string(findfirst(x -> x < 0, SU_Cs .- epsilon)),
)
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
include("boundarymps.jl")
include(joinpath("beliefpropagation", "beliefpropagation.jl"))
include(joinpath("beliefpropagation", "beliefpropagation_schedule.jl"))
include(joinpath("beliefpropagation", "sqrt_beliefpropagation.jl"))
include("contraction_tree_to_graph.jl")
include("gauging.jl")
Expand Down
86 changes: 64 additions & 22 deletions src/beliefpropagation/beliefpropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ function message_tensors_skeleton(subgraphs::DataGraph)
end

function message_tensors(
subgraphs::DataGraph; itensor_constructor=inds_e -> dense(delta(inds_e))
subgraphs::DataGraph;
itensor_constructor=inds_e -> ITensor[dense(delta(i)) for i in inds_e],
)
mts = message_tensors_skeleton(subgraphs)
for e in edges(subgraphs)
inds_e = commoninds(subgraphs[src(e)], subgraphs[dst(e)])
mts[e] = ITensorNetwork(map(itensor_constructor, inds_e))
itensors = itensor_constructor(inds_e)
mts[e] = ITensorNetwork(itensors)
mts[reverse(e)] = dag(mts[e])
end
return mts
Expand Down Expand Up @@ -74,24 +76,24 @@ function update_message_tensor(
end

"""
Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
Do a sequential update of message tensors on `edges` for a given ITensornetwork and its partition into sub graphs
"""
function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph;
mts::DataGraph,
edges::Vector{<:AbstractEdge};
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
)
new_mts = copy(mts)
c = 0
es = edges(mts)
for e in es
for e in edges
environment_tensornetworks = ITensorNetwork[
mts[e_in] for e_in in setdiff(boundary_edges(mts, [src(e)]; dir=:in), [reverse(e)])
new_mts[e_in] for
e_in in setdiff(boundary_edges(new_mts, [src(e)]; dir=:in), [reverse(e)])
]

new_mts[src(e) => dst(e)] = update_message_tensor(
tn, mts[src(e)], environment_tensornetworks; contract_kwargs
tn, new_mts[src(e)], environment_tensornetworks; contract_kwargs
)

if compute_norm
Expand All @@ -102,25 +104,64 @@ function belief_propagation_iteration(
c += 0.5 * norm(denseblocks(LHS) - denseblocks(RHS))
end
end
return new_mts, c / (length(es))
return new_mts, c / (length(edges))
end

"""
Do parallel updates between groups of edges of all message tensors for a given ITensornetwork and its partition into sub graphs.
Currently we send the full message tensor data struct to belief_propagation_iteration for each subgraph. But really we only need the
mts relevant to that subgraph.
"""
function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph,
edge_groups::Vector{<:Vector{<:AbstractEdge}};
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
)
new_mts = copy(mts)
c = 0
for edges in edge_groups
updated_mts, ct = belief_propagation_iteration(
tn, mts, edges; contract_kwargs, compute_norm
)
for e in edges
new_mts[e] = updated_mts[e]
end
c += ct
end
return new_mts, c / (length(edge_groups))
end

function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
edges=edge_sequence(mts),
)
return belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
end

function belief_propagation(
tn::ITensorNetwork,
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
niters=20,
target_precision::Union{Float64,Nothing}=nothing,
niters=default_bp_niters(mts),
target_precision=nothing,
edges=edge_sequence(mts),
verbose=false,
)
compute_norm = target_precision == nothing ? false : true
compute_norm = !isnothing(target_precision)
if isnothing(niters)
error("You need to specify a number of iterations for BP!")
end
for i in 1:niters
mts, c = belief_propagation_iteration(tn, mts; contract_kwargs, compute_norm)
mts, c = belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
if compute_norm && c <= target_precision
println(
"Belief Propagation finished. Reached a canonicalness of " *
string(c) *
" after $i iterations. ",
)
if verbose
println("BP converged to desired precision after $i iterations.")
end
break
end
end
Expand All @@ -133,11 +174,12 @@ function belief_propagation(
nvertices_per_partition=nothing,
npartitions=nothing,
subgraph_vertices=nothing,
niters=20,
target_precision::Union{Float64,Nothing}=nothing,
niters=default_bp_niters(mts),
target_precision=nothing,
verbose=false,
)
mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices)
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision)
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision, verbose)
end

"""
Expand Down
Loading

0 comments on commit 6c1a0f6

Please sign in to comment.