Skip to content

Commit

Permalink
Make contract calls for BP do sequence finding by default
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed May 6, 2024
1 parent 4fea5e2 commit ea15404
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
29 changes: 21 additions & 8 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,27 +267,40 @@ function update_factor(bp_cache, vertex, factor)
return update_factors(bp_cache, Dictionary([vertex], [factor]))
end

function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
function region_scalar(
bp_cache::BeliefPropagationCache,
pv::PartitionVertex;
contract_kwargs=(; sequence="automatic"),
)
incoming_mts = environment(bp_cache, [pv])
local_state = factor(bp_cache, pv)
return contract(vcat(incoming_mts, local_state))[]
return contract(vcat(incoming_mts, local_state); contract_kwargs...)[]
end

function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
return contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
function region_scalar(
bp_cache::BeliefPropagationCache,
pe::PartitionEdge;
contract_kwargs=(; sequence="automatic"),
)
return contract(
vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))); contract_kwargs...
)[]
end

function vertex_scalars(
bp_cache::BeliefPropagationCache,
pvs=partitionvertices(partitioned_tensornetwork(bp_cache)),
pvs=partitionvertices(partitioned_tensornetwork(bp_cache));
kwargs...,
)
return map(pv -> region_scalar(bp_cache, pv), pvs)
return map(pv -> region_scalar(bp_cache, pv; kwargs...), pvs)
end

function edge_scalars(
bp_cache::BeliefPropagationCache, pes=partitionedges(partitioned_tensornetwork(bp_cache))
bp_cache::BeliefPropagationCache,
pes=partitionedges(partitioned_tensornetwork(bp_cache));
kwargs...,
)
return map(pe -> region_scalar(bp_cache, pe), pes)
return map(pe -> region_scalar(bp_cache, pe; kwargs...), pes)
end

function scalar_factors_quotient(bp_cache::BeliefPropagationCache)
Expand Down
4 changes: 3 additions & 1 deletion src/expect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ using ITensors.ITensorMPS: ITensorMPS, expect

default_expect_alg() = "bp"

function ITensorMPS.expect(ψIψ::AbstractFormNetwork, op::Op; contract_kwargs=(;), kwargs...)
function ITensorMPS.expect(
ψIψ::AbstractFormNetwork, op::Op; contract_kwargs=(; sequence="automatic"), kwargs...
)
v = only(op.sites)
ψIψ_v = ψIψ[operator_vertex(ψIψ, v)]
s = commonind(ψIψ[ket_vertex(ψIψ, v)], ψIψ_v)
Expand Down

0 comments on commit ea15404

Please sign in to comment.