Skip to content

Commit

Permalink
derivate_state -> derivative.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 19, 2024
1 parent 3e875cf commit 40d08ad
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 24 deletions.
13 changes: 6 additions & 7 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ end
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioning::AbstractITensorNetwork) = group(v -> v, vertices(ψ))

#We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition.
default_partitioning(f::AbstractFormNetwork) = group(v -> state_vertex(f, v), vertices(f))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)
Expand Down Expand Up @@ -104,7 +103,7 @@ function set_messages(cache::BeliefPropagationCache, messages)
)
end

function incoming_messages(
function environment(
bp_cache::BeliefPropagationCache,
partition_vertices::Vector{<:PartitionVertex};
ignore_edges=PartitionEdge[],
Expand All @@ -114,15 +113,15 @@ function incoming_messages(
return reduce(vcat, ms; init=[])
end

function incoming_messages(
function environment(
bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs...
)
return incoming_messages(bp_cache, [partition_vertex]; kwargs...)
return environment(bp_cache, [partition_vertex]; kwargs...)
end

function incoming_messages(bp_cache::BeliefPropagationCache, verts::Vector)
function environment(bp_cache::BeliefPropagationCache, verts::Vector)
partition_verts = partitionvertices(bp_cache, verts)
messages = incoming_messages(bp_cache, partition_verts)
messages = environment(bp_cache, partition_verts)
central_tensors = ITensor[
tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts)
]
Expand All @@ -144,7 +143,7 @@ function update_message(
message_update_kwargs=(;),
)
vertex = src(edge)
messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bp_cache, vertex)

return message_update(ITensor[messages; state]; message_update_kwargs...)
Expand Down
5 changes: 3 additions & 2 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ function derivative(
ψ::AbstractITensorNetwork,
verts::Vector;
(cache!)=nothing,
partitions=default_partitioning(ψ),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(ψ))
cache! = Ref(BeliefPropagationCache, partitions))
end

if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end

return incoming_messages(cache![], setdiff(vertices(ψ), verts))
return environment(cache![], setdiff(vertices(ψ), verts))
end
14 changes: 12 additions & 2 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,19 @@ function operator_network(f::AbstractFormNetwork)
)
end

function derivative_state(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
function derivative(
f::AbstractFormNetwork,
state_vertices::Vector;
alg=default_derivative_algorithm(),
kwargs...,
)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(f, tn_vertices; kwargs...)
if alg == "bp"
partitions = group(v -> state_vertex(f, v), vertices(f))
return derivative(tensornetwork(f), tn_vertices; alg, partitions, kwargs...)
else
return derivative(tensornetwork(f), tn_vertices; alg, kwargs...)
end
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
Expand Down
6 changes: 3 additions & 3 deletions test/test_apply.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ITensorNetworks
using ITensorNetworks:
incoming_messages,
environment,
update,
contract_inner,
norm_network,
Expand Down Expand Up @@ -29,14 +29,14 @@ using SplitApplyCombine
#Simple Belief Propagation Grouping
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter=20)
envsSBP = incoming_messages(bp_cache, PartitionVertex.([v1, v2]))
envsSBP = environment(bp_cache, PartitionVertex.([v1, v2]))

ψv = VidalITensorNetwork(ψ)

#This grouping will correspond to calculating the environments exactly (each column of the grid is a partition)
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter=20)
envsGBP = incoming_messages(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])
envsGBP = environment(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])

ngates = 5

Expand Down
12 changes: 6 additions & 6 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
incoming_messages,
environment,
contract
using Test
using Compat
Expand Down Expand Up @@ -41,7 +41,7 @@ ITensors.disable_warn_order()

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = incoming_messages(bpc, [PartitionVertex(v)])
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

Expand Down Expand Up @@ -71,7 +71,7 @@ ITensors.disable_warn_order()

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = incoming_messages(bpc, [PartitionVertex(v)])
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

Expand All @@ -94,7 +94,7 @@ ITensors.disable_warn_order()
bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc; maxiter=20)

env_tensors = incoming_messages(bpc, vs)
env_tensors = environment(bpc, vs)
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[]

Expand All @@ -113,7 +113,7 @@ ITensors.disable_warn_order()
bpc = update(bpc; maxiter=20)

ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = incoming_messages(bpc, [(v, 2) for v in vs])
env_tensors = environment(bpc, [(v, 2) for v in vs])
rdm = ITensors.contract(
vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]])
)
Expand Down Expand Up @@ -149,7 +149,7 @@ ITensors.disable_warn_order()
message_update_kwargs=(; cutoff=1e-6, maxdim=4),
)

env_tensors = incoming_messages(bpc, [v])
env_tensors = environment(bpc, [v])
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[]

Expand Down
8 changes: 4 additions & 4 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ITensorNetworks:
bra_network,
ket_network,
operator_network,
derivative_state,
derivative,
BeliefPropagationCache
using Test
using Random
Expand Down Expand Up @@ -52,16 +52,16 @@ using SplitApplyCombine
@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(derivative_state(qf, [v]))
∂qf_∂v = only(derivative(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) contract(qf)

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
∂qf_∂v /= norm(∂qf_∂v)
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
@test ∂qf_∂v_bp ∂qf_∂v
Expand Down

0 comments on commit 40d08ad

Please sign in to comment.