Skip to content

Commit

Permalink
Modified inner and forms test for new BiLinearForm code
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 28, 2024
1 parent cb4ccc1 commit 8efff49
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 86 deletions.
53 changes: 3 additions & 50 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,62 +635,15 @@ function split_index(
return tn
end

function stack_flatten_combine(
tns::Vector{<:AbstractITensorNetwork};
map_bra_linkinds=prime,
flatten=false,
combine_linkinds=false,
)
tns = copy(tns)
stacked_tn = map_bra_linkinds(popfirst!(tns); sites=[])
current_suffix = 1
for tn in tns
stacked_tn_vertices = vertices(stacked_tn)
stacked_tn = disjoint_union(current_suffix => stacked_tn, current_suffix + 1 => tn)

if flatten
@assert issetequal(vertices(tn), stacked_tn_vertices)
for v in vertices(tn)
stacked_tn = contract(
stacked_tn, (v, current_suffix + 1) => (v, current_suffix); merged_vertex=v
)
end
else
if current_suffix != 1
#Strip back
stacked_tn_vertices = [(v, current_suffix) for v in stacked_tn_vertices]
stacked_tn = rename_vertices(
v -> v stacked_tn_vertices ? first(v) : v, stacked_tn
)
end
current_suffix += 1
end
if combine_linkinds
stacked_tn = ITensorNetworks.combine_linkinds(stacked_tn)
end
end

return stacked_tn
end

function stack_flatten_combine(tns::AbstractITensorNetwork...; kwargs...)
return stack_flatten_combine([tns...]; kwargs...)
end

function flatten_networks(
tns::AbstractITensorNetwork...; combine_linkinds=true, map_bra_linkinds=prime
)
return stack_flatten_combine(tns...; flatten=true, combine_linkinds, map_bra_linkinds)
end

#Just make this call to form network, rip out flatten
function inner_network(x::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...)
return stack_flatten_combine(dag(x), y; kwargs...)
return BilinearFormNetwork(x, y; kwargs...)
end

function inner_network(
x::AbstractITensorNetwork, A::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...
)
return stack_flatten_combine(dag(x), A, y; kwargs...)
return BilinearFormNetwork(x, A, y; kwargs...)
end

inner_network(x::AbstractITensorNetwork; kwargs...) = inner_network(x, x; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ end
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioned_vertices::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
default_partitioned_vertices(f::AbstractFormNetwork) = group(v -> original_state_vertex(f, v), vertices(f))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

#TODO: Define a version of this that works for QN supporting tensors
Expand Down
6 changes: 6 additions & 0 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ function logscalar(
return log(complex(scalar(alg, tn; kwargs...)))
end

#This should just pass to a logscalar(bp_cache::BeliefPropagationCache, ...)
#Catch the 0 case in logscalar not in scalar, then pass to exp(logscalar(...))
#Check if negative before complex() call
#Break down into scalar factors?
#Make general to all algorithms

function logscalar(
alg::Algorithm"bp",
tn::AbstractITensorNetwork;
Expand Down
25 changes: 18 additions & 7 deletions src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
default_dual_site_index_map = prime
default_dual_link_index_map = sim

struct BilinearFormNetwork{
V,
TensorNetwork<:AbstractITensorNetwork{V},
Expand All @@ -18,9 +21,12 @@ function BilinearFormNetwork(
operator_vertex_suffix=default_operator_vertex_suffix(),
bra_vertex_suffix=default_bra_vertex_suffix(),
ket_vertex_suffix=default_ket_vertex_suffix(),
dual_site_index_map = default_dual_site_index_map,
dual_link_index_map = default_dual_link_index_map
)
bra_mapped = dual_link_index_map(dual_site_index_map(bra; links = []); sites = [])
tn = disjoint_union(
operator_vertex_suffix => operator, bra_vertex_suffix => bra, ket_vertex_suffix => ket
operator_vertex_suffix => operator, bra_vertex_suffix => dag(bra_mapped), ket_vertex_suffix => ket
)
return BilinearFormNetwork(
tn, operator_vertex_suffix, bra_vertex_suffix, ket_vertex_suffix
Expand All @@ -43,24 +49,29 @@ function copy(blf::BilinearFormNetwork)
)
end

#Is the ordering of the indices correct here? CHECK THIS
#Put bra into the vector space!!!!
function BilinearFormNetwork(
bra::AbstractITensorNetwork, ket::AbstractITensorNetwork; kwargs...
bra::AbstractITensorNetwork, ket::AbstractITensorNetwork;
dual_site_index_map = default_dual_site_index_map,
kwargs...
)
operator_inds = union_all_inds(siteinds(bra), siteinds(ket))
@assert issetequal(externalinds(bra), externalinds(ket))
operator_inds = union_all_inds(siteinds(ket), dual_site_index_map(siteinds(ket)))
O = delta_network(operator_inds)
return BilinearFormNetwork(O, bra, ket; kwargs...)
return BilinearFormNetwork(O, bra, ket; dual_site_index_map, kwargs...)
end

function update(
blf::BilinearFormNetwork, original_state_vertex, bra_state::ITensor, ket_state::ITensor
blf::BilinearFormNetwork, original_bra_state_vertex, original_ket_state_vertex, bra_state::ITensor, ket_state::ITensor
)
blf = copy(blf)
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(
tensornetwork(blf), bra_state, bra_vertex(blf, original_state_vertex)
tensornetwork(blf), bra_state, bra_vertex(blf, original_bra_state_vertex)
)
setindex_preserve_graph!(
tensornetwork(blf), ket_state, ket_vertex(blf, original_state_vertex)
tensornetwork(blf), ket_state, ket_vertex(blf, original_ket_state_vertex)
)
return blf
end
8 changes: 3 additions & 5 deletions src/formnetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ function QuadraticFormNetwork(
dual_inv_index_map=default_inv_index_map,
kwargs...,
)
bra = map_inds(dual_index_map, dag(ket))
blf = BilinearFormNetwork(operator, bra, ket; kwargs...)
blf = BilinearFormNetwork(operator, ket, ket; dual_site_index_map = dual_index_map, dual_link_index_map = dual_index_map, kwargs...)
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

Expand All @@ -52,14 +51,13 @@ function QuadraticFormNetwork(
dual_inv_index_map=default_inv_index_map,
kwargs...,
)
bra = map_inds(dual_index_map, dag(ket))
blf = BilinearFormNetwork(bra, ket; kwargs...)
blf = BilinearFormNetwork(bra, ket; dual_site_index_map = dual_index_map, dual_link_index_map = dual_index_map, kwargs...)
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor)
state_inds = inds(ket_state)
bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds))
new_blf = update(bilinear_formnetwork(qf), original_state_vertex, bra_state, ket_state)
new_blf = update(bilinear_formnetwork(qf), original_state_vertex, original_state_vertex, bra_state, ket_state)
return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf))
end
2 changes: 0 additions & 2 deletions src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ import ITensors:
dag,
# permute
permute,
#commoninds
hascommoninds,
# linkdims
linkdim,
linkdims,
Expand Down
26 changes: 16 additions & 10 deletions src/inner.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
default_inner_partitioned_vertices(tn) = group(v -> first(v), vertices(tn))
#Default to BP always?!
default_algorithm(tns::Vector) = all(is_tree.(tns)) ? "bp" : "exact"

