Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Environments #145

Merged
merged 11 commits into from
Mar 21, 2024
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ include(joinpath("treetensornetworks", "solvers", "contract.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include("apply.jl")
include("derivative.jl")

include("exports.jl")

Expand Down
20 changes: 16 additions & 4 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
return contract_exact(contract_list; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
updated_messages = normalize!(updated_messages)
return ITensor[updated_messages]
end
default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="optimal")
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioning(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

#We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition.
default_partitioning(f::AbstractFormNetwork) = group(v -> first(v), vertices(f))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
return 0.5 *
Expand All @@ -27,11 +35,15 @@ function BeliefPropagationCache(
return BeliefPropagationCache(ptn, messages, default_message)
end

function BeliefPropagationCache(tn::ITensorNetwork, partitioned_vertices; kwargs...)
function BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
ptn = PartitionedGraph(tn, partitioned_vertices)
return BeliefPropagationCache(ptn; kwargs...)
end

function BeliefPropagationCache(tn; kwargs...)
return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...)
end

function partitioned_itensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_itensornetwork
end
Expand Down Expand Up @@ -129,7 +141,7 @@ function update_message(
bp_cache::BeliefPropagationCache,
edge::PartitionEdge;
message_update=default_message_update,
message_update_kwargs=default_message_update_kwargs(),
message_update_kwargs=(;),
)
vertex = src(edge)
messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
Expand Down
14 changes: 0 additions & 14 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,3 @@ function contract_density_matrix(
end
return out
end

function contract_exact(
contract_list::Vector{ITensor};
contraction_sequence_alg="optimal",
normalize=true,
contractor_kwargs...,
)
seq = contraction_sequence(contract_list; alg=contraction_sequence_alg)
out = ITensors.contract(contract_list; sequence=seq, contractor_kwargs...)
if normalize
normalize!(out)
end
return ITensor[out]
end
38 changes: 38 additions & 0 deletions src/derivative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
default_derivative_algorithm() = "exact"

function derivative(
ψ::AbstractITensorNetwork, vertices::Vector; alg=default_derivative_algorithm(), kwargs...
)
return derivative(Algorithm(alg), ψ, vertices; kwargs...)
end

function derivative(
::Algorithm"exact",
ψ::AbstractITensorNetwork,
vertices::Vector;
contraction_sequence_alg="optimal",
kwargs...,
)
ψ_reduced = Vector{ITensor}(subgraph(ψ, vertices))
sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg)
return ITensor[contract(ψ_reduced; sequence, kwargs...)]
end

function derivative(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
verts::Vector;
(bp_cache!)=nothing,
update_bp_cache=isnothing(bp_cache!),
bp_cache_update_kwargs=default_cache_update_kwargs(bp_cache!),
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
)
if isnothing(bp_cache!)
bp_cache! = Ref(BeliefPropagationCache(ψ))
end

if update_bp_cache
bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...)
end

return incoming_messages(bp_cache![], setdiff(vertices(ψ), verts))
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
end
4 changes: 2 additions & 2 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ function operator_network(f::AbstractFormNetwork)
)
end

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
function derivative_state(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
return derivative(f, tn_vertices; kwargs...)
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
Expand Down
1 change: 0 additions & 1 deletion src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ function default_norm_cache(ψ::ITensorNetwork)
ψψ = norm_network(ψ)
return BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function ITensorNetwork(
ψ_vidal::VidalITensorNetwork; (cache!)=nothing, update_gauge=false, update_kwargs...
Expand Down
3 changes: 2 additions & 1 deletion test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
incoming_messages
incoming_messages,
contract
using Test
using Compat
using ITensors
Expand Down
22 changes: 20 additions & 2 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ using ITensorNetworks:
dual_index_map,
bra_network,
ket_network,
operator_network
operator_network,
derivative_state,
BeliefPropagationCache
using Test
using Random
using SplitApplyCombine

@testset "FormNetworkss" begin
@testset "FormNetworks" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
Expand Down Expand Up @@ -48,4 +51,19 @@ using Random

@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(derivative_state(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex_map(qf)(v)] * qf[bra_vertex_map(qf)(v)]) ≈
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
ITensors.contract(qf)

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=false)
∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
∂qf_∂v_bp = normalize!(∂qf_∂v_bp)
∂qf_∂v = normalize!(∂qf_∂v)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=true)
∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp)
∂qf_∂v_bp = normalize!(∂qf_∂v_bp)
@test ∂qf_∂v_bp ≈ ∂qf_∂v
end
Loading