Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement inner(x,A,y) for TTN via BP. #146

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,48 @@ function contract_inner(
return contract(tn; sequence)[]
end

function contract_with_BP(
itn::AbstractITensorNetwork; outputlevel=1, partitioning=group(v -> v, vertices(itn))
)
return contract_with_BP(
ComplexF64, itn::AbstractITensorNetwork; outputlevel, partitioning
)
end

function contract_with_BP(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contract_with_BP isn't a good function name, what about overloading ITensors.contract(::Algorithm"bp", kwargs...) and running with contract(itn; alg="bp", kwargs...).

T::Type,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this part of the interface, why can you set the element type in this way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the right way to do it. Basically, due to the log in this implementation, we need to promote the scalars of which the element type is taken to complex. Then the idea was that contract_with_BP returns a complex scalar, unless you specify a type, in which case it converts the scalar to the specified type. This should of course be handled differently for a proper implementation as contract with BP backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I think contract_with_BP as it stands is too far away from being a library function, so we should think about how to deal with that before continuing forward with this PR.

itn::AbstractITensorNetwork;
outputlevel=1,
partitioning=group(v -> v, vertices(itn)),
)
@assert isempty(externalinds(itn))
bp_cache = BeliefPropagationCache(copy(itn), partitioning)

bp_cache = update(bp_cache)

pg = partitioned_itensornetwork(bp_cache)

if !is_tree(partitioned_graph(pg)) && outputlevel > 0
println("Partitioned graph is not a tree, result will be approximate!!")
end
Comment on lines +745 to +747
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this, I think this is inherent if someone selects that they are using BP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it may be good to allow to pass a flag like assert_exact=true, to check that the BP will indeed be exact for the contraction (for the case of contracting TTNs)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's think about the design of that, but I don't think we have to address that in this PR. I agree with the sentiment that we don't want to secretly perform approximate contractions but give users the impression that we are doing exact contractions.


log_numerator, log_denominator = 0, 0
for pv in partitionvertices(pg)
incoming_mts = incoming_messages(bp_cache, [pv])
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
local_state = ITensor[itn[v] for v in vertices(pg, pv)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you use pg[pv] here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or factor(bp_cache, pv)?

log_numerator += log(complex(ITensors.contract(vcat(incoming_mts, local_state))[]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log_numerator += log(complex(ITensors.contract(vcat(incoming_mts, local_state))[]))
log_numerator += log(complex(contract(vcat(incoming_mts, local_state))[]))

end
for pe in partitionedges(pg)
log_denominator += log(
complex(
ITensors.contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ITensors.contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]
contract(vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))))[]

),
)
end
res = exp(log_numerator - log_denominator)
return T(res)
end

# TODO: rename `sqnorm` to match https://github.com/JuliaStats/Distances.jl,
# or `norm_sqr` to match `LinearAlgebra.norm_sqr`
norm_sqr(ψ::AbstractITensorNetwork; sequence) = contract_inner(ψ, ψ; sequence)
Expand Down
20 changes: 7 additions & 13 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,14 @@ end
# Inner products
#

# TODO: implement using multi-graph disjoint union
function inner(
y::AbstractTTN, A::AbstractTTN, x::AbstractTTN; root_vertex=default_root_vertex(x, A, y)
)
traversal_order = reverse(post_order_dfs_vertices(x, root_vertex))
check_hascommoninds(siteinds, A, x)
check_hascommoninds(siteinds, A, y)
# TODO: Remove dispatch on AbstractTTN (unless we want to trigger warning for trees if BP is nonexact)
function inner(y::AbstractTTN, A::AbstractTTN, x::AbstractTTN; kwargs...)
ydag = sim(dag(y); sites=[])
x = sim(x; sites=[])
O = ydag[root_vertex] * A[root_vertex] * x[root_vertex]
for v in traversal_order[2:end]
O = O * ydag[v] * A[v] * x[v]
end
return O[]
blf = BilinearFormNetwork(A, ydag, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should be using a form here, that is mostly meant for optimization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your goal to make a tensor network out of the network y, A, and x? If so, you can use disjoint_union.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as far as I understand the BilinearFormNetwork just stacks the network here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's do that without going through BilinearFormNetwork.

blf_scalar_bp = contract_with_BP(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just use contract here, then we can allow an alg keyword argument to get passed from inner, and users can control what algorithm is used to contract the network.

To handle partitioning in the more general case, something we could do is that instead of passing an ITensorNetwork and a partitioning, we pass a PartitionedGraph to contract, then it is up to the contract backends to contract the partitioned tensor network (either using or ignoring the partitioning).

tensornetwork(blf); outputlevel=1, partitioning=group(v -> first(v), vertices(blf))
)
return blf_scalar_bp
end

# TODO: implement using multi-graph disjoint
Expand Down
4 changes: 3 additions & 1 deletion test/test_forms.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ITensors
using ITensors: contract
using Graphs
using NamedGraphs
using ITensorNetworks
Expand All @@ -15,7 +16,7 @@ using ITensorNetworks:
using Test
using Random

@testset "FormNetworkss" begin
@testset "FormNetworks" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
Expand Down Expand Up @@ -49,3 +50,4 @@ using Random
@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)
end
nothing
53 changes: 53 additions & 0 deletions test/test_treetensornetworks/test_treetensornetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using ITensors
using ITensors: contract
using ITensorNetworks
using ITensorNetworks: contract_with_BP, group
using Test

@testset "TTN constructor defaulting to link_space=1" begin
tooth_lengths = fill(5, 6)
c = named_comb_tree(tooth_lengths)
s = siteinds("S=1/2", c)
d = Dict()
for (i, v) in enumerate(vertices(s))
d[v] = isodd(i) ? "Up" : "Dn"
end
states = v -> d[v]
#test a few signatures
state = TTN(s, states)
lds = edge_data(linkdims(state))
@test all([isone(lds[k]) for k in keys(lds)])
state = TTN(s)
lds = edge_data(linkdims(state))
@test all([isone(lds[k]) for k in keys(lds)])
end

@testset "Inner products for TTN via BP" begin
Random.seed!(1234)
Lx, Ly = 3, 3
χ = 2
g = named_comb_tree((Lx, Ly))
s = siteinds("S=1/2", g)
y = TTN(randomITensorNetwork(ComplexF64, s; link_space=χ))
x = TTN(randomITensorNetwork(ComplexF64, s; link_space=χ))

A = TTN(ITensorNetworks.heisenberg(s), s)
#First lets do it with the flattened version of the network
xy = inner_network(x, y; combine_linkinds=true)
xy_scalar = contract(xy)[]
xy_scalar_bp = contract_with_BP(xy)

@test_broken xy_scalar ≈ xy_scalar_bp

#Now lets keep it unflattened and do Block BP to keep the partitioned graph as a tree
xy = inner_network(x, y; combine_linkinds=false)
xy_scalar = contract(xy)[]
xy_scalar_bp = contract_with_BP(xy; partitioning=group(v -> first(v), vertices(xy)))

@test xy_scalar ≈ xy_scalar_bp
# test contraction of three layers for expectation values
# for TTN inner with this signature passes via contract_with_BP
@test inner(prime(x), A, y) ≈
inner(x, apply(A, y; nsweeps=10, maxdim=16, cutoff=1e-10, init=y)) rtol = 1e-6
end
nothing
Loading