Skip to content

Commit

Permalink
Better specification of update sequence for BP
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Oct 26, 2023
1 parent dfec1e2 commit 1e6ec2b
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 106 deletions.
79 changes: 63 additions & 16 deletions examples/belief_propagation/bpsequences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ using ITensorNetworks:
nested_graph_leaf_vertices

function main()

g = named_comb_tree((6,6))
g = named_comb_tree((6, 6))
s = siteinds("S=1/2", g)
χ = 4

Expand All @@ -29,17 +28,33 @@ function main()
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

println("First testing out a comb tree. Random network with bond dim ")
println("\nFirst testing out a comb tree. Random network with bond dim ")

#Now test out various sequences
print("Parallel updates (sequence is irrelevant): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "parallel")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential", edges = edges(mts_init))
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential")
belief_propagation(
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
)

g = named_grid((6,6))
g = named_grid((6, 6))
s = siteinds("S=1/2", g)
χ = 2

Expand All @@ -53,17 +68,33 @@ function main()
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

println("Now testing out a 6x6 grid. Random network with bond dim ")
println("\nNow testing out a 6x6 grid. Random network with bond dim ")

#Now test out various sequences
print("Parallel updates (sequence is irrelevant): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "parallel")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential", edges = edges(mts_init))
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential")
belief_propagation(
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
)

g = NamedGraphs.hexagonal_lattice_graph(4,4)
g = NamedGraphs.hexagonal_lattice_graph(4, 4)
s = siteinds("S=1/2", g)
χ = 3

Expand All @@ -77,15 +108,31 @@ function main()
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

println("Now testing out a 4 x 4 hexagonal lattice. Random network with bond dim ")
println("\nNow testing out a 4 x 4 hexagonal lattice. Random network with bond dim ")

#Now test out various sequences
print("Parallel updates (sequence is irrelevant): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "parallel")
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[[e] for e in edges(mts_init)],
)
print("Sequential updates (sequence is default edge list of the message tensors): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential", edges = edges(mts_init))
belief_propagation(
ψψ,
mts_init;
contract_kwargs=(; alg="exact"),
target_precision=1e-10,
niters=100,
edges=[e for e in edges(mts_init)],
)
print("Sequential updates (sequence is our custom sequence finder): ")
belief_propagation(ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision = 1e-10, niters = 100, update_sequence = "sequential")
return belief_propagation(
ψψ, mts_init; contract_kwargs=(; alg="exact"), target_precision=1e-10, niters=100
)
end

main()
114 changes: 89 additions & 25 deletions src/beliefpropagation/beliefpropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,24 @@ function update_message_tensor(
end

"""
Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
Do a sequential update of message tensors on `edges` for a given ITensornetwork and its partition into sub graphs
"""
function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph;
mts::DataGraph,
edges::Vector{E};
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
update_sequence::String="sequential",
edges = edge_update_order(undirected_graph(underlying_graph(mts))),
)
) where {E<:NamedEdge}
new_mts = copy(mts)
if update_sequence != "parallel" && update_sequence != "sequential"
error(
"Specified update order is not currently implemented. Choose parallel or sequential."
)
end
incoming_mts = update_sequence == "parallel" ? mts : new_mts
c = 0
for e in edges
environment_tensornetworks = ITensorNetwork[
incoming_mts[e_in] for
e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)])
new_mts[e_in] for
e_in in setdiff(boundary_edges(new_mts, [src(e)]; dir=:in), [reverse(e)])
]
new_mts[src(e) => dst(e)] = update_message_tensor(
tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs
tn, new_mts[src(e)], environment_tensornetworks; contract_kwargs
)

if compute_norm
Expand All @@ -113,24 +106,96 @@ function belief_propagation_iteration(
return new_mts, c / (length(edges))
end

"""
Do parallel updates between groups of edges of all message tensors for a given ITensornetwork and its partition into sub graphs
"""
function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph,
edge_groups::Vector{Vector{E}};
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
) where {E<:NamedEdge}
new_mts = copy(mts)
c = 0
for edges in edge_groups
updated_mts, ct = belief_propagation_iteration(
tn, mts, edges; contract_kwargs, compute_norm
)
for e in edges
new_mts[e] = updated_mts[e]
end
c += ct
end
return new_mts, c / (length(edge_groups))
end

function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
compute_norm=false,
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
undirected_graph(underlying_graph(mts))
),
) where {E<:NamedEdge}
return belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
end

# """
# Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
# """
# function belief_propagation_iteration(
# tn::ITensorNetwork,
# mts::DataGraph;
# contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
# compute_norm=false,
# update_sequence::String="sequential",
# edges::Vector{Vector{}} = edge_update_order(undirected_graph(underlying_graph(mts))),
# )
# new_mts = copy(mts)
# if update_sequence != "parallel" && update_sequence != "sequential"
# error(
# "Specified update order is not currently implemented. Choose parallel or sequential."
# )
# end
# incoming_mts = update_sequence == "parallel" ? mts : new_mts
# c = 0
# for e in edges
# environment_tensornetworks = ITensorNetwork[
# incoming_mts[e_in] for
# e_in in setdiff(boundary_edges(incoming_mts, [src(e)]; dir=:in), [reverse(e)])
# ]
# new_mts[src(e) => dst(e)] = update_message_tensor(
# tn, incoming_mts[src(e)], environment_tensornetworks; contract_kwargs
# )

# if compute_norm
# LHS, RHS = ITensors.contract(ITensor(mts[src(e) => dst(e)])),
# ITensors.contract(ITensor(new_mts[src(e) => dst(e)]))
# LHS /= sum(diag(LHS))
# RHS /= sum(diag(RHS))
# c += 0.5 * norm(denseblocks(LHS) - denseblocks(RHS))
# end
# end
# return new_mts, c / (length(edges))
# end

