Skip to content

Commit

Permalink
Updated interface. FormNetworks. QuadraticFormNetwork wraps BilinearF…
Browse files Browse the repository at this point in the history
…ormNetwork
  • Loading branch information
JoeyT1994 committed Feb 9, 2024
1 parent 3cd5222 commit d0e0c0f
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 188 deletions.
21 changes: 21 additions & 0 deletions src/FormNetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
abstract type AbstractFormNetwork{V} <: AbstractITensorNetwork{V} end

#Needed for interface
bra_vertex_map(f::AbstractFormNetwork) = not_implemented()
ket_vertex_map(f::AbstractFormNetwork) = not_implemented()
operator_vertex_map(f::AbstractFormNetwork) = not_implemented()
dual_index_map(f::AbstractFormNetwork) = not_implemented()
tensornetwork(f::AbstractFormNetwork) = not_implemented()
copy(f::AbstractFormNetwork) = not_implemented()
derivative_vertices(f::AbstractFormNetwork) = not_implemented()

bra(f::AbstractFormNetwork) = induced_subgraph(f, collect(values(bra_vertex_map(f))))
ket(f::AbstractFormNetwork) = induced_subgraph(f, collect(values(ket_vertex_map(f))))
function operator(f::AbstractFormNetwork)
return induced_subgraph(f, collect(values(operator_vertex_map(f))))
end

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
end
62 changes: 62 additions & 0 deletions src/FormNetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
default_bra_vertex_map(v) = (v, "bra")
default_ket_vertex_map(v) = (v, "ket")
default_operator_vertex_map(v) = (v, "operator")
default_operator_constructor(s::IndsNetwork) = delta_network(s)

struct BilinearFormNetwork{V,KetMap,BraMap,OperatorMap} <: AbstractFormNetwork{V}
tn::AbstractITensorNetwork{V}
bra_vertex_map::BraMap
ket_vertex_map::KetMap
operator_vertex_map::OperatorMap
end

function BilinearFormNetwork(
operator::AbstractITensorNetwork,
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
bra_vertex_map=default_bra_vertex_map,
ket_vertex_map=default_ket_vertex_map,
operator_vertex_map=default_operator_vertex_map,
)
@assert Set(externalinds(operator)) ==
union(Set(externalinds(ket)), Set(externalinds(bra)))
@assert isempty(findall(in(internalinds(bra)), internalinds(ket)))
@assert isempty(findall(in(internalinds(bra)), internalinds(operator)))
@assert isempty(findall(in(internalinds(ket)), internalinds(operator)))

bra_renamed = rename_vertices_itn(bra, bra_vertex_map)
ket_renamed = rename_vertices_itn(ket, ket_vertex_map)
operator_renamed = rename_vertices_itn(operator, operator_vertex_map)

tn = union(union(bra_renamed, operator_renamed), ket_renamed)

return BilinearFormNetwork(tn, bra_vertex_map, ket_vertex_map, operator_vertex_map)
end

#Needed for implementation
bra_vertex_map(blf::BilinearFormNetwork) = blf.bra_vertex_map
ket_vertex_map(blf::BilinearFormNetwork) = blf.ket_vertex_map
operator_vertex_map(blf::BilinearFormNetwork) = blf.operator_vertex_map
tensornetwork(blf::BilinearFormNetwork) = blf.tn
data_graph_type(::Type{<:BilinearFormNetwork}) = data_graph_type(tensornetwork(blf))
data_graph(blf::BilinearFormNetwork) = data_graph(tensornetwork(blf))

function copy(blf::BilinearFormNetwork)
return BilinearFormNetwork(
copy(tensornetwork(blf)),
bra_vertex_map(blf),
ket_vertex_map(blf),
operator_vertex_map(blf),
)
end

function BilinearFormNetwork(
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
operator_constructor=default_operator_constructor,
kwargs...,
)
operator_space = union_all_inds(siteinds(bra), siteinds(ket))
O = tno_constructor(operator_space)
return BilinearFormNetwork(bra, O, ket; kwargs...)
end
86 changes: 86 additions & 0 deletions src/FormNetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
default_index_map = prime

struct QuadraticFormNetwork{V,FormNetwork<:BilinearFormNetwork{V},IndexMap} <:
AbstractFormNetwork{V}
formnetwork::FormNetwork
dual_index_map::IndexMap
end

