Skip to content

Commit

Permalink
Convert argument of log to complex for proper loginner via BP.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Kloss committed Mar 19, 2024
1 parent 7ee14bd commit 438954a
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,17 @@ end

function contract_with_BP(
itn::AbstractITensorNetwork; outputlevel=1, partitioning=group(v -> v, vertices(itn))
)
return contract_with_BP(
ComplexF64, itn::AbstractITensorNetwork; outputlevel, partitioning
)
end

function contract_with_BP(
T::Type,
itn::AbstractITensorNetwork;
outputlevel=1,
partitioning=group(v -> v, vertices(itn)),
)
@assert isempty(externalinds(itn))
bp_cache = BeliefPropagationCache(copy(itn), partitioning)
Expand All @@ -739,15 +750,17 @@ function contract_with_BP(
for pv in partitionvertices(pg)
incoming_mts = incoming_messages(bp_cache, [pv])
local_state = ITensor[itn[v] for v in vertices(pg, pv)]
log_numerator += log(ITensors.contract(vcat(incoming_mts, local_state))[])
log_numerator += log(complex(ITensors.contract(vcat(incoming_mts, local_state))[]))
end
for pe in partitionedges(pg)
log_denominator += log(
ITensors.contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
complex(
ITensors.contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
),
)
end

return exp(log_numerator - log_denominator)
res = exp(log_numerator - log_denominator)
return T(res)
end

# TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,
Expand Down

0 comments on commit 438954a

Please sign in to comment.