diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 3da4ca67..d80b2644 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -106,7 +106,7 @@ end function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge) mts = messages(bp_cache) - return get(mts, edge, default_message(bp_cache, edge)) + return get(() -> default_message(bp_cache, edge), mts, edge) end function messages(bp_cache::BeliefPropagationCache, edges; kwargs...) return map(edge -> message(bp_cache, edge; kwargs...), edges) @@ -152,15 +152,16 @@ end function environment(bp_cache::BeliefPropagationCache, verts::Vector) partition_verts = partitionvertices(bp_cache, verts) messages = environment(bp_cache, partition_verts) - central_tensors = ITensor[ - tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts) - ] + central_tensors = factors(bp_cache, setdiff(vertices(bp_cache, partition_verts), verts)) return vcat(messages, central_tensors) end +function factors(bp_cache::BeliefPropagationCache, verts::Vector) + return ITensor[tensornetwork(bp_cache)[v] for v in verts] +end + function factor(bp_cache::BeliefPropagationCache, vertex::PartitionVertex) - ptn = partitioned_tensornetwork(bp_cache) - return collect(eachtensor(subgraph(ptn, vertex))) + return factors(bp_cache, vertices(bp_cache, vertex)) end """