Skip to content

Commit

Permalink
New Derivative Functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 18, 2024
1 parent 94691eb commit e9a6bc5
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 41 deletions.
20 changes: 14 additions & 6 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
return contract_exact(contract_list; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
updated_messages = normalize!(updated_messages)
return ITensor[updated_messages]
end
default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="optimal")
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_cache::ITensorNetwork) = BeliefPropagationCache(ψ, [[v] for v in vertices(tn)])
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)
default_partitioning::AbstractITensorNetwork) = group(v -> v, vertices(ψ))

#We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition.
default_partitioning(f::AbstractFormNetwork) = group(v -> first(v), vertices(f))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
Expand All @@ -31,11 +35,15 @@ function BeliefPropagationCache(
return BeliefPropagationCache(ptn, messages, default_message)
end

function BeliefPropagationCache(tn::ITensorNetwork, partitioned_vertices; kwargs...)
function BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
ptn = PartitionedGraph(tn, partitioned_vertices)
return BeliefPropagationCache(ptn; kwargs...)
end

function BeliefPropagationCache(tn; kwargs...)
return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...)
end

function partitioned_itensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_itensornetwork
end
Expand Down Expand Up @@ -133,7 +141,7 @@ function update_message(
bp_cache::BeliefPropagationCache,
edge::PartitionEdge;
message_update=default_message_update,
message_update_kwargs=default_message_update_kwargs(),
message_update_kwargs=(;),
)
vertex = src(edge)
messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
Expand Down
15 changes: 0 additions & 15 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,3 @@ function contract_density_matrix(
end
return out
end

#This should probably just be removed in favour of the second function above
function contract_exact(
contract_list::Vector{ITensor};
contraction_sequence_alg="optimal",
normalize=true,
contractor_kwargs...,
)
seq = contraction_sequence(contract_list; alg=contraction_sequence_alg)
out = ITensors.contract(contract_list; sequence=seq, contractor_kwargs...)
if normalize
normalize!(out)
end
return ITensor[out]
end
40 changes: 29 additions & 11 deletions src/derivative.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
default_derivative_algorithm() = "exact"

function derivative::AbstractITensorNetwork, vertices::Vector; alg = default_derivative_algorithm(), kwargs...)
return derivative(Algorithm(alg), ψ, vertices; kwargs...)
function derivative(
ψ::AbstractITensorNetwork, vertices::Vector; alg=default_derivative_algorithm(), kwargs...
)
return derivative(Algorithm(alg), ψ, vertices; kwargs...)
end

function derivative(::Algorithm"exact", ψ::AbstractITensorNetwork, vertices::Vector; contraction_sequence_alg = "optimal", kwargs...)
ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices))
return contract_exact(ψ_reduced; normalize = false, contraction_sequence_alg, kwargs...)
function derivative(
::Algorithm"exact",
ψ::AbstractITensorNetwork,
vertices::Vector;
contraction_sequence_alg="optimal",
kwargs...,
)
ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices))
sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg)
return ITensor[contract(ψ_reduced; sequence, kwargs...)]
end

function derivative(::Algorithm"bp", ψ::AbstractITensorNetwork, vertices::Vector; (bp_cache!) = nothing, bp_cache_update_kwargs = default_cache_update_kwargs(bp_cache))
function derivative(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
verts::Vector;
(bp_cache!)=nothing,
update_bp_cache=isnothing(bp_cache!),
bp_cache_update_kwargs=default_cache_update_kwargs(bp_cache!),
)
if isnothing(bp_cache!)
bp_cache! = Ref(BeliefPropagationCache(ψ))
end

if isnothing(bp_cache!)
bp_cache! = Ref(default_cache(ψ))
end
if update_bp_cache
bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...)
return incoming_messages(bp_cache![], vertices)
end

return incoming_messages(bp_cache![], setdiff(vertices(ψ), verts))
end

4 changes: 2 additions & 2 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ function operator_network(f::AbstractFormNetwork)
)
end

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
function derivative_state(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
return derivative(f, tn_vertices; kwargs...)
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
Expand Down
3 changes: 2 additions & 1 deletion test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
incoming_messages
incoming_messages,
contract
using Test
using Compat
using ITensors
Expand Down
20 changes: 14 additions & 6 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ITensorNetworks:
bra_network,
ket_network,
operator_network,
derivative,
derivative_state,
BeliefPropagationCache
using Test
using Random
Expand Down Expand Up @@ -52,10 +52,18 @@ using SplitApplyCombine
@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(derivative(qf, [v]))
@test (∂qf_∂v)*(qf[ket_vertex_map(qf)(v)]*qf[bra_vertex_map(qf)(v)]) ITensors.contract(qf)
∂qf_∂v = only(derivative_state(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex_map(qf)(v)] * qf[bra_vertex_map(qf)(v)])
ITensors.contract(qf)

partition_vertices= group(v -> first(v), vertices(qf))
qf_cache = BeliefPropagationCache(tensornetwork(qf), partition_vertices)
∂qf_∂v = derivative(qf, [v]; alg = "bp", (bp_cache!) = Ref(qf_cache))
∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=false)
∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp)
∂qf_∂v_bp = normalize!(∂qf_∂v_bp)
∂qf_∂v = normalize!(∂qf_∂v)
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=true)
∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp)
∂qf_∂v_bp = normalize!(∂qf_∂v_bp)
@test ∂qf_∂v_bp ∂qf_∂v
end

0 comments on commit e9a6bc5

Please sign in to comment.