From c2cc66a1cd686657a5945ac350e2240547f0dc77 Mon Sep 17 00:00:00 2001 From: Joseph Tindall <51231103+JoeyT1994@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:50:15 -0500 Subject: [PATCH] Add support for tensor network forms (bilinear and quadratic) (#136) --- Project.toml | 2 +- src/ITensorNetworks.jl | 4 +- src/exports.jl | 2 + src/formnetworks/abstractformnetwork.jl | 74 ++++++++++++++++++++++++ src/formnetworks/bilinearformnetwork.jl | 62 ++++++++++++++++++++ src/formnetworks/quadraticformnetwork.jl | 65 +++++++++++++++++++++ src/renameitensornetwork.jl | 25 -------- test/test_forms.jl | 51 ++++++++++++++++ 8 files changed, 258 insertions(+), 27 deletions(-) create mode 100644 src/formnetworks/abstractformnetwork.jl create mode 100644 src/formnetworks/bilinearformnetwork.jl create mode 100644 src/formnetworks/quadraticformnetwork.jl delete mode 100644 src/renameitensornetwork.jl create mode 100644 test/test_forms.jl diff --git a/Project.toml b/Project.toml index 874c0c1b..28610a59 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ ITensors = "0.3.23" IsApprox = "0.1" IterTools = "1.4.0" KrylovKit = "0.6.0" -NamedGraphs = "0.1.11" +NamedGraphs = "0.1.20" Observers = "0.2" Requires = "1.3" SimpleTraits = "0.9" diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index dbd1e515..6df0d036 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -95,10 +95,12 @@ include(joinpath("approx_itensornetwork", "binary_tree_partition.jl")) include("contract.jl") include("utility.jl") include("specialitensornetworks.jl") -include("renameitensornetwork.jl") include("boundarymps.jl") include(joinpath("beliefpropagation", "beliefpropagation.jl")) include(joinpath("beliefpropagation", "beliefpropagation_schedule.jl")) +include(joinpath("formnetworks", "abstractformnetwork.jl")) +include(joinpath("formnetworks", "bilinearformnetwork.jl")) +include(joinpath("formnetworks", "quadraticformnetwork.jl")) include("contraction_tree_to_graph.jl") include("gauging.jl") include("utils.jl") diff --git a/src/exports.jl b/src/exports.jl index e76c3ea9..a3ef21f5 100644 --- a/src/exports.jl +++ b/src/exports.jl @@ -71,6 +71,8 @@ export AbstractITensorNetwork, mps, ortho_center, set_ortho_center, + BilinearFormNetwork, + QuadraticFormNetwork, TreeTensorNetwork, TTN, random_ttn, diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl new file mode 100644 index 00000000..e6efe54e --- /dev/null +++ b/src/formnetworks/abstractformnetwork.jl @@ -0,0 +1,74 @@ +default_bra_vertex_suffix() = "bra" +default_ket_vertex_suffix() = "ket" +default_operator_vertex_suffix() = "operator" + +abstract type AbstractFormNetwork{V} <: AbstractITensorNetwork{V} end + +#Needed for interface +dual_index_map(f::AbstractFormNetwork) = not_implemented() +tensornetwork(f::AbstractFormNetwork) = not_implemented() +copy(f::AbstractFormNetwork) = not_implemented() +operator_vertex_suffix(f::AbstractFormNetwork) = not_implemented() +bra_vertex_suffix(f::AbstractFormNetwork) = not_implemented() +ket_vertex_suffix(f::AbstractFormNetwork) = not_implemented() + +function operator_vertices(f::AbstractFormNetwork) + return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f)) +end +function bra_vertices(f::AbstractFormNetwork) + return filter(v -> last(v) == bra_vertex_suffix(f), vertices(f)) +end + +function ket_vertices(f::AbstractFormNetwork) + return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f)) +end + +function bra_ket_vertices(f::AbstractFormNetwork) + return vcat(bra_vertices(f), ket_vertices(f)) +end + +function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector) + return [bra_vertex_map(f)(sv) for sv in state_vertices] +end + +function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) + return [ket_vertex_map(f)(sv) for sv in state_vertices] +end + +function bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector) + return vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) +end + +function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector) + return induced_subgraph(tensornetwork(f), vertices) +end + +function bra_network(f::AbstractFormNetwork) + return rename_vertices(inv_vertex_map(f), first(induced_subgraph(f, bra_vertices(f)))) +end + +function ket_network(f::AbstractFormNetwork) + return rename_vertices(inv_vertex_map(f), first(induced_subgraph(f, ket_vertices(f)))) +end + +function operator_network(f::AbstractFormNetwork) + return rename_vertices( + inv_vertex_map(f), first(induced_subgraph(f, operator_vertices(f))) + ) +end + +function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) + tn_vertices = derivative_vertices(f, state_vertices) + return derivative(tensornetwork(f), tn_vertices; kwargs...) +end + +function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...) + return setdiff( + vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices)) + ) +end + +operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f)) +bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f)) +ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f)) +inv_vertex_map(f::AbstractFormNetwork) = v -> first(v) diff --git a/src/formnetworks/bilinearformnetwork.jl b/src/formnetworks/bilinearformnetwork.jl new file mode 100644 index 00000000..356b0ed1 --- /dev/null +++ b/src/formnetworks/bilinearformnetwork.jl @@ -0,0 +1,62 @@ +struct BilinearFormNetwork{ + V, + TensorNetwork<:AbstractITensorNetwork{V}, + OperatorVertexSuffix, + BraVertexSuffix, + KetVertexSuffix, +} <: AbstractFormNetwork{V} + tensornetwork::TensorNetwork + operator_vertex_suffix::OperatorVertexSuffix + bra_vertex_suffix::BraVertexSuffix + ket_vertex_suffix::KetVertexSuffix +end + +function BilinearFormNetwork( + operator::AbstractITensorNetwork, + bra::AbstractITensorNetwork, + ket::AbstractITensorNetwork; + operator_vertex_suffix=default_operator_vertex_suffix(), + bra_vertex_suffix=default_bra_vertex_suffix(), + ket_vertex_suffix=default_ket_vertex_suffix(), +) + tn = disjoint_union( + operator_vertex_suffix => operator, bra_vertex_suffix => bra, ket_vertex_suffix => ket + ) + return BilinearFormNetwork( + tn, operator_vertex_suffix, bra_vertex_suffix, ket_vertex_suffix + ) +end + +operator_vertex_suffix(blf::BilinearFormNetwork) = blf.operator_vertex_suffix +bra_vertex_suffix(blf::BilinearFormNetwork) = blf.bra_vertex_suffix +ket_vertex_suffix(blf::BilinearFormNetwork) = blf.ket_vertex_suffix +tensornetwork(blf::BilinearFormNetwork) = blf.tensornetwork +data_graph_type(::Type{<:BilinearFormNetwork}) = data_graph_type(tensornetwork(blf)) +data_graph(blf::BilinearFormNetwork) = data_graph(tensornetwork(blf)) + +function copy(blf::BilinearFormNetwork) + return BilinearFormNetwork( + copy(tensornetwork(blf)), + operator_vertex_suffix(blf), + bra_vertex_suffix(blf), + ket_vertex_suffix(blf), + ) +end + +function BilinearFormNetwork( + bra::AbstractITensorNetwork, ket::AbstractITensorNetwork; kwargs... +) + operator_inds = union_all_inds(siteinds(bra), siteinds(ket)) + O = delta_network(operator_inds) + return BilinearFormNetwork(O, bra, ket; kwargs...) +end + +function update( + blf::BilinearFormNetwork, state_vertex, bra_state::ITensor, ket_state::ITensor +) + blf = copy(blf) + # TODO: Maybe add a check that it really does preserve the graph. + setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex)) + setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex)) + return blf +end diff --git a/src/formnetworks/quadraticformnetwork.jl b/src/formnetworks/quadraticformnetwork.jl new file mode 100644 index 00000000..5acee59e --- /dev/null +++ b/src/formnetworks/quadraticformnetwork.jl @@ -0,0 +1,65 @@ +default_index_map = prime +default_inv_index_map = noprime + +struct QuadraticFormNetwork{V,FormNetwork<:BilinearFormNetwork{V},IndexMap,InvIndexMap} <: + AbstractFormNetwork{V} + formnetwork::FormNetwork + dual_index_map::IndexMap + dual_inv_index_map::InvIndexMap +end + +bilinear_formnetwork(qf::QuadraticFormNetwork) = qf.formnetwork + +#Needed for implementation, forward from bilinear form +for f in [ + :operator_vertex_suffix, + :bra_vertex_suffix, + :ket_vertex_suffix, + :tensornetwork, + :data_graph, + :data_graph_type, +] + @eval begin + function $f(qf::QuadraticFormNetwork, args...; kwargs...) + return $f(bilinear_formnetwork(qf), args...; kwargs...) + end + end +end + +dual_index_map(qf::QuadraticFormNetwork) = qf.dual_index_map +dual_inv_index_map(qf::QuadraticFormNetwork) = qf.dual_inv_index_map +function copy(qf::QuadraticFormNetwork) + return QuadraticFormNetwork( + copy(bilinear_formnetwork(qf)), dual_index_map(qf), dual_inv_index_map(qf) + ) +end + +function QuadraticFormNetwork( + operator::AbstractITensorNetwork, + ket::AbstractITensorNetwork; + dual_index_map=default_index_map, + dual_inv_index_map=default_inv_index_map, + kwargs..., +) + bra = map_inds(dual_index_map, dag(ket)) + blf = BilinearFormNetwork(operator, bra, ket; kwargs...) + return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map) +end + +function QuadraticFormNetwork( + ket::AbstractITensorNetwork; + dual_index_map=default_index_map, + dual_inv_index_map=default_inv_index_map, + kwargs..., +) + bra = map_inds(dual_index_map, dag(ket)) + blf = BilinearFormNetwork(bra, ket; kwargs...) + return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map) +end + +function update(qf::QuadraticFormNetwork, state_vertex, ket_state::ITensor) + state_inds = inds(ket_state) + bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds)) + new_blf = update(bilinear_formnetwork(qf), state_vertex, bra_state, ket_state) + return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf)) +end diff --git a/src/renameitensornetwork.jl b/src/renameitensornetwork.jl deleted file mode 100644 index da810c15..00000000 --- a/src/renameitensornetwork.jl +++ /dev/null @@ -1,25 +0,0 @@ - -#RENAME THE VERTICES OF AN ITENSORNETWORK, THIS SHOULD NOT BE NEEDED BUT CURRENTLY IS BECAUSE RENAME_VERTICES DOESN'T WRAP ONTO IT -function rename_vertices_itn(psi::ITensorNetwork, name_map::Dictionary) - old_g = NamedGraph(vertices(psi)) - - for e in edges(psi) - add_edge!(old_g, e) - end - - new_g = rename_vertices(old_g, name_map) - - psi_new = ITensorNetwork(new_g) - for v in vertices(psi) - psi_new[name_map[v]] = psi[v] - end - - return psi_new -end - -function rename_vertices_itn(psi::ITensorNetwork, name_map::Function) - original_vertices = vertices(psi) - return rename_vertices_itn( - psi, Dictionary(original_vertices, name_map.(original_vertices)) - ) -end diff --git a/test/test_forms.jl b/test/test_forms.jl new file mode 100644 index 00000000..74982629 --- /dev/null +++ b/test/test_forms.jl @@ -0,0 +1,51 @@ +using ITensors +using Graphs +using NamedGraphs +using ITensorNetworks +using ITensorNetworks: + delta_network, + update, + tensornetwork, + bra_vertex_map, + ket_vertex_map, + dual_index_map, + bra_network, + ket_network, + operator_network +using Test +using Random + +@testset "FormNetworkss" begin + g = named_grid((1, 4)) + s_ket = siteinds("S=1/2", g) + s_bra = prime(s_ket; links=[]) + s_operator = union_all_inds(s_bra, s_ket) + χ, D = 2, 3 + Random.seed!(1234) + ψket = randomITensorNetwork(s_ket; link_space=χ) + ψbra = randomITensorNetwork(s_bra; link_space=χ) + A = randomITensorNetwork(s_operator; link_space=D) + + blf = BilinearFormNetwork(A, ψbra, ψket) + @test nv(blf) == nv(ψket) + nv(ψbra) + nv(A) + @test isempty(externalinds(blf)) + + @test underlying_graph(ket_network(blf)) == underlying_graph(ψket) + @test underlying_graph(operator_network(blf)) == underlying_graph(A) + @test underlying_graph(bra_network(blf)) == underlying_graph(ψbra) + + qf = QuadraticFormNetwork(A, ψket) + @test nv(qf) == 2 * nv(ψbra) + nv(A) + @test isempty(externalinds(qf)) + + v = (1, 1) + new_tensor = randomITensor(inds(ψket[v])) + qf_updated = update(qf, v, copy(new_tensor)) + + @test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)] ≈ + dual_index_map(qf_updated)(dag(new_tensor)) + @test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] ≈ new_tensor + + @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) + @test underlying_graph(operator_network(qf)) == underlying_graph(A) +end