Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tensor network forms (bilinear and quadratic) #136

Merged
merged 19 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/FormNetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
default_bra_vertex_suffix() = "bra"
default_ket_vertex_suffix() = "ket"
default_operator_vertex_suffix() = "operator"

abstract type AbstractFormNetwork{V} <: AbstractITensorNetwork{V} end

#Needed for interface
dual_index_map(f::AbstractFormNetwork) = not_implemented()
tensornetwork(f::AbstractFormNetwork) = not_implemented()
copy(f::AbstractFormNetwork) = not_implemented()
bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) = not_implemented()
operator_vertex_suffix(f::AbstractFormNetwork) = not_implemented()
bra_vertex_suffix(f::AbstractFormNetwork) = not_implemented()
ket_vertex_suffix(f::AbstractFormNetwork) = not_implemented()

function operator_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f))
end
function bra_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == bra_vertex_suffix(f), vertices(f))
end
function ket_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f))
end

function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [bra_vertex_map(f)(sv) for sv in state_vertices]
end

function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [ket_vertex_map(f)(sv) for sv in state_vertices]
end

function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector)
return induced_subgraph(tensornetwork(f), vertices)
end
function bra(f::AbstractFormNetwork)
return rename_vertices(inv_vertex_map(f), first(induced_subgraph(f, bra_vertices(f))))
end
function ket(f::AbstractFormNetwork)
return rename_vertices(inv_vertex_map(f), first(induced_subgraph(f, ket_vertices(f))))
end
function operator(f::AbstractFormNetwork)
return rename_vertices(
inv_vertex_map(f), first(induced_subgraph(f, operator_vertices(f)))
)
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
return setdiff(
vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
)
end

operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f))
bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f))
ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f))
inv_vertex_map(f::AbstractFormNetwork) = v -> first(v)
52 changes: 52 additions & 0 deletions src/FormNetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
struct BilinearFormNetwork{
V,
TensorNetwork<:AbstractITensorNetwork{V},
OperatorVertexSuffix,
BraVertexSuffix,
KetVertexSuffix,
} <: AbstractFormNetwork{V}
tensornetwork::TensorNetwork
operator_vertex_suffix::OperatorVertexSuffix
bra_vertex_suffix::BraVertexSuffix
ket_vertex_suffix::KetVertexSuffix
end

function BilinearFormNetwork(
operator::AbstractITensorNetwork,
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
operator_vertex_suffix=default_operator_vertex_suffix(),
bra_vertex_suffix=default_bra_vertex_suffix(),
ket_vertex_suffix=default_ket_vertex_suffix(),
)
tn = disjoint_union(
operator_vertex_suffix => operator, bra_vertex_suffix => bra, ket_vertex_suffix => ket
)
return BilinearFormNetwork(
tn, operator_vertex_suffix, bra_vertex_suffix, ket_vertex_suffix
)
end

operator_vertex_suffix(blf::BilinearFormNetwork) = blf.operator_vertex_suffix
bra_vertex_suffix(blf::BilinearFormNetwork) = blf.bra_vertex_suffix
ket_vertex_suffix(blf::BilinearFormNetwork) = blf.ket_vertex_suffix
tensornetwork(blf::BilinearFormNetwork) = blf.tensornetwork
data_graph_type(::Type{<:BilinearFormNetwork}) = data_graph_type(tensornetwork(blf))
data_graph(blf::BilinearFormNetwork) = data_graph(tensornetwork(blf))

function copy(blf::BilinearFormNetwork)
return BilinearFormNetwork(
copy(tensornetwork(blf)),
operator_vertex_suffix(blf),
bra_vertex_suffix(blf),
ket_vertex_suffix(blf),
)
end

function BilinearFormNetwork(
bra::AbstractITensorNetwork, ket::AbstractITensorNetwork; kwargs...
)
operator_inds = union_all_inds(siteinds(bra), siteinds(ket))
O = delta_network(operator_inds)
return BilinearFormNetwork(bra, O, ket; kwargs...)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
end
75 changes: 75 additions & 0 deletions src/FormNetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
default_index_map = prime
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
default_inv_index_map = noprime

struct QuadraticFormNetwork{V,FormNetwork<:BilinearFormNetwork{V},IndexMap} <:
AbstractFormNetwork{V}
formnetwork::FormNetwork
dual_index_map::IndexMap
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
end

bilinear_formnetwork(qf::QuadraticFormNetwork) = qf.formnetwork
function QuadraticFormNetwork(
operator::AbstractITensorNetwork,
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
kwargs...,
)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
return QuadraticFormNetwork(
BilinearFormNetwork(operator, bra, ket; kwargs...), dual_index_map
)
end