bilinear_formnetwork(qf::QuadraticFormNetwork) = qf.formnetwork
function QuadraticFormNetwork(
operator::AbstractITensorNetwork,
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
kwargs...,
)
return QuadraticFormNetwork(
BilinearFormNetwork(operator, bra, ket; kwargs...), dual_index_map
)
end

#Needed for implementation, forward from bilinear form
for f in [
:bra_vertex_map,
:ket_vertex_map,
:operator_vertex_map,
:tensornetwork,
:data_graph,
:data_graph_type,
]
@eval begin
function $f(qf::QuadraticFormNetwork, args...; kwargs...)
return $f(bilinear_formnetwork(qf), args...; kwargs...)
end
end
end

dual_index_map(qf::QuadraticFormNetwork) = qf.dual_index_map
function copy(qf::QuadraticFormNetwork)
return QuadraticFormNetwork(copy(bilinear_formnetwork(qf)), dual_index_map(qf))
end

function QuadraticFormNetwork(
operator::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
kwargs...,
)
bra = map_inds(dual_index_map, dag(ket))
return QuadraticFormNetwork(operator, bra, ket; dual_index_map, kwargs...)
end

function QuadraticFormNetwork(
ket::AbstractITensorNetwork;
dual_index_map=default_index_map,
operator_constructor=default_operator_constructor,
kwargs...,
)
s = siteinds(ket)
operator_space = union_all_inds(s, dual_index_map(s; links=[]))
operator = operator_constructor(operator_space)
return QuadraticFormNetwork(operator, ket; kwargs...)
end

function bra_vertices(qf::QuadraticFormNetwork, state_vertices::Vector)
return [bra_vertex_map(qf)(sv) for sv in state_vertices]
end

function ket_vertices(qf::QuadraticFormNetwork, state_vertices::Vector)
return [ket_vertex_map(qf)(sv) for sv in state_vertices]
end

function derivative_vertices(qf::QuadraticFormNetwork, state_vertices::Vector; kwargs...)
return setdiff(
vertices(f), vcat(bra_vertices(qf, state_vertices), ket_vertices(qf, state_vertices))
)
end

function update(qf::QuadraticFormNetwork, state_vertex, state::ITensor)
qf = copy(qf)
state_inds = inds(state)
state_dag = replaceinds(dag(state), state_inds, dual_index_map(qf).(state_inds))
setindex_preserve_graph!(tensornetwork(qf), state, ket_vertex_map(qf)(state_vertex))
setindex_preserve_graph!(tensornetwork(qf), state_dag, bra_vertex_map(qf)(state_vertex))
return qf
end
14 changes: 0 additions & 14 deletions src/Forms/abstractbilinearform.jl

This file was deleted.

42 changes: 0 additions & 42 deletions src/Forms/bilinearform.jl

This file was deleted.

44 changes: 0 additions & 44 deletions src/Forms/construct_form.jl

This file was deleted.

70 changes: 0 additions & 70 deletions src/Forms/quadraticform.jl

This file was deleted.

7 changes: 3 additions & 4 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ include("renameitensornetwork.jl")
include("boundarymps.jl")
include(joinpath("beliefpropagation", "beliefpropagation.jl"))
include(joinpath("beliefpropagation", "beliefpropagation_schedule.jl"))
include(joinpath("Forms", "abstractbilinearform.jl"))
include(joinpath("Forms", "bilinearform.jl"))
include(joinpath("Forms", "quadraticform.jl"))
include(joinpath("Forms", "construct_form.jl"))
include(joinpath("FormNetworks", "abstractformnetwork.jl"))
include(joinpath("FormNetworks", "bilinearformnetwork.jl"))
include(joinpath("FormNetworks", "quadraticformnetwork.jl"))
include("contraction_tree_to_graph.jl")
include("gauging.jl")
include("utils.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ export AbstractITensorNetwork,
mps,
ortho_center,
set_ortho_center,
QuadraticForm,
BilinearForm,
QuadraticFormNetwork,
BilinearFormNetwork,
TreeTensorNetwork,
TTN,
random_ttn,
Expand Down
Loading

0 comments on commit d0e0c0f

Please sign in to comment.