diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 35f5702e..536db27c 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -129,6 +129,7 @@ include(joinpath("treetensornetworks", "solvers", "contract.jl")) include(joinpath("treetensornetworks", "solvers", "linsolve.jl")) include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl")) include("apply.jl") +include("environment.jl") include("exports.jl") diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index f5337784..fa9ea51e 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -1,13 +1,18 @@ default_message(inds_e) = ITensor[denseblocks(delta(inds_e))] default_messages(ptn::PartitionedGraph) = Dictionary() function default_message_update(contract_list::Vector{ITensor}; kwargs...) - return contract_exact(contract_list; kwargs...) + sequence = optimal_contraction_sequence(contract_list) + updated_messages = contract(contract_list; sequence, kwargs...) + updated_messages /= norm(updated_messages) + return ITensor[updated_messages] end -default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="optimal") @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing @traitfn function default_bp_maxiter(g::::IsDirected) return default_bp_maxiter(undirected_graph(underlying_graph(g))) end +default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) +default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) + function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) lhs, rhs = contract(message_a), contract(message_b) return 0.5 * @@ -27,11 +32,15 @@ function BeliefPropagationCache( return BeliefPropagationCache(ptn, messages, default_message) end -function BeliefPropagationCache(tn::ITensorNetwork, partitioned_vertices; kwargs...) +function BeliefPropagationCache(tn, partitioned_vertices; kwargs...) ptn = PartitionedGraph(tn, partitioned_vertices) return BeliefPropagationCache(ptn; kwargs...) end +function BeliefPropagationCache(tn; kwargs...) + return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...) +end + function partitioned_itensornetwork(bp_cache::BeliefPropagationCache) return bp_cache.partitioned_itensornetwork end @@ -92,7 +101,7 @@ function set_messages(cache::BeliefPropagationCache, messages) ) end -function incoming_messages( +function environment( bp_cache::BeliefPropagationCache, partition_vertices::Vector{<:PartitionVertex}; ignore_edges=PartitionEdge[], @@ -102,15 +111,15 @@ function incoming_messages( return reduce(vcat, ms; init=[]) end -function incoming_messages( +function environment( bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... ) - return incoming_messages(bp_cache, [partition_vertex]; kwargs...) + return environment(bp_cache, [partition_vertex]; kwargs...) end -function incoming_messages(bp_cache::BeliefPropagationCache, verts::Vector) +function environment(bp_cache::BeliefPropagationCache, verts::Vector) partition_verts = partitionvertices(bp_cache, verts) - messages = incoming_messages(bp_cache, partition_verts) + messages = environment(bp_cache, partition_verts) central_tensors = ITensor[ tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts) ] @@ -129,10 +138,10 @@ function update_message( bp_cache::BeliefPropagationCache, edge::PartitionEdge; message_update=default_message_update, - message_update_kwargs=default_message_update_kwargs(), + message_update_kwargs=(;), ) vertex = src(edge) - messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) + messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) state = factor(bp_cache, vertex) return message_update(ITensor[messages; state]; message_update_kwargs...) diff --git a/src/contract.jl b/src/contract.jl index 77c4ee1a..f4e74603 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -30,17 +30,3 @@ function contract_density_matrix( end return out end - -function contract_exact( - contract_list::Vector{ITensor}; - contraction_sequence_alg="optimal", - normalize=true, - contractor_kwargs..., -) - seq = contraction_sequence(contract_list; alg=contraction_sequence_alg) - out = ITensors.contract(contract_list; sequence=seq, contractor_kwargs...) - if normalize - normalize!(out) - end - return ITensor[out] -end diff --git a/src/environment.jl b/src/environment.jl new file mode 100644 index 00000000..262a7c23 --- /dev/null +++ b/src/environment.jl @@ -0,0 +1,42 @@ +default_environment_algorithm() = "exact" + +function environment( + ψ::AbstractITensorNetwork, + vertices::Vector; + alg=default_environment_algorithm(), + kwargs..., +) + return environment(Algorithm(alg), ψ, vertices; kwargs...) +end + +function environment( + ::Algorithm"exact", + ψ::AbstractITensorNetwork, + verts::Vector; + contraction_sequence_alg="optimal", + kwargs..., +) + ψ_reduced = Vector{ITensor}(subgraph(ψ, setdiff(vertices(ψ), verts))) + sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg) + return ITensor[contract(ψ_reduced; sequence, kwargs...)] +end + +function environment( + ::Algorithm"bp", + ψ::AbstractITensorNetwork, + vertices::Vector; + (cache!)=nothing, + partitioned_vertices=default_partitioned_vertices(ψ), + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(cache!), +) + if isnothing(cache!) + cache! = Ref(BeliefPropagationCache(ψ, partitioned_vertices)) + end + + if update_cache + cache![] = update(cache![]; cache_update_kwargs...) + end + + return environment(cache![], vertices) +end diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index e6efe54e..f0557ac6 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -23,20 +23,22 @@ function ket_vertices(f::AbstractFormNetwork) return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f)) end -function bra_ket_vertices(f::AbstractFormNetwork) - return vcat(bra_vertices(f), ket_vertices(f)) +function bra_vertices(f::AbstractFormNetwork, original_state_vertices::Vector) + return [bra_vertex_map(f)(osv) for osv in original_state_vertices] end -function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector) - return [bra_vertex_map(f)(sv) for sv in state_vertices] +function ket_vertices(f::AbstractFormNetwork, original_state_vertices::Vector) + return [ket_vertex_map(f)(osv) for osv in original_state_vertices] end -function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) - return [ket_vertex_map(f)(sv) for sv in state_vertices] +function state_vertices(f::AbstractFormNetwork) + return vcat(bra_vertices(f), ket_vertices(f)) end -function bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) - return vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) +function state_vertices(f::AbstractFormNetwork, original_state_vertices::Vector) + return vcat( + bra_vertices(f, original_state_vertices), ket_vertices(f, original_state_vertices) + ) end function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector) @@ -57,18 +59,28 @@ function operator_network(f::AbstractFormNetwork) ) end -function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) - tn_vertices = derivative_vertices(f, state_vertices) - return derivative(tensornetwork(f), tn_vertices; kwargs...) -end - -function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) - return setdiff( - vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) - ) +function environment( + f::AbstractFormNetwork, + original_state_vertices::Vector; + alg=default_environment_algorithm(), + kwargs..., +) + form_vertices = state_vertices(f, original_state_vertices) + if alg == "bp" + partitioned_vertices = group(v -> original_state_vertex(f, v), vertices(f)) + return environment( + tensornetwork(f), form_vertices; alg, partitioned_vertices, kwargs... + ) + else + return environment(tensornetwork(f), form_vertices; alg, kwargs...) + end end operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f)) bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f)) ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f)) inv_vertex_map(f::AbstractFormNetwork) = v -> first(v) +operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v) +bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v) +ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v) +original_state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v) diff --git a/src/formnetworks/bilinearformnetwork.jl b/src/formnetworks/bilinearformnetwork.jl index 356b0ed1..5519c1e3 100644 --- a/src/formnetworks/bilinearformnetwork.jl +++ b/src/formnetworks/bilinearformnetwork.jl @@ -52,11 +52,15 @@ function BilinearFormNetwork( end function update( - blf::BilinearFormNetwork, state_vertex, bra_state::ITensor, ket_state::ITensor + blf::BilinearFormNetwork, original_state_vertex, bra_state::ITensor, ket_state::ITensor ) blf = copy(blf) # TODO: Maybe add a check that it really does preserve the graph. - setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex)) - setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex)) + setindex_preserve_graph!( + tensornetwork(blf), bra_state, bra_vertex(blf, original_state_vertex) + ) + setindex_preserve_graph!( + tensornetwork(blf), ket_state, ket_vertex(blf, original_state_vertex) + ) return blf end diff --git a/src/formnetworks/quadraticformnetwork.jl b/src/formnetworks/quadraticformnetwork.jl index 5acee59e..8aac841a 100644 --- a/src/formnetworks/quadraticformnetwork.jl +++ b/src/formnetworks/quadraticformnetwork.jl @@ -57,9 +57,9 @@ function QuadraticFormNetwork( return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map) end -function update(qf::QuadraticFormNetwork, state_vertex, ket_state::ITensor) +function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor) state_inds = inds(ket_state) bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds)) - new_blf = update(bilinear_formnetwork(qf), state_vertex, bra_state, ket_state) + new_blf = update(bilinear_formnetwork(qf), original_state_vertex, bra_state, ket_state) return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf)) end diff --git a/src/gauging.jl b/src/gauging.jl index 41bd02f0..89a30555 100644 --- a/src/gauging.jl +++ b/src/gauging.jl @@ -23,7 +23,6 @@ function default_norm_cache(ψ::ITensorNetwork) ψψ = norm_network(ψ) return BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) end -default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) function ITensorNetwork( ψ_vidal::VidalITensorNetwork; (cache!)=nothing, update_gauge=false, update_kwargs... diff --git a/test/test_apply.jl b/test/test_apply.jl index 0a815b9a..f32d8853 100644 --- a/test/test_apply.jl +++ b/test/test_apply.jl @@ -1,6 +1,6 @@ using ITensorNetworks using ITensorNetworks: - incoming_messages, + environment, update, contract_inner, norm_network, @@ -29,14 +29,14 @@ using SplitApplyCombine #Simple Belief Propagation Grouping bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsSBP = incoming_messages(bp_cache, PartitionVertex.([v1, v2])) + envsSBP = environment(bp_cache, PartitionVertex.([v1, v2])) ψv = VidalITensorNetwork(ψ) #This grouping will correspond to calculating the environments exactly (each column of the grid is a partition) bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsGBP = incoming_messages(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) + envsGBP = environment(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) ngates = 5 diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index a388e758..4cbb9a7a 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -8,7 +8,8 @@ using ITensorNetworks: tensornetwork, update, update_factor, - incoming_messages + environment, + contract using Test using Compat using ITensors @@ -40,7 +41,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc) - env_tensors = incoming_messages(bpc, [PartitionVertex(v)]) + env_tensors = environment(bpc, [PartitionVertex(v)]) numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[] denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[] @@ -70,7 +71,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc) - env_tensors = incoming_messages(bpc, [PartitionVertex(v)]) + env_tensors = environment(bpc, [PartitionVertex(v)]) numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[] denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[] @@ -93,7 +94,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc; maxiter=20) - env_tensors = incoming_messages(bpc, vs) + env_tensors = environment(bpc, vs) numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[] denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[] @@ -112,7 +113,7 @@ ITensors.disable_warn_order() bpc = update(bpc; maxiter=20) ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs])) - env_tensors = incoming_messages(bpc, [(v, 2) for v in vs]) + env_tensors = environment(bpc, [(v, 2) for v in vs]) rdm = ITensors.contract( vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]) ) @@ -148,7 +149,7 @@ ITensors.disable_warn_order() message_update_kwargs=(; cutoff=1e-6, maxdim=4), ) - env_tensors = incoming_messages(bpc, [v]) + env_tensors = environment(bpc, [v]) numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[] denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[] diff --git a/test/test_forms.jl b/test/test_forms.jl index 74982629..0bfa2d02 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -1,21 +1,24 @@ using ITensors -using Graphs +using Graphs: nv using NamedGraphs using ITensorNetworks using ITensorNetworks: delta_network, update, tensornetwork, - bra_vertex_map, - ket_vertex_map, + bra_vertex, + ket_vertex, dual_index_map, bra_network, ket_network, - operator_network + operator_network, + environment, + BeliefPropagationCache using Test using Random +using SplitApplyCombine -@testset "FormNetworkss" begin +@testset "FormNetworks" begin g = named_grid((1, 4)) s_ket = siteinds("S=1/2", g) s_bra = prime(s_ket; links=[]) @@ -42,10 +45,24 @@ using Random new_tensor = randomITensor(inds(ψket[v])) qf_updated = update(qf, v, copy(new_tensor)) - @test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)] ≈ + @test tensornetwork(qf_updated)[bra_vertex(qf_updated, v)] ≈ dual_index_map(qf_updated)(dag(new_tensor)) - @test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] ≈ new_tensor + @test tensornetwork(qf_updated)[ket_vertex(qf_updated, v)] ≈ new_tensor @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) + + ∂qf_∂v = only(environment(qf, [v])) + @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) + + ∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = contract(∂qf_∂v_bp) + ∂qf_∂v_bp /= norm(∂qf_∂v_bp) + ∂qf_∂v /= norm(∂qf_∂v) + @test ∂qf_∂v_bp != ∂qf_∂v + + ∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = contract(∂qf_∂v_bp) + ∂qf_∂v_bp /= norm(∂qf_∂v_bp) + @test ∂qf_∂v_bp ≈ ∂qf_∂v end