Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues around keyword arguments #114

Merged
merged 2 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/treetensornetworks/solvers/applyexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ struct ApplyExpInfo
converged::Int
end

function applyexp(H, tau::Number, x0; kwargs...)
maxiter = get(kwargs, :maxiter, 30)
tol = get(kwargs, :tol, 1E-12)
outputlevel = get(kwargs, :outputlevel, 0)
beta_tol = get(kwargs, :normcutoff, 1E-7)

function applyexp(H, tau::Number, x0; maxiter=30, tol=1e-12, outputlevel=0, normcutoff=1e-7)
# Initialize Lanczos vectors
v1 = copy(x0)
nrm = norm(v1)
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function contract_solver(PH, psi; kwargs...)
function contract_solver(PH, psi; normalize, region, half_sweep)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
Expand Down
54 changes: 43 additions & 11 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
function solver(H, init; kws...)
function eigsolve_solver(;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
)
function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing)
howmany = 1
which = solver_which_eigenvalue
solver_kwargs = (;
ishermitian=get(kwargs, :ishermitian, true),
tol=get(kwargs, :solver_tol, 1E-14),
krylovdim=get(kwargs, :solver_krylovdim, 3),
maxiter=get(kwargs, :solver_maxiter, 1),
verbosity=get(kwargs, :solver_verbosity, 0),
vals, vecs, info = eigsolve(
H,
init,
howmany,
which;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
vals, vecs, info = eigsolve(H, init, howmany, which; solver_kwargs...)
psi = vecs[1]
return psi, (; solver_info=info, energies=vals)
end
Expand All @@ -19,8 +29,30 @@ end
"""
Overload of `ITensors.dmrg`.
"""
function dmrg(H, init::AbstractTTN; kwargs...)
return alternating_update(eigsolve_solver(; kwargs...), H, init; kwargs...)
function dmrg(
H,
init::AbstractTTN;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
kwargs...,
)
return alternating_update(
eigsolve_solver(;
solver_which_eigenvalue,
ishermitian,
solver_tol,
solver_krylovdim,
solver_maxiter,
solver_verbosity,
),
H,
init;
kwargs...,
)
end

"""
Expand Down
4 changes: 3 additions & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
function dmrg_x_solver(PH, init; kwargs...)
function dmrg_x_solver(
PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing
)
H = contract(PH, ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
Expand Down
22 changes: 16 additions & 6 deletions src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ Keyword arguments:
Overload of `KrylovKit.linsolve`.
"""
function linsolve(
A::AbstractTTN, b::AbstractTTN, x₀::AbstractTTN, a₀::Number=0, a₁::Number=1; kwargs...
A::AbstractTTN,
b::AbstractTTN,
x₀::AbstractTTN,
a₀::Number=0,
a₁::Number=1;
normalize,
region,
half_sweep,
)
function linsolve_solver(
P,
Expand All @@ -33,17 +40,20 @@ function linsolve(
solver_krylovdim=30,
solver_maxiter=100,
solver_verbosity=0,
kwargs...,
)
solver_kwargs = (;
ishermitian=ishermitian,
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(
P,
b,
x₀,
a₀,
a₁;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, x₀, a₀, a₁; solver_kwargs...)
return x, NamedTuple()
end

Expand Down
37 changes: 23 additions & 14 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function exponentiate_solver(; kwargs...)
function exponentiate_solver()
function solver(
H,
init;
Expand All @@ -10,10 +10,13 @@ function exponentiate_solver(; kwargs...)
solver_outputlevel=0,
solver_tol=1E-12,
substep,
normalize,
time_step,
kws...,
)
solver_kwargs = (;
psi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian,
issymmetric,
tol=solver_tol,
Expand All @@ -22,14 +25,12 @@ function exponentiate_solver(; kwargs...)
verbosity=solver_outputlevel,
eager=true,
)

psi, exp_info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, (; info=exp_info)
end
return solver
end

function applyexp_solver(; kwargs...)
function applyexp_solver()
function solver(
H,
init;
Expand All @@ -39,13 +40,13 @@ function applyexp_solver(; kwargs...)
solver_tol=1E-8,
substep,
time_step,
kws...,
normalize,
)
solver_kwargs = (; maxiter=solver_krylovdim, outputlevel=solver_outputlevel)

#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
psi, exp_info = applyexp(
H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel
)
return psi, (; info=exp_info)
end
return solver
Expand Down Expand Up @@ -84,7 +85,12 @@ function sub_time_steps(order)
end

function tdvp_sweep(
order::Int, nsite::Int, time_step::Number, graph::AbstractGraph; kwargs...
order::Int,
nsite::Int,
time_step::Number,
graph::AbstractGraph;
root_vertex=default_root_vertex(graph),
reverse_step=true,
)
sweep = []
for (substep, fac) in enumerate(sub_time_steps(order))
Expand All @@ -93,10 +99,11 @@ function tdvp_sweep(
direction(substep),
graph,
make_region;
root_vertex,
nsite,
region_args=(; substep, time_step=sub_time_step),
reverse_args=(; substep, time_step=-sub_time_step),
reverse_step=true,
reverse_step,
)
append!(sweep, half)
end
Expand All @@ -113,10 +120,12 @@ function tdvp(
nsteps=nothing,
order::Integer=2,
(sweep_observer!)=observer(),
root_vertex=default_root_vertex(init),
reverse_step=true,
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step)

function sweep_time_printer(; outputlevel, sweep, kwargs...)
if outputlevel >= 1
Expand Down Expand Up @@ -169,5 +178,5 @@ function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kw
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
return tdvp(solver(; kwargs...), H, t, init; kwargs...)
return tdvp(solver(), H, t, init; kwargs...)
end
34 changes: 18 additions & 16 deletions src/treetensornetworks/solvers/update_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,29 @@ function insert_local_tensor(
psi::AbstractTTN,
phi::ITensor,
pos::Vector;
which_decomp=nothing,
normalize=false,
# factorize kwargs
maxdim=nothing,
mindim=nothing,
cutoff=nothing,
which_decomp=nothing,
eigen_perturbation=nothing,
kwargs...,
ortho=nothing,
)
spec = nothing
for (v, vnext) in IterTools.partition(pos, 2, 1)
e = edgetype(psi)(v, vnext)
indsTe = inds(psi[v])
L, phi, spec = factorize(
phi, indsTe; which_decomp, tags=tags(psi, e), eigen_perturbation, kwargs...
phi,
indsTe;
tags=tags(psi, e),
maxdim,
mindim,
cutoff,
which_decomp,
eigen_perturbation,
ortho,
)
psi[v] = L
eigen_perturbation = nothing # TODO: fix this
Expand Down Expand Up @@ -162,16 +174,15 @@ function local_update(
sweep,
sweep_regions,
sweep_step,
kwargs...,
solver_kwargs...,
)
psi = orthogonalize(psi, current_ortho(region))
psi, phi = extract_local_tensor(psi, region)

nsites = (region isa AbstractEdge) ? 0 : length(region)
PH = set_nsite(PH, nsites)
PH = position(PH, psi, region)

phi, info = solver(PH, phi; normalize, region, step_kwargs..., kwargs...)
phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...)
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
Expand All @@ -185,16 +196,7 @@ function local_update(
#end

psi, spec = insert_local_tensor(
psi,
phi,
region;
eigen_perturbation=drho,
ortho,
normalize,
cutoff,
maxdim,
mindim,
kwargs...,
psi, phi, region; eigen_perturbation=drho, ortho, normalize, maxdim, mindim, cutoff
)

update!(
Expand Down
4 changes: 2 additions & 2 deletions test/test_treetensornetworks/test_solvers/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ using Test
@test inner(psit, Hpsi) ≈ inner(psit, H, psi) atol = 1E-5

# Test with less good initial guess MPS not equal to psi
psi_guess = truncate(psi; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init_state=psi_guess)
psi_guess = truncate(psit; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=psi_guess)
@test inner(psit, Hpsi) ≈ inner(psit, H, psi) atol = 1E-5

# Test with nsite=1
Expand Down
26 changes: 6 additions & 20 deletions test/test_treetensornetworks/test_solvers/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ using Test
solver_tol=1e-12,
solver_maxiter=500,
solver_krylovdim=25,
solver_solver_backend="exponentiate",
solver_backend="exponentiate",
)
# TODO: What should `expect` output? Right now
# it outputs a dictionary.
Expand Down Expand Up @@ -312,14 +312,7 @@ using Test

nsite = (step <= 3 ? 2 : 1)
phi = tdvp(
H,
-tau * im,
phi;
nsteps=1,
cutoff,
nsite,
normalize=true,
exponentiate_krylovdim=15,
H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15
)

Sz1[step] = real(expect("Sz", psi; vertices=[c])[c])
Expand Down Expand Up @@ -379,7 +372,7 @@ using Test
for (step, t) in enumerate(trange)
nsite = (step <= 10 ? 2 : 1)
psi = tdvp(
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, exponentiate_krylovdim=15
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15
)
#Different backend solvers, default solver_backend = "applyexp"
psi2 = tdvp(
Expand All @@ -390,7 +383,7 @@ using Test
nsite,
reverse_step,
normalize=true,
exponentiate_krylovdim=15,
solver_krylovdim=15,
solver_backend="exponentiate",
)
end
Expand Down Expand Up @@ -684,14 +677,7 @@ end

nsite = (step <= 3 ? 2 : 1)
phi = tdvp(
H,
-tau * im,
phi;
nsteps=1,
cutoff,
nsite,
normalize=true,
exponentiate_krylovdim=15,
H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15
)

Sz1[step] = real(expect("Sz", psi; vertices=[c])[c])
Expand Down Expand Up @@ -742,7 +728,7 @@ end
for (step, t) in enumerate(trange)
nsite = (step <= 10 ? 2 : 1)
psi = tdvp(
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, exponentiate_krylovdim=15
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15
)
end

Expand Down
Loading