Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 3, 2024
1 parent 015d2e2 commit 423deac
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 72 deletions.
53 changes: 0 additions & 53 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,59 +37,6 @@ end

reset_ortho_center::AbstractTTN) = set_ortho_center(ψ, vertices(ψ))

#
# Dense constructors
#

# construct from dense ITensor, using IndsNetwork of site indices
function (::Type{TTNT})(
A::ITensor, is::IndsNetwork; ortho_center=default_root_vertex(is), kwargs...
) where {TTNT<:AbstractTTN}
for v in vertices(is)
@assert hasinds(A, is[v])
end
@assert ortho_center vertices(is)
ψ = ITensorNetwork(is)
= A
for e in post_order_dfs_edges(ψ, ortho_center)
left_inds = uniqueinds(is, e)
L, R = factorize(Ã, left_inds; tags=edge_tag(e), ortho="left", kwargs...)
l = commonind(L, R)
ψ[src(e)] = L
is[e] = [l]
= R
end
ψ[ortho_center] =
T = TTNT(ψ)
T = orthogonalize(T, ortho_center)
return T
end

# construct from dense ITensor, using AbstractNamedGraph and vector of site indices
# TODO: remove if it doesn't turn out to be useful
function (::Type{TTNT})(
A::ITensor, sites::Vector, g::AbstractNamedGraph; vertex_order=vertices(g), kwargs...
) where {TTNT<:AbstractTTN}
is = IndsNetwork(g; site_space=Dictionary(vertex_order, sites))
return TTNT(A, is; kwargs...)
end

# construct from dense array, using IndsNetwork
# TODO: probably remove this one, doesn't seem very useful
function (::Type{TTNT})(
A::AbstractArray{<:Number}, is::IndsNetwork; vertex_order=vertices(is), kwargs...
) where {TTNT<:AbstractTTN}
sites = [is[v] for v in vertex_order]
return TTNT(itensor(A, sites...), is; kwargs...)
end

# construct from dense array, using NamedDimGraph and vector of site indices
function (::Type{TTNT})(
A::AbstractArray{<:Number}, sites::Vector, args...; kwargs...
) where {TTNT<:AbstractTTN}
return TTNT(itensor(A, sites...), sites, args...; kwargs...)
end

#
# Orthogonalization
#
Expand Down
24 changes: 24 additions & 0 deletions src/treetensornetworks/ttn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,30 @@ function ttn(
return ttn(ElT, sites, ops; kwargs...)
end

# construct from dense ITensor, using IndsNetwork of site indices
function ttn(
A::ITensor, is::IndsNetwork; ortho_center=default_root_vertex(is), kwargs...
)
for v in vertices(is)
@assert hasinds(A, is[v])
end
@assert ortho_center vertices(is)
ψ = ITensorNetwork(is)
= A
for e in post_order_dfs_edges(ψ, ortho_center)
left_inds = uniqueinds(is, e)
L, R = factorize(Ã, left_inds; tags=edge_tag(e), ortho="left", kwargs...)
l = commonind(L, R)
ψ[src(e)] = L
is[e] = [l]
= R
end
ψ[ortho_center] =
T = ttn(ψ)
T = orthogonalize(T, ortho_center)
return T
end

# Special constructors

function mps(external_inds::Vector{<:Index}; states)
Expand Down
7 changes: 3 additions & 4 deletions test/test_treetensornetworks/test_solvers/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using ITensorNetworks:
ProjOuterProdTTN,
ProjTTNSum,
ttn,
TreeTensorNetwork,
apply,
contract,
delta,
Expand Down Expand Up @@ -157,10 +156,10 @@ end
M1 = replaceprime(randomMPO(sites) + randomMPO(sites), 1 => 2, 0 => 1)
M2 = randomMPO(sites) + randomMPO(sites)
M12_ref = contract(M1, M2; alg="naive")
t12_ref = TreeTensorNetwork([M12_ref[v] for v in eachindex(M12_ref)])
t12_ref = ttn([M12_ref[v] for v in eachindex(M12_ref)])

t1 = TreeTensorNetwork([M1[v] for v in eachindex(M1)])
t2 = TreeTensorNetwork([M2[v] for v in eachindex(M2)])
t1 = ttn([M1[v] for v in eachindex(M1)])
t2 = ttn([M2[v] for v in eachindex(M2)])

# Test with good initial guess
@test contract(t1, t2; alg="fit", init=t12_ref, nsweeps=1) t12_ref rtol = 1e-7
Expand Down
15 changes: 0 additions & 15 deletions test/test_ttno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,11 @@ using Test: @test, @testset
O = randomITensor(sites_o...)
# dense TTN constructor from IndsNetwork
@disable_warn_order o1 = ttn(O, is_isp; cutoff)
# dense TTN constructor from Vector{Vector{Index}} and NamedDimGraph
@disable_warn_order o2 = ttn(O, sites_o, c; vertex_order, cutoff)
# convert to array with proper index order
AO = Array(O, sites_o...)
# dense array constructor from IndsNetwork
@disable_warn_order o3 = ttn(AO, is_isp; vertex_order, cutoff)
# dense array constructor from Vector{Vector{Index}} and NamedDimGraph
@disable_warn_order o4 = ttn(AO, sites_o, c; vertex_order, cutoff)
# see if this actually worked
root_vertex = only(ortho_center(o1))
@disable_warn_order begin
O1 = contract(o1, root_vertex)
O2 = contract(o2, root_vertex)
O3 = contract(o3, root_vertex)
O4 = contract(o4, root_vertex)
end
@test norm(O - O1) < 1e2 * cutoff
@test norm(O - O2) < 1e2 * cutoff
@test norm(O - O3) < 1e2 * cutoff
@test norm(O - O4) < 1e2 * cutoff
end

@testset "Ortho" begin
Expand Down

0 comments on commit 423deac

Please sign in to comment.