Skip to content

Commit

Permalink
Remove prime-level logic from apply, and use contract when primelevel…
Browse files Browse the repository at this point in the history
… is not increased by operator.
  • Loading branch information
b-kloss committed Feb 1, 2024
1 parent b6a8f32 commit 80ccddc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 25 deletions.
21 changes: 2 additions & 19 deletions src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,33 +74,16 @@ end
Overload of `ITensors.apply`.
"""
function apply(tn1::AbstractTTN, tn2::AbstractTTN; init, kwargs...)
#plin=plev(first(externalinds(init)))
#outputindsnet=
#plout=plev(outputindsnet[first(vertices(outputindsnet))])
plev_inc = plev_diff(flatten_external_indsnetwork(tn1, tn2), external_indsnetwork(init))
init = prime(init, plev_inc)
init = init'
tn12 = contract(tn1, tn2; init, kwargs...)
return replaceprime(tn12, 1 => 0)
end

function sum_apply(
tns::Vector{<:Tuple{<:AbstractTTN,<:AbstractTTN}}; alg="fit", init, kwargs...
)
#plin=plev(first(externalinds(init)))
#outputindsnet = flatten_external_indsnetwork(first(tns), first(tn2s))
#plout=plev(outputindsnet[first(vertices(outputindsnet))])
plev_inc = plev_diff(
flatten_external_indsnetwork(first(first(tns)), last(first(tns))),
external_indsnetwork(init),
)
init = prime(init, plev_inc)
init = init'
alg != "fit" && error("sum_apply not implemented for other algorithms than fit.")
tn12 = contract(Algorithm(alg), tns; init, kwargs...)
return replaceprime(tn12, 1 => 0)
end

function plev_diff(a::IndsNetwork, b::IndsNetwork)
pla = plev(only(a[first(vertices(a))]))
plb = plev(only(b[first(vertices(b))]))
return pla - plb
end
12 changes: 6 additions & 6 deletions test/test_treetensornetworks/test_solvers/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ using Test
psit[j] *= delta(s[j], t[j])
end
# Test with nsweeps=3
Hpsi = apply(H, psi; alg="fit", init=psit, nsweeps=3)
Hpsi = contract(H, psi; alg="fit", init=psit, nsweeps=3)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
# Test with less good initial guess MPS not equal to psi
psi_guess = truncate(psit; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=psi_guess)
Hpsi = contract(H, psi; alg="fit", nsweeps=4, init=psi_guess)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with nsite=1
Hpsi_guess = random_mps(t; internal_inds_space=32)
Hpsi = apply(H, psi; alg="fit", init=Hpsi_guess, nsites=1, nsweeps=4)
Hpsi = contract(H, psi; alg="fit", init=Hpsi_guess, nsites=1, nsweeps=4)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
end

Expand Down Expand Up @@ -93,17 +93,17 @@ end
H = replaceinds(H, prime(s; links=[]) => t)

# Test with nsweeps=2
Hpsi = apply(H, psi; alg="fit", init=psit, nsweeps=2)
Hpsi = contract(H, psi; alg="fit", init=psit, nsweeps=2)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with less good initial guess MPS not equal to psi
Hpsi_guess = truncate(psit; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=Hpsi_guess)
Hpsi = contract(H, psi; alg="fit", nsweeps=4, init=Hpsi_guess)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with nsite=1
Hpsi_guess = random_ttn(t; link_space=4)
Hpsi = apply(H, psi; alg="fit", nsites=1, nsweeps=4, init=Hpsi_guess)
Hpsi = contract(H, psi; alg="fit", nsites=1, nsweeps=4, init=Hpsi_guess)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
end

Expand Down

0 comments on commit 80ccddc

Please sign in to comment.