From 94691eb9f26bee77b89ec92b3026a8820918814c Mon Sep 17 00:00:00 2001 From: Joey Date: Wed, 13 Mar 2024 16:47:17 -0400 Subject: [PATCH 01/11] Derivative functionality --- src/ITensorNetworks.jl | 1 + src/caches/beliefpropagationcache.jl | 4 ++++ src/contract.jl | 1 + src/derivative.jl | 20 ++++++++++++++++++++ src/gauging.jl | 1 - test/test_forms.jl | 14 ++++++++++++-- 6 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 src/derivative.jl diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 35f5702e..9e81a3e6 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -129,6 +129,7 @@ include(joinpath("treetensornetworks", "solvers", "contract.jl")) include(joinpath("treetensornetworks", "solvers", "linsolve.jl")) include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl")) include("apply.jl") +include("derivative.jl") include("exports.jl") diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index f5337784..019d7c3e 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -8,6 +8,10 @@ default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="o @traitfn function default_bp_maxiter(g::::IsDirected) return default_bp_maxiter(undirected_graph(underlying_graph(g))) end +default_cache(ψ::ITensorNetwork) = BeliefPropagationCache(ψ, [[v] for v in vertices(tn)]) +default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) + + function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) lhs, rhs = contract(message_a), contract(message_b) return 0.5 * diff --git a/src/contract.jl b/src/contract.jl index 77c4ee1a..d8b4e8d0 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -31,6 +31,7 @@ function contract_density_matrix( return out end +#This should probably just be removed in favour of the second function above function contract_exact( contract_list::Vector{ITensor}; contraction_sequence_alg="optimal", diff --git a/src/derivative.jl b/src/derivative.jl new file mode 100644 index 00000000..599baa29 --- /dev/null +++ b/src/derivative.jl @@ -0,0 +1,20 @@ +default_derivative_algorithm() = "exact" + +function derivative(ψ::AbstractITensorNetwork, vertices::Vector; alg = default_derivative_algorithm(), kwargs...) + return derivative(Algorithm(alg), ψ, vertices; kwargs...) +end + +function derivative(::Algorithm"exact", ψ::AbstractITensorNetwork, vertices::Vector; contraction_sequence_alg = "optimal", kwargs...) + ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices)) + return contract_exact(ψ_reduced; normalize = false, contraction_sequence_alg, kwargs...) +end + +function derivative(::Algorithm"bp", ψ::AbstractITensorNetwork, vertices::Vector; (bp_cache!) = nothing, bp_cache_update_kwargs = default_cache_update_kwargs(bp_cache)) + + if isnothing(bp_cache!) + bp_cache! = Ref(default_cache(ψ)) + end + bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...) + return incoming_messages(bp_cache![], vertices) +end + \ No newline at end of file diff --git a/src/gauging.jl b/src/gauging.jl index 41bd02f0..89a30555 100644 --- a/src/gauging.jl +++ b/src/gauging.jl @@ -23,7 +23,6 @@ function default_norm_cache(ψ::ITensorNetwork) ψψ = norm_network(ψ) return BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) end -default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) function ITensorNetwork( ψ_vidal::VidalITensorNetwork; (cache!)=nothing, update_gauge=false, update_kwargs... diff --git a/test/test_forms.jl b/test/test_forms.jl index 74982629..e9771c68 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -11,11 +11,14 @@ using ITensorNetworks: dual_index_map, bra_network, ket_network, - operator_network + operator_network, + derivative, + BeliefPropagationCache using Test using Random +using SplitApplyCombine -@testset "FormNetworkss" begin +@testset "FormNetworks" begin g = named_grid((1, 4)) s_ket = siteinds("S=1/2", g) s_bra = prime(s_ket; links=[]) @@ -48,4 +51,11 @@ using Random @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) + + ∂qf_∂v = only(derivative(qf, [v])) + @test (∂qf_∂v)*(qf[ket_vertex_map(qf)(v)]*qf[bra_vertex_map(qf)(v)]) ≈ ITensors.contract(qf) + + partition_vertices= group(v -> first(v), vertices(qf)) + qf_cache = BeliefPropagationCache(tensornetwork(qf), partition_vertices) + ∂qf_∂v = derivative(qf, [v]; alg = "bp", (bp_cache!) = Ref(qf_cache)) end From e9a6bc5ffea4330262738b74f909a4ec819c05f7 Mon Sep 17 00:00:00 2001 From: Joey Date: Mon, 18 Mar 2024 12:37:28 +0000 Subject: [PATCH 02/11] New Derivative Functionality --- src/caches/beliefpropagationcache.jl | 20 +++++++++---- src/contract.jl | 15 ---------- src/derivative.jl | 40 ++++++++++++++++++------- src/formnetworks/abstractformnetwork.jl | 4 +-- test/test_belief_propagation.jl | 3 +- test/test_forms.jl | 20 +++++++++---- 6 files changed, 61 insertions(+), 41 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 019d7c3e..3150400d 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -1,16 +1,20 @@ default_message(inds_e) = ITensor[denseblocks(delta(inds_e))] default_messages(ptn::PartitionedGraph) = Dictionary() function default_message_update(contract_list::Vector{ITensor}; kwargs...) - return contract_exact(contract_list; kwargs...) + sequence = optimal_contraction_sequence(contract_list) + updated_messages = contract(contract_list; sequence, kwargs...) + updated_messages = normalize!(updated_messages) + return ITensor[updated_messages] end -default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="optimal") @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing @traitfn function default_bp_maxiter(g::::IsDirected) return default_bp_maxiter(undirected_graph(underlying_graph(g))) end -default_cache(ψ::ITensorNetwork) = BeliefPropagationCache(ψ, [[v] for v in vertices(tn)]) -default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) +default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) +#We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition. +default_partitioning(f::AbstractFormNetwork) = group(v -> first(v), vertices(f)) +default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) lhs, rhs = contract(message_a), contract(message_b) @@ -31,11 +35,15 @@ function BeliefPropagationCache( return BeliefPropagationCache(ptn, messages, default_message) end -function BeliefPropagationCache(tn::ITensorNetwork, partitioned_vertices; kwargs...) +function BeliefPropagationCache(tn, partitioned_vertices; kwargs...) ptn = PartitionedGraph(tn, partitioned_vertices) return BeliefPropagationCache(ptn; kwargs...) end +function BeliefPropagationCache(tn; kwargs...) + return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...) +end + function partitioned_itensornetwork(bp_cache::BeliefPropagationCache) return bp_cache.partitioned_itensornetwork end @@ -133,7 +141,7 @@ function update_message( bp_cache::BeliefPropagationCache, edge::PartitionEdge; message_update=default_message_update, - message_update_kwargs=default_message_update_kwargs(), + message_update_kwargs=(;), ) vertex = src(edge) messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) diff --git a/src/contract.jl b/src/contract.jl index d8b4e8d0..f4e74603 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -30,18 +30,3 @@ function contract_density_matrix( end return out end - -#This should probably just be removed in favour of the second function above -function contract_exact( - contract_list::Vector{ITensor}; - contraction_sequence_alg="optimal", - normalize=true, - contractor_kwargs..., -) - seq = contraction_sequence(contract_list; alg=contraction_sequence_alg) - out = ITensors.contract(contract_list; sequence=seq, contractor_kwargs...) - if normalize - normalize!(out) - end - return ITensor[out] -end diff --git a/src/derivative.jl b/src/derivative.jl index 599baa29..731c6663 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -1,20 +1,38 @@ default_derivative_algorithm() = "exact" -function derivative(ψ::AbstractITensorNetwork, vertices::Vector; alg = default_derivative_algorithm(), kwargs...) - return derivative(Algorithm(alg), ψ, vertices; kwargs...) +function derivative( + ψ::AbstractITensorNetwork, vertices::Vector; alg=default_derivative_algorithm(), kwargs... +) + return derivative(Algorithm(alg), ψ, vertices; kwargs...) end -function derivative(::Algorithm"exact", ψ::AbstractITensorNetwork, vertices::Vector; contraction_sequence_alg = "optimal", kwargs...) - ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices)) - return contract_exact(ψ_reduced; normalize = false, contraction_sequence_alg, kwargs...) +function derivative( + ::Algorithm"exact", + ψ::AbstractITensorNetwork, + vertices::Vector; + contraction_sequence_alg="optimal", + kwargs..., +) + ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices)) + sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg) + return ITensor[contract(ψ_reduced; sequence, kwargs...)] end -function derivative(::Algorithm"bp", ψ::AbstractITensorNetwork, vertices::Vector; (bp_cache!) = nothing, bp_cache_update_kwargs = default_cache_update_kwargs(bp_cache)) +function derivative( + ::Algorithm"bp", + ψ::AbstractITensorNetwork, + verts::Vector; + (bp_cache!)=nothing, + update_bp_cache=isnothing(bp_cache!), + bp_cache_update_kwargs=default_cache_update_kwargs(bp_cache!), +) + if isnothing(bp_cache!) + bp_cache! = Ref(BeliefPropagationCache(ψ)) + end - if isnothing(bp_cache!) - bp_cache! = Ref(default_cache(ψ)) - end + if update_bp_cache bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...) - return incoming_messages(bp_cache![], vertices) + end + + return incoming_messages(bp_cache![], setdiff(vertices(ψ), verts)) end - \ No newline at end of file diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index e6efe54e..ccd4aa67 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -57,9 +57,9 @@ function operator_network(f::AbstractFormNetwork) ) end -function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) +function derivative_state(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) tn_vertices = derivative_vertices(f, state_vertices) - return derivative(tensornetwork(f), tn_vertices; kwargs...) + return derivative(f, tn_vertices; kwargs...) end function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index a388e758..80d8df11 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -8,7 +8,8 @@ using ITensorNetworks: tensornetwork, update, update_factor, - incoming_messages + incoming_messages, + contract using Test using Compat using ITensors diff --git a/test/test_forms.jl b/test/test_forms.jl index e9771c68..7f74c8a7 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -12,7 +12,7 @@ using ITensorNetworks: bra_network, ket_network, operator_network, - derivative, + derivative_state, BeliefPropagationCache using Test using Random @@ -52,10 +52,18 @@ using SplitApplyCombine @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) - ∂qf_∂v = only(derivative(qf, [v])) - @test (∂qf_∂v)*(qf[ket_vertex_map(qf)(v)]*qf[bra_vertex_map(qf)(v)]) ≈ ITensors.contract(qf) + ∂qf_∂v = only(derivative_state(qf, [v])) + @test (∂qf_∂v) * (qf[ket_vertex_map(qf)(v)] * qf[bra_vertex_map(qf)(v)]) ≈ + ITensors.contract(qf) - partition_vertices= group(v -> first(v), vertices(qf)) - qf_cache = BeliefPropagationCache(tensornetwork(qf), partition_vertices) - ∂qf_∂v = derivative(qf, [v]; alg = "bp", (bp_cache!) = Ref(qf_cache)) + ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=false) + ∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp) + ∂qf_∂v_bp = normalize!(∂qf_∂v_bp) + ∂qf_∂v = normalize!(∂qf_∂v) + @test ∂qf_∂v_bp != ∂qf_∂v + + ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=true) + ∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp) + ∂qf_∂v_bp = normalize!(∂qf_∂v_bp) + @test ∂qf_∂v_bp ≈ ∂qf_∂v end From 33f39a28b1a74ab77651bb741ae3e8b8ee69b080 Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 19 Mar 2024 10:40:05 +0000 Subject: [PATCH 03/11] Rename variables. Extra functionality --- src/derivative.jl | 16 +++++++-------- src/formnetworks/abstractformnetwork.jl | 3 +++ src/formnetworks/bilinearformnetwork.jl | 4 ++-- test/test_forms.jl | 27 ++++++++++++------------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/derivative.jl b/src/derivative.jl index 731c6663..347f9c5d 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -22,17 +22,17 @@ function derivative( ::Algorithm"bp", ψ::AbstractITensorNetwork, verts::Vector; - (bp_cache!)=nothing, - update_bp_cache=isnothing(bp_cache!), - bp_cache_update_kwargs=default_cache_update_kwargs(bp_cache!), + (cache!)=nothing, + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(cache!), ) - if isnothing(bp_cache!) - bp_cache! = Ref(BeliefPropagationCache(ψ)) + if isnothing(cache!) + cache! = Ref(BeliefPropagationCache(ψ)) end - if update_bp_cache - bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...) + if update_cache + cache![] = update(cache![]; cache_update_kwargs...) end - return incoming_messages(bp_cache![], setdiff(vertices(ψ), verts)) + return incoming_messages(cache![], setdiff(vertices(ψ), verts)) end diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index ccd4aa67..a8187751 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -72,3 +72,6 @@ operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f) bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f)) ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f)) inv_vertex_map(f::AbstractFormNetwork) = v -> first(v) +operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v) +bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v) +ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v) diff --git a/src/formnetworks/bilinearformnetwork.jl b/src/formnetworks/bilinearformnetwork.jl index 356b0ed1..f93d89ee 100644 --- a/src/formnetworks/bilinearformnetwork.jl +++ b/src/formnetworks/bilinearformnetwork.jl @@ -56,7 +56,7 @@ function update( ) blf = copy(blf) # TODO: Maybe add a check that it really does preserve the graph. - setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex)) - setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex)) + setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex(blf, state_vertex)) + setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex(blf, state_vertex)) return blf end diff --git a/test/test_forms.jl b/test/test_forms.jl index 7f74c8a7..a7652c52 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -1,13 +1,13 @@ using ITensors -using Graphs +using Graphs: nv using NamedGraphs using ITensorNetworks using ITensorNetworks: delta_network, update, tensornetwork, - bra_vertex_map, - ket_vertex_map, + bra_vertex, + ket_vertex, dual_index_map, bra_network, ket_network, @@ -45,25 +45,24 @@ using SplitApplyCombine new_tensor = randomITensor(inds(ψket[v])) qf_updated = update(qf, v, copy(new_tensor)) - @test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)] ≈ + @test tensornetwork(qf_updated)[bra_vertex(qf_updated, v)] ≈ dual_index_map(qf_updated)(dag(new_tensor)) - @test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] ≈ new_tensor + @test tensornetwork(qf_updated)[ket_vertex(qf_updated, v)] ≈ new_tensor @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) ∂qf_∂v = only(derivative_state(qf, [v])) - @test (∂qf_∂v) * (qf[ket_vertex_map(qf)(v)] * qf[bra_vertex_map(qf)(v)]) ≈ - ITensors.contract(qf) + @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=false) - ∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp) - ∂qf_∂v_bp = normalize!(∂qf_∂v_bp) - ∂qf_∂v = normalize!(∂qf_∂v) + ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = contract(∂qf_∂v_bp) + ∂qf_∂v_bp /= norm(∂qf_∂v_bp) + ∂qf_∂v /= norm(∂qf_∂v) @test ∂qf_∂v_bp != ∂qf_∂v - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=true) - ∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp) - ∂qf_∂v_bp = normalize!(∂qf_∂v_bp) + ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = contract(∂qf_∂v_bp) + ∂qf_∂v_bp /= norm(∂qf_∂v_bp) @test ∂qf_∂v_bp ≈ ∂qf_∂v end From c818f136f79643330a9f7d5024524c7d2d14e89b Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 19 Mar 2024 10:53:19 +0000 Subject: [PATCH 04/11] Better default partitioning of abstractformnetwork --- src/caches/beliefpropagationcache.jl | 2 +- src/formnetworks/abstractformnetwork.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 3150400d..81ef80ba 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -13,7 +13,7 @@ end default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) #We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition. -default_partitioning(f::AbstractFormNetwork) = group(v -> first(v), vertices(f)) +default_partitioning(f::AbstractFormNetwork) = group(v -> state_vertex(f, v), vertices(f)) default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index a8187751..82d79a24 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -75,3 +75,4 @@ inv_vertex_map(f::AbstractFormNetwork) = v -> first(v) operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v) bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v) ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v) +state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v) From 3e875cf08c11b8126c1a2b02453beb1f10378956 Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 19 Mar 2024 11:52:33 +0000 Subject: [PATCH 05/11] Avoid in-place normalization in default_message_update --- src/caches/beliefpropagationcache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 81ef80ba..072d9553 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -3,7 +3,7 @@ default_messages(ptn::PartitionedGraph) = Dictionary() function default_message_update(contract_list::Vector{ITensor}; kwargs...) sequence = optimal_contraction_sequence(contract_list) updated_messages = contract(contract_list; sequence, kwargs...) - updated_messages = normalize!(updated_messages) + updated_messages /= norm(updated_messages) return ITensor[updated_messages] end @traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing From 40d08adc8a12a044c3414796a5a59f6b36de1f16 Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 19 Mar 2024 14:41:22 +0000 Subject: [PATCH 06/11] derivate_state -> derivative.jl --- src/caches/beliefpropagationcache.jl | 13 ++++++------- src/derivative.jl | 5 +++-- src/formnetworks/abstractformnetwork.jl | 14 ++++++++++++-- test/test_apply.jl | 6 +++--- test/test_belief_propagation.jl | 12 ++++++------ test/test_forms.jl | 8 ++++---- 6 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 072d9553..ab6204b5 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -11,7 +11,6 @@ end return default_bp_maxiter(undirected_graph(underlying_graph(g))) end default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) - #We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition. default_partitioning(f::AbstractFormNetwork) = group(v -> state_vertex(f, v), vertices(f)) default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) @@ -104,7 +103,7 @@ function set_messages(cache::BeliefPropagationCache, messages) ) end -function incoming_messages( +function environment( bp_cache::BeliefPropagationCache, partition_vertices::Vector{<:PartitionVertex}; ignore_edges=PartitionEdge[], @@ -114,15 +113,15 @@ function incoming_messages( return reduce(vcat, ms; init=[]) end -function incoming_messages( +function environment( bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... ) - return incoming_messages(bp_cache, [partition_vertex]; kwargs...) + return environment(bp_cache, [partition_vertex]; kwargs...) end -function incoming_messages(bp_cache::BeliefPropagationCache, verts::Vector) +function environment(bp_cache::BeliefPropagationCache, verts::Vector) partition_verts = partitionvertices(bp_cache, verts) - messages = incoming_messages(bp_cache, partition_verts) + messages = environment(bp_cache, partition_verts) central_tensors = ITensor[ tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts) ] @@ -144,7 +143,7 @@ function update_message( message_update_kwargs=(;), ) vertex = src(edge) - messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) + messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) state = factor(bp_cache, vertex) return message_update(ITensor[messages; state]; message_update_kwargs...) diff --git a/src/derivative.jl b/src/derivative.jl index 347f9c5d..4acbd1a8 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -23,16 +23,17 @@ function derivative( ψ::AbstractITensorNetwork, verts::Vector; (cache!)=nothing, + partitions=default_partitioning(ψ), update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(cache!), ) if isnothing(cache!) - cache! = Ref(BeliefPropagationCache(ψ)) + cache! = Ref(BeliefPropagationCache(ψ, partitions)) end if update_cache cache![] = update(cache![]; cache_update_kwargs...) end - return incoming_messages(cache![], setdiff(vertices(ψ), verts)) + return environment(cache![], setdiff(vertices(ψ), verts)) end diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 82d79a24..66b395ad 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -57,9 +57,19 @@ function operator_network(f::AbstractFormNetwork) ) end -function derivative_state(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) +function derivative( + f::AbstractFormNetwork, + state_vertices::Vector; + alg=default_derivative_algorithm(), + kwargs..., +) tn_vertices = derivative_vertices(f, state_vertices) - return derivative(f, tn_vertices; kwargs...) + if alg == "bp" + partitions = group(v -> state_vertex(f, v), vertices(f)) + return derivative(tensornetwork(f), tn_vertices; alg, partitions, kwargs...) + else + return derivative(tensornetwork(f), tn_vertices; alg, kwargs...) + end end function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) diff --git a/test/test_apply.jl b/test/test_apply.jl index 0a815b9a..f32d8853 100644 --- a/test/test_apply.jl +++ b/test/test_apply.jl @@ -1,6 +1,6 @@ using ITensorNetworks using ITensorNetworks: - incoming_messages, + environment, update, contract_inner, norm_network, @@ -29,14 +29,14 @@ using SplitApplyCombine #Simple Belief Propagation Grouping bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsSBP = incoming_messages(bp_cache, PartitionVertex.([v1, v2])) + envsSBP = environment(bp_cache, PartitionVertex.([v1, v2])) ψv = VidalITensorNetwork(ψ) #This grouping will correspond to calculating the environments exactly (each column of the grid is a partition) bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsGBP = incoming_messages(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) + envsGBP = environment(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) ngates = 5 diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index 80d8df11..4cbb9a7a 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -8,7 +8,7 @@ using ITensorNetworks: tensornetwork, update, update_factor, - incoming_messages, + environment, contract using Test using Compat @@ -41,7 +41,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc) - env_tensors = incoming_messages(bpc, [PartitionVertex(v)]) + env_tensors = environment(bpc, [PartitionVertex(v)]) numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[] denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[] @@ -71,7 +71,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc) - env_tensors = incoming_messages(bpc, [PartitionVertex(v)]) + env_tensors = environment(bpc, [PartitionVertex(v)]) numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[] denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[] @@ -94,7 +94,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc; maxiter=20) - env_tensors = incoming_messages(bpc, vs) + env_tensors = environment(bpc, vs) numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[] denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[] @@ -113,7 +113,7 @@ ITensors.disable_warn_order() bpc = update(bpc; maxiter=20) ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs])) - env_tensors = incoming_messages(bpc, [(v, 2) for v in vs]) + env_tensors = environment(bpc, [(v, 2) for v in vs]) rdm = ITensors.contract( vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]) ) @@ -149,7 +149,7 @@ ITensors.disable_warn_order() message_update_kwargs=(; cutoff=1e-6, maxdim=4), ) - env_tensors = incoming_messages(bpc, [v]) + env_tensors = environment(bpc, [v]) numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[] denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[] diff --git a/test/test_forms.jl b/test/test_forms.jl index a7652c52..63686336 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -12,7 +12,7 @@ using ITensorNetworks: bra_network, ket_network, operator_network, - derivative_state, + derivative, BeliefPropagationCache using Test using Random @@ -52,16 +52,16 @@ using SplitApplyCombine @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) - ∂qf_∂v = only(derivative_state(qf, [v])) + ∂qf_∂v = only(derivative(qf, [v])) @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=false) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) ∂qf_∂v /= norm(∂qf_∂v) @test ∂qf_∂v_bp != ∂qf_∂v - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=true) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) @test ∂qf_∂v_bp ≈ ∂qf_∂v From 73680b6e0ba7419427f4ff50f471c073ed94763e Mon Sep 17 00:00:00 2001 From: Joey Date: Wed, 20 Mar 2024 10:30:30 +0000 Subject: [PATCH 07/11] Derivative -> Environment --- src/ITensorNetworks.jl | 2 +- src/caches/beliefpropagationcache.jl | 2 -- src/{derivative.jl => environment.jl} | 15 +++++++++------ src/formnetworks/abstractformnetwork.jl | 12 ++++++------ test/test_forms.jl | 8 ++++---- 5 files changed, 20 insertions(+), 19 deletions(-) rename src/{derivative.jl => environment.jl} (74%) diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 9e81a3e6..536db27c 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -129,7 +129,7 @@ include(joinpath("treetensornetworks", "solvers", "contract.jl")) include(joinpath("treetensornetworks", "solvers", "linsolve.jl")) include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl")) include("apply.jl") -include("derivative.jl") +include("environment.jl") include("exports.jl") diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index ab6204b5..92f761c2 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -11,8 +11,6 @@ end return default_bp_maxiter(undirected_graph(underlying_graph(g))) end default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) -#We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition. -default_partitioning(f::AbstractFormNetwork) = group(v -> state_vertex(f, v), vertices(f)) default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) diff --git a/src/derivative.jl b/src/environment.jl similarity index 74% rename from src/derivative.jl rename to src/environment.jl index 4acbd1a8..7d1c2ef1 100644 --- a/src/derivative.jl +++ b/src/environment.jl @@ -1,12 +1,15 @@ -default_derivative_algorithm() = "exact" +default_environment_algorithm() = "exact" -function derivative( - ψ::AbstractITensorNetwork, vertices::Vector; alg=default_derivative_algorithm(), kwargs... +function environment( + ψ::AbstractITensorNetwork, + vertices::Vector; + alg=default_environment_algorithm(), + kwargs..., ) - return derivative(Algorithm(alg), ψ, vertices; kwargs...) + return environment(Algorithm(alg), ψ, vertices; kwargs...) end -function derivative( +function environment( ::Algorithm"exact", ψ::AbstractITensorNetwork, vertices::Vector; @@ -18,7 +21,7 @@ function derivative( return ITensor[contract(ψ_reduced; sequence, kwargs...)] end -function derivative( +function environment( ::Algorithm"bp", ψ::AbstractITensorNetwork, verts::Vector; diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 66b395ad..07a2a0fd 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -57,22 +57,22 @@ function operator_network(f::AbstractFormNetwork) ) end -function derivative( +function environment( f::AbstractFormNetwork, state_vertices::Vector; - alg=default_derivative_algorithm(), + alg=default_environment_algorithm(), kwargs..., ) - tn_vertices = derivative_vertices(f, state_vertices) + tn_vertices = environment_vertices(f, state_vertices) if alg == "bp" partitions = group(v -> state_vertex(f, v), vertices(f)) - return derivative(tensornetwork(f), tn_vertices; alg, partitions, kwargs...) + return environment(tensornetwork(f), tn_vertices; alg, partitions, kwargs...) else - return derivative(tensornetwork(f), tn_vertices; alg, kwargs...) + return environment(tensornetwork(f), tn_vertices; alg, kwargs...) end end -function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) +function environment_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) return setdiff( vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) ) diff --git a/test/test_forms.jl b/test/test_forms.jl index 63686336..0bfa2d02 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -12,7 +12,7 @@ using ITensorNetworks: bra_network, ket_network, operator_network, - derivative, + environment, BeliefPropagationCache using Test using Random @@ -52,16 +52,16 @@ using SplitApplyCombine @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) - ∂qf_∂v = only(derivative(qf, [v])) + ∂qf_∂v = only(environment(qf, [v])) @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) - ∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=false) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) ∂qf_∂v /= norm(∂qf_∂v) @test ∂qf_∂v_bp != ∂qf_∂v - ∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=true) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) @test ∂qf_∂v_bp ≈ ∂qf_∂v From 01a89fa1ed4d69dfd0b560e2cf720213db7d9341 Mon Sep 17 00:00:00 2001 From: Joey Date: Wed, 20 Mar 2024 11:44:50 +0000 Subject: [PATCH 08/11] environment_vertices -> state_vertices --- src/formnetworks/abstractformnetwork.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 07a2a0fd..b2e27fb9 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -63,7 +63,7 @@ function environment( alg=default_environment_algorithm(), kwargs..., ) - tn_vertices = environment_vertices(f, state_vertices) + tn_vertices = state_vertices(f, state_vertices) if alg == "bp" partitions = group(v -> state_vertex(f, v), vertices(f)) return environment(tensornetwork(f), tn_vertices; alg, partitions, kwargs...) @@ -72,7 +72,7 @@ function environment( end end -function environment_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) +function state_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) return setdiff( vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) ) From 70258b413efd94ff0ef662bd6423253aa82746cb Mon Sep 17 00:00:00 2001 From: Joey Date: Wed, 20 Mar 2024 13:47:36 +0000 Subject: [PATCH 09/11] Better naming convention for vertices of forms --- src/formnetworks/abstractformnetwork.jl | 36 +++++++++++------------- src/formnetworks/bilinearformnetwork.jl | 10 +++++-- src/formnetworks/quadraticformnetwork.jl | 4 +-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index b2e27fb9..5f70955d 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -23,20 +23,22 @@ function ket_vertices(f::AbstractFormNetwork) return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f)) end -function bra_ket_vertices(f::AbstractFormNetwork) - return vcat(bra_vertices(f), ket_vertices(f)) +function bra_vertices(f::AbstractFormNetwork, original_state_vertices::Vector) + return [bra_vertex_map(f)(osv) for osv in original_state_vertices] end -function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector) - return [bra_vertex_map(f)(sv) for sv in state_vertices] +function ket_vertices(f::AbstractFormNetwork, original_state_vertices::Vector) + return [ket_vertex_map(f)(osv) for osv in original_state_vertices] end -function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) - return [ket_vertex_map(f)(sv) for sv in state_vertices] +function state_vertices(f::AbstractFormNetwork) + return vcat(bra_vertices(f), ket_vertices(f)) end -function bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) - return vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) +function state_vertices(f::AbstractFormNetwork, original_state_vertices::Vector) + return vcat( + bra_vertices(f, original_state_vertices), ket_vertices(f, original_state_vertices) + ) end function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector) @@ -59,25 +61,19 @@ end function environment( f::AbstractFormNetwork, - state_vertices::Vector; + original_state_vertices::Vector; alg=default_environment_algorithm(), kwargs..., ) - tn_vertices = state_vertices(f, state_vertices) + form_vertices = setdiff(vertices(f), state_vertices(f, original_state_vertices)) if alg == "bp" - partitions = group(v -> state_vertex(f, v), vertices(f)) - return environment(tensornetwork(f), tn_vertices; alg, partitions, kwargs...) + partitions = group(v -> original_state_vertex(f, v), vertices(f)) + return environment(tensornetwork(f), form_vertices; alg, partitions, kwargs...) else - return environment(tensornetwork(f), tn_vertices; alg, kwargs...) + return environment(tensornetwork(f), form_vertices; alg, kwargs...) end end -function state_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) - return setdiff( - vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) - ) -end - operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f)) bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f)) ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f)) @@ -85,4 +81,4 @@ inv_vertex_map(f::AbstractFormNetwork) = v -> first(v) operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v) bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v) ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v) -state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v) +original_state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v) diff --git a/src/formnetworks/bilinearformnetwork.jl b/src/formnetworks/bilinearformnetwork.jl index f93d89ee..5519c1e3 100644 --- a/src/formnetworks/bilinearformnetwork.jl +++ b/src/formnetworks/bilinearformnetwork.jl @@ -52,11 +52,15 @@ function BilinearFormNetwork( end function update( - blf::BilinearFormNetwork, state_vertex, bra_state::ITensor, ket_state::ITensor + blf::BilinearFormNetwork, original_state_vertex, bra_state::ITensor, ket_state::ITensor ) blf = copy(blf) # TODO: Maybe add a check that it really does preserve the graph. - setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex(blf, state_vertex)) - setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex(blf, state_vertex)) + setindex_preserve_graph!( + tensornetwork(blf), bra_state, bra_vertex(blf, original_state_vertex) + ) + setindex_preserve_graph!( + tensornetwork(blf), ket_state, ket_vertex(blf, original_state_vertex) + ) return blf end diff --git a/src/formnetworks/quadraticformnetwork.jl b/src/formnetworks/quadraticformnetwork.jl index 5acee59e..8aac841a 100644 --- a/src/formnetworks/quadraticformnetwork.jl +++ b/src/formnetworks/quadraticformnetwork.jl @@ -57,9 +57,9 @@ function QuadraticFormNetwork( return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map) end -function update(qf::QuadraticFormNetwork, state_vertex, ket_state::ITensor) +function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor) state_inds = inds(ket_state) bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds)) - new_blf = update(bilinear_formnetwork(qf), state_vertex, bra_state, ket_state) + new_blf = update(bilinear_formnetwork(qf), original_state_vertex, bra_state, ket_state) return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf)) end From 1d04a046a2005d91ed3b9b4ee2954721a1b29d4f Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 21 Mar 2024 08:39:26 +0000 Subject: [PATCH 10/11] Default partitioning - default partitioned vertices --- src/caches/beliefpropagationcache.jl | 2 +- src/environment.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 92f761c2..fa9ea51e 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -10,7 +10,7 @@ end @traitfn function default_bp_maxiter(g::::IsDirected) return default_bp_maxiter(undirected_graph(underlying_graph(g))) end -default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) +default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor}) diff --git a/src/environment.jl b/src/environment.jl index 7d1c2ef1..efc99ad3 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -26,12 +26,12 @@ function environment( ψ::AbstractITensorNetwork, verts::Vector; (cache!)=nothing, - partitions=default_partitioning(ψ), + partitioned_vertices=default_partitioned_vertices(ψ), update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(cache!), ) if isnothing(cache!) - cache! = Ref(BeliefPropagationCache(ψ, partitions)) + cache! = Ref(BeliefPropagationCache(ψ, partitioned_vertices)) end if update_cache From 9bb5c7a9cfca36770c6dcfc2a03d088e6b480f11 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 21 Mar 2024 17:10:30 +0000 Subject: [PATCH 11/11] Logic for setdiff in environment fixed --- src/environment.jl | 8 ++++---- src/formnetworks/abstractformnetwork.jl | 8 +++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/environment.jl b/src/environment.jl index efc99ad3..262a7c23 100644 --- a/src/environment.jl +++ b/src/environment.jl @@ -12,11 +12,11 @@ end function environment( ::Algorithm"exact", ψ::AbstractITensorNetwork, - vertices::Vector; + verts::Vector; contraction_sequence_alg="optimal", kwargs..., ) - ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices)) + ψ_reduced = Vector{ITensor}(subgraph(ψ, setdiff(vertices(ψ), verts))) sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg) return ITensor[contract(ψ_reduced; sequence, kwargs...)] end @@ -24,7 +24,7 @@ end function environment( ::Algorithm"bp", ψ::AbstractITensorNetwork, - verts::Vector; + vertices::Vector; (cache!)=nothing, partitioned_vertices=default_partitioned_vertices(ψ), update_cache=isnothing(cache!), @@ -38,5 +38,5 @@ function environment( cache![] = update(cache![]; cache_update_kwargs...) end - return environment(cache![], setdiff(vertices(ψ), verts)) + return environment(cache![], vertices) end diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 5f70955d..f0557ac6 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -65,10 +65,12 @@ function environment( alg=default_environment_algorithm(), kwargs..., ) - form_vertices = setdiff(vertices(f), state_vertices(f, original_state_vertices)) + form_vertices = state_vertices(f, original_state_vertices) if alg == "bp" - partitions = group(v -> original_state_vertex(f, v), vertices(f)) - return environment(tensornetwork(f), form_vertices; alg, partitions, kwargs...) + partitioned_vertices = group(v -> original_state_vertex(f, v), vertices(f)) + return environment( + tensornetwork(f), form_vertices; alg, partitioned_vertices, kwargs... + ) else return environment(tensornetwork(f), form_vertices; alg, kwargs...) end