From 40d08adc8a12a044c3414796a5a59f6b36de1f16 Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 19 Mar 2024 14:41:22 +0000 Subject: [PATCH] derivate_state -> derivative.jl --- src/caches/beliefpropagationcache.jl | 13 ++++++------- src/derivative.jl | 5 +++-- src/formnetworks/abstractformnetwork.jl | 14 ++++++++++++-- test/test_apply.jl | 6 +++--- test/test_belief_propagation.jl | 12 ++++++------ test/test_forms.jl | 8 ++++---- 6 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 072d9553..ab6204b5 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -11,7 +11,6 @@ end return default_bp_maxiter(undirected_graph(underlying_graph(g))) end default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) - #We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition. default_partitioning(f::AbstractFormNetwork) = group(v -> state_vertex(f, v), vertices(f)) default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) @@ -104,7 +103,7 @@ function set_messages(cache::BeliefPropagationCache, messages) ) end -function incoming_messages( +function environment( bp_cache::BeliefPropagationCache, partition_vertices::Vector{<:PartitionVertex}; ignore_edges=PartitionEdge[], @@ -114,15 +113,15 @@ function incoming_messages( return reduce(vcat, ms; init=[]) end -function incoming_messages( +function environment( bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs... ) - return incoming_messages(bp_cache, [partition_vertex]; kwargs...) + return environment(bp_cache, [partition_vertex]; kwargs...) end -function incoming_messages(bp_cache::BeliefPropagationCache, verts::Vector) +function environment(bp_cache::BeliefPropagationCache, verts::Vector) partition_verts = partitionvertices(bp_cache, verts) - messages = incoming_messages(bp_cache, partition_verts) + messages = environment(bp_cache, partition_verts) central_tensors = ITensor[ tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts) ] @@ -144,7 +143,7 @@ function update_message( message_update_kwargs=(;), ) vertex = src(edge) - messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) + messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)]) state = factor(bp_cache, vertex) return message_update(ITensor[messages; state]; message_update_kwargs...) diff --git a/src/derivative.jl b/src/derivative.jl index 347f9c5d..4acbd1a8 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -23,16 +23,17 @@ function derivative( ψ::AbstractITensorNetwork, verts::Vector; (cache!)=nothing, + partitions=default_partitioning(ψ), update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(cache!), ) if isnothing(cache!) - cache! = Ref(BeliefPropagationCache(ψ)) + cache! = Ref(BeliefPropagationCache(ψ, partitions)) end if update_cache cache![] = update(cache![]; cache_update_kwargs...) end - return incoming_messages(cache![], setdiff(vertices(ψ), verts)) + return environment(cache![], setdiff(vertices(ψ), verts)) end diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 82d79a24..66b395ad 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -57,9 +57,19 @@ function operator_network(f::AbstractFormNetwork) ) end -function derivative_state(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) +function derivative( + f::AbstractFormNetwork, + state_vertices::Vector; + alg=default_derivative_algorithm(), + kwargs..., +) tn_vertices = derivative_vertices(f, state_vertices) - return derivative(f, tn_vertices; kwargs...) + if alg == "bp" + partitions = group(v -> state_vertex(f, v), vertices(f)) + return derivative(tensornetwork(f), tn_vertices; alg, partitions, kwargs...) + else + return derivative(tensornetwork(f), tn_vertices; alg, kwargs...) + end end function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) diff --git a/test/test_apply.jl b/test/test_apply.jl index 0a815b9a..f32d8853 100644 --- a/test/test_apply.jl +++ b/test/test_apply.jl @@ -1,6 +1,6 @@ using ITensorNetworks using ITensorNetworks: - incoming_messages, + environment, update, contract_inner, norm_network, @@ -29,14 +29,14 @@ using SplitApplyCombine #Simple Belief Propagation Grouping bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsSBP = incoming_messages(bp_cache, PartitionVertex.([v1, v2])) + envsSBP = environment(bp_cache, PartitionVertex.([v1, v2])) ψv = VidalITensorNetwork(ψ) #This grouping will correspond to calculating the environments exactly (each column of the grid is a partition) bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ))) bp_cache = update(bp_cache; maxiter=20) - envsGBP = incoming_messages(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) + envsGBP = environment(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) ngates = 5 diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index 80d8df11..4cbb9a7a 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -8,7 +8,7 @@ using ITensorNetworks: tensornetwork, update, update_factor, - incoming_messages, + environment, contract using Test using Compat @@ -41,7 +41,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc) - env_tensors = incoming_messages(bpc, [PartitionVertex(v)]) + env_tensors = environment(bpc, [PartitionVertex(v)]) numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[] denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[] @@ -71,7 +71,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc) - env_tensors = incoming_messages(bpc, [PartitionVertex(v)]) + env_tensors = environment(bpc, [PartitionVertex(v)]) numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[] denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[] @@ -94,7 +94,7 @@ ITensors.disable_warn_order() bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ))) bpc = update(bpc; maxiter=20) - env_tensors = incoming_messages(bpc, vs) + env_tensors = environment(bpc, vs) numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[] denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[] @@ -113,7 +113,7 @@ ITensors.disable_warn_order() bpc = update(bpc; maxiter=20) ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs])) - env_tensors = incoming_messages(bpc, [(v, 2) for v in vs]) + env_tensors = environment(bpc, [(v, 2) for v in vs]) rdm = ITensors.contract( vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]) ) @@ -149,7 +149,7 @@ ITensors.disable_warn_order() message_update_kwargs=(; cutoff=1e-6, maxdim=4), ) - env_tensors = incoming_messages(bpc, [v]) + env_tensors = environment(bpc, [v]) numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[] denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[] diff --git a/test/test_forms.jl b/test/test_forms.jl index a7652c52..63686336 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -12,7 +12,7 @@ using ITensorNetworks: bra_network, ket_network, operator_network, - derivative_state, + derivative, BeliefPropagationCache using Test using Random @@ -52,16 +52,16 @@ using SplitApplyCombine @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) - ∂qf_∂v = only(derivative_state(qf, [v])) + ∂qf_∂v = only(derivative(qf, [v])) @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=false) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) ∂qf_∂v /= norm(∂qf_∂v) @test ∂qf_∂v_bp != ∂qf_∂v - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=true) ∂qf_∂v_bp = contract(∂qf_∂v_bp) ∂qf_∂v_bp /= norm(∂qf_∂v_bp) @test ∂qf_∂v_bp ≈ ∂qf_∂v