From 30455a4c09ce4156bd2fc344e4b4d222aa6af206 Mon Sep 17 00:00:00 2001 From: Benedikt Kloss Date: Sat, 20 Jan 2024 17:32:15 -0500 Subject: [PATCH] Fix renaming to tdvp_sweep and kwarg handling in updaters. --- src/solvers/eigsolve.jl | 10 +++++++--- src/solvers/exponentiate.jl | 7 ++++--- src/treetensornetworks/solvers/update_step.jl | 20 +++++++++---------- .../test_solvers/test_dmrg.jl | 4 ++-- .../test_solvers/test_tdvp.jl | 3 ++- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 98f9c28f..eb9e8aef 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -4,7 +4,7 @@ function eigsolve_updater( projected_operator!, outputlevel, which_sweep, - region_updates, + sweep_plan, which_region_update, region_kwargs, updater_kwargs, @@ -20,13 +20,17 @@ function eigsolve_updater( ) updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence howmany = 1 - which = updater_kwargs.which_eigenvalue + which, updater_kwargs = _pop_which_eigenvalue(;updater_kwargs...) vals, vecs, info = eigsolve( projected_operator![], init, howmany, which; - updater_kwargs... # leaves it to the user to supply only supported kwargs + updater_kwargs... #this leaves it ) return vecs[1], (; info, eigvals=vals) end + +function _pop_which_eigenvalue(;which_eigenvalue, kwargs...) + return which_eigenvalue, NamedTuple(kwargs) +end \ No newline at end of file diff --git a/src/solvers/exponentiate.jl b/src/solvers/exponentiate.jl index 8cd0e3ff..7d56ad39 100644 --- a/src/solvers/exponentiate.jl +++ b/src/solvers/exponentiate.jl @@ -4,13 +4,13 @@ function exponentiate_updater( projected_operator!, outputlevel, which_sweep, - region_updates, + sweep_plan, which_region_update, region_kwargs, updater_kwargs, ) default_updater_kwargs = (; - krylovdim=30, #from here only solver kwargs + krylovdim=30, maxiter=100, verbosity=0, tol=1E-12, @@ -18,10 +18,11 @@ function exponentiate_updater( issymmetric=true, eager=true, ) + updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence result, exp_info = exponentiate( projected_operator![], - time_step, + region_kwargs.time_step, init; updater_kwargs... ) diff --git a/src/treetensornetworks/solvers/update_step.jl b/src/treetensornetworks/solvers/update_step.jl index 90452c43..445fb329 100644 --- a/src/treetensornetworks/solvers/update_step.jl +++ b/src/treetensornetworks/solvers/update_step.jl @@ -15,10 +15,10 @@ function default_sweep_regions(nsites, graph::AbstractGraph; kwargs...) ###move end function region_update_printer(; - cutoff, maxdim, mindim, outputlevel::Int=0, state, region_updates, spec, which_region_update, which_sweep,kwargs... + cutoff, maxdim, mindim, outputlevel::Int=0, state, sweep_plan, spec, which_region_update, which_sweep,kwargs... ) if outputlevel >= 2 - region=first(region_updates[which_region_update]) + region=first(sweep_plan[which_region_update]) @printf("Sweep %d, region=%s \n", which_sweep, region) print(" Truncated using") @printf(" cutoff=%.1E", cutoff) @@ -53,7 +53,7 @@ function sweep_update( # Append empty namedtuple to each element if not already present # (Needed to handle user-provided region_updates) - region_updates = append_missing_namedtuple.(to_tuple.(region_updates)) + sweep_plan = append_missing_namedtuple.(to_tuple.(sweep_plan)) if nv(state) == 1 error( @@ -71,7 +71,7 @@ function sweep_update( normalize, outputlevel, which_sweep, - region_updates, + sweep_plan, which_region_update, region_kwargs, region_observer!, @@ -167,7 +167,7 @@ function region_update( normalize, outputlevel, which_sweep, - region_updates, + sweep_plan, which_region_update, region_kwargs, region_observer!, @@ -175,7 +175,7 @@ function region_update( #extraction_kwargs, #ToDo: implement later with possibility to pass custom extraction/insertion func (or code into func) updater_kwargs ) - region=first(region_updates[which_region_update]) + region=first(sweep_plan[which_region_update]) state = orthogonalize(state, current_ortho(region)) state, phi = extract_local_tensor(state, region;) nsites = (region isa AbstractEdge) ? 0 : length(region) #ToDo move into separate funtion @@ -189,7 +189,7 @@ function region_update( projected_operator!, outputlevel, which_sweep, - region_updates, + sweep_plan, which_region_update, region_kwargs, updater_kwargs @@ -222,9 +222,9 @@ function region_update( maxdim, mindim, which_region_update, - region_updates, - total_sweep_steps=length(region_updates), - end_of_sweep=(which_region_update == length(region_updates)), + sweep_plan, + total_sweep_steps=length(sweep_plan), + end_of_sweep=(which_region_update == length(sweep_plan)), state, region, which_sweep, diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index 92a318d7..9bc3f608 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -74,8 +74,8 @@ end sweep(; which_sweep, kw...) = which_sweep sweep_observer! = observer(sweep) - region(; which_region_update, region_updates, kw...) = - first(region_updates[which_region_update]) + region(; which_region_update, sweep_plan, kw...) = + first(sweep_plan[which_region_update]) energy(; eigvals, kw...) = eigvals[1] region_observer! = observer(region, sweep, energy) diff --git a/test/test_treetensornetworks/test_solvers/test_tdvp.jl b/test/test_treetensornetworks/test_solvers/test_tdvp.jl index 68d9a49b..d53c081c 100644 --- a/test/test_treetensornetworks/test_solvers/test_tdvp.jl +++ b/test/test_treetensornetworks/test_solvers/test_tdvp.jl @@ -542,7 +542,8 @@ end # Should rotate back to original state: @test abs(inner(ψ0, ψ2)) > 0.99 end -#= + # ToDo: Discuss whether the test commented out here is necessary given the new design? + #= @testset "Custom updater in TDVP" begin cutoff = 1e-12