Skip to content

Commit

Permalink
Refactor and bring down upstream changes
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Dec 10, 2024
1 parent 6a8d4b9 commit 2cb7f85
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
14 changes: 9 additions & 5 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,15 @@ function normalize_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:Part
mts = messages(bp_cache)
for pe in pes
me, mer = only(mts[pe]), only(mts[reverse(pe)])
set!(mts, pe, ITensor[me / norm(me)])
set!(mts, reverse(pe), ITensor[mer / norm(mer)])
n = region_scalar(bp_cache, pe)
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
me, mer = normalize(me), normalize(mer)
n = dot(me, mer)
if isreal(n) && n < 0
set!(mts, pe, ITensor[(sgn(n) / sqrt(abs(n))) * me])
set!(mts, reverse(pe), ITensor[(1 / sqrt(abs(n))) * mer])
else
set!(mts, pe, ITensor[(1 / sqrt(n)) * me])
set!(mts, reverse(pe), ITensor[(1 / sqrt(n)) * mer])
end
end
return bp_cache
end
Expand Down
33 changes: 15 additions & 18 deletions src/normalize.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
using LinearAlgebra

function rescale(tn::AbstractITensorNetwork, c::Number, vs=collect(vertices(tn)))
tn = copy(tn)
for v in vs
tn[v] *= c
end
return tn
end

function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return normalize(Algorithm(alg), tn; kwargs...)
end

function LinearAlgebra.normalize(alg::Algorithm"exact", tn::AbstractITensorNetwork)
norm_tn = norm_sqr_network(tn)
log_norm = logscalar(alg, norm_tn)
tn = copy(tn)
L = length(vertices(tn))
c = exp(log_norm / L)
for v in vertices(tn)
tn[v] = tn[v] / sqrt(c)
end
return tn
norm_tn = QuadraticFormNetwork(tn)
c = exp(logscalar(alg, norm_tn) / (2 * length(vertices(tn))))
return rescale(tn, 1 / c)
end

function LinearAlgebra.normalize(
Expand All @@ -40,17 +42,12 @@ function LinearAlgebra.normalize(
v_ket, v_bra = ket_vertex(norm_tn, v), bra_vertex(norm_tn, v)
pv = only(partitionvertices(cache![], [v_ket]))
vn = region_scalar(cache![], pv)
state = tn[v] / sqrt(vn)
state_dag = copy(dag(state))
state_dag = replaceinds(
state_dag, inds(state_dag), dual_index_map(norm_tn).(inds(state_dag))
)
set!(vertices_states, v_ket, state)
set!(vertices_states, v_bra, state_dag)
tn[v] = state
norm_tn = rescale(norm_tn, 1 / sqrt(vn), [v_ket, v_bra])
set!(vertices_states, v_ket, norm_tn[v_ket])
set!(vertices_states, v_bra, norm_tn[v_bra])
end

cache![] = update_factors(cache![], vertices_states)

return tn
return ket_network(norm_tn)
end

0 comments on commit 2cb7f85

Please sign in to comment.