Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement inner using BP #147

Merged
merged 41 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
62ad944
Refactor inner interface. Add logscalar(tn) functionality
JoeyT1994 Mar 21, 2024
fe27c86
Updated tests and examples to work with new inner interface
JoeyT1994 Mar 22, 2024
2720491
Merge remote-tracking branch 'upstream/main' into logcontract
JoeyT1994 Mar 25, 2024
1af45d8
Tree tests now pass to new inner function
JoeyT1994 Mar 25, 2024
eb3fa77
Bug Fix.
JoeyT1994 Mar 26, 2024
e6e8c6f
Argument rearrange DMRGx
JoeyT1994 Mar 26, 2024
c6f4dac
Reinstate Tree Inner for TTN Type
JoeyT1994 Mar 26, 2024
372f9b4
Format
JoeyT1994 Mar 26, 2024
fbe9063
Remove import of check_hascommoninds. Fix norm naming
JoeyT1994 Mar 26, 2024
19a158c
Refactor flatten_networks
JoeyT1994 Mar 27, 2024
9971ff3
Remove set nsite and ProjMPS
JoeyT1994 Mar 27, 2024
cb4ccc1
Formatting
JoeyT1994 Mar 27, 2024
8efff49
Modified inner and forms test for new BiLinearForm code
JoeyT1994 Mar 28, 2024
5b1393f
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 28, 2024
74e5b0b
Refactor tests to account for changes
JoeyT1994 Mar 28, 2024
4e8a1ef
All tests refactored. logscalar in terms of scalarnorm
JoeyT1994 Mar 28, 2024
2c6da52
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 29, 2024
ba77b9e
Refactor test bp and test inner for new namespace formatting
JoeyT1994 Mar 29, 2024
691320e
Refactor test_apply and test_additensornetworks
JoeyT1994 Mar 29, 2024
755cbbd
Formatting, namespace fix on test apply
JoeyT1994 Mar 30, 2024
a7080ea
Better scalar definitions in contract.jl
JoeyT1994 Apr 1, 2024
d94900c
Alphabetize namespaces, polish up
JoeyT1994 Apr 1, 2024
9eb55c4
Formatting
JoeyT1994 Apr 1, 2024
9b33331
Bug Fix
JoeyT1994 Apr 2, 2024
be39bc4
Bug Fix
JoeyT1994 Apr 2, 2024
84da9d7
Fixed logscalar bug
JoeyT1994 Apr 2, 2024
4bf3bab
Quick Bug Fix, add_itensornetworks
JoeyT1994 Apr 2, 2024
bd59e90
Merge remote-tracking branch 'upstream/main' into logcontract
JoeyT1994 Apr 3, 2024
c750a08
Merged upstream changes
JoeyT1994 Apr 3, 2024
4227726
Fix tests to be compatable with upstream changes
JoeyT1994 Apr 3, 2024
0664f24
Fix tests to be compatable with upstream changes
JoeyT1994 Apr 3, 2024
beb1fee
Merge remote-tracking branch 'upstream/main' into logcontract and fix…
JoeyT1994 Apr 5, 2024
388c15d
Refactor due to upstream changes
JoeyT1994 Apr 5, 2024
ad21df9
Remove contract_bp example. remove inner(x)
JoeyT1994 Apr 5, 2024
4881095
Added ToDO for norm_sqr_network
JoeyT1994 Apr 5, 2024
825d324
Optional positional argument for scalars(bp_cache)
JoeyT1994 Apr 5, 2024
6c91629
Improved, more general message_diff function
JoeyT1994 Apr 5, 2024
fa0f8a7
Removed support for Vector{ITensor} in contract.jl functions for now
JoeyT1994 Apr 10, 2024
fd5b9a3
Fix Bug in test BP
JoeyT1994 Apr 10, 2024
ec3e840
Merge branch 'main' into logcontract
mtfishman Apr 12, 2024
0026623
Construct correct identity_network in BiLinearFormNetwork
JoeyT1994 Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ julia> using ITensorNetworks: ITensorNetwork, siteinds
julia> using NamedGraphs: named_grid, subgraph

