Skip to content

Commit

Permalink
Adapt linsolve.
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss committed Jan 21, 2024
1 parent 64ef223 commit 812de92
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand Down
22 changes: 22 additions & 0 deletions src/solvers/linsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
function linsolve_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
default_updater_kwargs = (;
ishermitian=false, tol=1E-14, krylovdim=30, maxiter=100, verbosity=0, a₀, a₁
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
P = projected_operator![]
(; a₀, a₁) = updater_kwargs
updater_kwargs = Base.structdiff(updater_kwargs, (; a₀=nothing, a₁=nothing))
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, init, a₀, a₁; updater_kwargs...)
return x, (;)
end
40 changes: 11 additions & 29 deletions src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,21 @@ function linsolve(
x₀::AbstractTTN,
a₀::Number=0,
a₁::Number=1;
normalize,
region,
half_sweep,
updater=linsolve_updater,
nsweeps, #it makes sense to require this to be defined
nsites=2,
(sweep_observer!)=observer(),
root_vertex=default_root_vertex(init),
updater_kwargs=(;),
kwargs...,
)
function linsolve_solver(
P,
x₀;
ishermitian=false,
solver_tol=1E-14,
solver_krylovdim=30,
solver_maxiter=100,
solver_verbosity=0,
)
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(
P,
b,
x₀,
a₀,
a₁;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
return x, NamedTuple()
end

updater_kwargs = (; a₀, a₁, updater_kwargs...)
error("`linsolve` for TTN not yet implemented.")

sweep_plan = default_sweep_regions(nsites, x0)
# TODO: Define `itensornetwork_cache`
# TODO: Define `linsolve_cache`

P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b))
return alternating_update(linsolve_solver, P, x₀; kwargs...)
return alternating_update(linsolve_updater, P, x₀; sweep_plan, updater_kwargs, kwargs...)
end
4 changes: 2 additions & 2 deletions test/test_treetensornetworks/test_solvers/test_linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ using Random
x_c =
random_mps(s; states, internal_inds_space=4) +
0.1im * random_mps(s; states, internal_inds_space=2)
b = apply(H, x_c; cutoff)
b = apply(H, x_c; alg="fit", nsweeps=3) #cutoff is unsupported kwarg for apply/contract

x0 = random_mps(s; states, internal_inds_space=10)
x = @test_broken linsolve(
H, b, x0; cutoff, maxdim, nsweeps, ishermitian=true, solver_tol=1E-6
H, b, x0; cutoff, maxdim, nsweeps, updater_kwargs=(; tol=1E-6, ishermitian=true)
)

# @test norm(x - x_c) < 1E-3
Expand Down

0 comments on commit 812de92

Please sign in to comment.