From ac761235decb43665ed0247bbf132e176fc9a4c3 Mon Sep 17 00:00:00 2001 From: Benedikt Kloss Date: Thu, 11 Jan 2024 17:16:50 -0500 Subject: [PATCH] Change required solver function signature. Passes psi and projected Hamiltonian to solver by reference, passes more information about sweep to solver. --- src/treetensornetworks/solvers/tdvp.jl | 25 +++++++++++++++---- src/treetensornetworks/solvers/update_step.jl | 6 ++++- .../test_solvers/test_tdvp.jl | 6 +++-- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/treetensornetworks/solvers/tdvp.jl b/src/treetensornetworks/solvers/tdvp.jl index 01fe2985..dfc553c6 100644 --- a/src/treetensornetworks/solvers/tdvp.jl +++ b/src/treetensornetworks/solvers/tdvp.jl @@ -1,10 +1,13 @@ function exponentiate_solver() function solver( - H, init; + psi_ref!, + PH_ref!, ishermitian=true, issymmetric=true, region, + sweep_regions, + sweep_step, solver_krylovdim=30, solver_maxiter=100, solver_outputlevel=0, @@ -13,7 +16,15 @@ function exponentiate_solver() normalize, time_step, ) - psi, exp_info = KrylovKit.exponentiate( + #H=copy(PH_ref![]) + H=PH_ref![] ###since we are not changing H we don't need the copy + # let's test whether given region and sweep regions we can find out what the previous and next region were + # this will be needed in subspace expansion + region_ind=sweep_step + next_region=region_ind==length(sweep_regions) ? nothing : first(sweep_regions[region_ind+1]) + previous_region=region_ind==1 ? nothing : first(sweep_regions[region_ind-1]) + + phi, exp_info = KrylovKit.exponentiate( H, time_step, init; @@ -25,16 +36,19 @@ function exponentiate_solver() verbosity=solver_outputlevel, eager=true, ) - return psi, (; info=exp_info) + return phi, (; info=exp_info) end return solver end function applyexp_solver() function solver( - H, init; - tdvp_order, + psi_ref!, + PH_ref!, + region, + sweep_regions, + sweep_step, solver_krylovdim=30, solver_outputlevel=0, solver_tol=1E-8, @@ -42,6 +56,7 @@ function applyexp_solver() time_step, normalize, ) + H=PH_ref![] #applyexp tol is absolute, compute from tol_per_unit_time: tol = abs(time_step) * tol_per_unit_time psi, exp_info = applyexp( diff --git a/src/treetensornetworks/solvers/update_step.jl b/src/treetensornetworks/solvers/update_step.jl index 0f69ed52..30e3f29e 100644 --- a/src/treetensornetworks/solvers/update_step.jl +++ b/src/treetensornetworks/solvers/update_step.jl @@ -182,7 +182,11 @@ function local_update( nsites = (region isa AbstractEdge) ? 0 : length(region) PH = set_nsite(PH, nsites) PH = position(PH, psi, region) - phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...) + (psi_ref!) = Ref(psi) # create references, in case solver does (out-of-place) modify PH or psi + (PH_ref!) = Ref(PH) + phi, info = solver(phi;(psi_ref!),(PH_ref!), normalize, region, sweep_regions, sweep_step, step_kwargs..., solver_kwargs...) # args passed by reference are supposed to be modified out of place + psi = psi_ref![] # dereference + PH = PH_ref![] if !(phi isa ITensor && info isa NamedTuple) println("Solver returned the following types: $(typeof(phi)), $(typeof(info))") error("In alternating_update, solver must return an ITensor and a NamedTuple") diff --git a/test/test_treetensornetworks/test_solvers/test_tdvp.jl b/test/test_treetensornetworks/test_solvers/test_tdvp.jl index ec002af1..efa76f02 100644 --- a/test/test_treetensornetworks/test_solvers/test_tdvp.jl +++ b/test/test_treetensornetworks/test_solvers/test_tdvp.jl @@ -163,10 +163,11 @@ using Test ψ0 = random_mps(s; internal_inds_space=10) - function solver(PH, psi0; time_step, kwargs...) + function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...) solver_kwargs = (; ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true ) + PH=PH_ref![] psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...) return psi, (; info=exp_info) end @@ -551,10 +552,11 @@ end ψ0 = normalize!(random_ttn(s; link_space=10)) - function solver(PH, psi0; time_step, kwargs...) + function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...) solver_kwargs = (; ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true ) + PH=PH_ref![] psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...) return psi, (; info=exp_info) end