Skip to content

Commit

Permalink
Add support for tensor network forms (bilinear and quadratic) (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Feb 13, 2024
1 parent fbb4e53 commit c2cc66a
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ ITensors = "0.3.23"
IsApprox = "0.1"
IterTools = "1.4.0"
KrylovKit = "0.6.0"
NamedGraphs = "0.1.11"
NamedGraphs = "0.1.20"
Observers = "0.2"
Requires = "1.3"
SimpleTraits = "0.9"
Expand Down
4 changes: 3 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ include(joinpath("approx_itensornetwork", "binary_tree_partition.jl"))
include("contract.jl")
include("utility.jl")
include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
include("boundarymps.jl")
include(joinpath("beliefpropagation", "beliefpropagation.jl"))
include(joinpath("beliefpropagation", "beliefpropagation_schedule.jl"))
include(joinpath("formnetworks", "abstractformnetwork.jl"))
include(joinpath("formnetworks", "bilinearformnetwork.jl"))
include(joinpath("formnetworks", "quadraticformnetwork.jl"))
include("contraction_tree_to_graph.jl")
include("gauging.jl")
include("utils.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ export AbstractITensorNetwork,
mps,
ortho_center,
set_ortho_center,
BilinearFormNetwork,
QuadraticFormNetwork,
TreeTensorNetwork,
TTN,
random_ttn,
Expand Down
74 changes: 74 additions & 0 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
default_bra_vertex_suffix() = "bra"
default_ket_vertex_suffix() = "ket"
default_operator_vertex_suffix() = "operator"

abstract type AbstractFormNetwork{V} <: AbstractITensorNetwork{V} end

#Needed for interface
dual_index_map(f::AbstractFormNetwork) = not_implemented()
tensornetwork(f::AbstractFormNetwork) = not_implemented()
copy(f::AbstractFormNetwork) = not_implemented()
operator_vertex_suffix(f::AbstractFormNetwork) = not_implemented()
bra_vertex_suffix(f::AbstractFormNetwork) = not_implemented()
ket_vertex_suffix(f::AbstractFormNetwork) = not_implemented()

function operator_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f))
end
function bra_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == bra_vertex_suffix(f), vertices(f))
end

function ket_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f))
end

function bra_ket_vertices(f::AbstractFormNetwork)
return vcat(bra_vertices(f), ket_vertices(f))
end

function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [bra_vertex_map(f)(sv) for sv in state_vertices]
end

function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [ket_vertex_map(f)(sv) for sv in state_vertices]
end

function bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
end

function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector)
return induced_subgraph(tensornetwork(f), vertices)
end

function bra_network(f::AbstractFormNetwork)
return rename_vertices(inv_vertex_map(f), first(induced_subgraph(f, bra_vertices(f))))
end

function ket_network(f::AbstractFormNetwork)
return rename_vertices(inv_vertex_map(f), first(induced_subgraph(f, ket_vertices(f))))
end

function operator_network(f::AbstractFormNetwork)
return rename_vertices(
inv_vertex_map(f), first(induced_subgraph(f, operator_vertices(f)))
)
end

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
return setdiff(
vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
)
end

operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f))
bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f))
ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f))
inv_vertex_map(f::AbstractFormNetwork) = v -> first(v)
62 changes: 62 additions & 0 deletions src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
struct BilinearFormNetwork{
V,
TensorNetwork<:AbstractITensorNetwork{V},
OperatorVertexSuffix,
BraVertexSuffix,
KetVertexSuffix,
} <: AbstractFormNetwork{V}
tensornetwork::TensorNetwork
operator_vertex_suffix::OperatorVertexSuffix
bra_vertex_suffix::BraVertexSuffix
ket_vertex_suffix::KetVertexSuffix
end

function BilinearFormNetwork(
operator::AbstractITensorNetwork,
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
operator_vertex_suffix=default_operator_vertex_suffix(),
bra_vertex_suffix=default_bra_vertex_suffix(),
ket_vertex_suffix=default_ket_vertex_suffix(),
)
tn = disjoint_union(
operator_vertex_suffix => operator, bra_vertex_suffix => bra, ket_vertex_suffix => ket
)
return BilinearFormNetwork(
tn, operator_vertex_suffix, bra_vertex_suffix, ket_vertex_suffix
)
end

operator_vertex_suffix(blf::BilinearFormNetwork) = blf.operator_vertex_suffix
bra_vertex_suffix(blf::BilinearFormNetwork) = blf.bra_vertex_suffix
ket_vertex_suffix(blf::BilinearFormNetwork) = blf.ket_vertex_suffix
tensornetwork(blf::BilinearFormNetwork) = blf.tensornetwork
data_graph_type(::Type{<:BilinearFormNetwork}) = data_graph_type(tensornetwork(blf))
data_graph(blf::BilinearFormNetwork) = data_graph(tensornetwork(blf))