#Default for map_linkinds should be sim.
#Use form code and just default to identity inbetween x and y
#Have ϕ in the same space as y and then a dual_map kwarg?

function inner(
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
Expand All @@ -10,14 +15,15 @@ function inner(
return inner(Algorithm(alg), ϕ, ψ; kwargs...)
end

#Make [A, ϕ, ψ] a Tuple
function inner(
ϕ::AbstractITensorNetwork,
A::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
alg=default_algorithm([A, ϕ, ψ]),
alg=default_algorithm([ϕ, A, ψ]),
kwargs...,
)
return inner(Algorithm(alg), A, ϕ, ψ; kwargs...)
return inner(Algorithm(alg), ϕ, A, ψ; kwargs...)
end

function inner(
Expand Down Expand Up @@ -92,10 +98,10 @@ function loginner(
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
partitioned_verts=default_inner_partitioned_vertices,
map_bra_linkinds=sim,
dual_link_index_map=sim,
kwargs...,
)
tn = inner_network(ϕ, ψ; map_bra_linkinds)
tn = inner_network(ϕ, ψ; dual_link_index_map)
return logscalar(alg, tn; partitioned_vertices=partitioned_verts(tn), kwargs...)
end

Expand All @@ -105,10 +111,10 @@ function loginner(
A::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
partitioned_verts=default_inner_partitioned_vertices,
map_bra_linkinds=sim,
dual_link_index_map=sim,
kwargs...,
)
tn = inner_network(ϕ, A, ψ; map_bra_linkinds)
tn = inner_network(ϕ, A, ψ; dual_link_index_map)
return logscalar(alg, tn; partitioned_vertices=partitioned_verts(tn), kwargs...)
end

Expand All @@ -117,10 +123,10 @@ function inner(
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
partitioned_verts=default_inner_partitioned_vertices,
map_bra_linkinds=prime,
dual_link_index_map=prime,
kwargs...,
)
tn = inner_network(ϕ, ψ; map_bra_linkinds)
tn = inner_network(ϕ, ψ; dual_link_index_map)
return scalar(alg, tn; partitioned_vertices=partitioned_verts(tn), kwargs...)
end

Expand All @@ -130,10 +136,10 @@ function inner(
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
partitioned_verts=default_inner_partitioned_vertices,
map_bra_linkinds=prime,
dual_link_index_map=prime,
kwargs...,
)
tn = inner_network(ϕ, A, ψ; map_bra_linkinds)
tn = inner_network(ϕ, A, ψ; dual_link_index_map)
return scalar(alg, tn; partitioned_vertices=partitioned_verts(tn), kwargs...)
end

Expand Down
9 changes: 4 additions & 5 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ using SplitApplyCombine

@testset "FormNetworks" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
s_operator = union_all_inds(s_bra, s_ket)
s = siteinds("S=1/2", g)
s_operator = union_all_inds(s, prime(s))
χ, D = 2, 3
Random.seed!(1234)
ψket = randomITensorNetwork(s_ket; link_space=χ)
ψbra = randomITensorNetwork(s_bra; link_space=χ)
ψket = randomITensorNetwork(s; link_space=χ)
ψbra = randomITensorNetwork(s; link_space=χ)
A = randomITensorNetwork(s_operator; link_space=D)

blf = BilinearFormNetwork(A, ψbra, ψket)
Expand Down
14 changes: 7 additions & 7 deletions test/test_inner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,37 @@ using ITensors: siteinds, dag

@testset "Inner products, BP vs exact comparison" begin
Random.seed!(1234)
L = 12
L = 4
χ = 2
g = NamedGraph(Graphs.SimpleGraph(uniform_tree(L)))
s = siteinds("S=1/2", g)
y = randomITensorNetwork(s; link_space=χ)
x = randomITensorNetwork(s; link_space=χ)

#First lets do it with the flattened version of the network
xy = inner_network(x, y; combine_linkinds=true, flatten=true)
xy = inner_network(x, y)
xy_scalar = scalar(xy)
xy_scalar_bp = scalar(xy; alg="bp", partitioned_vertices=group(v -> v, vertices(xy)))
xy_scalar_bp = scalar(xy; alg="bp")
xy_scalar_logbp = exp(
logscalar(xy; alg="bp", partitioned_vertices=group(v -> v, vertices(xy)))
logscalar(xy; alg="bp")
)

@test xy_scalar xy_scalar_bp
@test xy_scalar_bp xy_scalar_logbp
@test xy_scalar xy_scalar_logbp

#Now lets do it via the inner function
xy_scalar = inner(x, y; alg="exact", flatten=true, combine_linkinds=true)
xy_scalar = inner(x, y; alg="exact")
xy_scalar_bp = inner(x, y; alg="bp")
xy_scalar_logbp = exp(loginner(x, y; alg="bp"))

@test xy_scalar xy_scalar_bp
@test xy_scalar_bp xy_scalar_logbp
@test xy_scalar xy_scalar_logbp

#test contraction of three layers for expectation values
# #test contraction of three layers for expectation values
A = ITensorNetwork(TTN(ITensorNetworks.heisenberg(s), s))
xAy_scalar = inner(x', A, y; alg="exact", flatten=true, combine_linkinds=true)
xAy_scalar = inner(x', A, y; alg="exact")
xAy_scalar_bp = inner(x', A, y; alg="bp")
xAy_scalar_logbp = exp(loginner(x', A, y; alg="bp"))

Expand Down

0 comments on commit 8efff49

Please sign in to comment.