Skip to content

Commit

Permalink
Fix bug in TNO construction
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 5, 2024
1 parent 33f15fb commit df3f11a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ function NDTensors.contract(
neighbors_src = setdiff(neighbors(tn, src(edge)), [dst(edge)])
neighbors_dst = setdiff(neighbors(tn, dst(edge)), [src(edge)])
new_itensor = tn[src(edge)] * tn[dst(edge)]

# The following is equivalent to:
#
# tn[dst(edge)] = new_itensor
Expand All @@ -482,6 +481,7 @@ function NDTensors.contract(
for n_dst in neighbors_dst
add_edge!(tn, merged_vertex => n_dst)
end

setindex_preserve_graph!(tn, new_itensor, merged_vertex)

return tn
Expand Down
13 changes: 11 additions & 2 deletions src/itensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using DataGraphs: DataGraphs, DataGraph
using Dictionaries: Indices, dictionary
using ITensors: ITensors, ITensor
using ITensors: ITensors, ITensor, op, state
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype

struct Private end
Expand Down Expand Up @@ -164,8 +164,17 @@ end
function generic_state(a::AbstractArray, inds::Vector)
return itensor(a, inds)
end
function generic_state(x::Op, inds::NamedTuple)
# TODO: Figure out what to do if there is more than one site.
@assert length(inds.siteinds) == 2
i = inds.siteinds[findfirst(i -> plev(i) == 0, inds.siteinds)]
@assert i' inds.siteinds
site_tensors = [op(x.which_op, i)]
link_tensors = [[onehot(i => 1) for i in inds.linkinds[e]] for e in keys(inds.linkinds)]
return contract(reduce(vcat, link_tensors; init=site_tensors))
end
function generic_state(s::AbstractString, inds::NamedTuple)
# TODO: Handle the case of multiple site indices.
# TODO: Figure out what to do if there is more than one site.
site_tensors = [state(s, only(inds.siteinds))]
link_tensors = [[onehot(i => 1) for i in inds.linkinds[e]] for e in keys(inds.linkinds)]
return contract(reduce(vcat, link_tensors; init=site_tensors))
Expand Down
3 changes: 2 additions & 1 deletion src/tensornetworkoperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ function gate_group_to_tno(s::IndsNetwork, gates::Vector{ITensor})
#Construct indsnetwork for TNO
s_O = union_all_inds(s, prime(s; links=[]))

O = delta_network(s_O)
# Make a TNO with `I` on every site.
O = ITensorNetwork(Op("I"), s_O)

for gate in gates
v⃗ = vertices(s)[findall(i -> (length(commoninds(s[i], inds(gate))) != 0), vertices(s))]
Expand Down
3 changes: 3 additions & 0 deletions test/test_tno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ using Test: @test, @testset
ψ = random_tensornetwork(s; link_space=2)

ψ_gated = copy(ψ)

for gate in gates
ψ_gated = apply(gate, ψ_gated)
end
ψ_tnod = copy(ψ)

for tno in tnos
ψ_tnod = flatten_networks(ψ_tnod, tno)
for v in vertices(ψ_tnod)
Expand All @@ -54,6 +56,7 @@ using Test: @test, @testset
z1 = contract_inner(ψ_gated, ψ_gated)
z2 = contract_inner(ψ_tnod, ψ_tnod)
z3 = contract_inner(ψ_tno, ψ_tno)

f12 = contract_inner(ψ_tnod, ψ_gated) / sqrt(z1 * z2)
f13 = contract_inner(ψ_tno, ψ_gated) / sqrt(z1 * z3)
f23 = contract_inner(ψ_tno, ψ_tnod) / sqrt(z2 * z3)
Expand Down

0 comments on commit df3f11a

Please sign in to comment.