Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor interfaces built around alternating update #121

Merged
merged 32 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ac76123
Change required solver function signature. Passes psi and projected H…
b-kloss Jan 11, 2024
b23ab06
Move solver_funcs to separate directory. Adapt dmrg and dmrg-x to new…
b-kloss Jan 16, 2024
c2c5455
Format.
b-kloss Jan 16, 2024
ed0df91
Modify alternating update kwarg naming, structure, and also solver_in…
b-kloss Jan 17, 2024
db524f7
Adapt eigsolve and dmrg to new interfaces, fix dmrg tests (and remove…
b-kloss Jan 17, 2024
d258ccc
Remove default_sweep_regions from dmrg.
b-kloss Jan 17, 2024
6ff116a
Format.
b-kloss Jan 17, 2024
0ac2f2d
rename region_[update]_printer and nsite[s]
Jan 19, 2024
c02662c
Remove applyexp.jl
Jan 19, 2024
1a8ff33
Fix imports/namespaces.
Jan 19, 2024
80620b8
Add ToDo regarding testing applyexp in test_tdvp.jl
Jan 19, 2024
4c33d65
Define sweep_params outside of function call.
Jan 19, 2024
114ef3e
Change NamedTuple access pattern.
Jan 19, 2024
b817587
Change NamedTuple(;) to (;).
Jan 19, 2024
723fae9
(;) also for tdvp.
Jan 19, 2024
1e40a79
Remove second applyexp.jl file.
Jan 19, 2024
5587391
Remove applyexp from ITensorNetworks.jl
Jan 19, 2024
0620943
Start renaming. One tdvp test not passing, observer related.
Jan 19, 2024
7865f81
Cleanup solvers.
b-kloss Jan 20, 2024
e08ba96
Rename to sweep_plan.
b-kloss Jan 20, 2024
30455a4
Fix renaming to tdvp_sweep and kwarg handling in updaters.
b-kloss Jan 20, 2024
86dd746
Fix dmrg-x.
b-kloss Jan 20, 2024
59f45c1
Adapt contract(_updater) to new interface.
b-kloss Jan 20, 2024
8bc62fb
Format.
b-kloss Jan 20, 2024
ac1a923
Fix tdvp_time_dependent tests.
b-kloss Jan 21, 2024
2748927
Format tests.
b-kloss Jan 21, 2024
1fe92b3
Remove obsolete tests from test_tdvp (mostly those with an alternativ…
b-kloss Jan 21, 2024
9e8e635
Remove exponentiate from imports from KrylovKit
b-kloss Jan 21, 2024
23cc545
Apply review suggestions.
b-kloss Jan 21, 2024
3d5ade8
Fix test_tdvp.jl
b-kloss Jan 21, 2024
64ef223
Format.
b-kloss Jan 21, 2024
812de92
Adapt linsolve.
b-kloss Jan 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ include("tensornetworkoperators.jl")
include(joinpath("ITensorsExt", "itensorutils.jl"))
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl")) #this defines the primitive before the solver function
include(joinpath("solvers", "applyexp.jl"))
include(joinpath("solvers", "dmrg_x_solver.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand All @@ -114,7 +119,6 @@ include(joinpath("treetensornetworks", "projttns", "projttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
Expand Down
25 changes: 25 additions & 0 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
function applyexp_solver()
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
function solver(
init;
psi_ref!,
PH_ref!,
region,
sweep_regions,
sweep_step,
solver_krylovdim=30,
solver_outputlevel=0,
solver_tol=1E-8,
substep,
time_step,
normalize,
)
H = PH_ref![]
#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(
H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel
)
return psi, (; info=exp_info)
end
return solver
end
19 changes: 19 additions & 0 deletions src/solvers/dmrg_x_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function dmrg_x_solver(
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
init;
psi_ref!,
PH_ref!,
normalize=nothing,
region,
sweep_regions,
sweep_step,
half_sweep,
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
step_kwargs...,
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
)
H = contract(PH_ref![], ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
# TODO: improve this to return the energy estimate too
return U_max, NamedTuple()
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
end
38 changes: 38 additions & 0 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

b-kloss marked this conversation as resolved.
Show resolved Hide resolved
function eigsolve_updater(
init;
psi_ref!,
PH_ref!,
outputlevel,
which_sweep,
region_updates,
which_region_update,
region_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think region_update_kwargs sounds better to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the same discussion on exponentiate_updater applies here, I think we should discuss how these are being passed and maybe merge them with updater_kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this will be obsolete soon, I am in favor of sticking with region_kwargs.

updater_kwargs,
)
default_updater_kwargs = (;
solver_which_eigenvalue=:SR,
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
ishermitian=true,
tol=1e-14,
krylovdim=3,
maxiter=1,
outputlevel=0,
eager=false,
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
howmany = 1
which = updater_kwargs[:solver_which_eigenvalue]
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
vals, vecs, info = KrylovKit.eigsolve(
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
PH_ref![],
init,
howmany,
which;
ishermitian=updater_kwargs[:ishermitian],
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
tol=updater_kwargs[:tol],
krylovdim=updater_kwargs[:krylovdim],
maxiter=updater_kwargs[:maxiter],
verbosity=updater_kwargs[:outputlevel],
eager=updater_kwargs[:eager],
)
return vecs[1], (; info, energies=vals)
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
end
48 changes: 48 additions & 0 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
function exponentiate_updater(
init;
psi_ref!,
PH_ref!,
outputlevel,
which_sweep,
region_updates,
which_region_update,
region_kwargs,
updater_kwargs,
Comment on lines +9 to +10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to combine these into updater_kwargs?

Copy link
Contributor Author

@b-kloss b-kloss Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updater_kwargs are updater specific kwargs, while region_args are among the things that we expose to all updaters. in principle, we can nest the region_args into updater_kwargs in the call to region_update but I am not sure if that's preferable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok, I just have found it hard to keep track of the logic of why certain keyword arguments are bundled in certain ways, how they will be used, etc.

For example, from the perspective of this function, the only argument I can see that is being used here from region_kwargs is time_step, which I don't think is really any different conceptually from the arguments being passed in updater_kwargs (it's just another thing being used by the solver/updater). So it makes sense to me to just bundle those together in one flat NamedTuple called updater_kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, these will eventually be bundled together in an upcoming PR.

)
default_updater_kwargs = (;
krylovdim=30, #from here only solver kwargs
maxiter=100,
outputlevel=0,
tol=1E-12,
ishermitian=true,
issymmetric=true,
eager=true,
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
#H=copy(PH_ref![])
H = PH_ref![] ###since we are not changing H we don't need the copy
# let's test whether given region and sweep regions we can find out what the previous and next region were
# this will be needed in subspace expansion
#@show step_kwargs
substep = get(region_kwargs, :substep, nothing)
time_step = get(region_kwargs, :time_step, nothing)
@assert !isnothing(time_step) && !isnothing(substep)
region_ind = which_region_update
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
next_region =
region_ind == length(region_updates) ? nothing : first(region_updates[region_ind + 1])
previous_region = region_ind == 1 ? nothing : first(region_updates[region_ind - 1])

phi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian=updater_kwargs[:ishermitian],
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
issymmetric=updater_kwargs[:issymmetric],
tol=updater_kwargs[:tol],
krylovdim=updater_kwargs[:krylovdim],
maxiter=updater_kwargs[:maxiter],
verbosity=updater_kwargs[:outputlevel],
eager=updater_kwargs[:eager],
)
return phi, (; info=exp_info)
end
49 changes: 28 additions & 21 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, sweep, sw_time)
function sweep_printer(; outputlevel, psi, which_sweep, sw_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print("After sweep ", which_sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print(" cpu_time=", round(sw_time; digits=3))
println()
Expand All @@ -37,7 +37,7 @@ function sweep_printer(; outputlevel, psi, sweep, sw_time)
end

function alternating_update(
solver,
updater,
PH,
psi0::AbstractTTN;
checkdone=(; kws...) -> false,
Expand All @@ -46,55 +46,60 @@ function alternating_update(
(sweep_observer!)=observer(),
sweep_printer=sweep_printer,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
updater_kwargs,
kwargs...,
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

psi = copy(psi0)

insert_function!(sweep_observer!, "sweep_printer" => sweep_printer)
insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) # FIX THIS

for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
for which_sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) &&
maxdim[which_sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk",
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[which_sweep] = $(maxdim[which_sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
end

sw_time = @elapsed begin
psi, PH = update_step(
solver,
psi, PH = sweep_update(
updater,
PH,
psi;
outputlevel,
sweep,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
which_sweep,
sweep_params=(;
maxdim=maxdim[which_sweep],
mindim=mindim[which_sweep],
cutoff=cutoff[which_sweep],
noise=noise[which_sweep],
),
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
updater_kwargs,
kwargs...,
)
end

update!(sweep_observer!; psi, sweep, sw_time, outputlevel)
update!(sweep_observer!; psi, which_sweep, sw_time, outputlevel)

checkdone(; psi, sweep, outputlevel, kwargs...) && break
checkdone(; psi, which_sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer"))
return psi
end

function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
function alternating_update(updater, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return alternating_update(solver, PH, psi0; kwargs...)
return alternating_update(updater, PH, psi0; kwargs...)
end

"""
Expand All @@ -116,12 +121,14 @@ each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
function alternating_update(
updater, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return alternating_update(solver, PHs, psi0; kwargs...)
return alternating_update(updater, PHs, psi0; kwargs...)
end
68 changes: 22 additions & 46 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,34 @@
function eigsolve_solver(;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
)
function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing)
howmany = 1
which = solver_which_eigenvalue
vals, vecs, info = eigsolve(
H,
init,
howmany,
which;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
psi = vecs[1]
return psi, (; solver_info=info, energies=vals)
end
return solver
end

"""
Overload of `ITensors.dmrg`.
"""

function dmrg_sweep(
nsite::Int, graph::AbstractGraph; root_vertex=default_root_vertex(graph)
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
)
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
return tdvp_sweep(2, nsite, Inf, graph; root_vertex, reverse_step=false)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return tdvp_sweep(2, nsite, Inf, graph; root_vertex, reverse_step=false)
order = 2
time_step = Inf
return tdvp_sweep(order, nsite, time_step, graph; root_vertex, reverse_step=false)

so we remember what 2 and Inf mean.

Copy link
Member

@mtfishman mtfishman Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm not sure how I feel about calling this dmrg_sweep since it could be used by other solvers like linsolve. I'm also not a huge fan of the function name tdvp_sweep for the same reason.

Maybe we can rethink this API a bit, for example rename the current tdvp_sweep to default_sweep_plan and then just call that with different inputs from within the different solvers. I think a good design for that would be to choose default values that "just work" for DMRG, including defaulting to order=2, and then provide optional inputs like reverse_step=true and time_step for use with TDVP. That may mean moving some of the inputs like order and time step to keyword arguments, which seems like a better design anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and this is what the default_sweep_regions function was trying to be. That is, a sort of "typical" sweep plan that most algorithms (dmrg, linsolve, etc.) could use. We could definitely think about moving that function out of the update_step.jl file but the intent there was for it to be a pretty generic default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I forgot about default_sweep_regions, I'm not sure what the fate of that is in the current PR since a lot of things are getting changed in this PR. But it seems like default_sweep_regions, tdvp_sweep, and dmrg_sweep should all get consolidated into a single function, which tentatively we are thinking of calling something like default_sweep_plan, default_region_updates, or default_region_update_plans (since it contains both the regions that will be updated but also information about how to do the update), which I think should just get called directly by the different solvers like dmrg, tdvp, linsolve, etc. and passed to alternating_update.

Opinions on what to name that function would be welcome. We want to come up with a good name for that function and what it outputs which we can then use consistently throughout the rest of the code. The list output by tdvp_sweep/dmrg_sweep is called region_updates in the current version of this PR but we are discussing alternative names, which hopefully will be consistent with the name of the new function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I incorporated the suggestion, and the sweep_regions constructor is planend to be refactored in an upcoming PR


function dmrg(
updater,
H,
init::AbstractTTN;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
nsweeps, #it makes sense to require this to be defined
nsite=2,
(sweep_observer!)=observer(),
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
root_vertex=default_root_vertex(init),
updater_kwargs=NamedTuple(;),
kwargs...,
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
)
return alternating_update(
eigsolve_solver(;
solver_which_eigenvalue,
ishermitian,
solver_tol,
solver_krylovdim,
solver_maxiter,
solver_verbosity,
),
H,
init;
kwargs...,
region_updates = dmrg_sweep(nsite, init; root_vertex)

b-kloss marked this conversation as resolved.
Show resolved Hide resolved
psi = alternating_update(
updater, H, init; nsweeps, sweep_observer!, region_updates, updater_kwargs, kwargs...
)
return psi
end

function dmrg(H, init::AbstractTTN; updater=eigsolve_updater, kwargs...)
return dmrg(updater, H, init; kwargs...)
end

"""
Expand Down
14 changes: 1 addition & 13 deletions src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
function dmrg_x_solver(
PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing
)
H = contract(PH, ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
# TODO: improve this to return the energy estimate too
return U_max, NamedTuple()
end

function dmrg_x(PH, init::AbstractTTN; kwargs...)
psi = alternating_update(dmrg_x_solver, PH, init; kwargs...)
psi = alternating_update(ITensorNetworks.dmrg_x_solver, PH, init; kwargs...)
return psi
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
end
Loading
Loading