From 5ec52289cc224af054ee9eb72c1467333bbf342f Mon Sep 17 00:00:00 2001 From: b-kloss Date: Sun, 21 Jan 2024 09:12:30 -0500 Subject: [PATCH] Initial refactor of `alternating_update` (#121) --- src/ITensorNetworks.jl | 6 +- src/solvers/contract.jl | 19 ++ src/solvers/dmrg_x.jl | 22 ++ src/solvers/eigsolve.jl | 33 ++ src/solvers/exponentiate.jl | 27 ++ src/solvers/linsolve.jl | 22 ++ .../solvers/alternating_update.jl | 87 +++--- src/treetensornetworks/solvers/applyexp.jl | 128 -------- src/treetensornetworks/solvers/contract.jl | 16 +- src/treetensornetworks/solvers/dmrg.jl | 70 ++--- src/treetensornetworks/solvers/dmrg_x.jl | 30 +- src/treetensornetworks/solvers/linsolve.jl | 40 +-- src/treetensornetworks/solvers/tdvp.jl | 123 +++----- .../solvers/tree_sweeping.jl | 8 +- src/treetensornetworks/solvers/update_step.jl | 216 +++++++------ .../test_solvers/test_contract.jl | 4 +- .../test_solvers/test_dmrg.jl | 36 ++- .../test_solvers/test_dmrg_x.jl | 21 +- .../test_solvers/test_linsolve.jl | 4 +- .../test_solvers/test_tdvp.jl | 294 ++++++------------ .../test_solvers/test_tdvp_time_dependent.jl | 106 ++++++- 21 files changed, 621 insertions(+), 691 deletions(-) create mode 100644 src/solvers/contract.jl create mode 100644 src/solvers/dmrg_x.jl create mode 100644 src/solvers/eigsolve.jl create mode 100644 src/solvers/exponentiate.jl create mode 100644 src/solvers/linsolve.jl delete mode 100644 src/treetensornetworks/solvers/applyexp.jl diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 5b660bc9..9d5f14d2 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -106,6 +106,11 @@ include("tensornetworkoperators.jl") include(joinpath("ITensorsExt", "itensorutils.jl")) include(joinpath("Graphs", "abstractgraph.jl")) include(joinpath("Graphs", "abstractdatagraph.jl")) +include(joinpath("solvers", "eigsolve.jl")) +include(joinpath("solvers", "exponentiate.jl")) +include(joinpath("solvers", "dmrg_x.jl")) +include(joinpath("solvers", "contract.jl")) +include(joinpath("solvers", "linsolve.jl")) include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl")) include(joinpath("treetensornetworks", "ttn.jl")) include(joinpath("treetensornetworks", "opsum_to_ttn.jl")) @@ -114,7 +119,6 @@ include(joinpath("treetensornetworks", "projttns", "projttn.jl")) include(joinpath("treetensornetworks", "projttns", "projttnsum.jl")) include(joinpath("treetensornetworks", "projttns", "projttn_apply.jl")) include(joinpath("treetensornetworks", "solvers", "solver_utils.jl")) -include(joinpath("treetensornetworks", "solvers", "applyexp.jl")) include(joinpath("treetensornetworks", "solvers", "update_step.jl")) include(joinpath("treetensornetworks", "solvers", "alternating_update.jl")) include(joinpath("treetensornetworks", "solvers", "tdvp.jl")) diff --git a/src/solvers/contract.jl b/src/solvers/contract.jl new file mode 100644 index 00000000..588d35ae --- /dev/null +++ b/src/solvers/contract.jl @@ -0,0 +1,19 @@ +function contract_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + v = ITensor(true) + projected_operator = projected_operator![] + for j in sites(projected_operator) + v *= projected_operator.psi0[j] + end + vp = contract(projected_operator, v) + return vp, (;) +end diff --git a/src/solvers/dmrg_x.jl b/src/solvers/dmrg_x.jl new file mode 100644 index 00000000..f1054726 --- /dev/null +++ b/src/solvers/dmrg_x.jl @@ -0,0 +1,22 @@ +function dmrg_x_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + # this updater does not seem to accept any kwargs? + default_updater_kwargs = (;) + updater_kwargs = merge(default_updater_kwargs, updater_kwargs) + H = contract(projected_operator![], ITensor(true)) + D, U = eigen(H; ishermitian=true) + u = uniqueind(U, H) + max_overlap, max_ind = findmax(abs, array(dag(init) * U)) + U_max = U * dag(onehot(u => max_ind)) + # TODO: improve this to return the energy estimate too + return U_max, (;) +end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl new file mode 100644 index 00000000..85e99b3f --- /dev/null +++ b/src/solvers/eigsolve.jl @@ -0,0 +1,33 @@ +function eigsolve_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + default_updater_kwargs = (; + which_eigval=:SR, + ishermitian=true, + tol=1e-14, + krylovdim=3, + maxiter=1, + verbosity=0, + eager=false, + ) + updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence + howmany = 1 + (; which_eigval) = updater_kwargs + updater_kwargs = Base.structdiff(updater_kwargs, (; which_eigval=nothing)) + vals, vecs, info = eigsolve( + projected_operator![], init, howmany, which_eigval; updater_kwargs... + ) + return vecs[1], (; info, eigvals=vals) +end + +function _pop_which_eigenvalue(; which_eigenvalue, kwargs...) + return which_eigenvalue, NamedTuple(kwargs) +end diff --git a/src/solvers/exponentiate.jl b/src/solvers/exponentiate.jl new file mode 100644 index 00000000..a4dacebe --- /dev/null +++ b/src/solvers/exponentiate.jl @@ -0,0 +1,27 @@ +function exponentiate_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + default_updater_kwargs = (; + krylovdim=30, + maxiter=100, + verbosity=0, + tol=1E-12, + ishermitian=true, + issymmetric=true, + eager=true, + ) + + updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence + result, exp_info = exponentiate( + projected_operator![], region_kwargs.time_step, init; updater_kwargs... + ) + return result, (; info=exp_info) +end diff --git a/src/solvers/linsolve.jl b/src/solvers/linsolve.jl new file mode 100644 index 00000000..1a595950 --- /dev/null +++ b/src/solvers/linsolve.jl @@ -0,0 +1,22 @@ +function linsolve_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + default_updater_kwargs = (; + ishermitian=false, tol=1E-14, krylovdim=30, maxiter=100, verbosity=0, a₀, a₁ + ) + updater_kwargs = merge(default_updater_kwargs, updater_kwargs) + P = projected_operator![] + (; a₀, a₁) = updater_kwargs + updater_kwargs = Base.structdiff(updater_kwargs, (; a₀=nothing, a₁=nothing)) + b = dag(only(proj_mps(P))) + x, info = KrylovKit.linsolve(P, b, init, a₀, a₁; updater_kwargs...) + return x, (;) +end diff --git a/src/treetensornetworks/solvers/alternating_update.jl b/src/treetensornetworks/solvers/alternating_update.jl index 17d1f62d..9c1ea8b6 100644 --- a/src/treetensornetworks/solvers/alternating_update.jl +++ b/src/treetensornetworks/solvers/alternating_update.jl @@ -26,10 +26,10 @@ function process_sweeps( return maxdim, mindim, cutoff, noise, kwargs end -function sweep_printer(; outputlevel, psi, sweep, sw_time) +function sweep_printer(; outputlevel, state, which_sweep, sw_time) if outputlevel >= 1 - print("After sweep ", sweep, ":") - print(" maxlinkdim=", maxlinkdim(psi)) + print("After sweep ", which_sweep, ":") + print(" maxlinkdim=", maxlinkdim(state)) print(" cpu_time=", round(sw_time; digits=3)) println() flush(stdout) @@ -37,72 +37,77 @@ function sweep_printer(; outputlevel, psi, sweep, sw_time) end function alternating_update( - solver, - PH, - psi0::AbstractTTN; + updater, + projected_operator, + init_state::AbstractTTN; checkdone=(; kws...) -> false, outputlevel::Integer=0, nsweeps::Integer=1, (sweep_observer!)=observer(), sweep_printer=sweep_printer, write_when_maxdim_exceeds::Union{Int,Nothing}=nothing, + updater_kwargs, kwargs..., ) maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...) - psi = copy(psi0) + state = copy(init_state) - insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) + insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) # FIX THIS - for sweep in 1:nsweeps - if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds + for which_sweep in 1:nsweeps + if !isnothing(write_when_maxdim_exceeds) && + maxdim[which_sweep] > write_when_maxdim_exceeds if outputlevel >= 2 println( - "write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk", + "write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[which_sweep] = $(maxdim[which_sweep]), writing environment tensors to disk", ) end - PH = disk(PH) + projected_operator = disk(projected_operator) end - + sweep_params = (; + maxdim=maxdim[which_sweep], + mindim=mindim[which_sweep], + cutoff=cutoff[which_sweep], + noise=noise[which_sweep], + ) sw_time = @elapsed begin - psi, PH = update_step( - solver, - PH, - psi; + state, projected_operator = sweep_update( + updater, + projected_operator, + state; outputlevel, - sweep, - maxdim=maxdim[sweep], - mindim=mindim[sweep], - cutoff=cutoff[sweep], - noise=noise[sweep], + which_sweep, + sweep_params, + updater_kwargs, kwargs..., ) end - update!(sweep_observer!; psi, sweep, sw_time, outputlevel) + update!(sweep_observer!; state, which_sweep, sw_time, outputlevel) - checkdone(; psi, sweep, outputlevel, kwargs...) && break + checkdone(; state, which_sweep, outputlevel, kwargs...) && break end - select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer - return psi + select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) + return state end -function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...) - check_hascommoninds(siteinds, H, psi0) - check_hascommoninds(siteinds, H, psi0') +function alternating_update(updater, H::AbstractTTN, init_state::AbstractTTN; kwargs...) + check_hascommoninds(siteinds, H, init_state) + check_hascommoninds(siteinds, H, init_state') # Permute the indices to have a better memory layout # and minimize permutations H = ITensors.permute(H, (linkind, siteinds, linkind)) - PH = ProjTTN(H) - return alternating_update(solver, PH, psi0; kwargs...) + projected_operator = ProjTTN(H) + return alternating_update(updater, projected_operator, init_state; kwargs...) end """ - tdvp(Hs::Vector{MPO},psi0::MPS,t::Number; kwargs...) - tdvp(Hs::Vector{MPO},psi0::MPS,t::Number, sweeps::Sweeps; kwargs...) + tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...) + tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...) Use the time dependent variational principle (TDVP) algorithm -to compute `exp(t*H)*psi0` using an efficient algorithm based +to compute `exp(t*H)*init_state` using an efficient algorithm based on alternating optimization of the MPS tensors and local Krylov exponentiation of H. @@ -114,14 +119,16 @@ the set of MPOs [H1,H2,H3,..] is efficiently looped over at each step of the algorithm when optimizing the MPS. Returns: -* `psi::MPS` - time-evolved MPS +* `state::MPS` - time-evolved MPS """ -function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...) +function alternating_update( + updater, Hs::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs... +) for H in Hs - check_hascommoninds(siteinds, H, psi0) - check_hascommoninds(siteinds, H, psi0') + check_hascommoninds(siteinds, H, init_state) + check_hascommoninds(siteinds, H, init_state') end Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind))) - PHs = ProjTTNSum(Hs) - return alternating_update(solver, PHs, psi0; kwargs...) + projected_operators = ProjTTNSum(Hs) + return alternating_update(updater, projected_operators, init_state; kwargs...) end diff --git a/src/treetensornetworks/solvers/applyexp.jl b/src/treetensornetworks/solvers/applyexp.jl deleted file mode 100644 index 6e84036e..00000000 --- a/src/treetensornetworks/solvers/applyexp.jl +++ /dev/null @@ -1,128 +0,0 @@ -# -# To Do: -# - implement assembleLanczosVectors -# - check slice ranges - change end value by 1? -# - -function assemble_lanczos_vecs(lanczos_vectors, linear_comb, norm) - #if length(lanczos_vectors) != length(linear_comb) - # @show length(lanczos_vectors) - # @show length(linear_comb) - #end - xt = norm * linear_comb[1] * lanczos_vectors[1] - for i in 2:length(lanczos_vectors) - xt += norm * linear_comb[i] * lanczos_vectors[i] - end - return xt -end - -struct ApplyExpInfo - numops::Int - converged::Int -end - -function applyexp(H, tau::Number, x0; maxiter=30, tol=1e-12, outputlevel=0, normcutoff=1e-7) - # Initialize Lanczos vectors - v1 = copy(x0) - nrm = norm(v1) - v1 /= nrm - lanczos_vectors = [v1] - - ElT = promote_type(typeof(tau), eltype(x0)) - - bigTmat = zeros(ElT, maxiter + 3, maxiter + 3) - - nmatvec = 0 - - v0 = nothing - beta = 0.0 - for iter in 1:maxiter - tmat_size = iter + 1 - - # Matrix-vector multiplication - w = H(v1) - nmatvec += 1 - - avnorm = norm(w) - alpha = dot(w, v1) - - bigTmat[iter, iter] = alpha - - w -= alpha * v1 - if iter > 1 - w -= beta * v0 - end - v0 = copy(v1) - - beta = norm(w) - - # check for Lanczos sequence exhaustion - if abs(beta) < beta_tol - # Assemble the time evolved state - tmat = bigTmat[1:tmat_size, 1:tmat_size] - tmat_exp = exp(tau * tmat) - linear_comb = tmat_exp[:, 1] - xt = assemble_lanczos_vecs(lanczos_vectors, linear_comb, nrm) - return xt, ApplyExpInfo(nmatvec, 1) - end - - # update next lanczos vector - v1 = copy(w) - v1 /= beta - push!(lanczos_vectors, v1) - bigTmat[iter + 1, iter] = beta - bigTmat[iter, iter + 1] = beta - - # Convergence check - if iter > 0 - # Prepare extended T-matrix for exponentiation - tmat_ext_size = tmat_size + 2 - tmat_ext = bigTmat[1:tmat_ext_size, 1:tmat_ext_size] - - tmat_ext[tmat_size - 1, tmat_size] = 0.0 - tmat_ext[tmat_size + 1, tmat_size] = 1.0 - - # Exponentiate extended T-matrix - tmat_ext_exp = exp(tau * tmat_ext) - - ϕ1 = abs(nrm * tmat_ext_exp[tmat_size, 1]) - ϕ2 = abs(nrm * tmat_ext_exp[tmat_size + 1, 1] * avnorm) - - if ϕ1 > 10 * ϕ2 - error = ϕ2 - elseif (ϕ1 > ϕ2) - error = (ϕ1 * ϕ2) / (ϕ1 - ϕ2) - else - error = ϕ1 - end - - if outputlevel >= 3 - @printf(" Iteration: %d, Error: %.2E\n", iter, error) - end - - if ((error < tol) || (iter == maxiter)) - converged = 1 - if (iter == maxiter) - println("warning: applyexp not converged in $maxiter steps") - converged = 0 - end - - # Assemble the time evolved state - linear_comb = tmat_ext_exp[:, 1] - xt = assemble_lanczos_vecs(lanczos_vectors, linear_comb, nrm) - - if outputlevel >= 3 - println(" Number of iterations: $iter") - end - - return xt, ApplyExpInfo(nmatvec, converged) - end - end # end convergence test - end # iter - - if outputlevel >= 0 - println("In applyexp, number of matrix-vector multiplies: ", nmatvec) - end - - return x0 -end diff --git a/src/treetensornetworks/solvers/contract.jl b/src/treetensornetworks/solvers/contract.jl index 51f6787f..398138e7 100644 --- a/src/treetensornetworks/solvers/contract.jl +++ b/src/treetensornetworks/solvers/contract.jl @@ -1,18 +1,11 @@ -function contract_solver(PH, psi; normalize, region, half_sweep) - v = ITensor(1.0) - for j in sites(PH) - v *= PH.psi0[j] - end - Hpsi0 = contract(PH, v) - return Hpsi0, NamedTuple() -end - function contract( ::Algorithm"fit", tn1::AbstractTTN, tn2::AbstractTTN; init=random_ttn(flatten_external_indsnetwork(tn1, tn2); link_space=trivial_space(tn1)), nsweeps=1, + nsites=2, # used to be default of call to default_sweep_regions + updater_kwargs=(;), kwargs..., ) n = nv(tn1) @@ -42,7 +35,10 @@ function contract( ## end PH = ProjTTNApply(tn2, tn1) - psi = alternating_update(contract_solver, PH, init; nsweeps, kwargs...) + sweep_plan = default_sweep_regions(nsites, init; kwargs...) + psi = alternating_update( + contract_updater, PH, init; nsweeps, sweep_plan, updater_kwargs, kwargs... + ) return psi end diff --git a/src/treetensornetworks/solvers/dmrg.jl b/src/treetensornetworks/solvers/dmrg.jl index c67b3b86..653c00c8 100644 --- a/src/treetensornetworks/solvers/dmrg.jl +++ b/src/treetensornetworks/solvers/dmrg.jl @@ -1,58 +1,36 @@ -function eigsolve_solver(; - solver_which_eigenvalue=:SR, - ishermitian=true, - solver_tol=1e-14, - solver_krylovdim=3, - solver_maxiter=1, - solver_verbosity=0, -) - function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing) - howmany = 1 - which = solver_which_eigenvalue - vals, vecs, info = eigsolve( - H, - init, - howmany, - which; - ishermitian, - tol=solver_tol, - krylovdim=solver_krylovdim, - maxiter=solver_maxiter, - verbosity=solver_verbosity, - ) - psi = vecs[1] - return psi, (; solver_info=info, energies=vals) - end - return solver -end - """ Overload of `ITensors.dmrg`. """ + +function dmrg_sweep_plan( + nsites::Int, graph::AbstractGraph; root_vertex=default_root_vertex(graph) +) + order = 2 + time_step = Inf + return tdvp_sweep_plan(order, nsites, time_step, graph; root_vertex, reverse_step=false) +end + function dmrg( + updater, H, init::AbstractTTN; - solver_which_eigenvalue=:SR, - ishermitian=true, - solver_tol=1e-14, - solver_krylovdim=3, - solver_maxiter=1, - solver_verbosity=0, + nsweeps, #it makes sense to require this to be defined + nsites=2, + (sweep_observer!)=observer(), + root_vertex=default_root_vertex(init), + updater_kwargs=(;), kwargs..., ) - return alternating_update( - eigsolve_solver(; - solver_which_eigenvalue, - ishermitian, - solver_tol, - solver_krylovdim, - solver_maxiter, - solver_verbosity, - ), - H, - init; - kwargs..., + sweep_plan = dmrg_sweep_plan(nsites, init; root_vertex) + + psi = alternating_update( + updater, H, init; nsweeps, sweep_observer!, sweep_plan, updater_kwargs, kwargs... ) + return psi +end + +function dmrg(H, init::AbstractTTN; updater=eigsolve_updater, kwargs...) + return dmrg(updater, H, init; kwargs...) end """ diff --git a/src/treetensornetworks/solvers/dmrg_x.jl b/src/treetensornetworks/solvers/dmrg_x.jl index 30a97f3a..4e89620e 100644 --- a/src/treetensornetworks/solvers/dmrg_x.jl +++ b/src/treetensornetworks/solvers/dmrg_x.jl @@ -1,16 +1,22 @@ -function dmrg_x_solver( - PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing +function dmrg_x( + updater, + operator, + init::AbstractTTN; + nsweeps, #it makes sense to require this to be defined + nsites=2, + (sweep_observer!)=observer(), + root_vertex=default_root_vertex(init), + updater_kwargs=(;), + kwargs..., ) - H = contract(PH, ITensor(1.0)) - D, U = eigen(H; ishermitian=true) - u = uniqueind(U, H) - max_overlap, max_ind = findmax(abs, array(dag(init) * U)) - U_max = U * dag(onehot(u => max_ind)) - # TODO: improve this to return the energy estimate too - return U_max, NamedTuple() -end + sweep_plan = dmrg_sweep_plan(nsites, init; root_vertex) -function dmrg_x(PH, init::AbstractTTN; kwargs...) - psi = alternating_update(dmrg_x_solver, PH, init; kwargs...) + psi = alternating_update( + updater, operator, init; nsweeps, sweep_observer!, sweep_plan, updater_kwargs, kwargs... + ) return psi end + +function dmrg_x(operator, init::AbstractTTN; updater=dmrg_x_updater, kwargs...) + return dmrg_x(updater, operator, init; kwargs...) +end diff --git a/src/treetensornetworks/solvers/linsolve.jl b/src/treetensornetworks/solvers/linsolve.jl index 90ba8572..6f936020 100644 --- a/src/treetensornetworks/solvers/linsolve.jl +++ b/src/treetensornetworks/solvers/linsolve.jl @@ -28,39 +28,21 @@ function linsolve( x₀::AbstractTTN, a₀::Number=0, a₁::Number=1; - normalize, - region, - half_sweep, + updater=linsolve_updater, + nsweeps, #it makes sense to require this to be defined + nsites=2, + (sweep_observer!)=observer(), + root_vertex=default_root_vertex(init), + updater_kwargs=(;), + kwargs..., ) - function linsolve_solver( - P, - x₀; - ishermitian=false, - solver_tol=1E-14, - solver_krylovdim=30, - solver_maxiter=100, - solver_verbosity=0, - ) - b = dag(only(proj_mps(P))) - x, info = KrylovKit.linsolve( - P, - b, - x₀, - a₀, - a₁; - ishermitian, - tol=solver_tol, - krylovdim=solver_krylovdim, - maxiter=solver_maxiter, - verbosity=solver_verbosity, - ) - return x, NamedTuple() - end - + updater_kwargs = (; a₀, a₁, updater_kwargs...) error("`linsolve` for TTN not yet implemented.") + sweep_plan = default_sweep_regions(nsites, x0) # TODO: Define `itensornetwork_cache` # TODO: Define `linsolve_cache` + P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b)) - return alternating_update(linsolve_solver, P, x₀; kwargs...) + return alternating_update(linsolve_updater, P, x₀; sweep_plan, updater_kwargs, kwargs...) end diff --git a/src/treetensornetworks/solvers/tdvp.jl b/src/treetensornetworks/solvers/tdvp.jl index 01fe2985..f6081f46 100644 --- a/src/treetensornetworks/solvers/tdvp.jl +++ b/src/treetensornetworks/solvers/tdvp.jl @@ -1,57 +1,3 @@ -function exponentiate_solver() - function solver( - H, - init; - ishermitian=true, - issymmetric=true, - region, - solver_krylovdim=30, - solver_maxiter=100, - solver_outputlevel=0, - solver_tol=1E-12, - substep, - normalize, - time_step, - ) - psi, exp_info = KrylovKit.exponentiate( - H, - time_step, - init; - ishermitian, - issymmetric, - tol=solver_tol, - krylovdim=solver_krylovdim, - maxiter=solver_maxiter, - verbosity=solver_outputlevel, - eager=true, - ) - return psi, (; info=exp_info) - end - return solver -end - -function applyexp_solver() - function solver( - H, - init; - tdvp_order, - solver_krylovdim=30, - solver_outputlevel=0, - solver_tol=1E-8, - substep, - time_step, - normalize, - ) - #applyexp tol is absolute, compute from tol_per_unit_time: - tol = abs(time_step) * tol_per_unit_time - psi, exp_info = applyexp( - H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel - ) - return psi, (; info=exp_info) - end - return solver -end - function _compute_nsweeps(nsteps, t, time_step, order) nsweeps_per_step = order / 2 nsweeps = 1 @@ -84,15 +30,15 @@ function sub_time_steps(order) end end -function tdvp_sweep( +function tdvp_sweep_plan( order::Int, - nsite::Int, + nsites::Int, time_step::Number, graph::AbstractGraph; root_vertex=default_root_vertex(graph), reverse_step=true, ) - sweep = [] + sweep_plan = [] for (substep, fac) in enumerate(sub_time_steps(order)) sub_time_step = time_step * fac half = half_sweep( @@ -100,39 +46,42 @@ function tdvp_sweep( graph, make_region; root_vertex, - nsite, + nsites, region_args=(; substep, time_step=sub_time_step), reverse_args=(; substep, time_step=-sub_time_step), reverse_step, ) - append!(sweep, half) + append!(sweep_plan, half) end - return sweep + return sweep_plan end function tdvp( - solver, - H, + updater, + operator, t::Number, - init::AbstractTTN; + init_state::AbstractTTN; time_step::Number=t, - nsite=2, + nsites=2, nsteps=nothing, order::Integer=2, (sweep_observer!)=observer(), - root_vertex=default_root_vertex(init), + root_vertex=default_root_vertex(init_state), reverse_step=true, + updater_kwargs=(;), kwargs..., ) nsweeps = _compute_nsweeps(nsteps, t, time_step, order) - sweep_regions = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step) + sweep_plan = tdvp_sweep_plan( + order, nsites, time_step, init_state; root_vertex, reverse_step + ) - function sweep_time_printer(; outputlevel, sweep, kwargs...) + function sweep_time_printer(; outputlevel, which_sweep, kwargs...) if outputlevel >= 1 sweeps_per_step = order ÷ 2 if sweep % sweeps_per_step == 0 - current_time = (sweep / sweeps_per_step) * time_step - println("Current time (sweep $sweep) = ", round(current_time; digits=3)) + current_time = (which_sweep / sweeps_per_step) * time_step + println("Current time (sweep $which_sweep) = ", round(current_time; digits=3)) end end return nothing @@ -140,26 +89,33 @@ function tdvp( insert_function!(sweep_observer!, "sweep_time_printer" => sweep_time_printer) - psi = alternating_update( - solver, H, init; nsweeps, sweep_observer!, sweep_regions, nsite, kwargs... + state = alternating_update( + updater, + operator, + init_state; + nsweeps, + sweep_observer!, + sweep_plan, + updater_kwargs, + kwargs..., ) # remove sweep_time_printer from sweep_observer! select!(sweep_observer!, Observers.DataFrames.Not("sweep_time_printer")) - return psi + return state end """ - tdvp(H::TTN, t::Number, psi0::TTN; kwargs...) + tdvp(operator::TTN, t::Number, init_state::TTN; kwargs...) Use the time dependent variational principle (TDVP) algorithm -to approximately compute `exp(H*t)*psi0` using an efficient algorithm based +to approximately compute `exp(operator*t)*init_state` using an efficient algorithm based on alternating optimization of the state tensors and local Krylov -exponentiation of H. The time parameter `t` can be a real or complex number. +exponentiation of operator. The time parameter `t` can be a real or complex number. Returns: -* `psi` - time-evolved state +* `state` - time-evolved state Optional keyword arguments: * `time_step::Number = t` - time step to use when evolving the state. Smaller time steps generally give more accurate results but can make the algorithm take more computational time to run. @@ -168,15 +124,8 @@ Optional keyword arguments: * `observer` - object implementing the Observer interface which can perform measurements and stop early * `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations """ -function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kwargs...) - if solver_backend == "exponentiate" - solver = exponentiate_solver - elseif solver_backend == "applyexp" - solver = applyexp_solver - else - error( - "solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")", - ) - end - return tdvp(solver(), H, t, init; kwargs...) +function tdvp( + operator, t::Number, init_state::AbstractTTN; updater=exponentiate_updater, kwargs... +) + return tdvp(updater, operator, t, init_state; kwargs...) end diff --git a/src/treetensornetworks/solvers/tree_sweeping.jl b/src/treetensornetworks/solvers/tree_sweeping.jl index 66171e68..99375ba9 100644 --- a/src/treetensornetworks/solvers/tree_sweeping.jl +++ b/src/treetensornetworks/solvers/tree_sweeping.jl @@ -3,12 +3,12 @@ direction(step_number) = isodd(step_number) ? Base.Forward : Base.Reverse function make_region( edge; last_edge=false, - nsite=1, + nsites=1, region_args=(;), reverse_args=region_args, reverse_step=false, ) - if nsite == 1 + if nsites == 1 site = ([src(edge)], region_args) bond = (edge, reverse_args) region = reverse_step ? (site, bond) : (site,) @@ -17,7 +17,7 @@ function make_region( else return region end - elseif nsite == 2 + elseif nsites == 2 sites_two = ([src(edge), dst(edge)], region_args) sites_one = ([dst(edge)], reverse_args) region = reverse_step ? (sites_two, sites_one) : (sites_two,) @@ -27,7 +27,7 @@ function make_region( return region end else - error("nsite=$nsite not supported in alternating_update / update_step") + error("nsites=$nsites not supported in alternating_update / update_step") end end diff --git a/src/treetensornetworks/solvers/update_step.jl b/src/treetensornetworks/solvers/update_step.jl index 0f69ed52..890a13fa 100644 --- a/src/treetensornetworks/solvers/update_step.jl +++ b/src/treetensornetworks/solvers/update_step.jl @@ -1,12 +1,12 @@ -function default_sweep_regions(nsite, graph::AbstractGraph; kwargs...) +function default_sweep_regions(nsites, graph::AbstractGraph; kwargs...) ###move this to a different file, algorithmic level idea return vcat( [ half_sweep( direction(half), graph, make_region; - nsite, + nsites, region_args=(; half_sweep=half), kwargs..., ) for half in 1:2 @@ -14,11 +14,21 @@ function default_sweep_regions(nsite, graph::AbstractGraph; kwargs...) ) end -function step_printer(; - cutoff, maxdim, mindim, outputlevel::Int=0, psi, region, spec, sweep_step +function region_update_printer(; + cutoff, + maxdim, + mindim, + outputlevel::Int=0, + state, + sweep_plan, + spec, + which_region_update, + which_sweep, + kwargs..., ) if outputlevel >= 2 - @printf("Sweep %d, region=%s \n", sweep, region) + region = first(sweep_plan[which_region_update]) + @printf("Sweep %d, region=%s \n", which_sweep, region) print(" Truncated using") @printf(" cutoff=%.1E", cutoff) @printf(" maxdim=%d", maxdim) @@ -28,61 +38,61 @@ function step_printer(; @printf( " Trunc. err=%.2E, bond dimension %d\n", spec.truncerr, - linkdim(psi, edgetype(psi)(region...)) + linkdim(state, edgetype(state)(region...)) ) end flush(stdout) end end -function update_step( +function sweep_update( solver, - PH, - psi::AbstractTTN; - normalize::Bool=false, - nsite::Int=2, - step_printer=step_printer, - (step_observer!)=observer(), - sweep::Int=1, - sweep_regions=default_sweep_regions(nsite, psi), - kwargs..., + projected_operator, + state::AbstractTTN; + normalize::Bool=false, # ToDo: think about where to put the default, probably this default is best defined at algorithmic level + outputlevel, + region_update_printer=region_update_printer, + (region_observer!)=observer(), # ToDo: change name to region_observer! ? + which_sweep::Int, + sweep_params::NamedTuple, + sweep_plan, + updater_kwargs, ) - PH = copy(PH) - psi = copy(psi) - - insert_function!(step_observer!, "step_printer" => step_printer) + insert_function!(region_observer!, "region_update_printer" => region_update_printer) #ToDo fix this # Append empty namedtuple to each element if not already present - # (Needed to handle user-provided sweep_regions) - sweep_regions = append_missing_namedtuple.(to_tuple.(sweep_regions)) + # (Needed to handle user-provided region_updates) + sweep_plan = append_missing_namedtuple.(to_tuple.(sweep_plan)) - if nv(psi) == 1 + if nv(state) == 1 error( "`alternating_update` currently does not support system sizes of 1. You can diagonalize the MPO tensor directly with tools like `LinearAlgebra.eigen`, `KrylovKit.exponentiate`, etc.", ) end - for (sweep_step, (region, step_kwargs)) in enumerate(sweep_regions) - psi, PH = local_update( + for which_region_update in eachindex(sweep_plan) + (region, region_kwargs) = sweep_plan[which_region_update] + region_kwargs = merge(region_kwargs, sweep_params) # sweep params has precedence over step_kwargs + state, projected_operator = region_update( solver, - PH, - psi, - region; + projected_operator, + state; normalize, - step_kwargs, - step_observer!, - sweep, - sweep_regions, - sweep_step, - kwargs..., + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + region_observer!, + updater_kwargs, ) end - select!(step_observer!, Observers.DataFrames.Not("step_printer")) # remove step_printer + select!(region_observer!, Observers.DataFrames.Not("region_update_printer")) # remove update_printer # Just to be sure: - normalize && normalize!(psi) + normalize && normalize!(state) - return psi, PH + return state, projected_operator end # @@ -96,20 +106,20 @@ end # apart and puts it back into the network. # -function extract_local_tensor(psi::AbstractTTN, pos::Vector) - return psi, prod(psi[v] for v in pos) +function extract_local_tensor(state::AbstractTTN, pos::Vector) + return state, prod(state[v] for v in pos) end -function extract_local_tensor(psi::AbstractTTN, e::NamedEdge) - left_inds = uniqueinds(psi, e) - U, S, V = svd(psi[src(e)], left_inds; lefttags=tags(psi, e), righttags=tags(psi, e)) - psi[src(e)] = U - return psi, S * V +function extract_local_tensor(state::AbstractTTN, e::NamedEdge) + left_inds = uniqueinds(state, e) + U, S, V = svd(state[src(e)], left_inds; lefttags=tags(state, e), righttags=tags(state, e)) + state[src(e)] = U + return state, S * V end # sort of multi-site replacebond!; TODO: use dense TTN constructor instead function insert_local_tensor( - psi::AbstractTTN, + state::AbstractTTN, phi::ITensor, pos::Vector; normalize=false, @@ -123,12 +133,12 @@ function insert_local_tensor( ) spec = nothing for (v, vnext) in IterTools.partition(pos, 2, 1) - e = edgetype(psi)(v, vnext) - indsTe = inds(psi[v]) + e = edgetype(state)(v, vnext) + indsTe = inds(state[v]) L, phi, spec = factorize( phi, indsTe; - tags=tags(psi, e), + tags=tags(state, e), maxdim, mindim, cutoff, @@ -136,53 +146,65 @@ function insert_local_tensor( eigen_perturbation, ortho, ) - psi[v] = L + state[v] = L eigen_perturbation = nothing # TODO: fix this end - psi[last(pos)] = phi - psi = set_ortho_center(psi, [last(pos)]) - @assert isortho(psi) && only(ortho_center(psi)) == last(pos) - normalize && (psi[last(pos)] ./= norm(psi[last(pos)])) + state[last(pos)] = phi + state = set_ortho_center(state, [last(pos)]) + @assert isortho(state) && only(ortho_center(state)) == last(pos) + normalize && (state[last(pos)] ./= norm(state[last(pos)])) # TODO: return maxtruncerr, will not be correct in cases where insertion executes multiple factorizations - return psi, spec + return state, spec end -function insert_local_tensor(psi::AbstractTTN, phi::ITensor, e::NamedEdge; kwargs...) - psi[dst(e)] *= phi - psi = set_ortho_center(psi, [dst(e)]) - return psi, nothing +function insert_local_tensor(state::AbstractTTN, phi::ITensor, e::NamedEdge; kwargs...) + state[dst(e)] *= phi + state = set_ortho_center(state, [dst(e)]) + return state, nothing end #TODO: clean this up: +# also can we entirely rely on directionality of edges by construction? current_ortho(::Type{<:Vector{<:V}}, st) where {V} = first(st) current_ortho(::Type{NamedEdge{V}}, st) where {V} = src(st) current_ortho(st) = current_ortho(typeof(st), st) -function local_update( - solver, - PH, - psi, - region; +function region_update( + updater, + projected_operator, + state; normalize, - noise, - cutoff::AbstractFloat=1E-16, - maxdim::Int=typemax(Int), - mindim::Int=1, - outputlevel::Int=0, - step_kwargs=NamedTuple(), - step_observer!, - sweep, - sweep_regions, - sweep_step, - solver_kwargs..., + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + region_observer!, + #insertion_kwargs, #ToDo: later + #extraction_kwargs, #ToDo: implement later with possibility to pass custom extraction/insertion func (or code into func) + updater_kwargs, ) - psi = orthogonalize(psi, current_ortho(region)) - psi, phi = extract_local_tensor(psi, region) - - nsites = (region isa AbstractEdge) ? 0 : length(region) - PH = set_nsite(PH, nsites) - PH = position(PH, psi, region) - phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...) + region = first(sweep_plan[which_region_update]) + state = orthogonalize(state, current_ortho(region)) + state, phi = extract_local_tensor(state, region;) + nsites = (region isa AbstractEdge) ? 0 : length(region) #ToDo move into separate funtion + projected_operator = set_nsite(projected_operator, nsites) + projected_operator = position(projected_operator, state, region) + state! = Ref(state) # create references, in case solver does (out-of-place) modify PH or state + projected_operator! = Ref(projected_operator) + phi, info = updater( + phi; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, + ) # args passed by reference are supposed to be modified out of place + state = state![] # dereference + projected_operator = projected_operator![] if !(phi isa ITensor && info isa NamedTuple) println("Solver returned the following types: $(typeof(phi)), $(typeof(info))") error("In alternating_update, solver must return an ITensor and a NamedTuple") @@ -190,30 +212,40 @@ function local_update( normalize && (phi /= norm(phi)) drho = nothing - ortho = "left" + ortho = "left" #i guess with respect to ordered vertices that's valid but may be cleaner to use next_region logic #if noise > 0.0 && isforward(direction) # drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees... + # so noiseterm is a solver #end - psi, spec = insert_local_tensor( - psi, phi, region; eigen_perturbation=drho, ortho, normalize, maxdim, mindim, cutoff + state, spec = insert_local_tensor( + state, + phi, + region; + eigen_perturbation=drho, + ortho, + normalize, + maxdim=region_kwargs.maxdim, + mindim=region_kwargs.mindim, + cutoff=region_kwargs.cutoff, ) update!( - step_observer!; + region_observer!; cutoff, maxdim, mindim, - sweep_step, - total_sweep_steps=length(sweep_regions), - end_of_sweep=(sweep_step == length(sweep_regions)), - psi, + which_region_update, + sweep_plan, + total_sweep_steps=length(sweep_plan), + end_of_sweep=(which_region_update == length(sweep_plan)), + state, region, - sweep, + which_sweep, spec, outputlevel, info..., - step_kwargs..., + region_kwargs..., ) - return psi, PH + return state, projected_operator end diff --git a/test/test_treetensornetworks/test_solvers/test_contract.jl b/test/test_treetensornetworks/test_solvers/test_contract.jl index 2643b5db..810db6fb 100644 --- a/test/test_treetensornetworks/test_solvers/test_contract.jl +++ b/test/test_treetensornetworks/test_solvers/test_contract.jl @@ -46,7 +46,7 @@ using Test # Test with nsite=1 Hpsi_guess = random_mps(t; internal_inds_space=32) - Hpsi = apply(H, psi; alg="fit", init=Hpsi_guess, nsite=1, nsweeps=4) + Hpsi = apply(H, psi; alg="fit", init=Hpsi_guess, nsites=1, nsweeps=4) @test inner(psit, Hpsi) ≈ inner(psit, H, psi) atol = 1E-4 end @@ -84,7 +84,7 @@ end # Test with nsite=1 Hpsi_guess = random_ttn(t; link_space=4) - Hpsi = apply(H, psi; alg="fit", nsite=1, nsweeps=4, init=Hpsi_guess) + Hpsi = apply(H, psi; alg="fit", nsites=1, nsweeps=4, init=Hpsi_guess) @test inner(psit, Hpsi) ≈ inner(psit, H, psi) atol = 1E-4 end diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index 939b4fee..870012a0 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -5,7 +5,7 @@ using Random using Test using Observers -@testset "MPS DMRG" for nsite in [1, 2] +@testset "MPS DMRG" for nsites in [1, 2] N = 10 cutoff = 1e-12 @@ -30,21 +30,25 @@ using Observers psi_mps = MPS([psi[v] for v in 1:nv(psi)]) e2, psi2 = dmrg(H_mpo, psi_mps; nsweeps, maxdim, outputlevel=0) - psi = dmrg(H, psi; nsweeps, maxdim, cutoff, nsite, solver_krylovdim=3, solver_maxiter=1) + psi = dmrg( + H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) + ) @test inner(psi', H, psi) ≈ inner(psi2', H_mpo, psi2) # Alias for `ITensorNetworks.dmrg` psi = eigsolve( - H, psi; nsweeps, maxdim, cutoff, nsite, solver_krylovdim=3, solver_maxiter=1 + H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) ) @test inner(psi', H, psi) ≈ inner(psi2', H_mpo, psi2) - # Test custom sweep regions + # Test custom sweep regions #BROKEN, ToDo: Make proper custom sweep regions for test + #= orig_E = inner(psi', H, psi) sweep_regions = [[1], [2], [3], [3], [2], [1]] psi = dmrg(H, psi; nsweeps, maxdim, cutoff, sweep_regions) new_E = inner(psi', H, psi) @test new_E ≈ orig_E + =# end @testset "Observers" begin @@ -67,20 +71,20 @@ end # # Make observers # - sweep(; sweep, kw...) = sweep + sweep(; which_sweep, kw...) = which_sweep sweep_observer! = observer(sweep) - region(; region, kw...) = region - energy(; energies, kw...) = energies[1] - step_observer! = observer(region, sweep, energy) + region(; which_region_update, sweep_plan, kw...) = first(sweep_plan[which_region_update]) + energy(; eigvals, kw...) = eigvals[1] + region_observer! = observer(region, sweep, energy) - psi = dmrg(H, psi; nsweeps, maxdim, cutoff, sweep_observer!, step_observer!) + psi = dmrg(H, psi; nsweeps, maxdim, cutoff, sweep_observer!, region_observer!) # # Test out certain values # - @test step_observer![9, :region] == [2, 1] - @test step_observer![30, :energy] < -4.25 + @test region_observer![9, :region] == [2, 1] + @test region_observer![30, :energy] < -4.25 end @testset "Regression test: Arrays of Parameters" begin @@ -108,7 +112,7 @@ end psi = dmrg(H, psi; nsweeps, maxdim, cutoff) end -@testset "Tree DMRG" for nsite in [1, 2] +@testset "Tree DMRG" for nsites in [1, 2] cutoff = 1e-12 tooth_lengths = fill(2, 3) @@ -126,7 +130,9 @@ end sweeps = Sweeps(nsweeps) # number of sweeps is 5 maxdim!(sweeps, 10, 20, 40, 100) # gradually increase states kept cutoff!(sweeps, cutoff) - psi = dmrg(H, psi; nsweeps, maxdim, cutoff, nsite, solver_krylovdim=3, solver_maxiter=1) + psi = dmrg( + H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) + ) # Compare to `ITensors.MPO` version of `dmrg` linear_order = [4, 1, 2, 5, 3, 6] @@ -141,7 +147,7 @@ end @testset "Regression test: tree truncation" begin maxdim = 4 - nsite = 2 + nsites = 2 nsweeps = 10 c = named_comb_tree((3, 2)) @@ -149,7 +155,7 @@ end os = ITensorNetworks.heisenberg(c) H = TTN(os, s) psi = random_ttn(s; link_space=5) - psi = dmrg(H, psi; nsweeps, maxdim, nsite) + psi = dmrg(H, psi; nsweeps, maxdim, nsites) @test all(edge_data(linkdims(psi)) .<= maxdim) end diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg_x.jl b/test/test_treetensornetworks/test_solvers/test_dmrg_x.jl index 65871b56..4bb16268 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg_x.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg_x.jl @@ -16,17 +16,15 @@ using Test ψ = mps(s; states=(v -> rand(["↑", "↓"]))) - dmrg_x_kwargs = ( - nsweeps=20, reverse_step=false, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0 - ) + dmrg_x_kwargs = (nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0) - ϕ = dmrg_x(H, ψ; nsite=2, dmrg_x_kwargs...) + ϕ = dmrg_x(H, ψ; nsites=2, dmrg_x_kwargs...) @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ', H, ϕ) / inner(ϕ, ϕ) rtol = 1e-1 @test inner(H, ψ, H, ψ) ≉ inner(ψ', H, ψ)^2 rtol = 1e-7 @test inner(H, ϕ, H, ϕ) ≈ inner(ϕ', H, ϕ)^2 rtol = 1e-7 - ϕ̃ = dmrg_x(H, ϕ; nsite=1, dmrg_x_kwargs...) + ϕ̃ = dmrg_x(H, ϕ; nsites=1, dmrg_x_kwargs...) @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ̃', H, ϕ̃) / inner(ϕ̃, ϕ̃) rtol = 1e-1 @test inner(H, ϕ̃, H, ϕ̃) ≈ inner(ϕ̃', H, ϕ̃)^2 rtol = 1e-3 @@ -34,10 +32,7 @@ using Test # @test abs(loginner(ϕ̃, ϕ) / n) ≈ 0.0 atol = 1e-6 end -@testset "Tree DMRG-X" for conserve_qns in ( - false, - true, # OpSum → TTN with QNs not working for non-path graphs -) +@testset "Tree DMRG-X" for conserve_qns in (false, true) tooth_lengths = fill(2, 3) root_vertex = (3, 2) c = named_comb_tree(tooth_lengths) @@ -55,17 +50,15 @@ end # `ttns(s; states=v -> rand(["↑", "↓"]))` ψ = normalize!(TTN(s, v -> rand(["↑", "↓"]))) - dmrg_x_kwargs = ( - nsweeps=20, reverse_step=false, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0 - ) + dmrg_x_kwargs = (nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0) - ϕ = dmrg_x(H, ψ; nsite=2, dmrg_x_kwargs...) + ϕ = dmrg_x(H, ψ; nsites=2, dmrg_x_kwargs...) @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ', H, ϕ) / inner(ϕ, ϕ) rtol = 1e-1 @test inner(H, ψ, H, ψ) ≉ inner(ψ', H, ψ)^2 rtol = 1e-2 @test inner(H, ϕ, H, ϕ) ≈ inner(ϕ', H, ϕ)^2 rtol = 1e-7 - ϕ̃ = dmrg_x(H, ϕ; nsite=1, dmrg_x_kwargs...) + ϕ̃ = dmrg_x(H, ϕ; nsites=1, dmrg_x_kwargs...) @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ̃', H, ϕ̃) / inner(ϕ̃, ϕ̃) rtol = 1e-1 @test inner(H, ϕ̃, H, ϕ̃) ≈ inner(ϕ̃', H, ϕ̃)^2 rtol = 1e-6 diff --git a/test/test_treetensornetworks/test_solvers/test_linsolve.jl b/test/test_treetensornetworks/test_solvers/test_linsolve.jl index 68e8e953..72291642 100644 --- a/test/test_treetensornetworks/test_solvers/test_linsolve.jl +++ b/test/test_treetensornetworks/test_solvers/test_linsolve.jl @@ -41,11 +41,11 @@ using Random x_c = random_mps(s; states, internal_inds_space=4) + 0.1im * random_mps(s; states, internal_inds_space=2) - b = apply(H, x_c; cutoff) + b = apply(H, x_c; alg="fit", nsweeps=3) #cutoff is unsupported kwarg for apply/contract x0 = random_mps(s; states, internal_inds_space=10) x = @test_broken linsolve( - H, b, x0; cutoff, maxdim, nsweeps, ishermitian=true, solver_tol=1E-6 + H, b, x0; cutoff, maxdim, nsweeps, updater_kwargs=(; tol=1E-6, ishermitian=true) ) # @test norm(x - x_c) < 1E-3 diff --git a/test/test_treetensornetworks/test_solvers/test_tdvp.jl b/test/test_treetensornetworks/test_solvers/test_tdvp.jl index ec002af1..d5d1ec49 100644 --- a/test/test_treetensornetworks/test_solvers/test_tdvp.jl +++ b/test/test_treetensornetworks/test_solvers/test_tdvp.jl @@ -1,5 +1,6 @@ using ITensors using ITensorNetworks +using ITensorNetworks: exponentiate_updater using KrylovKit: exponentiate using Observers using Random @@ -24,16 +25,7 @@ using Test ψ0 = random_mps(s; internal_inds_space=10) # Time evolve forward: - ψ1 = tdvp(H, -0.1im, ψ0; nsteps=1, cutoff, nsite=1) - - # - #TODO: exponentiate is now the default, so switch this to applyexp - # - #Different backend solvers, default solver_backend = "applyexp" - ψ1_exponentiate_backend = tdvp( - H, -0.1im, ψ0; nsteps=1, cutoff, nsite=1, solver_backend="exponentiate" - ) - @test ψ1 ≈ ψ1_exponentiate_backend rtol = 1e-7 + ψ1 = tdvp(H, -0.1im, ψ0; nsteps=1, cutoff, nsites=1) @test norm(ψ1) ≈ 1.0 @@ -46,12 +38,6 @@ using Test # Time evolve backwards: ψ2 = tdvp(H, +0.1im, ψ1; nsteps=1, cutoff) - #Different backend solvers, default solver_backend = "applyexp" - ψ2_exponentiate_backend = tdvp( - H, +0.1im, ψ1; nsteps=1, cutoff, solver_backend="exponentiate" - ) - @test ψ2 ≈ ψ2_exponentiate_backend rtol = 1e-7 - @test norm(ψ2) ≈ 1.0 # Should rotate back to original state: @@ -80,13 +66,7 @@ using Test ψ0 = random_mps(s; internal_inds_space=10) - ψ1 = tdvp(Hs, -0.1im, ψ0; nsteps=1, cutoff, nsite=1) - - #Different backend solvers, default solver_backend = "applyexp" - ψ1_exponentiate_backend = tdvp( - Hs, -0.1im, ψ0; nsteps=1, cutoff, nsite=1, solver_backend="exponentiate" - ) - @test ψ1 ≈ ψ1_exponentiate_backend rtol = 1e-7 + ψ1 = tdvp(Hs, -0.1im, ψ0; nsteps=1, cutoff, nsites=1) @test norm(ψ1) ≈ 1.0 @@ -99,12 +79,6 @@ using Test # Time evolve backwards: ψ2 = tdvp(Hs, +0.1im, ψ1; nsteps=1, cutoff) - #Different backend solvers, default solver_backend = "applyexp" - ψ2_exponentiate_backend = tdvp( - Hs, +0.1im, ψ1; nsteps=1, cutoff, solver_backend="exponentiate" - ) - @test ψ2 ≈ ψ2_exponentiate_backend rtol = 1e-7 - @test norm(ψ2) ≈ 1.0 # Should rotate back to original state: @@ -130,7 +104,7 @@ using Test ψ0 = random_mps(s; internal_inds_space=10) # Time evolve forward: - ψ1 = tdvp(H, -0.1im, ψ0; time_step=-0.05im, order, cutoff, nsite=1) + ψ1 = tdvp(H, -0.1im, ψ0; time_step=-0.05im, order, cutoff, nsites=1) @test norm(ψ1) ≈ 1.0 @@ -146,50 +120,6 @@ using Test @test abs(inner(ψ0, ψ2)) > 0.99 end - @testset "Custom solver in TDVP" begin - N = 10 - cutoff = 1e-12 - - s = siteinds("S=1/2", N) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - ψ0 = random_mps(s; internal_inds_space=10) - - function solver(PH, psi0; time_step, kwargs...) - solver_kwargs = (; - ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true - ) - psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...) - return psi, (; info=exp_info) - end - - ψ1 = tdvp(solver, H, -0.1im, ψ0; cutoff, nsite=1) - - @test norm(ψ1) ≈ 1.0 - - ## Should lose fidelity: - #@test abs(inner(ψ0,ψ1)) < 0.9 - - # Average energy should be conserved: - @test real(inner(ψ1', H, ψ1)) ≈ inner(ψ0', H, ψ0) - - # Time evolve backwards: - ψ2 = tdvp(H, +0.1im, ψ1; cutoff) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - end - @testset "Accuracy Test" begin N = 4 tau = 0.1 @@ -209,9 +139,9 @@ using Test Ut = exp(-im * tau * HM) - psi = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) - psi2 = deepcopy(psi) - psix = contract(psi) + state = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) + psi2 = deepcopy(state) + psix = contract(state) Sz_tdvp = Float64[] Sz_tdvp2 = Float64[] @@ -225,19 +155,17 @@ using Test psix = noprime(Ut * psix) psix /= norm(psix) - psi = tdvp( + state = tdvp( H, -im * tau, - psi; + state; cutoff, normalize=false, - solver_tol=1e-12, - solver_maxiter=500, - solver_krylovdim=25, + updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25), ) # TODO: What should `expect` output? Right now # it outputs a dictionary. - push!(Sz_tdvp, real(expect("Sz", psi; vertices=[c])[c])) + push!(Sz_tdvp, real(expect("Sz", state; vertices=[c])[c])) psi2 = tdvp( H, @@ -245,17 +173,15 @@ using Test psi2; cutoff, normalize=false, - solver_tol=1e-12, - solver_maxiter=500, - solver_krylovdim=25, - solver_backend="exponentiate", + updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25), + updater=exponentiate_updater, ) # TODO: What should `expect` output? Right now # it outputs a dictionary. push!(Sz_tdvp2, real(expect("Sz", psi2; vertices=[c])[c])) push!(Sz_exact, real(scalar(dag(prime(psix, s[c])) * Szc * psix))) - F = abs(scalar(dag(psix) * contract(psi))) + F = abs(scalar(dag(psix) * contract(state))) end @test norm(Sz_tdvp - Sz_exact) < 1e-5 @@ -292,8 +218,8 @@ using Test end append!(gates, reverse(gates)) - psi = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) - phi = deepcopy(psi) + state = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) + phi = deepcopy(state) c = div(N, 2) # @@ -307,17 +233,24 @@ using Test #En2 = zeros(Nsteps) for step in 1:Nsteps - psi = apply(gates, psi; cutoff) - #normalize!(psi) + state = apply(gates, state; cutoff) + #normalize!(state) - nsite = (step <= 3 ? 2 : 1) + nsites = (step <= 3 ? 2 : 1) phi = tdvp( - H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15 + H, + -tau * im, + phi; + nsteps=1, + cutoff, + nsites, + normalize=true, + updater_kwargs=(; krylovdim=15), ) - Sz1[step] = real(expect("Sz", psi; vertices=[c])[c]) + Sz1[step] = real(expect("Sz", state; vertices=[c])[c]) #Sz2[step] = real(expect("Sz", phi; vertices=[c])[c]) - En1[step] = real(inner(psi', H, psi)) + En1[step] = real(inner(state', H, state)) #En2[step] = real(inner(phi', H, phi)) end @@ -328,8 +261,8 @@ using Test phi = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) obs = Observer( - "Sz" => (; psi) -> expect("Sz", psi; vertices=[c])[c], - "En" => (; psi) -> real(inner(psi', H, psi)), + "Sz" => (; state) -> expect("Sz", state; vertices=[c])[c], + "En" => (; state) -> real(inner(state', H, state)), ) phi = tdvp( @@ -366,34 +299,24 @@ using Test H = mpo(os, s) - psi = random_mps(s; internal_inds_space=2) - psi2 = deepcopy(psi) + state = random_mps(s; internal_inds_space=2) trange = 0.0:tau:ttotal for (step, t) in enumerate(trange) - nsite = (step <= 10 ? 2 : 1) - psi = tdvp( - H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15 - ) - #Different backend solvers, default solver_backend = "applyexp" - psi2 = tdvp( + nsites = (step <= 10 ? 2 : 1) + state = tdvp( H, -tau, - psi2; + state; cutoff, - nsite, + nsites, reverse_step, normalize=true, - solver_krylovdim=15, - solver_backend="exponentiate", + updater_kwargs=(; krylovdim=15), ) end - @test psi ≈ psi2 rtol = 1e-6 - - en1 = inner(psi', H, psi) - en2 = inner(psi2', H, psi2) + en1 = inner(state', H, state) @test en1 < -4.25 - @test en1 ≈ en2 end @testset "Observers" begin @@ -418,39 +341,36 @@ using Test # Using Observers.jl # - measure_sz(; psi) = expect("Sz", psi; vertices=[c])[c] - measure_en(; psi) = real(inner(psi', H, psi)) + measure_sz(; state) = expect("Sz", state; vertices=[c])[c] + measure_en(; state) = real(inner(state', H, state)) sweep_obs = Observer("Sz" => measure_sz, "En" => measure_en) get_info(; info) = info - step_measure_sz(; psi) = expect("Sz", psi; vertices=[c])[c] - step_measure_en(; psi) = real(inner(psi', H, psi)) - step_obs = Observer( + step_measure_sz(; state) = expect("Sz", state; vertices=[c])[c] + step_measure_en(; state) = real(inner(state', H, state)) + region_obs = Observer( "Sz" => step_measure_sz, "En" => step_measure_en, "info" => get_info ) - psi2 = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) + state2 = mps(s; states=(n -> isodd(n) ? "Up" : "Dn")) tdvp( H, -im * ttotal, - psi2; + state2; time_step=-im * tau, cutoff, normalize=false, (sweep_observer!)=sweep_obs, - (step_observer!)=step_obs, + (region_observer!)=region_obs, root_vertex=N, # defaults to 1, which breaks observer equality ) Sz2 = sweep_obs.Sz En2 = sweep_obs.En - Sz2_step = step_obs.Sz - En2_step = step_obs.En - infos = step_obs.info - - #@show sweep_obs - #@show step_obs + Sz2_step = region_obs.Sz + En2_step = region_obs.En + infos = region_obs.info # # Could use ideas of other things to test here @@ -476,7 +396,7 @@ end ψ0 = normalize!(random_ttn(s; link_space=10)) # Time evolve forward: - ψ1 = tdvp(H, -0.1im, ψ0; nsteps=1, cutoff, nsite=1) + ψ1 = tdvp(H, -0.1im, ψ0; nsteps=1, cutoff, nsites=1) @test norm(ψ1) ≈ 1.0 @@ -518,7 +438,7 @@ end ψ0 = normalize!(random_ttn(s; link_space=10)) - ψ1 = tdvp(Hs, -0.1im, ψ0; nsteps=1, cutoff, nsite=1) + ψ1 = tdvp(Hs, -0.1im, ψ0; nsteps=1, cutoff, nsites=1) @test norm(ψ1) ≈ 1.0 @@ -537,50 +457,6 @@ end @test abs(inner(ψ0, ψ2)) > 0.99 end - @testset "Custom solver in TDVP" begin - cutoff = 1e-12 - - tooth_lengths = fill(2, 3) - root_vertex = (3, 2) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - os = ITensorNetworks.heisenberg(c) - - H = TTN(os, s) - - ψ0 = normalize!(random_ttn(s; link_space=10)) - - function solver(PH, psi0; time_step, kwargs...) - solver_kwargs = (; - ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true - ) - psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...) - return psi, (; info=exp_info) - end - - ψ1 = tdvp(solver, H, -0.1im, ψ0; cutoff, nsite=1) - - #@test ψ1 ≈ tdvp(solver, -0.1im, H, ψ0; cutoff, nsite=1) - #@test ψ1 ≈ tdvp(solver, H, ψ0, -0.1im; cutoff, nsite=1) - - @test norm(ψ1) ≈ 1.0 - - ## Should lose fidelity: - #@test abs(inner(ψ0,ψ1)) < 0.9 - - # Average energy should be conserved: - @test real(inner(ψ1', H, ψ1)) ≈ inner(ψ0', H, ψ0) - - # Time evolve backwards: - ψ2 = tdvp(H, +0.1im, ψ1; cutoff) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - end - @testset "Accuracy Test" begin tau = 0.1 ttotal = 1.0 @@ -597,8 +473,8 @@ end Ut = exp(-im * tau * HM) - psi = TTN(ComplexF64, s, v -> iseven(sum(isodd.(v))) ? "Up" : "Dn") - psix = contract(psi) + state = TTN(ComplexF64, s, v -> iseven(sum(isodd.(v))) ? "Up" : "Dn") + statex = contract(state) Sz_tdvp = Float64[] Sz_exact = Float64[] @@ -608,22 +484,20 @@ end Nsteps = Int(ttotal / tau) for step in 1:Nsteps - psix = noprime(Ut * psix) - psix /= norm(psix) + statex = noprime(Ut * statex) + statex /= norm(statex) - psi = tdvp( + state = tdvp( H, -im * tau, - psi; + state; cutoff, normalize=false, - solver_tol=1e-12, - solver_maxiter=500, - solver_krylovdim=25, + updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25), ) - push!(Sz_tdvp, real(expect("Sz", psi; vertices=[c])[c])) - push!(Sz_exact, real(scalar(dag(prime(psix, s[c])) * Szc * psix))) - F = abs(scalar(dag(psix) * contract(psi))) + push!(Sz_tdvp, real(expect("Sz", state; vertices=[c])[c])) + push!(Sz_exact, real(scalar(dag(prime(statex, s[c])) * Szc * statex))) + F = abs(scalar(dag(statex) * contract(state))) end @test norm(Sz_tdvp - Sz_exact) < 1e-5 @@ -657,8 +531,8 @@ end end append!(gates, reverse(gates)) - psi = TTN(s, v -> iseven(sum(isodd.(v))) ? "Up" : "Dn") - phi = copy(psi) + state = TTN(s, v -> iseven(sum(isodd.(v))) ? "Up" : "Dn") + phi = copy(state) c = (2, 1) # @@ -672,17 +546,24 @@ end En2 = zeros(Nsteps) for step in 1:Nsteps - psi = apply(gates, psi; cutoff, maxdim) - #normalize!(psi) + state = apply(gates, state; cutoff, maxdim) + #normalize!(state) - nsite = (step <= 3 ? 2 : 1) + nsites = (step <= 3 ? 2 : 1) phi = tdvp( - H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15 + H, + -tau * im, + phi; + nsteps=1, + cutoff, + nsites, + normalize=true, + updater_kwargs=(; krylovdim=15), ) - Sz1[step] = real(expect("Sz", psi; vertices=[c])[c]) + Sz1[step] = real(expect("Sz", state; vertices=[c])[c]) Sz2[step] = real(expect("Sz", phi; vertices=[c])[c]) - En1[step] = real(inner(psi', H, psi)) + En1[step] = real(inner(state', H, state)) En2[step] = real(inner(phi', H, phi)) end @@ -692,8 +573,8 @@ end phi = TTN(s, v -> iseven(sum(isodd.(v))) ? "Up" : "Dn") obs = Observer( - "Sz" => (; psi) -> expect("Sz", psi; vertices=[c])[c], - "En" => (; psi) -> real(inner(psi', H, psi)), + "Sz" => (; state) -> expect("Sz", state; vertices=[c])[c], + "En" => (; state) -> real(inner(state', H, state)), ) phi = tdvp( H, @@ -702,7 +583,7 @@ end time_step=-im * tau, cutoff, normalize=false, - (step_observer!)=obs, + (region_observer!)=obs, root_vertex=(3, 2), ) @@ -722,17 +603,24 @@ end os = ITensorNetworks.heisenberg(c) H = TTN(os, s) - psi = normalize!(random_ttn(s; link_space=2)) + state = normalize!(random_ttn(s; link_space=2)) trange = 0.0:tau:ttotal for (step, t) in enumerate(trange) - nsite = (step <= 10 ? 2 : 1) - psi = tdvp( - H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15 + nsites = (step <= 10 ? 2 : 1) + state = tdvp( + H, + -tau, + state; + cutoff, + nsites, + reverse_step, + normalize=true, + updater_kwargs=(; krylovdim=15), ) end - @test inner(psi', H, psi) < -2.47 + @test inner(state', H, state) < -2.47 end # TODO: verify quantum number suport in ITensorNetworks diff --git a/test/test_treetensornetworks/test_solvers/test_tdvp_time_dependent.jl b/test/test_treetensornetworks/test_solvers/test_tdvp_time_dependent.jl index 4abe5613..763c9df2 100644 --- a/test/test_treetensornetworks/test_solvers/test_tdvp_time_dependent.jl +++ b/test/test_treetensornetworks/test_solvers/test_tdvp_time_dependent.jl @@ -23,6 +23,29 @@ ode_kwargs = (; reltol=1e-8, abstol=1e-8) ω⃗ = [ω₁, ω₂] f⃗ = [t -> cos(ω * t) for ω in ω⃗] +ode_updater_kwargs = (; f=f⃗, solver_alg=ode_alg, ode_kwargs) + +function ode_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + time_step = region_kwargs.time_step + f⃗ = updater_kwargs.f + ode_kwargs = updater_kwargs.ode_kwargs + solver_alg = updater_kwargs.solver_alg + H⃗₀ = projected_operator![] + result, info = ode_solver( + -im * TimeDependentSum(f⃗, H⃗₀), time_step, init; solver_alg, ode_kwargs... + ) + return result, (; info) +end function tdvp_ode_solver(H⃗₀, ψ₀; time_step, kwargs...) psi_t, info = ode_solver( @@ -32,6 +55,7 @@ function tdvp_ode_solver(H⃗₀, ψ₀; time_step, kwargs...) end krylov_kwargs = (; tol=1e-8, eager=true) +krylov_updater_kwargs = (; f=f⃗, krylov_kwargs) function krylov_solver(H⃗₀, ψ₀; time_step, ishermitian=false, issymmetric=false, kwargs...) psi_t, info = krylov_solver( @@ -45,6 +69,38 @@ function krylov_solver(H⃗₀, ψ₀; time_step, ishermitian=false, issymmetric return psi_t, (; info) end +function krylov_updater( + init; + state!, + projected_operator!, + outputlevel, + which_sweep, + sweep_plan, + which_region_update, + region_kwargs, + updater_kwargs, +) + default_updater_kwargs = (; ishermitian=false, issymmetric=false) + + updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedenc + time_step = region_kwargs.time_step + f⃗ = updater_kwargs.f + krylov_kwargs = updater_kwargs.krylov_kwargs + ishermitian = updater_kwargs.ishermitian + issymmetric = updater_kwargs.issymmetric + H⃗₀ = projected_operator![] + + result, info = krylov_solver( + -im * TimeDependentSum(f⃗, H⃗₀), + time_step, + init; + krylov_kwargs..., + ishermitian, + issymmetric, + ) + return result, (; info) +end + @testset "MPS: Time dependent Hamiltonian" begin n = 4 J₁ = 1.0 @@ -53,7 +109,7 @@ end time_step = 0.1 time_total = 1.0 - nsite = 2 + nsites = 2 maxdim = 100 cutoff = 1e-8 @@ -65,9 +121,28 @@ end ψ₀ = complex(mps(s; states=(j -> isodd(j) ? "↑" : "↓"))) - ψₜ_ode = tdvp(tdvp_ode_solver, H⃗₀, time_total, ψ₀; time_step, maxdim, cutoff, nsite) + ψₜ_ode = tdvp( + ode_updater, + H⃗₀, + time_total, + ψ₀; + time_step, + maxdim, + cutoff, + nsites, + updater_kwargs=ode_updater_kwargs, + ) - ψₜ_krylov = tdvp(krylov_solver, H⃗₀, time_total, ψ₀; time_step, cutoff, nsite) + ψₜ_krylov = tdvp( + krylov_updater, + H⃗₀, + time_total, + ψ₀; + time_step, + cutoff, + nsites, + updater_kwargs=krylov_updater_kwargs, + ) ψₜ_full, _ = tdvp_ode_solver(contract.(H⃗₀), contract(ψ₀); time_step=time_total) @@ -96,7 +171,7 @@ end time_step = 0.1 time_total = 1.0 - nsite = 2 + nsites = 2 maxdim = 100 cutoff = 1e-8 @@ -108,9 +183,28 @@ end ψ₀ = TTN(ComplexF64, s, v -> iseven(sum(isodd.(v))) ? "↑" : "↓") - ψₜ_ode = tdvp(tdvp_ode_solver, H⃗₀, time_total, ψ₀; time_step, maxdim, cutoff, nsite) + ψₜ_ode = tdvp( + ode_updater, + H⃗₀, + time_total, + ψ₀; + time_step, + maxdim, + cutoff, + nsites, + updater_kwargs=ode_updater_kwargs, + ) - ψₜ_krylov = tdvp(krylov_solver, H⃗₀, time_total, ψ₀; time_step, cutoff, nsite) + ψₜ_krylov = tdvp( + krylov_updater, + H⃗₀, + time_total, + ψ₀; + time_step, + cutoff, + nsites, + updater_kwargs=krylov_updater_kwargs, + ) ψₜ_full, _ = tdvp_ode_solver(contract.(H⃗₀), contract(ψ₀); time_step=time_total)