#Needed for implementation, forward from bilinear form
for f in [
:operator_vertex_suffix,
:bra_vertex_suffix,
:ket_vertex_suffix,
:tensornetwork,
:data_graph,
:data_graph_type,
]
@eval begin
function $f(qf::QuadraticFormNetwork, args...; kwargs...)
return $f(bilinear_formnetwork(qf), args...; kwargs...)
end
end
end

dual_index_map(qf::QuadraticFormNetwork) = qf.dual_index_map
function copy(qf::QuadraticFormNetwork)
return QuadraticFormNetwork(copy(bilinear_formnetwork(qf)), dual_index_map(qf))
end

function QuadraticFormNetwork(
operator::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
kwargs...,
)
bra = map_inds(dual_index_map, dag(ket))
return QuadraticFormNetwork(operator, bra, ket; dual_index_map, kwargs...)
end

function QuadraticFormNetwork(
ket::AbstractITensorNetwork; dual_index_map=default_index_map, kwargs...
)
s = siteinds(ket)
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
operator_inds = union_all_inds(s, dual_index_map(s; links=[]))
operator = delta_network(operator_inds)
return QuadraticFormNetwork(operator, ket; kwargs...)
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

function bra_ket_vertices(qf::QuadraticFormNetwork, state_vertices::Vector)
return vcat(bra_vertices(qf, state_vertices), ket_vertices(qf, state_vertices))
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

function update(qf::QuadraticFormNetwork, state_vertex, state::ITensor)
qf = copy(qf)
state_inds = inds(state)
state_dag = replaceinds(dag(state), state_inds, dual_index_map(qf).(state_inds))
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(tensornetwork(qf), state, ket_vertex_map(qf)(state_vertex))
setindex_preserve_graph!(tensornetwork(qf), state_dag, bra_vertex_map(qf)(state_vertex))
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
return qf
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ include(joinpath("approx_itensornetwork", "binary_tree_partition.jl"))
include("contract.jl")
include("utility.jl")
include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
include("boundarymps.jl")
include(joinpath("beliefpropagation", "beliefpropagation.jl"))
include(joinpath("beliefpropagation", "beliefpropagation_schedule.jl"))
include(joinpath("formnetworks", "abstractformnetwork.jl"))
include(joinpath("formnetworks", "bilinearformnetwork.jl"))
include(joinpath("formnetworks", "quadraticformnetwork.jl"))
include("contraction_tree_to_graph.jl")
include("gauging.jl")
include("utils.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ export AbstractITensorNetwork,
mps,
ortho_center,
set_ortho_center,
QuadraticFormNetwork,
BilinearFormNetwork,
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
TreeTensorNetwork,
TTN,
random_ttn,
Expand Down
25 changes: 0 additions & 25 deletions src/renameitensornetwork.jl

This file was deleted.

9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ function line_to_tree(line::Vector)
end
return [line_to_tree(line[1:(end - 1)]), line[end]]
end

function Base.union(
graph1::AbstractNamedGraph,
graph2::AbstractNamedGraph,
graph3::AbstractNamedGraph,
graph_rest::AbstractNamedGraph...,
)
return union(union(graph1, graph2), graph3, graph_rest...)
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 51 additions & 0 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using ITensors
using Graphs
using NamedGraphs
using ITensorNetworks
using ITensorNetworks:
delta_network,
update,
tensornetwork,
bra_vertex_map,
ket_vertex_map,
dual_index_map,
bra,
ket,
operator
using Test
using Random

@testset "FormNetworkss" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
s_operator = union_all_inds(s_bra, s_ket)
χ, D = 2, 3
Random.seed!(1234)
ψket = randomITensorNetwork(s_ket; link_space=χ)
ψbra = randomITensorNetwork(s_bra; link_space=χ)
A = randomITensorNetwork(s_operator; link_space=D)

blf = BilinearFormNetwork(A, ψbra, ψket)
@test nv(blf) == nv(ψket) + nv(ψbra) + nv(A)
@test isempty(externalinds(blf))

@test underlying_graph(ket(blf)) == underlying_graph(ψket)
@test underlying_graph(operator(blf)) == underlying_graph(A)
@test underlying_graph(bra(blf)) == underlying_graph(ψbra)

qf = QuadraticFormNetwork(A, ψket)
@test nv(qf) == 2 * nv(ψbra) + nv(A)
@test isempty(externalinds(qf))

v = (1, 1)
new_tensor = randomITensor(inds(ψket[v]))
qf_updated = update(qf, v, copy(new_tensor))

@test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)] ≈
dual_index_map(qf_updated)(dag(new_tensor))
@test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] ≈ new_tensor

@test underlying_graph(ket(qf)) == underlying_graph(ψket)
@test underlying_graph(operator(qf)) == underlying_graph(A)
end
Loading