Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor interfaces built around alternating update #121

Merged
merged 32 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ac76123
Change required solver function signature. Passes psi and projected H…
b-kloss Jan 11, 2024
b23ab06
Move solver_funcs to separate directory. Adapt dmrg and dmrg-x to new…
b-kloss Jan 16, 2024
c2c5455
Format.
b-kloss Jan 16, 2024
ed0df91
Modify alternating update kwarg naming, structure, and also solver_in…
b-kloss Jan 17, 2024
db524f7
Adapt eigsolve and dmrg to new interfaces, fix dmrg tests (and remove…
b-kloss Jan 17, 2024
d258ccc
Remove default_sweep_regions from dmrg.
b-kloss Jan 17, 2024
6ff116a
Format.
b-kloss Jan 17, 2024
0ac2f2d
rename region_[update]_printer and nsite[s]
Jan 19, 2024
c02662c
Remove applyexp.jl
Jan 19, 2024
1a8ff33
Fix imports/namespaces.
Jan 19, 2024
80620b8
Add ToDo regarding testing applyexp in test_tdvp.jl
Jan 19, 2024
4c33d65
Define sweep_params outside of function call.
Jan 19, 2024
114ef3e
Change NamedTuple access pattern.
Jan 19, 2024
b817587
Change NamedTuple(;) to (;).
Jan 19, 2024
723fae9
(;) also for tdvp.
Jan 19, 2024
1e40a79
Remove second applyexp.jl file.
Jan 19, 2024
5587391
Remove applyexp from ITensorNetworks.jl
Jan 19, 2024
0620943
Start renaming. One tdvp test not passing, observer related.
Jan 19, 2024
7865f81
Cleanup solvers.
b-kloss Jan 20, 2024
e08ba96
Rename to sweep_plan.
b-kloss Jan 20, 2024
30455a4
Fix renaming to tdvp_sweep and kwarg handling in updaters.
b-kloss Jan 20, 2024
86dd746
Fix dmrg-x.
b-kloss Jan 20, 2024
59f45c1
Adapt contract(_updater) to new interface.
b-kloss Jan 20, 2024
8bc62fb
Format.
b-kloss Jan 20, 2024
ac1a923
Fix tdvp_time_dependent tests.
b-kloss Jan 21, 2024
2748927
Format tests.
b-kloss Jan 21, 2024
1fe92b3
Remove obsolete tests from test_tdvp (mostly those with an alternativ…
b-kloss Jan 21, 2024
9e8e635
Remove exponentiate from imports from KrylovKit
b-kloss Jan 21, 2024
23cc545
Apply review suggestions.
b-kloss Jan 21, 2024
3d5ade8
Fix test_tdvp.jl
b-kloss Jan 21, 2024
64ef223
Format.
b-kloss Jan 21, 2024
812de92
Adapt linsolve.
b-kloss Jan 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@
include(joinpath("ITensorsExt", "itensorutils.jl"))
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers","eigsolve.jl"))

Check warning on line 109 in src/ITensorNetworks.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/ITensorNetworks.jl:109:-include(joinpath("solvers","eigsolve.jl")) src/ITensorNetworks.jl:110:-include(joinpath("solvers","exponentiate.jl")) src/ITensorNetworks.jl:109:+include(joinpath("solvers", "eigsolve.jl")) src/ITensorNetworks.jl:110:+include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers","exponentiate.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl")) #this defines the primitive before the solver function
include(joinpath("solvers","applyexp.jl"))

Check warning on line 112 in src/ITensorNetworks.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/ITensorNetworks.jl:112:-include(joinpath("solvers","applyexp.jl")) src/ITensorNetworks.jl:113:-include(joinpath("solvers","dmrg_x_solver.jl")) src/ITensorNetworks.jl:112:+include(joinpath("solvers", "applyexp.jl")) src/ITensorNetworks.jl:113:+include(joinpath("solvers", "dmrg_x_solver.jl"))
include(joinpath("solvers","dmrg_x_solver.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand All @@ -114,7 +119,6 @@
include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
Expand Down
27 changes: 27 additions & 0 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function applyexp_solver()
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
function solver(

Check warning on line 2 in src/solvers/applyexp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/applyexp.jl:2:- function solver( src/solvers/applyexp.jl:3:- init; src/solvers/applyexp.jl:4:- psi_ref!, src/solvers/applyexp.jl:5:- PH_ref!, src/solvers/applyexp.jl:6:- region, src/solvers/applyexp.jl:7:- sweep_regions, src/solvers/applyexp.jl:8:- sweep_step, src/solvers/applyexp.jl:9:- solver_krylovdim=30, src/solvers/applyexp.jl:10:- solver_outputlevel=0, src/solvers/applyexp.jl:11:- solver_tol=1E-8, src/solvers/applyexp.jl:12:- substep, src/solvers/applyexp.jl:13:- time_step, src/solvers/applyexp.jl:14:- normalize, src/solvers/applyexp.jl:15:- ) src/solvers/applyexp.jl:16:- H=PH_ref![] src/solvers/applyexp.jl:17:- #applyexp tol is absolute, compute from tol_per_unit_time: src/solvers/applyexp.jl:18:- tol = abs(time_step) * tol_per_unit_time src/solvers/applyexp.jl:19:- psi, exp_info = applyexp( src/solvers/applyexp.jl:20:- H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel src/solvers/applyexp.jl:21:- ) src/solvers/applyexp.jl:22:- return psi, (; info=exp_info) src/solvers/applyexp.jl:23:- end src/solvers/applyexp.jl:24:- return solver src/solvers/applyexp.jl:2:+ function solver( src/solvers/applyexp.jl:3:+ init; src/solvers/applyexp.jl:4:+ psi_ref!, src/solvers/applyexp.jl:5:+ PH_ref!, src/solvers/applyexp.jl:6:+ region, src/solvers/applyexp.jl:7:+ sweep_regions, src/solvers/applyexp.jl:8:+ sweep_step, src/solvers/applyexp.jl:9:+ solver_krylovdim=30, src/solvers/applyexp.jl:10:+ solver_outputlevel=0, src/solvers/applyexp.jl:11:+ solver_tol=1E-8, src/solvers/applyexp.jl:12:+ substep, src/solvers/applyexp.jl:13:+ time_step, src/solvers/applyexp.jl:14:+ normalize, src/solvers/applyexp.jl:15:+ ) src/solvers/applyexp.jl:16:+ H = PH_ref![] src/solvers/applyexp.jl:17:+ #applyexp tol is absolute, compute from tol_per_unit_time: src/solvers/applyexp.jl:18:+ tol = abs(time_step) * tol_per_unit_time src/solvers/applyexp.jl:19:+ psi, exp_info = applyexp( src/solvers/applyexp.jl:20:+ H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel src/solvers/applyexp.jl:21:+ ) src/solvers/applyexp.jl:22:+ return psi, (; info=exp_info) src/solvers/applyexp.jl:23:+ end src/solvers/applyexp.jl:24:+ return solver
init;
psi_ref!,
PH_ref!,
region,
sweep_regions,
sweep_step,
solver_krylovdim=30,
solver_outputlevel=0,
solver_tol=1E-8,
substep,
time_step,
normalize,
)
H=PH_ref![]
#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(
H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel
)
return psi, (; info=exp_info)
end
return solver
end

Check warning on line 26 in src/solvers/applyexp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/applyexp.jl:26:- src/solvers/applyexp.jl:27:-

19 changes: 19 additions & 0 deletions src/solvers/dmrg_x_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function dmrg_x_solver(
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
init;

Check warning on line 2 in src/solvers/dmrg_x_solver.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/dmrg_x_solver.jl:2:- init; src/solvers/dmrg_x_solver.jl:3:- psi_ref!, src/solvers/dmrg_x_solver.jl:4:- PH_ref!, src/solvers/dmrg_x_solver.jl:5:- normalize=nothing, src/solvers/dmrg_x_solver.jl:6:- region, src/solvers/dmrg_x_solver.jl:7:- sweep_regions, src/solvers/dmrg_x_solver.jl:8:- sweep_step, src/solvers/dmrg_x_solver.jl:9:- half_sweep, src/solvers/dmrg_x_solver.jl:10:- step_kwargs... src/solvers/dmrg_x_solver.jl:11:- ) src/solvers/dmrg_x_solver.jl:2:+ init; src/solvers/dmrg_x_solver.jl:3:+ psi_ref!, src/solvers/dmrg_x_solver.jl:4:+ PH_ref!, src/solvers/dmrg_x_solver.jl:5:+ normalize=nothing, src/solvers/dmrg_x_solver.jl:6:+ region, src/solvers/dmrg_x_solver.jl:7:+ sweep_regions, src/solvers/dmrg_x_solver.jl:8:+ sweep_step, src/solvers/dmrg_x_solver.jl:9:+ half_sweep, src/solvers/dmrg_x_solver.jl:10:+ step_kwargs..., src/solvers/dmrg_x_solver.jl:11:+)
psi_ref!,
PH_ref!,
normalize=nothing,
region,
sweep_regions,
sweep_step,
half_sweep,
step_kwargs...
)
H = contract(PH_ref![], ITensor(1.0))
D, U = eigen(H; ishermitian=true)

Check warning on line 13 in src/solvers/dmrg_x_solver.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/dmrg_x_solver.jl:13:- D, U = eigen(H; ishermitian=true) src/solvers/dmrg_x_solver.jl:13:+ D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
# TODO: improve this to return the energy estimate too
return U_max, NamedTuple()
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
end

Check warning on line 19 in src/solvers/dmrg_x_solver.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/dmrg_x_solver.jl:19:-end src/solvers/dmrg_x_solver.jl:19:+end
40 changes: 40 additions & 0 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

b-kloss marked this conversation as resolved.
Show resolved Hide resolved
function eigsolve_solver(;
solver_which_eigenvalue=:SR, #TODO: settle on pattern to pass solver kwargs

Check warning on line 3 in src/solvers/eigsolve.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/eigsolve.jl:3:- solver_which_eigenvalue=:SR, #TODO: settle on pattern to pass solver kwargs src/solvers/eigsolve.jl:4:- ishermitian=true, src/solvers/eigsolve.jl:5:- solver_tol=1e-14, src/solvers/eigsolve.jl:6:- solver_krylovdim=3, src/solvers/eigsolve.jl:7:- solver_maxiter=1, src/solvers/eigsolve.jl:8:- solver_verbosity=0, src/solvers/eigsolve.jl:3:+ solver_which_eigenvalue=:SR, #TODO: settle on pattern to pass solver kwargs src/solvers/eigsolve.jl:4:+ ishermitian=true, src/solvers/eigsolve.jl:5:+ solver_tol=1e-14, src/solvers/eigsolve.jl:6:+ solver_krylovdim=3, src/solvers/eigsolve.jl:7:+ solver_maxiter=1, src/solvers/eigsolve.jl:8:+ solver_verbosity=0, src/solvers/eigsolve.jl:9:+) src/solvers/eigsolve.jl:10:+ function solver( src/solvers/eigsolve.jl:11:+ init; src/solvers/eigsolve.jl:12:+ psi_ref!, src/solvers/eigsolve.jl:13:+ PH_ref!, src/solvers/eigsolve.jl:14:+ normalize, src/solvers/eigsolve.jl:15:+ region, src/solvers/eigsolve.jl:16:+ sweep_regions, src/solvers/eigsolve.jl:17:+ sweep_step, src/solvers/eigsolve.jl:18:+ sweep_kwargs..., src/solvers/eigsolve.jl:19:+ # slurp solver_kwargs? #TODO: homogenize how the solver kwargs are passed
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
)

Check warning on line 10 in src/solvers/eigsolve.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/eigsolve.jl:10:- src/solvers/eigsolve.jl:11:- function solver( src/solvers/eigsolve.jl:12:- init; src/solvers/eigsolve.jl:13:- psi_ref!, src/solvers/eigsolve.jl:14:- PH_ref!, src/solvers/eigsolve.jl:15:- normalize, src/solvers/eigsolve.jl:16:- region, src/solvers/eigsolve.jl:17:- sweep_regions, src/solvers/eigsolve.jl:18:- sweep_step, src/solvers/eigsolve.jl:19:- sweep_kwargs... src/solvers/eigsolve.jl:20:- # slurp solver_kwargs? #TODO: homogenize how the solver kwargs are passed src/solvers/eigsolve.jl:21:- ) src/solvers/eigsolve.jl:22:- howmany = 1 src/solvers/eigsolve.jl:23:- which = solver_which_eigenvalue src/solvers/eigsolve.jl:24:- vals, vecs, info = eigsolve( src/solvers/eigsolve.jl:25:- PH_ref![], src/solvers/eigsolve.jl:26:- init, src/solvers/eigsolve.jl:27:- howmany, src/solvers/eigsolve.jl:28:- which; src/solvers/eigsolve.jl:29:- ishermitian, src/solvers/eigsolve.jl:30:- tol=solver_tol, src/solvers/eigsolve.jl:31:- krylovdim=solver_krylovdim, src/solvers/eigsolve.jl:32:- maxiter=solver_maxiter, src/solvers/eigsolve.jl:33:- verbosity=solver_verbosity, src/solvers/eigsolve.jl:34:- ) src/solvers/eigsolve.jl:35:- phi = vecs[1] src/solvers/eigsolve.jl:36:- return phi, (; solver_info=info, energies=vals) src/solvers/eigsolve.jl:37:- end src/solvers/eigsolve.jl:38:- return solver src/solvers/eigsolve.jl:21:+ howmany = 1 src/solvers/eigsolve.jl:22:+ which = solver_which_eigenvalue src/solvers/eigsolve.jl:23:+ vals, vecs, info = eigsolve( src/solvers/eigsolve.jl:24:+ PH_ref![], src/solvers/eigsolve.jl:25:+ init, src/solvers/eigsolve.jl:26:+ howmany, src/solvers/eigsolve.jl:27:+ which; src/solvers/eigsolve.jl:28:+ ishermitian, src/solvers/eigsolve.jl:29:+ tol=solver_tol, src/solvers/eigsolve.jl:30:+ krylovdim=solver_krylovdim, src/solvers/eigsolve.jl:31:+ maxiter=solver_maxiter, src/solvers/eigsolve.jl:32:+ verbosity=solver_verbosity, src/solvers/eigsolve.jl:33:+ ) src/solvers/eigsolve.jl:34:+ phi = vecs[1] src/solvers/eigsolve.jl:35:+ return phi, (; solver_info=info, energies=vals)
function solver(
init;
psi_ref!,
PH_ref!,
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
normalize,
region,
sweep_regions,
sweep_step,
sweep_kwargs...
# slurp solver_kwargs? #TODO: homogenize how the solver kwargs are passed
)
howmany = 1
which = solver_which_eigenvalue
vals, vecs, info = eigsolve(
PH_ref![],
init,
howmany,
which;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
phi = vecs[1]
return phi, (; solver_info=info, energies=vals)
end
return solver
end

Check warning on line 40 in src/solvers/eigsolve.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/solvers/eigsolve.jl:40:- src/solvers/eigsolve.jl:37:+ return solver src/solvers/eigsolve.jl:38:+end
42 changes: 42 additions & 0 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
function exponentiate_solver()
function solver(
init;
psi_ref!,
PH_ref!,
ishermitian=true,
issymmetric=true,
region,
sweep_regions,
sweep_step,
solver_krylovdim=30,
solver_maxiter=100,
solver_outputlevel=0,
solver_tol=1E-12,
substep,
normalize,
time_step,
)
#H=copy(PH_ref![])
H=PH_ref![] ###since we are not changing H we don't need the copy
# let's test whether given region and sweep regions we can find out what the previous and next region were
# this will be needed in subspace expansion
region_ind=sweep_step
next_region=region_ind==length(sweep_regions) ? nothing : first(sweep_regions[region_ind+1])
previous_region=region_ind==1 ? nothing : first(sweep_regions[region_ind-1])
b-kloss marked this conversation as resolved.
Show resolved Hide resolved

phi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian,
issymmetric,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_outputlevel,
eager=true,
)
return phi, (; info=exp_info)
end
return solver
end
29 changes: 1 addition & 28 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,3 @@
function eigsolve_solver(;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
)
function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing)
howmany = 1
which = solver_which_eigenvalue
vals, vecs, info = eigsolve(
H,
init,
howmany,
which;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
psi = vecs[1]
return psi, (; solver_info=info, energies=vals)
end
return solver
end

"""
Overload of `ITensors.dmrg`.
"""
Expand All @@ -51,6 +23,7 @@ function dmrg(
),
H,
init;
reverse_step=false,
kwargs...,
)
end
Expand Down
14 changes: 1 addition & 13 deletions src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
function dmrg_x_solver(
PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing
)
H = contract(PH, ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
# TODO: improve this to return the energy estimate too
return U_max, NamedTuple()
end

function dmrg_x(PH, init::AbstractTTN; kwargs...)
psi = alternating_update(dmrg_x_solver, PH, init; kwargs...)
psi = alternating_update(ITensorNetworks.dmrg_x_solver, PH, init; kwargs...)
return psi
b-kloss marked this conversation as resolved.
Show resolved Hide resolved
end
65 changes: 1 addition & 64 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,3 @@
function exponentiate_solver()
function solver(
H,
init;
ishermitian=true,
issymmetric=true,
region,
solver_krylovdim=30,
solver_maxiter=100,
solver_outputlevel=0,
solver_tol=1E-12,
substep,
normalize,
time_step,
)
psi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian,
issymmetric,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_outputlevel,
eager=true,
)
return psi, (; info=exp_info)
end
return solver
end

function applyexp_solver()
function solver(
H,
init;
tdvp_order,
solver_krylovdim=30,
solver_outputlevel=0,
solver_tol=1E-8,
substep,
time_step,
normalize,
)
#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(
H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel
)
return psi, (; info=exp_info)
end
return solver
end

