Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Environments #145

Merged
merged 11 commits into from
Mar 21, 2024
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ include(joinpath("treetensornetworks", "solvers", "contract.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include("apply.jl")
include("environment.jl")

include("exports.jl")

Expand Down
29 changes: 19 additions & 10 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
return contract_exact(contract_list; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
updated_messages /= norm(updated_messages)
return ITensor[updated_messages]
end
default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="optimal")
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
return 0.5 *
Expand All @@ -27,11 +32,15 @@ function BeliefPropagationCache(
return BeliefPropagationCache(ptn, messages, default_message)
end

function BeliefPropagationCache(tn::ITensorNetwork, partitioned_vertices; kwargs...)
function BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
ptn = PartitionedGraph(tn, partitioned_vertices)
return BeliefPropagationCache(ptn; kwargs...)
end

function BeliefPropagationCache(tn; kwargs...)
return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...)
end

function partitioned_itensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_itensornetwork
end
Expand Down Expand Up @@ -92,7 +101,7 @@ function set_messages(cache::BeliefPropagationCache, messages)
)
end

function incoming_messages(
function environment(
bp_cache::BeliefPropagationCache,
partition_vertices::Vector{<:PartitionVertex};
ignore_edges=PartitionEdge[],
Expand All @@ -102,15 +111,15 @@ function incoming_messages(
return reduce(vcat, ms; init=[])
end

function incoming_messages(
function environment(
bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs...
)
return incoming_messages(bp_cache, [partition_vertex]; kwargs...)
return environment(bp_cache, [partition_vertex]; kwargs...)
end

function incoming_messages(bp_cache::BeliefPropagationCache, verts::Vector)
function environment(bp_cache::BeliefPropagationCache, verts::Vector)
partition_verts = partitionvertices(bp_cache, verts)
messages = incoming_messages(bp_cache, partition_verts)
messages = environment(bp_cache, partition_verts)
central_tensors = ITensor[
tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts)
]
Expand All @@ -129,10 +138,10 @@ function update_message(
bp_cache::BeliefPropagationCache,
edge::PartitionEdge;
message_update=default_message_update,
message_update_kwargs=default_message_update_kwargs(),
message_update_kwargs=(;),
)
vertex = src(edge)
messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bp_cache, vertex)

return message_update(ITensor[messages; state]; message_update_kwargs...)
Expand Down
14 changes: 0 additions & 14 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,3 @@ function contract_density_matrix(
end
return out
end

function contract_exact(
contract_list::Vector{ITensor};
contraction_sequence_alg="optimal",
normalize=true,
contractor_kwargs...,
)
seq = contraction_sequence(contract_list; alg=contraction_sequence_alg)
out = ITensors.contract(contract_list; sequence=seq, contractor_kwargs...)
if normalize
normalize!(out)
end
return ITensor[out]
end
42 changes: 42 additions & 0 deletions src/environment.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
default_environment_algorithm() = "exact"

function environment(
ψ::AbstractITensorNetwork,
vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
return environment(Algorithm(alg), ψ, vertices; kwargs...)
end

function environment(
::Algorithm"exact",
ψ::AbstractITensorNetwork,
vertices::Vector;
contraction_sequence_alg="optimal",
kwargs...,
)
ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices))
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg)
return ITensor[contract(ψ_reduced; sequence, kwargs...)]
end

function environment(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
verts::Vector;
(cache!)=nothing,
partitioned_vertices=default_partitioned_vertices(ψ),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(ψ, partitioned_vertices))
end

if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end

return environment(cache![], setdiff(vertices(ψ), verts))
end
44 changes: 27 additions & 17 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ function ket_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f))
end

function bra_ket_vertices(f::AbstractFormNetwork)
return vcat(bra_vertices(f), ket_vertices(f))
function bra_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [bra_vertex_map(f)(osv) for osv in original_state_vertices]
end

function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [bra_vertex_map(f)(sv) for sv in state_vertices]
function ket_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [ket_vertex_map(f)(osv) for osv in original_state_vertices]
end

function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [ket_vertex_map(f)(sv) for sv in state_vertices]
function state_vertices(f::AbstractFormNetwork)
return vcat(bra_vertices(f), ket_vertices(f))
end

function bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
function state_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return vcat(
bra_vertices(f, original_state_vertices), ket_vertices(f, original_state_vertices)
)
end

