Skip to content

Commit

Permalink
More test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 17, 2024
1 parent fa035f1 commit 8b4e8a5
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ include("solvers/local_solvers/dmrg_x.jl")
include("solvers/local_solvers/contract.jl")
include("solvers/local_solvers/linsolve.jl")
include("treetensornetworks/abstracttreetensornetwork.jl")
include("treetensornetworks/ttn.jl")
include("treetensornetworks/treetensornetwork.jl")
include("treetensornetworks/opsum_to_ttn.jl")
include("treetensornetworks/projttns/abstractprojttn.jl")
include("treetensornetworks/projttns/projttn.jl")
Expand Down
6 changes: 5 additions & 1 deletion src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,11 @@ function neighbor_vertices(ψ::AbstractITensorNetwork, T::ITensor)
end

function linkinds_combiners(tn::AbstractITensorNetwork; edges=edges(tn))
combiners = DataGraph(directed_graph(underlying_graph(tn)); vertex_data_eltype=ITensor, edge_data_eltype=ITensor)
combiners = DataGraph(
directed_graph(underlying_graph(tn));
vertex_data_eltype=ITensor,
edge_data_eltype=ITensor,
)
for e in edges
C = combiner(linkinds(tn, e); tags=edge_tag(e))
combiners[e] = C
Expand Down
6 changes: 4 additions & 2 deletions src/edge_sequences.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Graphs: IsDirected, connected_components, edges, edgetype
using ITensors.NDTensors: Algorithm, @Algorithm_str
using NamedGraphs: NamedGraphs
using NamedGraphs.GraphsExtensions: undirected_graph
using NamedGraphs.GraphsExtensions: GraphsExtensions, undirected_graph
using NamedGraphs.PartitionedGraphs: PartitionEdge, PartitionedGraph, partitioned_graph
using SimpleTraits: SimpleTraits, @traitfn, Not
using SimpleTraits
Expand All @@ -26,7 +26,9 @@ end
end

@traitfn function edge_sequence(
::Algorithm"forest_cover", g::::(!IsDirected); root_vertex=NamedGraphs.default_root_vertex
::Algorithm"forest_cover",
g::::(!IsDirected);
root_vertex=GraphsExtensions.default_root_vertex,
)
forests = NamedGraphs.forest_cover(g)
edges = edgetype(g)[]
Expand Down
3 changes: 2 additions & 1 deletion src/solvers/alternating_update/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ITensors: state
using ITensors.ITensorMPS: linkind
using NamedGraphs.GraphsExtensions: GraphsExtensions
using Observers: Observers