function _compute_nsweeps(nsteps, t, time_step, order)
nsweeps_per_step = order / 2
nsweeps = 1
Expand Down Expand Up @@ -168,15 +114,6 @@ Optional keyword arguments:
* `observer` - object implementing the Observer interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kwargs...)
if solver_backend == "exponentiate"
solver = exponentiate_solver
elseif solver_backend == "applyexp"
solver = applyexp_solver
else
error(
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
function tdvp(H, t::Number, init::AbstractTTN; solver=exponentiate_solver, kwargs...)
return tdvp(solver(), H, t, init; kwargs...)
end
9 changes: 7 additions & 2 deletions src/treetensornetworks/solvers/update_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ function update_step(
step_printer=step_printer,
(step_observer!)=observer(),
sweep::Int=1,
sweep_regions=default_sweep_regions(nsite, psi),
reverse_step::Bool=false,
sweep_regions=default_sweep_regions(nsite, psi;reverse_step),
kwargs...,
)
PH = copy(PH)
Expand Down Expand Up @@ -182,7 +183,11 @@ function local_update(
nsites = (region isa AbstractEdge) ? 0 : length(region)
PH = set_nsite(PH, nsites)
PH = position(PH, psi, region)
phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...)
(psi_ref!) = Ref(psi) # create references, in case solver does (out-of-place) modify PH or psi
(PH_ref!) = Ref(PH)
phi, info = solver(phi;(psi_ref!),(PH_ref!), normalize, region, sweep_regions, sweep_step, step_kwargs..., solver_kwargs...) # args passed by reference are supposed to be modified out of place
psi = psi_ref![] # dereference
PH = PH_ref![]
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
Expand Down
19 changes: 11 additions & 8 deletions test/test_treetensornetworks/test_solvers/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using KrylovKit: exponentiate
using Observers
using Random
using Test
#exponentiate_solver=ITensorNetworks.exponentiate_solver #ToDo: how to best handle importing etc.

@testset "MPS TDVP" begin
@testset "Basic TDVP" begin
Expand Down Expand Up @@ -31,7 +32,7 @@ using Test
#
#Different backend solvers, default solver_backend = "applyexp"
ψ1_exponentiate_backend = tdvp(
H, -0.1im, ψ0; nsteps=1, cutoff, nsite=1, solver_backend="exponentiate"
H, -0.1im, ψ0; nsteps=1, cutoff, nsite=1, solver=exponentiate_solver
)
@test ψ1 ≈ ψ1_exponentiate_backend rtol = 1e-7

Expand All @@ -48,7 +49,7 @@ using Test

#Different backend solvers, default solver_backend = "applyexp"
ψ2_exponentiate_backend = tdvp(
H, +0.1im, ψ1; nsteps=1, cutoff, solver_backend="exponentiate"
H, +0.1im, ψ1; nsteps=1, cutoff, solver=exponentiate_solver
)
@test ψ2 ≈ ψ2_exponentiate_backend rtol = 1e-7

Expand Down Expand Up @@ -84,7 +85,7 @@ using Test

#Different backend solvers, default solver_backend = "applyexp"
ψ1_exponentiate_backend = tdvp(
Hs, -0.1im, ψ0; nsteps=1, cutoff, nsite=1, solver_backend="exponentiate"
Hs, -0.1im, ψ0; nsteps=1, cutoff, nsite=1, solver=exponentiate_solver
)
@test ψ1 ≈ ψ1_exponentiate_backend rtol = 1e-7

Expand All @@ -101,7 +102,7 @@ using Test

#Different backend solvers, default solver_backend = "applyexp"
ψ2_exponentiate_backend = tdvp(
Hs, +0.1im, ψ1; nsteps=1, cutoff, solver_backend="exponentiate"
Hs, +0.1im, ψ1; nsteps=1, cutoff, solver=exponentiate_solver
)
@test ψ2 ≈ ψ2_exponentiate_backend rtol = 1e-7

Expand Down Expand Up @@ -163,10 +164,11 @@ using Test

ψ0 = random_mps(s; internal_inds_space=10)

function solver(PH, psi0; time_step, kwargs...)
function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...)
solver_kwargs = (;
ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true
)
PH=PH_ref![]
psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...)
return psi, (; info=exp_info)
end
Expand Down Expand Up @@ -248,7 +250,7 @@ using Test
solver_tol=1e-12,
solver_maxiter=500,
solver_krylovdim=25,
solver_backend="exponentiate",
solver=exponentiate_solver,
)
# TODO: What should `expect` output? Right now
# it outputs a dictionary.
Expand Down Expand Up @@ -384,7 +386,7 @@ using Test
reverse_step,
normalize=true,
solver_krylovdim=15,
solver_backend="exponentiate",
solver=ITensorNetworks.exponentiate_solver,
)
end

Expand Down Expand Up @@ -551,10 +553,11 @@ end

ψ0 = normalize!(random_ttn(s; link_space=10))

function solver(PH, psi0; time_step, kwargs...)
function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...)
solver_kwargs = (;
ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true
)
PH=PH_ref![]
psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...)
return psi, (; info=exp_info)
end
Expand Down
Loading