diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index d7096de3..b930f36b 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -723,6 +723,17 @@ end function contract_with_BP( itn::AbstractITensorNetwork; outputlevel=1, partitioning=group(v -> v, vertices(itn)) +) + return contract_with_BP( + ComplexF64, itn::AbstractITensorNetwork; outputlevel, partitioning + ) +end + +function contract_with_BP( + T::Type, + itn::AbstractITensorNetwork; + outputlevel=1, + partitioning=group(v -> v, vertices(itn)), ) @assert isempty(externalinds(itn)) bp_cache = BeliefPropagationCache(copy(itn), partitioning) @@ -739,15 +750,17 @@ function contract_with_BP( for pv in partitionvertices(pg) incoming_mts = incoming_messages(bp_cache, [pv]) local_state = ITensor[itn[v] for v in vertices(pg, pv)] - log_numerator += log(ITensors.contract(vcat(incoming_mts, local_state))[]) + log_numerator += log(complex(ITensors.contract(vcat(incoming_mts, local_state))[])) end for pe in partitionedges(pg) log_denominator += log( - ITensors.contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[] + complex( + ITensors.contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[] + ), ) end - - return exp(log_numerator - log_denominator) + res = exp(log_numerator - log_denominator) + return T(res) end # TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,