Skip to content

Commit

Permalink
Rewrite tensor network constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 5, 2024
1 parent ae4ad2c commit 706b7aa
Show file tree
Hide file tree
Showing 11 changed files with 536 additions and 466 deletions.
18 changes: 15 additions & 3 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ end
# TODO: broadcasting

function Base.union(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork; kwargs...)
tn = ITensorNetwork(union(data_graph(tn1), data_graph(tn2)); kwargs...)
# TODO: Use a different constructor call here?
tn = _ITensorNetwork(union(data_graph(tn1), data_graph(tn2)); kwargs...)
# Add any new edges that are introduced during the union
for v1 in vertices(tn1)
for v2 in vertices(tn2)
Expand All @@ -129,7 +130,8 @@ function Base.union(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork; kw
end

function NamedGraphs.rename_vertices(f::Function, tn::AbstractITensorNetwork)
return ITensorNetwork(rename_vertices(f, data_graph(tn)))
# TODO: Use a different constructor call here?
return _ITensorNetwork(rename_vertices(f, data_graph(tn)))
end

#
Expand Down Expand Up @@ -736,7 +738,8 @@ function norm_network(tn::AbstractITensorNetwork)
setindex_preserve_graph!(tndag, dag(tndag[v]), v)
end
tnket = rename_vertices(v -> (v, 2), data_graph(prime(tndag; sites=[])))
tntn = ITensorNetwork(union(tnbra, tnket))
# TODO: Use a different constructor here?
tntn = _ITensorNetwork(union(tnbra, tnket))
for v in vertices(tn)
if !isempty(commoninds(tntn[(v, 1)], tntn[(v, 2)]))
add_edge!(tntn, (v, 1) => (v, 2))
Expand Down Expand Up @@ -809,6 +812,9 @@ end

Base.show(io::IO, graph::AbstractITensorNetwork) = show(io, MIME"text/plain"(), graph)

# TODO: Move to an `ITensorNetworksVisualizationInterfaceExt`
# package extension (and define a `VisualizationInterface` package
# based on `ITensorVisualizationCore`.).
function ITensorVisualizationCore.visualize(
tn::AbstractITensorNetwork,
args...;
Expand Down Expand Up @@ -865,6 +871,7 @@ function site_combiners(tn::AbstractITensorNetwork{V}) where {V}
return Cs
end

# TODO: Combine with `insert_links`.
function insert_missing_internal_inds(
tn::AbstractITensorNetwork, edges; internal_inds_space=trivial_space(tn)
)
Expand All @@ -880,12 +887,17 @@ function insert_missing_internal_inds(
return tn
end

# TODO: Combine with `insert_links`.
function insert_missing_internal_inds(
tn::AbstractITensorNetwork; internal_inds_space=trivial_space(tn)
)
return insert_internal_inds(tn, edges(tn); internal_inds_space)
end

# TODO: What to output? Could be an `IndsNetwork`. Or maybe
# that would be a different function `commonindsnetwork`.
# Even in that case, this could output a `Dictionary`
# from the edges to the common inds on that edge.
function ITensors.commoninds(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
inds = Index[]
for v1 in vertices(tn1)
Expand Down
Loading

0 comments on commit 706b7aa

Please sign in to comment.