function copy(blf::BilinearFormNetwork)
return BilinearFormNetwork(
copy(tensornetwork(blf)),
operator_vertex_suffix(blf),
bra_vertex_suffix(blf),
ket_vertex_suffix(blf),
)
end

function BilinearFormNetwork(
bra::AbstractITensorNetwork, ket::AbstractITensorNetwork; kwargs...
)
operator_inds = union_all_inds(siteinds(bra), siteinds(ket))
O = delta_network(operator_inds)
return BilinearFormNetwork(O, bra, ket; kwargs...)
end

function update(
blf::BilinearFormNetwork, state_vertex, bra_state::ITensor, ket_state::ITensor
)
blf = copy(blf)
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex))
return blf
end
65 changes: 65 additions & 0 deletions src/formnetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
default_index_map = prime
default_inv_index_map = noprime

struct QuadraticFormNetwork{V,FormNetwork<:BilinearFormNetwork{V},IndexMap,InvIndexMap} <:
AbstractFormNetwork{V}
formnetwork::FormNetwork
dual_index_map::IndexMap
dual_inv_index_map::InvIndexMap
end

bilinear_formnetwork(qf::QuadraticFormNetwork) = qf.formnetwork

#Needed for implementation, forward from bilinear form
for f in [
:operator_vertex_suffix,
:bra_vertex_suffix,
:ket_vertex_suffix,
:tensornetwork,
:data_graph,
:data_graph_type,
]
@eval begin
function $f(qf::QuadraticFormNetwork, args...; kwargs...)
return $f(bilinear_formnetwork(qf), args...; kwargs...)
end
end
end

dual_index_map(qf::QuadraticFormNetwork) = qf.dual_index_map
dual_inv_index_map(qf::QuadraticFormNetwork) = qf.dual_inv_index_map
function copy(qf::QuadraticFormNetwork)
return QuadraticFormNetwork(
copy(bilinear_formnetwork(qf)), dual_index_map(qf), dual_inv_index_map(qf)
)
end

function QuadraticFormNetwork(
operator::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
dual_inv_index_map=default_inv_index_map,
kwargs...,
)
bra = map_inds(dual_index_map, dag(ket))
blf = BilinearFormNetwork(operator, bra, ket; kwargs...)
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

function QuadraticFormNetwork(
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
dual_inv_index_map=default_inv_index_map,
kwargs...,
)
bra = map_inds(dual_index_map, dag(ket))
blf = BilinearFormNetwork(bra, ket; kwargs...)
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

function update(qf::QuadraticFormNetwork, state_vertex, ket_state::ITensor)
state_inds = inds(ket_state)
bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds))
new_blf = update(bilinear_formnetwork(qf), state_vertex, bra_state, ket_state)
return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf))
end
25 changes: 0 additions & 25 deletions src/renameitensornetwork.jl

This file was deleted.

51 changes: 51 additions & 0 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using ITensors
using Graphs
using NamedGraphs
using ITensorNetworks
using ITensorNetworks:
delta_network,
update,
tensornetwork,
bra_vertex_map,
ket_vertex_map,
dual_index_map,
bra_network,
ket_network,
operator_network
using Test
using Random

@testset "FormNetworkss" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
s_operator = union_all_inds(s_bra, s_ket)
χ, D = 2, 3
Random.seed!(1234)
ψket = randomITensorNetwork(s_ket; link_space=χ)
ψbra = randomITensorNetwork(s_bra; link_space=χ)
A = randomITensorNetwork(s_operator; link_space=D)

blf = BilinearFormNetwork(A, ψbra, ψket)
@test nv(blf) == nv(ψket) + nv(ψbra) + nv(A)
@test isempty(externalinds(blf))

@test underlying_graph(ket_network(blf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(blf)) == underlying_graph(A)
@test underlying_graph(bra_network(blf)) == underlying_graph(ψbra)

qf = QuadraticFormNetwork(A, ψket)
@test nv(qf) == 2 * nv(ψbra) + nv(A)
@test isempty(externalinds(qf))

v = (1, 1)
new_tensor = randomITensor(inds(ψket[v]))
qf_updated = update(qf, v, copy(new_tensor))

@test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)]
dual_index_map(qf_updated)(dag(new_tensor))
@test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] new_tensor

@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)
end

0 comments on commit c2cc66a

Please sign in to comment.