function belief_propagation(
tn::ITensorNetwork,
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
niters=20,
target_precision::Union{Float64,Nothing}=nothing,
update_sequence::String="sequential",
edges = edge_update_order(undirected_graph(underlying_graph(mts)))
)
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
undirected_graph(underlying_graph(mts))
),
) where {E<:NamedEdge}
compute_norm = target_precision == nothing ? false : true
for i in 1:niters
mts, c = belief_propagation_iteration(
tn, mts; contract_kwargs, compute_norm, update_sequence, edges
)
mts, c = belief_propagation_iteration(tn, mts, edges; contract_kwargs, compute_norm)
if compute_norm && c <= target_precision
println(
"BP converged to desired precision after $i iterations.",
)
println("BP converged to desired precision after $i iterations.")
break
end
end
Expand All @@ -144,11 +209,10 @@ function belief_propagation(
npartitions=nothing,
subgraph_vertices=nothing,
niters=20,
update_sequence::String="sequential",
target_precision::Union{Float64,Nothing}=nothing,
)
mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices)
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision, update_sequence)
return belief_propagation(tn, mts; contract_kwargs, niters, target_precision)
end

"""
Expand Down
94 changes: 55 additions & 39 deletions src/beliefpropagation/sqrt_beliefpropagation.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,19 @@
# using ITensors: scalartype
# using ITensorNetworks: find_subgraph, map_diag, sqrt_diag, boundary_edges

function sqrt_belief_propagation(
tn::ITensorNetwork,
mts::DataGraph;
niters=20,
update_sequence::String="sequential",
# target_precision::Union{Float64,Nothing}=nothing,
)
# compute_norm = target_precision == nothing ? false : true
sqrt_mts = sqrt_message_tensors(tn, mts)
for i in 1:niters
sqrt_mts, c = sqrt_belief_propagation_iteration(tn, sqrt_mts; update_sequence) #; compute_norm)
# if compute_norm && c <= target_precision
# println(
# "Belief Propagation finished. Reached a canonicalness of " *
# string(c) *
# " after $i iterations. ",
# )
# break
# end
end
return sqr_message_tensors(sqrt_mts)
end

function sqrt_belief_propagation_iteration(
tn::ITensorNetwork,
sqrt_mts::DataGraph;
update_sequence::String="sequential",
edges=edge_update_order(undirected_graph(underlying_graph(mts))),

# compute_norm=false,
)
tn::ITensorNetwork, sqrt_mts::DataGraph, edges::Vector{E}
) where {E<:NamedEdge}
new_sqrt_mts = copy(sqrt_mts)
if update_sequence != "parallel" && update_sequence != "sequential"
error(
"Specified update order is not currently implemented. Choose parallel or sequential."
)
end
incoming_sqrt_mts = update_sequence == "parallel" ? sqrt_mts : new_sqrt_mts
c = 0.0
for e in edges
environment_tensornetworks = ITensorNetwork[
incoming_sqrt_mts[e_in] for
e_in in setdiff(boundary_edges(incoming_sqrt_mts, [src(e)]; dir=:in), [reverse(e)])
new_sqrt_mts[e_in] for
e_in in setdiff(boundary_edges(new_sqrt_mts, [src(e)]; dir=:in), [reverse(e)])
]

new_sqrt_mts[src(e) => dst(e)] = update_sqrt_message_tensor(
tn, incoming_sqrt_mts[src(e)], environment_tensornetworks;
tn, new_sqrt_mts[src(e)], environment_tensornetworks;
)

# if compute_norm
Expand All @@ -61,6 +27,56 @@ function sqrt_belief_propagation_iteration(
return new_sqrt_mts, c / (length(edges))
end

function sqrt_belief_propagation_iteration(
tn::ITensorNetwork, sqrt_mts::DataGraph, edges::Vector{Vector{E}}
) where {E<:NamedEdge}
new_sqrt_mts = copy(sqrt_mts)
c = 0.0
for e_group in edges
updated_sqrt_mts, ct = sqrt_belief_propagation_iteration(tn, sqr_mts, e_group)
for e in e_group
new_sqrt_mts[e] = updated_sqrt_mts[e]
end
c += ct
end
return new_sqrt_mts, c / (length(edges))
end

function sqrt_belief_propagation_iteration(
tn::ITensorNetwork,
sqrt_mts::DataGraph;
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
undirected_graph(underlying_graph(mts))
),
) where {E<:NamedEdge}
return sqrt_belief_propagation_iteration(tn, sqrt_mts, edges)
end

function sqrt_belief_propagation(
tn::ITensorNetwork,
mts::DataGraph;
niters=20,
edges::Union{Vector{Vector{E}},Vector{E}}=edge_update_order(
undirected_graph(underlying_graph(mts))
),
# target_precision::Union{Float64,Nothing}=nothing,
) where {E<:NamedEdge}
# compute_norm = target_precision == nothing ? false : true
sqrt_mts = sqrt_message_tensors(tn, mts)
for i in 1:niters
sqrt_mts, c = sqrt_belief_propagation_iteration(tn, sqrt_mts, edges) #; compute_norm)
# if compute_norm && c <= target_precision
# println(
# "Belief Propagation finished. Reached a canonicalness of " *
# string(c) *
# " after $i iterations. ",
# )
# break
# end
end
return sqr_message_tensors(sqrt_mts)
end

function update_sqrt_message_tensor(
tn::ITensorNetwork, subgraph_vertices::Vector, sqrt_mts::Vector{ITensorNetwork};
)
Expand Down
Loading

0 comments on commit 1e6ec2b

Please sign in to comment.