Skip to content

Commit

Permalink
Implement inner using BP (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Apr 12, 2024
1 parent 906f184 commit f813653
Show file tree
Hide file tree
Showing 23 changed files with 554 additions and 208 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ julia> using ITensorNetworks: ITensorNetwork, siteinds
julia> using NamedGraphs: named_grid, subgraph

julia> tn = ITensorNetwork(named_grid(4); link_space=2)
ITensorNetworks.ITensorNetwork{Int64} with 4 vertices:
ITensorNetwork{Int64} with 4 vertices:
4-element Vector{Int64}:
1
2
Expand Down Expand Up @@ -90,7 +90,7 @@ and here is a similar example for making a tensor network on a grid (a tensor pr

```julia
julia> tn = ITensorNetwork(named_grid((2, 2)); link_space=2)
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 4 vertices:
ITensorNetwork{Tuple{Int64, Int64}} with 4 vertices:
4-element Vector{Tuple{Int64, Int64}}:
(1, 1)
(2, 1)
Expand Down Expand Up @@ -125,7 +125,7 @@ julia> neighbors(tn, (1, 2))
(2, 2)

julia> tn_1 = subgraph(v -> v[1] == 1, tn)
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
2-element Vector{Tuple{Int64, Int64}}:
(1, 1)
(1, 2)
Expand All @@ -139,7 +139,7 @@ with vertex data:
(1, 2) │ ((dim=2|id=723|"1×1,1×2"), (dim=2|id=712|"1×2,2×2"))

julia> tn_2 = subgraph(v -> v[1] == 2, tn)
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
2-element Vector{Tuple{Int64, Int64}}:
(2, 1)
(2, 2)
Expand Down Expand Up @@ -184,7 +184,7 @@ and edge data:
0-element Dictionaries.Dictionary{NamedGraphs.NamedEdge{Int64}, Vector{ITensors.Index}}

julia> tn1 = ITensorNetwork(s; link_space=2)
ITensorNetworks.ITensorNetwork{Int64} with 3 vertices:
ITensorNetwork{Int64} with 3 vertices:
3-element Vector{Int64}:
1
2
Expand All @@ -201,7 +201,7 @@ with vertex data:
3 │ ((dim=2|id=656|"S=1/2,Site,n=3"), (dim=2|id=190|"2,3"))
julia> tn2 = ITensorNetwork(s; link_space=2)
ITensorNetworks.ITensorNetwork{Int64} with 3 vertices:
ITensorNetwork{Int64} with 3 vertices:
3-element Vector{Int64}:
1
2
Expand Down
3 changes: 2 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ include("opsum.jl")
include("sitetype.jl")
include("abstractitensornetwork.jl")
include("contraction_sequences.jl")
include("expect.jl")
include("tebd.jl")
include("itensornetwork.jl")
include("mincut.jl")
Expand Down Expand Up @@ -64,6 +63,8 @@ include("solvers/contract.jl")
include("solvers/linsolve.jl")
include("solvers/sweep_plans/sweep_plans.jl")
include("apply.jl")
include("inner.jl")
include("expect.jl")
include("environment.jl")
include("exports.jl")
include("ModelHamiltonians/ModelHamiltonians.jl")
Expand Down
82 changes: 82 additions & 0 deletions src/ITensorsExtensions/ITensorsExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module ITensorsExtensions
using LinearAlgebra: LinearAlgebra, eigen, pinv
using ITensors:
ITensor,
Index,
commonind,
dag,
hasqns,
inds,
isdiag,
itensor,
map_diag,
noncommonind,
noprime,
replaceinds,
space,
sqrt_decomp
using ITensors.NDTensors:
NDTensors,
Block,
Tensor,
blockdim,
blockoffsets,
denseblocks,
diaglength,
getdiagindex,
nzblocks,
setdiagindex!,
svd,
tensor,
DiagBlockSparseTensor,
DenseTensor,
BlockOffsets
using Observers: update!, insert_function!

function NDTensors.blockoffsets(dense::DenseTensor)
return BlockOffsets{ndims(dense)}([Block(ntuple(Returns(1), ndims(dense)))], [0])
end
function NDTensors.nzblocks(dense::DenseTensor)
return nzblocks(blockoffsets(dense))
end
NDTensors.blockdim(ind::Int, ::Block{1}) = ind
NDTensors.blockdim(i::Index{Int}, b::Integer) = blockdim(i, Block(b))
NDTensors.blockdim(i::Index{Int}, b::Block) = blockdim(space(i), b)

LinearAlgebra.isdiag(it::ITensor) = isdiag(tensor(it))

# Convenience functions
sqrt_diag(it::ITensor) = map_diag(sqrt, it)
inv_diag(it::ITensor) = map_diag(inv, it)
invsqrt_diag(it::ITensor) = map_diag(inv sqrt, it)
pinv_diag(it::ITensor) = map_diag(pinv, it)
pinvsqrt_diag(it::ITensor) = map_diag(pinv sqrt, it)

function map_itensor(
f::Function, A::ITensor, lind=first(inds(A)); regularization=nothing, kwargs...
)
USV = svd(A, lind; kwargs...)
U, S, V, spec, u, v = USV
S = map_diag(s -> f(s + regularization), S)
sqrtDL, δᵤᵥ, sqrtDR = sqrt_decomp(S, u, v)
sqrtDR = denseblocks(sqrtDR) * denseblocks(δᵤᵥ)
L, R = U * sqrtDL, V * sqrtDR
return L * R
end

# Analagous to `denseblocks`.
# Extract the diagonal entries into a diagonal tensor.
function diagblocks(D::Tensor)
nzblocksD = nzblocks(D)
T = DiagBlockSparseTensor(eltype(D), nzblocksD, inds(D))
for b in nzblocksD
for n in 1:diaglength(D)
setdiagindex!(T, getdiagindex(D, n), n)
end
end
return T
end

diagblocks(it::ITensor) = itensor(diagblocks(tensor(it)))

end
66 changes: 12 additions & 54 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -735,67 +735,23 @@ function flatten_networks(
return flatten_networks(flatten_networks(tn1, tn2; kwargs...), tn3, tn_tail...; kwargs...)
end

#Ideally this will dispatch to inner_network but this is a temporary fast version for now
function norm_network(tn::AbstractITensorNetwork)
tnbra = rename_vertices(v -> (v, 1), data_graph(tn))
tndag = copy(tn)
for v in vertices(tndag)
setindex_preserve_graph!(tndag, dag(tndag[v]), v)
end
tnket = rename_vertices(v -> (v, 2), data_graph(prime(tndag; sites=[])))
# TODO: Use a different constructor here?
tntn = _ITensorNetwork(union(tnbra, tnket))
for v in vertices(tn)
if !isempty(commoninds(tntn[(v, 1)], tntn[(v, 2)]))
add_edge!(tntn, (v, 1) => (v, 2))
end
end
return tntn
function inner_network(x::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...)
return BilinearFormNetwork(x, y; kwargs...)
end

# TODO: Use or replace with `flatten_networks`
function inner_network(
tn1::AbstractITensorNetwork,
tn2::AbstractITensorNetwork;
map_bra_linkinds=sim,
combine_linkinds=false,
flatten=combine_linkinds,
kwargs...,
x::AbstractITensorNetwork, A::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...
)
@assert issetequal(vertices(tn1), vertices(tn2))
tn1 = map_bra_linkinds(tn1; sites=[])
inner_net = (dag(tn1), tn2; kwargs...)
if flatten
for v in vertices(tn1)
inner_net = contract(inner_net, (v, 2) => (v, 1); merged_vertex=v)
end
end
if combine_linkinds
inner_net = ITensorNetworks.combine_linkinds(inner_net)
end
return inner_net
return BilinearFormNetwork(A, x, y; kwargs...)
end

# TODO: Rename `inner`.
function contract_inner(
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
sequence=nothing,
contraction_sequence_kwargs=(;),
)
tn = inner_network(ϕ, ψ; combine_linkinds=true)
if isnothing(sequence)
sequence = contraction_sequence(tn; contraction_sequence_kwargs...)
end
return contract(tn; sequence)[]
# TODO: We should make this use the QuadraticFormNetwork constructor here.
# Parts of the code (tests relying on norm_sqr being two layer and the gauging code
# which relies on specific message tensors) currently would break in that case so we need to resolve
function norm_sqr_network::AbstractITensorNetwork)
return disjoint_union("bra" => dag(prime(ψ; sites=[])), "ket" => ψ)
end

# TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,
# or `norm_sqr` to match `LinearAlgebra.norm_sqr`
norm_sqr::AbstractITensorNetwork; sequence) = contract_inner(ψ, ψ; sequence)

norm_sqr_network::AbstractITensorNetwork; kwargs...) = inner_network(ψ, ψ; kwargs...)

#
# Printing
#
Expand Down Expand Up @@ -942,7 +898,7 @@ function ITensorMPS.add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork

#Create vertices of tn12 as direct sum of tn1[v] and tn2[v]. Work out the matching indices by matching edges. Make index tags those of tn1[v]
for v in vertices(tn1)
@assert siteinds(tn1, v) == siteinds(tn2, v)
@assert issetequal(siteinds(tn1, v), siteinds(tn2, v))

e1_v = filter(x -> src(x) == v || dst(x) == v, edges_tn1)
e2_v = filter(x -> src(x) == v || dst(x) == v, edges_tn2)
Expand All @@ -966,3 +922,5 @@ function ITensorMPS.add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork
end

Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)

ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn))
2 changes: 1 addition & 1 deletion src/approx_itensornetwork/binary_tree_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,6 @@ function _partition(
return rename_vertices(par, name_map)
end

function _partition(tn::ITensorNetwork, inds_btree::DataGraph; alg::String)
function _partition(tn::ITensorNetwork, inds_btree::DataGraph; alg)
return _partition(Algorithm(alg), tn, inds_btree)
end
59 changes: 51 additions & 8 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ using NamedGraphs: PartitionVertex
using LinearAlgebra: diag
using ITensors: dir
using ITensors.ITensorMPS: ITensorMPS
using NamedGraphs: boundary_partitionedges
using NamedGraphs: boundary_partitionedges, partitionvertices, partitionedges

default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
default_message_norm(m::ITensor) = norm(m)
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
Expand All @@ -21,12 +22,20 @@ end
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioned_vertices::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
function default_partitioned_vertices(f::AbstractFormNetwork)
return group(v -> original_state_vertex(f, v), vertices(f))
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
return (; partitioned_vertices=default_partitioned_vertices(ψ))
end

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
function message_diff(
message_a::Vector{ITensor}, message_b::Vector{ITensor}; message_norm=default_message_norm
)
lhs, rhs = contract(message_a), contract(message_b)
return 0.5 *
norm((denseblocks(lhs) / sum(diag(lhs))) - (denseblocks(rhs) / sum(diag(rhs))))
norm_lhs, norm_rhs = message_norm(lhs), message_norm(rhs)
return 0.5 * norm((denseblocks(lhs) / norm_lhs) - (denseblocks(rhs) / norm_rhs))
end

struct BeliefPropagationCache{PTN,MTS,DM}
Expand All @@ -47,8 +56,14 @@ function BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
return BeliefPropagationCache(ptn; kwargs...)
end

function BeliefPropagationCache(tn; kwargs...)
return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...)
function BeliefPropagationCache(
tn; partitioned_vertices=default_partitioned_vertices(tn), kwargs...
)
return BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
end

function cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end

function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
Expand Down Expand Up @@ -118,7 +133,7 @@ function environment(
)
bpes = boundary_partitionedges(bp_cache, partition_vertices; dir=:in)
ms = messages(bp_cache, setdiff(bpes, ignore_edges))
return reduce(vcat, ms; init=[])
return reduce(vcat, ms; init=ITensor[])
end

function environment(
Expand Down Expand Up @@ -216,11 +231,11 @@ function update(
kwargs...,
)
compute_error = !isnothing(tol)
diff = compute_error ? Ref(0.0) : nothing
if isnothing(maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
diff = compute_error ? Ref(0.0) : nothing
bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
Expand Down Expand Up @@ -251,3 +266,31 @@ end
function update_factor(bp_cache, vertex, factor)
return update_factors(bp_cache, [vertex], ITensor[factor])
end

function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
incoming_mts = environment(bp_cache, [pv])
local_state = factor(bp_cache, pv)
return contract(vcat(incoming_mts, local_state))[]
end

function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
return contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
end

function vertex_scalars(
bp_cache::BeliefPropagationCache,
pvs::Vector=partitionvertices(partitioned_tensornetwork(bp_cache)),
)
return [region_scalar(bp_cache, pv) for pv in pvs]
end

function edge_scalars(
bp_cache::BeliefPropagationCache,
pes::Vector=partitionedges(partitioned_tensornetwork(bp_cache)),
)
return [region_scalar(bp_cache, pe) for pe in pes]
end

function scalar_factors_quotient(bp_cache::BeliefPropagationCache)
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
end
Loading

0 comments on commit f813653

Please sign in to comment.