function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector)
Expand All @@ -57,18 +59,26 @@ function operator_network(f::AbstractFormNetwork)
)
end

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
return setdiff(
vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
)
function environment(
f::AbstractFormNetwork,
original_state_vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
form_vertices = setdiff(vertices(f), state_vertices(f, original_state_vertices))
if alg == "bp"
partitions = group(v -> original_state_vertex(f, v), vertices(f))
return environment(tensornetwork(f), form_vertices; alg, partitions, kwargs...)
else
return environment(tensornetwork(f), form_vertices; alg, kwargs...)
end
end

operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f))
bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f))
ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f))
inv_vertex_map(f::AbstractFormNetwork) = v -> first(v)
operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v)
bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v)
ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v)
original_state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v)
10 changes: 7 additions & 3 deletions src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ function BilinearFormNetwork(
end

function update(
blf::BilinearFormNetwork, state_vertex, bra_state::ITensor, ket_state::ITensor
blf::BilinearFormNetwork, original_state_vertex, bra_state::ITensor, ket_state::ITensor
)
blf = copy(blf)
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(
tensornetwork(blf), bra_state, bra_vertex(blf, original_state_vertex)
)
setindex_preserve_graph!(
tensornetwork(blf), ket_state, ket_vertex(blf, original_state_vertex)
)
return blf
end
4 changes: 2 additions & 2 deletions src/formnetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ function QuadraticFormNetwork(
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

function update(qf::QuadraticFormNetwork, state_vertex, ket_state::ITensor)
function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor)
state_inds = inds(ket_state)
bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds))
new_blf = update(bilinear_formnetwork(qf), state_vertex, bra_state, ket_state)
new_blf = update(bilinear_formnetwork(qf), original_state_vertex, bra_state, ket_state)
return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf))
end
1 change: 0 additions & 1 deletion src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ function default_norm_cache(ψ::ITensorNetwork)
ψψ = norm_network(ψ)
return BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function ITensorNetwork(
ψ_vidal::VidalITensorNetwork; (cache!)=nothing, update_gauge=false, update_kwargs...
Expand Down
6 changes: 3 additions & 3 deletions test/test_apply.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ITensorNetworks
using ITensorNetworks:
incoming_messages,
environment,
update,
contract_inner,
norm_network,
Expand Down Expand Up @@ -29,14 +29,14 @@ using SplitApplyCombine
#Simple Belief Propagation Grouping
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter=20)
envsSBP = incoming_messages(bp_cache, PartitionVertex.([v1, v2]))
envsSBP = environment(bp_cache, PartitionVertex.([v1, v2]))

ψv = VidalITensorNetwork(ψ)

#This grouping will correspond to calculating the environments exactly (each column of the grid is a partition)
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter=20)
envsGBP = incoming_messages(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])
envsGBP = environment(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])

ngates = 5

Expand Down
13 changes: 7 additions & 6 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
incoming_messages
environment,
contract
using Test
using Compat
using ITensors
Expand Down Expand Up @@ -40,7 +41,7 @@ ITensors.disable_warn_order()

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = incoming_messages(bpc, [PartitionVertex(v)])
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

Expand Down Expand Up @@ -70,7 +71,7 @@ ITensors.disable_warn_order()

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = incoming_messages(bpc, [PartitionVertex(v)])
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

Expand All @@ -93,7 +94,7 @@ ITensors.disable_warn_order()
bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc; maxiter=20)

env_tensors = incoming_messages(bpc, vs)
env_tensors = environment(bpc, vs)
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[]

Expand All @@ -112,7 +113,7 @@ ITensors.disable_warn_order()
bpc = update(bpc; maxiter=20)

ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = incoming_messages(bpc, [(v, 2) for v in vs])
env_tensors = environment(bpc, [(v, 2) for v in vs])
rdm = ITensors.contract(
vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]])
)
Expand Down Expand Up @@ -148,7 +149,7 @@ ITensors.disable_warn_order()
message_update_kwargs=(; cutoff=1e-6, maxdim=4),
)

env_tensors = incoming_messages(bpc, [v])
env_tensors = environment(bpc, [v])
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[]

Expand Down
31 changes: 24 additions & 7 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
using ITensors
using Graphs
using Graphs: nv
using NamedGraphs
using ITensorNetworks
using ITensorNetworks:
delta_network,
update,
tensornetwork,
bra_vertex_map,
ket_vertex_map,
bra_vertex,
ket_vertex,
dual_index_map,
bra_network,
ket_network,
operator_network
operator_network,
environment,
BeliefPropagationCache
using Test
using Random
using SplitApplyCombine

@testset "FormNetworkss" begin
@testset "FormNetworks" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
Expand All @@ -42,10 +45,24 @@ using Random
new_tensor = randomITensor(inds(ψket[v]))
qf_updated = update(qf, v, copy(new_tensor))

@test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)] ≈
@test tensornetwork(qf_updated)[bra_vertex(qf_updated, v)] ≈
dual_index_map(qf_updated)(dag(new_tensor))
@test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] ≈ new_tensor
@test tensornetwork(qf_updated)[ket_vertex(qf_updated, v)] ≈ new_tensor

@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(environment(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf)

∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
∂qf_∂v /= norm(∂qf_∂v)
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
@test ∂qf_∂v_bp ≈ ∂qf_∂v
end
Loading