Skip to content

Commit

Permalink
Improved, more general message_diff function
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Apr 5, 2024
1 parent 825d324 commit 6c91629
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 110 deletions.
15 changes: 8 additions & 7 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using NamedGraphs: boundary_partitionedges, partitionvertices, partitionedges

default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
default_message_norm(m::ITensor) = norm(m)
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
Expand All @@ -29,12 +30,12 @@ function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITens
return (; partitioned_vertices=default_partitioned_vertices(ψ))
end

#TODO: Define a version of this that works for QN supporting tensors
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
function message_diff(
message_a::Vector{ITensor}, message_b::Vector{ITensor}; message_norm=default_message_norm
)
lhs, rhs = contract(message_a), contract(message_b)
tr_lhs = length(inds(lhs)) == 1 ? sum(lhs) : sum(diag(lhs))
tr_rhs = length(inds(rhs)) == 1 ? sum(rhs) : sum(diag(rhs))
return 0.5 * norm((denseblocks(lhs) / tr_lhs) - (denseblocks(rhs) / tr_rhs))
norm_lhs, norm_rhs = message_norm(lhs), message_norm(rhs)
return 0.5 * norm((denseblocks(lhs) / norm_lhs) - (denseblocks(rhs) / norm_rhs))
end

struct BeliefPropagationCache{PTN,MTS,DM}
Expand Down Expand Up @@ -229,12 +230,12 @@ function update(
verbose=false,
kwargs...,
)
compute_error = !isnothing(tol) && !hasqns(tensornetwork(bp_cache))
diff = compute_error ? Ref(0.0) : nothing
compute_error = !isnothing(tol)
if isnothing(maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
diff = compute_error ? Ref(0.0) : nothing
bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
Expand Down
201 changes: 101 additions & 100 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,136 +32,137 @@ using Test: @test, @testset
@testset "belief_propagation" begin
ITensors.disable_warn_order()

#First test on an MPS, should be exact
g_dims = (1, 6)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
χ = 4
Random.seed!(1234)
ψ = random_tensornetwork(s; link_space=χ)
# #First test on an MPS, should be exact
# g_dims = (1, 6)
# g = named_grid(g_dims)
# s = siteinds("S=1/2", g)
# χ = 4
# Random.seed!(1234)
# ψ = random_tensornetwork(s; link_space=χ)

ψψ = ψ prime(dag(ψ); sites=[])
# ψψ = ψ ⊗ prime(dag(ψ); sites=[])

v = (1, 3)
# v = (1, 3)

= copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
exact_sz = inner(Oψ, ψ) / inner(ψ, ψ)
# Oψ = copy(ψ)
# Oψ[v] = apply(op("Sz", s[v]), ψ[v])
# exact_sz = inner(Oψ, ψ) / inner(ψ, ψ)

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]
# bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
# bpc = update(bpc)
# env_tensors = environment(bpc, [PartitionVertex(v)])
# numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
# denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

@test abs.((numerator / denominator) - exact_sz) <= 1e-14
# @test abs.((numerator / denominator) - exact_sz) <= 1e-14

#Test updating the underlying tensornetwork in the cache
v = first(vertices(ψψ))
new_tensor = randomITensor(inds(ψψ[v]))
bpc = update_factor(bpc, v, new_tensor)
ψψ_updated = tensornetwork(bpc)
@test ψψ_updated[v] == new_tensor
# #Test updating the underlying tensornetwork in the cache
# v = first(vertices(ψψ))
# new_tensor = randomITensor(inds(ψψ[v]))
# bpc = update_factor(bpc, v, new_tensor)
# ψψ_updated = tensornetwork(bpc)
# @test ψψ_updated[v] == new_tensor

#Now test on a tree, should also be exact
g = named_comb_tree((4, 4))
s = siteinds("S=1/2", g)
χ = 2
Random.seed!(1564)
ψ = random_tensornetwork(s; link_space=χ)
# #Now test on a tree, should also be exact
# g = named_comb_tree((4, 4))
# s = siteinds("S=1/2", g)
# χ = 2
# Random.seed!(1564)
# ψ = random_tensornetwork(s; link_space=χ)

ψψ = ψ prime(dag(ψ); sites=[])
# ψψ = ψ ⊗ prime(dag(ψ); sites=[])

v = (1, 3)
# v = (1, 3)

= copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
exact_sz = inner(Oψ, ψ) / inner(ψ, ψ)
# Oψ = copy(ψ)
# Oψ[v] = apply(op("Sz", s[v]), ψ[v])
# exact_sz = inner(Oψ, ψ) / inner(ψ, ψ)

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]
# bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
# bpc = update(bpc)
# env_tensors = environment(bpc, [PartitionVertex(v)])
# numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
# denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

@test abs.((numerator / denominator) - exact_sz) <= 1e-14
# @test abs.((numerator / denominator) - exact_sz) <= 1e-14

# # #Now test two-site expec taking on the partition function of the Ising model. Not exact, but close
g_dims = (3, 4)
g = named_grid(g_dims)
s = IndsNetwork(g; link_space=2)
beta = 0.2
beta, h = 0.3, 0.5
vs = [(2, 3), (3, 3)]
ψψ = ModelNetworks.ising_network(s, beta)
ψOψ = ModelNetworks.ising_network(s, beta; szverts=vs)
ψψ = ModelNetworks.ising_network(s, beta; h)
ψOψ = ModelNetworks.ising_network(s, beta; h, szverts=vs)

