From a167f9d08c50f109b99be747d40b918210ef0091 Mon Sep 17 00:00:00 2001 From: Miles Date: Sun, 22 Sep 2024 14:28:49 -0400 Subject: [PATCH 1/5] Fixes to alternating_update default printing --- src/solvers/alternating_update/alternating_update.jl | 6 +++--- src/solvers/alternating_update/region_update.jl | 1 + src/solvers/defaults.jl | 12 ++++++------ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/solvers/alternating_update/alternating_update.jl b/src/solvers/alternating_update/alternating_update.jl index 2cd5de71..750f3f36 100644 --- a/src/solvers/alternating_update/alternating_update.jl +++ b/src/solvers/alternating_update/alternating_update.jl @@ -9,8 +9,8 @@ function alternating_update( nsites, # define default for each level of solver implementation updater, # this specifies the update performed locally outputlevel=default_outputlevel(), - region_printer=nothing, - sweep_printer=nothing, + region_printer=default_region_printer, + sweep_printer=default_sweep_printer, (sweep_observer!)=nothing, (region_observer!)=nothing, root_vertex=GraphsExtensions.default_root_vertex(init_state), @@ -59,7 +59,7 @@ function alternating_update( (sweep_observer!)=nothing, sweep_printer=default_sweep_printer,#? (region_observer!)=nothing, - region_printer=nothing, + region_printer=default_region_printer, ) state = copy(init_state) @assert !isnothing(sweep_plans) diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl index b92adc8c..1deb2367 100644 --- a/src/solvers/alternating_update/region_update.jl +++ b/src/solvers/alternating_update/region_update.jl @@ -112,6 +112,7 @@ function region_update( outputlevel, info..., region_kwargs..., + inserter_kwargs..., internal_kwargs..., ) update_observer!(region_observer!; all_kwargs...) diff --git a/src/solvers/defaults.jl b/src/solvers/defaults.jl index b5d315ff..598c3114 100644 --- a/src/solvers/defaults.jl +++ b/src/solvers/defaults.jl @@ -8,9 +8,9 @@ default_inserter() = default_inserter default_checkdone() = (; kws...) -> false default_transform_operator() = nothing function default_region_printer(; - cutoff, - maxdim, - mindim, + cutoff=nothing, + mindim=nothing, + maxdim=nothing, outputlevel, state, sweep_plan, @@ -23,9 +23,9 @@ function default_region_printer(; region = first(sweep_plan[which_region_update]) @printf("Sweep %d, region=%s \n", which_sweep, region) print(" Truncated using") - @printf(" cutoff=%.1E", cutoff) - @printf(" maxdim=%d", maxdim) - @printf(" mindim=%d", mindim) + !isnothing(cutoff) && @printf(" cutoff=%.1E", cutoff) + !isnothing(maxdim) && @printf(" maxdim=%d", maxdim) + !isnothing(mindim) && @printf(" mindim=%d", mindim) println() if spec != nothing @printf( From 89b2527cae7d082cbcd4e5c6481d2a2ad4baefd7 Mon Sep 17 00:00:00 2001 From: Miles Date: Wed, 25 Sep 2024 17:35:28 -0400 Subject: [PATCH 2/5] Pass inserter_kwargs as NamedTuple into all_kwargs --- src/solvers/defaults.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/solvers/defaults.jl b/src/solvers/defaults.jl index 598c3114..716214fe 100644 --- a/src/solvers/defaults.jl +++ b/src/solvers/defaults.jl @@ -8,9 +8,7 @@ default_inserter() = default_inserter default_checkdone() = (; kws...) -> false default_transform_operator() = nothing function default_region_printer(; - cutoff=nothing, - mindim=nothing, - maxdim=nothing, + inserter_kwargs, outputlevel, state, sweep_plan, @@ -23,9 +21,9 @@ function default_region_printer(; region = first(sweep_plan[which_region_update]) @printf("Sweep %d, region=%s \n", which_sweep, region) print(" Truncated using") - !isnothing(cutoff) && @printf(" cutoff=%.1E", cutoff) - !isnothing(maxdim) && @printf(" maxdim=%d", maxdim) - !isnothing(mindim) && @printf(" mindim=%d", mindim) + haskey(inserter_kwargs, :cutoff) && @printf(" cutoff=%.1E", inserter_kwargs[:cutoff]) + haskey(inserter_kwargs, :maxdim) && @printf(" maxdim=%d", inserter_kwargs[:maxdim]) + haskey(inserter_kwargs, :mindim) && @printf(" mindim=%d", inserter_kwargs[:mindim]) println() if spec != nothing @printf( From 5c7d2113b2987dadfa989f174d1d362cf175f50a Mon Sep 17 00:00:00 2001 From: Miles Date: Wed, 25 Sep 2024 21:29:36 -0400 Subject: [PATCH 3/5] Improvements --- src/solvers/alternating_update/region_update.jl | 1 - src/solvers/defaults.jl | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl index 1deb2367..b92adc8c 100644 --- a/src/solvers/alternating_update/region_update.jl +++ b/src/solvers/alternating_update/region_update.jl @@ -112,7 +112,6 @@ function region_update( outputlevel, info..., region_kwargs..., - inserter_kwargs..., internal_kwargs..., ) update_observer!(region_observer!; all_kwargs...) diff --git a/src/solvers/defaults.jl b/src/solvers/defaults.jl index 716214fe..f9896ea1 100644 --- a/src/solvers/defaults.jl +++ b/src/solvers/defaults.jl @@ -21,9 +21,9 @@ function default_region_printer(; region = first(sweep_plan[which_region_update]) @printf("Sweep %d, region=%s \n", which_sweep, region) print(" Truncated using") - haskey(inserter_kwargs, :cutoff) && @printf(" cutoff=%.1E", inserter_kwargs[:cutoff]) - haskey(inserter_kwargs, :maxdim) && @printf(" maxdim=%d", inserter_kwargs[:maxdim]) - haskey(inserter_kwargs, :mindim) && @printf(" mindim=%d", inserter_kwargs[:mindim]) + haskey(inserter_kwargs, :cutoff) && @printf(" cutoff=%.1E", inserter_kwargs.cutoff) + haskey(inserter_kwargs, :maxdim) && @printf(" maxdim=%d", inserter_kwargs.maxdim) + haskey(inserter_kwargs, :mindim) && @printf(" mindim=%d", inserter_kwargs.mindim) println() if spec != nothing @printf( From 11493249ab63771cff1aca74505920e1f9f7a73a Mon Sep 17 00:00:00 2001 From: Miles Date: Thu, 26 Sep 2024 10:45:24 -0400 Subject: [PATCH 4/5] Add test of outputlevels (and quiet another test) --- .../test_solvers/Project.toml | 1 + .../test_solvers/test_dmrg.jl | 28 ++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_treetensornetworks/test_solvers/Project.toml b/test/test_treetensornetworks/test_solvers/Project.toml index 77225041..dc5ca10d 100644 --- a/test/test_treetensornetworks/test_solvers/Project.toml +++ b/test/test_treetensornetworks/test_solvers/Project.toml @@ -10,4 +10,5 @@ Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index b352d43c..cf8a1caf 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -22,6 +22,7 @@ using KrylovKit: eigsolve using NamedGraphs.NamedGraphGenerators: named_comb_tree using Observers: observer using StableRNGs: StableRNG +using Suppressor: @capture_out using Test: @test, @test_broken, @testset # This is needed since `eigen` is broken @@ -76,6 +77,31 @@ ITensors.disable_auto_fermion() new_E = inner(psi', H, psi) @test new_E ≈ orig_E =# + + # + # Test outputlevels are working + # + prev_output = "" + for outputlevel in 0:2 + output = @capture_out begin + e, psi = dmrg( + H, + psi; + outputlevel, + nsweeps, + maxdim, + cutoff, + nsites, + updater_kwargs=(; krylovdim=3, maxiter=1), + ) + end + if outputlevel == 0 + @test length(output) == 0 + else + @test length(output) > length(prev_output) + end + prev_output = output + end end @testset "Observers" begin @@ -139,7 +165,7 @@ end nsweeps, maxdim, cutoff, - outputlevel=2, + outputlevel=0, transform_operator=ITensorNetworks.cache_operator_to_disk, transform_operator_kwargs=(; write_when_maxdim_exceeds=11), ) From c84d3b56f65463d50e0b74ddf4db708ee534f9c3 Mon Sep 17 00:00:00 2001 From: Miles Date: Thu, 26 Sep 2024 10:55:10 -0400 Subject: [PATCH 5/5] Improve printer code style --- src/solvers/defaults.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/solvers/defaults.jl b/src/solvers/defaults.jl index f9896ea1..09c2ae2f 100644 --- a/src/solvers/defaults.jl +++ b/src/solvers/defaults.jl @@ -1,4 +1,4 @@ -using Printf: @printf +using Printf: @printf, @sprintf using ITensorMPS: maxlinkdim default_outputlevel() = 0 default_nsites() = 2 @@ -7,6 +7,10 @@ default_extracter() = default_extracter default_inserter() = default_inserter default_checkdone() = (; kws...) -> false default_transform_operator() = nothing + +format(x) = @sprintf("%s", x) +format(x::AbstractFloat) = @sprintf("%.1E", x) + function default_region_printer(; inserter_kwargs, outputlevel, @@ -21,9 +25,11 @@ function default_region_printer(; region = first(sweep_plan[which_region_update]) @printf("Sweep %d, region=%s \n", which_sweep, region) print(" Truncated using") - haskey(inserter_kwargs, :cutoff) && @printf(" cutoff=%.1E", inserter_kwargs.cutoff) - haskey(inserter_kwargs, :maxdim) && @printf(" maxdim=%d", inserter_kwargs.maxdim) - haskey(inserter_kwargs, :mindim) && @printf(" mindim=%d", inserter_kwargs.mindim) + for key in [:cutoff, :maxdim, :mindim] + if haskey(inserter_kwargs, key) + print(" ", key, "=", format(inserter_kwargs[key])) + end + end println() if spec != nothing @printf(