function alternating_update(
Expand All @@ -13,7 +14,7 @@ function alternating_update(
sweep_printer=nothing,
(sweep_observer!)=nothing,
(region_observer!)=nothing,
root_vertex=default_root_vertex(init_state),
root_vertex=GraphsExtensions.default_root_vertex(init_state),
extracter_kwargs=(;),
extracter=default_extracter(),
updater_kwargs=(;),
Expand Down
12 changes: 9 additions & 3 deletions src/solvers/sweep_plans/sweep_plans.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using Graphs: AbstractEdge, dst, src
using NamedGraphs.GraphsExtensions: GraphsExtensions

direction(step_number) = isodd(step_number) ? Base.Forward : Base.Reverse

function overlap(edge_a::AbstractEdge, edge_b::AbstractEdge)
Expand Down Expand Up @@ -58,7 +61,7 @@ end
function forward_sweep(
dir::Base.ForwardOrdering,
graph::AbstractGraph;
root_vertex=default_root_vertex(graph),
root_vertex=GraphsExtensions.default_root_vertex(graph),
region_kwargs,
reverse_kwargs=region_kwargs,
reverse_step=false,
Expand Down Expand Up @@ -141,7 +144,10 @@ function default_sweep_plans(
end

function default_sweep_plan(
graph::AbstractGraph; root_vertex=default_root_vertex(graph), region_kwargs, nsites::Int
graph::AbstractGraph;
root_vertex=GraphsExtensions.default_root_vertex(graph),
region_kwargs,
nsites::Int,
)
return vcat(
[
Expand All @@ -158,7 +164,7 @@ end

function tdvp_sweep_plan(
graph::AbstractGraph;
root_vertex=default_root_vertex(graph),
root_vertex=GraphsExtensions.default_root_vertex(graph),
region_kwargs,
reverse_step=true,
order::Int,
Expand Down
4 changes: 3 additions & 1 deletion src/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using NamedGraphs.GraphsExtensions: GraphsExtensions

#ToDo: Cleanup _compute_nsweeps, maybe restrict flexibility to simplify code
function _compute_nsweeps(nsweeps::Int, t::Number, time_step::Number)
return error("Cannot specify both nsweeps and time_step in tdvp")
Expand Down Expand Up @@ -101,7 +103,7 @@ function tdvp(
sweep_printer=nothing,
(sweep_observer!)=nothing,
(region_observer!)=nothing,
root_vertex=default_root_vertex(init_state),
root_vertex=GraphsExtensions.default_root_vertex(init_state),
reverse_step=true,
extracter_kwargs=(;),
extracter=default_extracter(), # ToDo: extracter could be inside extracter_kwargs, at the cost of having to extract it in region_update
Expand Down
36 changes: 20 additions & 16 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Graphs: has_vertex
using NamedGraphs.GraphsExtensions:
edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
using IsApprox: IsApprox, Approx
using ITensors: @Algorithm_str, directsum, hasinds, permute, plev
using ITensors.ITensorMPS: linkind, loginner, lognorm, orthogonalize
Expand All @@ -21,11 +21,6 @@ end
ITensorNetwork(tn::AbstractTTN) = error("Not implemented")
ortho_region(tn::AbstractTTN) = error("Not implemented")

function default_root_vertex(gs::AbstractGraph...)
# @assert all(is_tree.(gs))
return first(leaf_vertices(gs[end]))
end

#
# Orthogonality center
#
Expand Down Expand Up @@ -64,7 +59,9 @@ end
# Truncation
#

function Base.truncate(tn::AbstractTTN; root_vertex=default_root_vertex(tn), kwargs...)
function Base.truncate(
tn::AbstractTTN; root_vertex=GraphsExtensions.default_root_vertex(tn), kwargs...
)
for e in post_order_dfs_edges(tn, root_vertex)
# always orthogonalize towards source first to make truncations controlled
tn = orthogonalize(tn, src(e))
Expand All @@ -84,7 +81,9 @@ end
#

# TODO: decide on contraction order: reverse dfs vertices or forward dfs edges?
function NDTensors.contract(tn::AbstractTTN, root_vertex=default_root_vertex(tn); kwargs...)
function NDTensors.contract(
tn::AbstractTTN, root_vertex=GraphsExtensions.default_root_vertex(tn); kwargs...
)
tn = copy(tn)
# reverse post order vertices
traversal_order = reverse(post_order_dfs_vertices(tn, root_vertex))
Expand All @@ -98,7 +97,7 @@ function NDTensors.contract(tn::AbstractTTN, root_vertex=default_root_vertex(tn)
end

function ITensors.inner(
x::AbstractTTN, y::AbstractTTN; root_vertex=default_root_vertex(x, y)
x::AbstractTTN, y::AbstractTTN; root_vertex=GraphsExtensions.default_root_vertex(x)
)
xᴴ = sim(dag(x); sites=[])
y = sim(y; sites=[])
Expand Down Expand Up @@ -186,7 +185,7 @@ end

# TODO: stick with this traversal or find optimal contraction sequence?
function ITensorMPS.loginner(
tn1::AbstractTTN, tn2::AbstractTTN; root_vertex=default_root_vertex(tn1, tn2)
tn1::AbstractTTN, tn2::AbstractTTN; root_vertex=GraphsExtensions.default_root_vertex(tn1)
)
N = nv(tn1)
if nv(tn2) != N
Expand Down Expand Up @@ -228,14 +227,16 @@ function Base.:+(
::Algorithm"densitymatrix",
tns::AbstractTTN...;
cutoff=1e-15,
root_vertex=default_root_vertex(tns...),
root_vertex=GraphsExtensions.default_root_vertex(first(tns)),
kwargs...,
)
return error("Not implemented (yet) for trees.")
end

function Base.:+(
::Algorithm"directsum", tns::AbstractTTN...; root_vertex=default_root_vertex(tns...)
::Algorithm"directsum",
tns::AbstractTTN...;
root_vertex=GraphsExtensions.default_root_vertex(first(tns)),
)
@assert all(tn -> nv(first(tns)) == nv(tn), tns)

Expand Down Expand Up @@ -302,7 +303,10 @@ end

# TODO: implement using multi-graph disjoint union
function ITensors.inner(
y::AbstractTTN, A::AbstractTTN, x::AbstractTTN; root_vertex=default_root_vertex(x, A, y)
y::AbstractTTN,
A::AbstractTTN,
x::AbstractTTN;
root_vertex=GraphsExtensions.default_root_vertex(x),
)
traversal_order = reverse(post_order_dfs_vertices(x, root_vertex))
ydag = sim(dag(y); sites=[])
Expand All @@ -320,7 +324,7 @@ function ITensors.inner(
y::AbstractTTN,
A::AbstractTTN,
x::AbstractTTN;
root_vertex=default_root_vertex(B, y, A, x),
root_vertex=GraphsExtensions.default_root_vertex(B),
)
N = nv(B)
if nv(y) != N || nv(x) != N || nv(A) != N
Expand Down Expand Up @@ -349,8 +353,8 @@ function ITensorMPS.expect(
operator::String,
state::AbstractTTN;
vertices=vertices(state),
# ToDo: verify that this is a sane default
root_vertex=default_root_vertex(siteinds(state)),
# TODO: verify that this is a sane default
root_vertex=GraphsExtensions.default_root_vertex(state),
)
# TODO: Optimize this with proper caching.
state /= norm(state)
Expand Down
23 changes: 6 additions & 17 deletions src/treetensornetworks/opsum_to_ttn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ using ITensors.ITensorMPS: ITensorMPS, cutoff, linkdims, truncate!
using ITensors.LazyApply: Prod, Sum, coefficient
using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
using ITensors.Ops: Op, OpSum
using NamedGraphs.GraphsExtensions: boundary_edges, degrees, is_leaf, vertex_path
using NamedGraphs.GraphsExtensions:
GraphsExtensions, boundary_edges, degrees, is_leaf, vertex_path
using StaticArrays: MVector

# convert ITensors.OpSum to TreeTensorNetwork
Expand Down Expand Up @@ -522,29 +523,17 @@ Convert an OpSum object `os` to a TreeTensorNetwork, with indices given by `site
function ttn(
os::OpSum,
sites::IndsNetwork;
root_vertex=default_root_vertex(sites),
splitblocks=false,
algorithm="svd",
root_vertex=GraphsExtensions.default_root_vertex(sites),
kwargs...,
)::TTN
)
length(ITensors.terms(os)) == 0 && error("OpSum has no terms")
is_tree(sites) || error("Site index graph must be a tree.")
is_leaf(sites, root_vertex) || error("Tree root must be a leaf vertex.")

os = deepcopy(os)
os = sorteachterm(os, sites, root_vertex)
os = ITensorMPS.sortmergeterms(os) # not exported
if algorithm == "svd"
T = ttn_svd(os, sites, root_vertex; kwargs...)
else
error("Currently only SVD is supported as TTN constructor backend.")
end

if splitblocks
error("splitblocks not yet implemented for AbstractTreeTensorNetwork.")
T = ITensors.splitblocks(linkinds, T) # TODO: make this work
end
return T
os = ITensorMPS.sortmergeterms(os)
return ttn_svd(os, sites, root_vertex; kwargs...)
end

function mpo(os::OpSum, external_inds::Vector; kwargs...)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Graphs: path_graph
using ITensors: ITensor
using LinearAlgebra: factorize, normalize
using NamedGraphs: vertextype
using NamedGraphs.GraphsExtensions: GraphsExtensions, vertextype

"""
TreeTensorNetwork{V} <: AbstractTreeTensorNetwork{V}
Expand Down Expand Up @@ -76,7 +76,12 @@ function mps(f, is::Vector{<:Index}; kwargs...)
end

# Construct from dense ITensor, using IndsNetwork of site indices.
function ttn(a::ITensor, is::IndsNetwork; ortho_region=[default_root_vertex(is)], kwargs...)
function ttn(
a::ITensor,
is::IndsNetwork;
ortho_region=[GraphsExtensions.default_root_vertex(is)],
kwargs...,
)
for v in vertices(is)
@assert hasinds(a, is[v])
end
Expand Down

0 comments on commit 8b4e8a5

Please sign in to comment.