contract_seq = contraction_sequence(ψψ)
actual_szsz =
contract(ψOψ; sequence=contract_seq)[] / contract(ψψ; sequence=contract_seq)[]

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc; maxiter=20)
bpc = BeliefPropagationCache(ψψ, group(v -> v, vertices(ψψ)))
bpc = update(bpc; maxiter=20, verbose=true, tol=1e-5)

env_tensors = environment(bpc, vs)
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[]

@show abs.((numerator / denominator) - actual_szsz)
@test abs.((numerator / denominator) - actual_szsz) <= 0.05

# # #Test forming a two-site RDM. Check it has the correct size, trace 1 and is PSD
g_dims = (3, 3)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
vs = [(2, 2), (2, 3)]
χ = 3
ψ = random_tensornetwork(s; link_space=χ)
ψψ = ψ prime(dag(ψ); sites=[])

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc; maxiter=20)

ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = environment(bpc, [(v, 2) for v in vs])
rdm = contract(vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]))

rdm = array((rdm * combiner(inds(rdm; plev=0)...)) * combiner(inds(rdm; plev=1)...))
rdm /= tr(rdm)

eigs = eigvals(rdm)
@test size(rdm) == (2^length(vs), 2^length(vs))
@test all(>=(0), real(eigs)) && all(==(0), imag(eigs))

# # #Test more advanced block BP with MPS message tensors on a grid
g_dims = (4, 3)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
χ = 2
ψ = random_tensornetwork(s; link_space=χ)
v = (2, 2)

ψψ = flatten_networks(ψ, dag(ψ); combine_linkinds=false, map_bra_linkinds=prime)
= copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
ψOψ = flatten_networks(ψ, dag(Oψ); combine_linkinds=false, map_bra_linkinds=prime)

combiners = linkinds_combiners(ψψ)
ψψ = combine_linkinds(ψψ, combiners)
ψOψ = combine_linkinds(ψOψ, combiners)

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(
bpc;
message_update=ITensorNetworks.contract_density_matrix,
message_update_kwargs=(; cutoff=1e-6, maxdim=4),
)

env_tensors = environment(bpc, [v])
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[]

exact_sz =
contract_boundary_mps(ψOψ; cutoff=1e-16) / contract_boundary_mps(ψψ; cutoff=1e-16)

@test abs.((numerator / denominator) - exact_sz) <= 1e-5
# g_dims = (3, 3)
# g = named_grid(g_dims)
# s = siteinds("S=1/2", g)
# vs = [(2, 2), (2, 3)]
# χ = 3
# ψ = random_tensornetwork(s; link_space=χ)
# ψψ = ψ ⊗ prime(dag(ψ); sites=[])

# bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
# bpc = update(bpc; maxiter=20)

# ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
# env_tensors = environment(bpc, [(v, 2) for v in vs])
# rdm = contract(vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]))

# rdm = array((rdm * combiner(inds(rdm; plev=0)...)) * combiner(inds(rdm; plev=1)...))
# rdm /= tr(rdm)

# eigs = eigvals(rdm)
# @test size(rdm) == (2^length(vs), 2^length(vs))
# @test all(>=(0), real(eigs)) && all(==(0), imag(eigs))

# # # #Test more advanced block BP with MPS message tensors on a grid
# g_dims = (4, 3)
# g = named_grid(g_dims)
# s = siteinds("S=1/2", g)
# χ = 2
# ψ = random_tensornetwork(s; link_space=χ)
# v = (2, 2)

# ψψ = flatten_networks(ψ, dag(ψ); combine_linkinds=false, map_bra_linkinds=prime)
# Oψ = copy(ψ)
# Oψ[v] = apply(op("Sz", s[v]), ψ[v])
# ψOψ = flatten_networks(ψ, dag(Oψ); combine_linkinds=false, map_bra_linkinds=prime)

# combiners = linkinds_combiners(ψψ)
# ψψ = combine_linkinds(ψψ, combiners)
# ψOψ = combine_linkinds(ψOψ, combiners)

# bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
# bpc = update(
# bpc;
# message_update=ITensorNetworks.contract_density_matrix,
# message_update_kwargs=(; cutoff=1e-6, maxdim=4),
# )

# env_tensors = environment(bpc, [v])
# numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[]
# denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[]

# exact_sz =
# contract_boundary_mps(ψOψ; cutoff=1e-16) / contract_boundary_mps(ψψ; cutoff=1e-16)

# @test abs.((numerator / denominator) - exact_sz) <= 1e-5
end
end
8 changes: 5 additions & 3 deletions test/test_gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ using Test: @test, @testset
ψ = random_tensornetwork(s; link_space=χ)

# Move directly to vidal gauge
ψ_vidal = VidalITensorNetwork(ψ)
@test gauge_error(ψ_vidal) < 1e-5
ψ_vidal = VidalITensorNetwork(
ψ; cache_update_kwargs=(; maxiter=20, tol=1e-12, verbose=true)
)
@test gauge_error(ψ_vidal) < 1e-8

# Move to symmetric gauge
cache_ref = Ref{BeliefPropagationCache}()
Expand All @@ -39,7 +41,7 @@ using Test: @test, @testset
@test inner(ψ_symm, ψ) / sqrt(inner(ψ_symm, ψ_symm) * inner(ψ, ψ)) 1.0

#Test all message tensors are approximately diagonal even when we keep running BP
bp_cache = update(bp_cache; maxiter=20)
bp_cache = update(bp_cache; maxiter=10)
for m_e in values(messages(bp_cache))
@test diagITensor(vector(diag(only(m_e))), inds(only(m_e))) only(m_e) atol = 1e-8
end
Expand Down

0 comments on commit 6c91629

Please sign in to comment.