julia> tn = ITensorNetwork(named_grid(4); link_space=2)
ITensorNetworks.ITensorNetwork{Int64} with 4 vertices:
ITensorNetwork{Int64} with 4 vertices:
4-element Vector{Int64}:
1
2
Expand Down Expand Up @@ -90,7 +90,7 @@ and here is a similar example for making a tensor network on a grid (a tensor pr

```julia
julia> tn = ITensorNetwork(named_grid((2, 2)); link_space=2)
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 4 vertices:
ITensorNetwork{Tuple{Int64, Int64}} with 4 vertices:
4-element Vector{Tuple{Int64, Int64}}:
(1, 1)
(2, 1)
Expand Down Expand Up @@ -125,7 +125,7 @@ julia> neighbors(tn, (1, 2))
(2, 2)

julia> tn_1 = subgraph(v -> v[1] == 1, tn)
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
2-element Vector{Tuple{Int64, Int64}}:
(1, 1)
(1, 2)
Expand All @@ -139,7 +139,7 @@ with vertex data:
(1, 2) │ ((dim=2|id=723|"1×1,1×2"), (dim=2|id=712|"1×2,2×2"))

julia> tn_2 = subgraph(v -> v[1] == 2, tn)
ITensorNetworks.ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
ITensorNetwork{Tuple{Int64, Int64}} with 2 vertices:
2-element Vector{Tuple{Int64, Int64}}:
(2, 1)
(2, 2)
Expand Down Expand Up @@ -184,7 +184,7 @@ and edge data:
0-element Dictionaries.Dictionary{NamedGraphs.NamedEdge{Int64}, Vector{ITensors.Index}}

julia> tn1 = ITensorNetwork(s; link_space=2)
ITensorNetworks.ITensorNetwork{Int64} with 3 vertices:
ITensorNetwork{Int64} with 3 vertices:
3-element Vector{Int64}:
1
2
Expand All @@ -201,7 +201,7 @@ with vertex data:
3 │ ((dim=2|id=656|"S=1/2,Site,n=3"), (dim=2|id=190|"2,3"))

julia> tn2 = ITensorNetwork(s; link_space=2)
ITensorNetworks.ITensorNetwork{Int64} with 3 vertices:
ITensorNetwork{Int64} with 3 vertices:
3-element Vector{Int64}:
1
2
Expand Down
3 changes: 2 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ include("opsum.jl")
include("sitetype.jl")
include("abstractitensornetwork.jl")
include("contraction_sequences.jl")
include("expect.jl")
include("tebd.jl")
include("itensornetwork.jl")
include("mincut.jl")
Expand Down Expand Up @@ -64,6 +63,8 @@ include("solvers/contract.jl")
include("solvers/linsolve.jl")
include("solvers/sweep_plans/sweep_plans.jl")
include("apply.jl")
include("inner.jl")
include("expect.jl")
include("environment.jl")
include("exports.jl")
include("ModelHamiltonians/ModelHamiltonians.jl")
Expand Down
82 changes: 82 additions & 0 deletions src/ITensorsExtensions/ITensorsExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module ITensorsExtensions
using LinearAlgebra: LinearAlgebra, eigen, pinv
using ITensors:
ITensor,
Index,
commonind,
dag,
hasqns,
inds,
isdiag,
itensor,
map_diag,
noncommonind,
noprime,
replaceinds,
space,
sqrt_decomp
using ITensors.NDTensors:
NDTensors,
Block,
Tensor,
blockdim,
blockoffsets,
denseblocks,
diaglength,
getdiagindex,
nzblocks,
setdiagindex!,
svd,
tensor,
DiagBlockSparseTensor,
DenseTensor,
BlockOffsets
using Observers: update!, insert_function!

function NDTensors.blockoffsets(dense::DenseTensor)
return BlockOffsets{ndims(dense)}([Block(ntuple(Returns(1), ndims(dense)))], [0])
end
function NDTensors.nzblocks(dense::DenseTensor)
return nzblocks(blockoffsets(dense))
end
NDTensors.blockdim(ind::Int, ::Block{1}) = ind
NDTensors.blockdim(i::Index{Int}, b::Integer) = blockdim(i, Block(b))
NDTensors.blockdim(i::Index{Int}, b::Block) = blockdim(space(i), b)

LinearAlgebra.isdiag(it::ITensor) = isdiag(tensor(it))

# Convenience functions
sqrt_diag(it::ITensor) = map_diag(sqrt, it)
inv_diag(it::ITensor) = map_diag(inv, it)
invsqrt_diag(it::ITensor) = map_diag(inv ∘ sqrt, it)
pinv_diag(it::ITensor) = map_diag(pinv, it)
pinvsqrt_diag(it::ITensor) = map_diag(pinv ∘ sqrt, it)

function map_itensor(
f::Function, A::ITensor, lind=first(inds(A)); regularization=nothing, kwargs...
)
USV = svd(A, lind; kwargs...)
U, S, V, spec, u, v = USV
S = map_diag(s -> f(s + regularization), S)
sqrtDL, δᵤᵥ, sqrtDR = sqrt_decomp(S, u, v)
sqrtDR = denseblocks(sqrtDR) * denseblocks(δᵤᵥ)
L, R = U * sqrtDL, V * sqrtDR
return L * R
end

# Analagous to `denseblocks`.
# Extract the diagonal entries into a diagonal tensor.
function diagblocks(D::Tensor)
nzblocksD = nzblocks(D)
T = DiagBlockSparseTensor(eltype(D), nzblocksD, inds(D))
for b in nzblocksD
for n in 1:diaglength(D)
setdiagindex!(T, getdiagindex(D, n), n)
end
end
return T
end

diagblocks(it::ITensor) = itensor(diagblocks(tensor(it)))

end
66 changes: 12 additions & 54 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -735,67 +735,23 @@ function flatten_networks(
return flatten_networks(flatten_networks(tn1, tn2; kwargs...), tn3, tn_tail...; kwargs...)
end

#Ideally this will dispatch to inner_network but this is a temporary fast version for now
function norm_network(tn::AbstractITensorNetwork)
tnbra = rename_vertices(v -> (v, 1), data_graph(tn))
tndag = copy(tn)
for v in vertices(tndag)
setindex_preserve_graph!(tndag, dag(tndag[v]), v)
end
tnket = rename_vertices(v -> (v, 2), data_graph(prime(tndag; sites=[])))
# TODO: Use a different constructor here?
tntn = _ITensorNetwork(union(tnbra, tnket))
for v in vertices(tn)
if !isempty(commoninds(tntn[(v, 1)], tntn[(v, 2)]))
add_edge!(tntn, (v, 1) => (v, 2))
end
end
return tntn
function inner_network(x::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...)
return BilinearFormNetwork(x, y; kwargs...)
end

# TODO: Use or replace with `flatten_networks`
function inner_network(
tn1::AbstractITensorNetwork,
tn2::AbstractITensorNetwork;
map_bra_linkinds=sim,
combine_linkinds=false,
flatten=combine_linkinds,
kwargs...,
x::AbstractITensorNetwork, A::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...
)
@assert issetequal(vertices(tn1), vertices(tn2))
tn1 = map_bra_linkinds(tn1; sites=[])
inner_net = ⊗(dag(tn1), tn2; kwargs...)
if flatten
for v in vertices(tn1)
inner_net = contract(inner_net, (v, 2) => (v, 1); merged_vertex=v)
end
end
if combine_linkinds
inner_net = ITensorNetworks.combine_linkinds(inner_net)
end
return inner_net
return BilinearFormNetwork(A, x, y; kwargs...)
end

# TODO: Rename `inner`.
function contract_inner(
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
sequence=nothing,
contraction_sequence_kwargs=(;),
)
tn = inner_network(ϕ, ψ; combine_linkinds=true)
if isnothing(sequence)
sequence = contraction_sequence(tn; contraction_sequence_kwargs...)
end
return contract(tn; sequence)[]
# TODO: We should make this use the QuadraticFormNetwork constructor here.
# Parts of the code (tests relying on norm_sqr being two layer and the gauging code
# which relies on specific message tensors) currently would break in that case so we need to resolve
function norm_sqr_network(ψ::AbstractITensorNetwork)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
return disjoint_union("bra" => dag(prime(ψ; sites=[])), "ket" => ψ)
end

# TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,
# or `norm_sqr` to match `LinearAlgebra.norm_sqr`
norm_sqr(ψ::AbstractITensorNetwork; sequence) = contract_inner(ψ, ψ; sequence)

norm_sqr_network(ψ::AbstractITensorNetwork; kwargs...) = inner_network(ψ, ψ; kwargs...)

#
# Printing
#
Expand Down Expand Up @@ -942,7 +898,7 @@ function ITensorMPS.add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork

#Create vertices of tn12 as direct sum of tn1[v] and tn2[v]. Work out the matching indices by matching edges. Make index tags those of tn1[v]
for v in vertices(tn1)
@assert siteinds(tn1, v) == siteinds(tn2, v)
@assert issetequal(siteinds(tn1, v), siteinds(tn2, v))

e1_v = filter(x -> src(x) == v || dst(x) == v, edges_tn1)
e2_v = filter(x -> src(x) == v || dst(x) == v, edges_tn2)
Expand All @@ -966,3 +922,5 @@ function ITensorMPS.add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork
end

Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)

ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn))
2 changes: 1 addition & 1 deletion src/approx_itensornetwork/binary_tree_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,6 @@ function _partition(
return rename_vertices(par, name_map)
end

function _partition(tn::ITensorNetwork, inds_btree::DataGraph; alg::String)
function _partition(tn::ITensorNetwork, inds_btree::DataGraph; alg)
return _partition(Algorithm(alg), tn, inds_btree)
end
59 changes: 51 additions & 8 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ using NamedGraphs: PartitionVertex
using LinearAlgebra: diag
using ITensors: dir
using ITensors.ITensorMPS: ITensorMPS
using NamedGraphs: boundary_partitionedges
using NamedGraphs: boundary_partitionedges, partitionvertices, partitionedges

default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
default_message_norm(m::ITensor) = norm(m)
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
Expand All @@ -21,12 +22,20 @@ end
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
function default_partitioned_vertices(f::AbstractFormNetwork)
return group(v -> original_state_vertex(f, v), vertices(f))
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
return (; partitioned_vertices=default_partitioned_vertices(ψ))
end

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
function message_diff(
message_a::Vector{ITensor}, message_b::Vector{ITensor}; message_norm=default_message_norm
)
lhs, rhs = contract(message_a), contract(message_b)
return 0.5 *
norm((denseblocks(lhs) / sum(diag(lhs))) - (denseblocks(rhs) / sum(diag(rhs))))
norm_lhs, norm_rhs = message_norm(lhs), message_norm(rhs)
return 0.5 * norm((denseblocks(lhs) / norm_lhs) - (denseblocks(rhs) / norm_rhs))
end

struct BeliefPropagationCache{PTN,MTS,DM}
Expand All @@ -47,8 +56,14 @@ function BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
return BeliefPropagationCache(ptn; kwargs...)
end

function BeliefPropagationCache(tn; kwargs...)
return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...)
function BeliefPropagationCache(
tn; partitioned_vertices=default_partitioned_vertices(tn), kwargs...
)
return BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
end

function cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end

function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
Expand Down Expand Up @@ -118,7 +133,7 @@ function environment(
)
bpes = boundary_partitionedges(bp_cache, partition_vertices; dir=:in)
ms = messages(bp_cache, setdiff(bpes, ignore_edges))
return reduce(vcat, ms; init=[])
return reduce(vcat, ms; init=ITensor[])
end

function environment(
Expand Down Expand Up @@ -216,11 +231,11 @@ function update(
kwargs...,
)
compute_error = !isnothing(tol)
diff = compute_error ? Ref(0.0) : nothing
if isnothing(maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
diff = compute_error ? Ref(0.0) : nothing
bp_cache = update(bp_cache, edges; (update_diff!)=diff, kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
Expand Down Expand Up @@ -251,3 +266,31 @@ end
function update_factor(bp_cache, vertex, factor)
return update_factors(bp_cache, [vertex], ITensor[factor])
end

function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
incoming_mts = environment(bp_cache, [pv])
local_state = factor(bp_cache, pv)
return contract(vcat(incoming_mts, local_state))[]
end

function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
return contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
end

function vertex_scalars(
bp_cache::BeliefPropagationCache,
pvs::Vector=partitionvertices(partitioned_tensornetwork(bp_cache)),
)
return [region_scalar(bp_cache, pv) for pv in pvs]
end

function edge_scalars(
bp_cache::BeliefPropagationCache,
pes::Vector=partitionedges(partitioned_tensornetwork(bp_cache)),
)
return [region_scalar(bp_cache, pe) for pe in pes]
end

function scalar_factors_quotient(bp_cache::BeliefPropagationCache)
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
end
Loading
Loading