Skip to content

Commit

Permalink
Initial refactor of alternating_update (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss authored Jan 21, 2024
1 parent 9a75be2 commit 5ec5228
Show file tree
Hide file tree
Showing 21 changed files with 621 additions and 691 deletions.
6 changes: 5 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ include("tensornetworkoperators.jl")
include(joinpath("ITensorsExt", "itensorutils.jl"))
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand All @@ -114,7 +119,6 @@ include(joinpath("treetensornetworks", "projttns", "projttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
Expand Down
19 changes: 19 additions & 0 deletions src/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function contract_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
v = ITensor(true)
projected_operator = projected_operator![]
for j in sites(projected_operator)
v *= projected_operator.psi0[j]
end
vp = contract(projected_operator, v)
return vp, (;)
end
22 changes: 22 additions & 0 deletions src/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function dmrg_x_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
# this updater does not seem to accept any kwargs?
default_updater_kwargs = (;)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
H = contract(projected_operator![], ITensor(true))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
# TODO: improve this to return the energy estimate too
return U_max, (;)
end
33 changes: 33 additions & 0 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
function eigsolve_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
default_updater_kwargs = (;
which_eigval=:SR,
ishermitian=true,
tol=1e-14,
krylovdim=3,
maxiter=1,
verbosity=0,
eager=false,
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
howmany = 1
(; which_eigval) = updater_kwargs
updater_kwargs = Base.structdiff(updater_kwargs, (; which_eigval=nothing))
vals, vecs, info = eigsolve(
projected_operator![], init, howmany, which_eigval; updater_kwargs...
)
return vecs[1], (; info, eigvals=vals)
end

function _pop_which_eigenvalue(; which_eigenvalue, kwargs...)
return which_eigenvalue, NamedTuple(kwargs)
end
27 changes: 27 additions & 0 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function exponentiate_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
default_updater_kwargs = (;
krylovdim=30,
maxiter=100,
verbosity=0,
tol=1E-12,
ishermitian=true,
issymmetric=true,
eager=true,
)

updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
result, exp_info = exponentiate(
projected_operator![], region_kwargs.time_step, init; updater_kwargs...
)
return result, (; info=exp_info)
end
22 changes: 22 additions & 0 deletions src/solvers/linsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function linsolve_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
default_updater_kwargs = (;
ishermitian=false, tol=1E-14, krylovdim=30, maxiter=100, verbosity=0, a₀, a₁
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
P = projected_operator![]
(; a₀, a₁) = updater_kwargs
updater_kwargs = Base.structdiff(updater_kwargs, (; a₀=nothing, a₁=nothing))
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, init, a₀, a₁; updater_kwargs...)
return x, (;)
end
87 changes: 47 additions & 40 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,83 +26,88 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, sweep, sw_time)
function sweep_printer(; outputlevel, state, which_sweep, sw_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print("After sweep ", which_sweep, ":")
print(" maxlinkdim=", maxlinkdim(state))
print(" cpu_time=", round(sw_time; digits=3))
println()
flush(stdout)
end
end

function alternating_update(
solver,
PH,
psi0::AbstractTTN;
updater,
projected_operator,
init_state::AbstractTTN;
checkdone=(; kws...) -> false,
outputlevel::Integer=0,
nsweeps::Integer=1,
(sweep_observer!)=observer(),
sweep_printer=sweep_printer,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
updater_kwargs,
kwargs...,
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

psi = copy(psi0)
state = copy(init_state)

insert_function!(sweep_observer!, "sweep_printer" => sweep_printer)
insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) # FIX THIS

for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
for which_sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) &&
maxdim[which_sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk",
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[which_sweep] = $(maxdim[which_sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
projected_operator = disk(projected_operator)
end

sweep_params = (;
maxdim=maxdim[which_sweep],
mindim=mindim[which_sweep],
cutoff=cutoff[which_sweep],
noise=noise[which_sweep],
)
sw_time = @elapsed begin
psi, PH = update_step(
solver,
PH,
psi;
state, projected_operator = sweep_update(
updater,
projected_operator,
state;
outputlevel,
sweep,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
which_sweep,
sweep_params,
updater_kwargs,
kwargs...,
)
end

update!(sweep_observer!; psi, sweep, sw_time, outputlevel)
update!(sweep_observer!; state, which_sweep, sw_time, outputlevel)

checkdone(; psi, sweep, outputlevel, kwargs...) && break
checkdone(; state, which_sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer
return psi
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer"))
return state
end

function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
function alternating_update(updater, H::AbstractTTN, init_state::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, init_state)
check_hascommoninds(siteinds, H, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return alternating_update(solver, PH, psi0; kwargs...)
projected_operator = ProjTTN(H)
return alternating_update(updater, projected_operator, init_state; kwargs...)
end

"""
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number, sweeps::Sweeps; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
to compute `exp(t*H)*init_state` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.
Expand All @@ -114,14 +119,16 @@ the set of MPOs [H1,H2,H3,..] is efficiently looped over at
each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
* `state::MPS` - time-evolved MPS
"""
function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
function alternating_update(
updater, Hs::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
check_hascommoninds(siteinds, H, init_state)
check_hascommoninds(siteinds, H, init_state')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return alternating_update(solver, PHs, psi0; kwargs...)
projected_operators = ProjTTNSum(Hs)
return alternating_update(updater, projected_operators, init_state; kwargs...)
end
Loading

0 comments on commit 5ec5228

Please sign in to comment.