Skip to content

Commit

Permalink
Change required solver function signature. Passes psi and projected H…
Browse files Browse the repository at this point in the history
…amiltonian to solver by reference, passes more information about sweep to solver.
  • Loading branch information
b-kloss committed Jan 11, 2024
1 parent 0b72b7c commit ac76123
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
25 changes: 20 additions & 5 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
function exponentiate_solver()
function solver(
H,
init;
psi_ref!,
PH_ref!,
ishermitian=true,
issymmetric=true,
region,
sweep_regions,
sweep_step,
solver_krylovdim=30,
solver_maxiter=100,
solver_outputlevel=0,
Expand All @@ -13,7 +16,15 @@ function exponentiate_solver()
normalize,
time_step,
)
psi, exp_info = KrylovKit.exponentiate(
#H=copy(PH_ref![])
H=PH_ref![] ###since we are not changing H we don't need the copy

Check warning on line 20 in src/treetensornetworks/solvers/tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/tdvp.jl:20:- H=PH_ref![] ###since we are not changing H we don't need the copy src/treetensornetworks/solvers/tdvp.jl:20:+ H = PH_ref![] ###since we are not changing H we don't need the copy
# let's test whether given region and sweep regions we can find out what the previous and next region were
# this will be needed in subspace expansion
region_ind=sweep_step

Check warning on line 23 in src/treetensornetworks/solvers/tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/tdvp.jl:23:- region_ind=sweep_step src/treetensornetworks/solvers/tdvp.jl:24:- next_region=region_ind==length(sweep_regions) ? nothing : first(sweep_regions[region_ind+1]) src/treetensornetworks/solvers/tdvp.jl:25:- previous_region=region_ind==1 ? nothing : first(sweep_regions[region_ind-1]) src/treetensornetworks/solvers/tdvp.jl:26:- src/treetensornetworks/solvers/tdvp.jl:23:+ region_ind = sweep_step src/treetensornetworks/solvers/tdvp.jl:24:+ next_region = src/treetensornetworks/solvers/tdvp.jl:25:+ region_ind == length(sweep_regions) ? nothing : first(sweep_regions[region_ind + 1]) src/treetensornetworks/solvers/tdvp.jl:26:+ previous_region = region_ind == 1 ? nothing : first(sweep_regions[region_ind - 1]) src/treetensornetworks/solvers/tdvp.jl:27:+
next_region=region_ind==length(sweep_regions) ? nothing : first(sweep_regions[region_ind+1])
previous_region=region_ind==1 ? nothing : first(sweep_regions[region_ind-1])

phi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
Expand All @@ -25,23 +36,27 @@ function exponentiate_solver()
verbosity=solver_outputlevel,
eager=true,
)
return psi, (; info=exp_info)
return phi, (; info=exp_info)
end
return solver
end

function applyexp_solver()
function solver(
H,
init;
tdvp_order,
psi_ref!,
PH_ref!,
region,
sweep_regions,
sweep_step,
solver_krylovdim=30,
solver_outputlevel=0,
solver_tol=1E-8,
substep,
time_step,
normalize,
)
H=PH_ref![]

Check warning on line 59 in src/treetensornetworks/solvers/tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/tdvp.jl:59:- H=PH_ref![] src/treetensornetworks/solvers/tdvp.jl:60:+ H = PH_ref![]
#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(
Expand Down
6 changes: 5 additions & 1 deletion src/treetensornetworks/solvers/update_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ function local_update(
nsites = (region isa AbstractEdge) ? 0 : length(region)
PH = set_nsite(PH, nsites)
PH = position(PH, psi, region)
phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...)
(psi_ref!) = Ref(psi) # create references, in case solver does (out-of-place) modify PH or psi
(PH_ref!) = Ref(PH)
phi, info = solver(phi;(psi_ref!),(PH_ref!), normalize, region, sweep_regions, sweep_step, step_kwargs..., solver_kwargs...) # args passed by reference are supposed to be modified out of place

Check warning on line 187 in src/treetensornetworks/solvers/update_step.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/update_step.jl:187:- phi, info = solver(phi;(psi_ref!),(PH_ref!), normalize, region, sweep_regions, sweep_step, step_kwargs..., solver_kwargs...) # args passed by reference are supposed to be modified out of place src/treetensornetworks/solvers/update_step.jl:187:+ phi, info = solver( src/treetensornetworks/solvers/update_step.jl:188:+ phi; src/treetensornetworks/solvers/update_step.jl:189:+ (psi_ref!), src/treetensornetworks/solvers/update_step.jl:190:+ (PH_ref!), src/treetensornetworks/solvers/update_step.jl:191:+ normalize, src/treetensornetworks/solvers/update_step.jl:192:+ region, src/treetensornetworks/solvers/update_step.jl:193:+ sweep_regions, src/treetensornetworks/solvers/update_step.jl:194:+ sweep_step, src/treetensornetworks/solvers/update_step.jl:195:+ step_kwargs..., src/treetensornetworks/solvers/update_step.jl:196:+ solver_kwargs..., src/treetensornetworks/solvers/update_step.jl:197:+ ) # args passed by reference are supposed to be modified out of place
psi = psi_ref![] # dereference
PH = PH_ref![]
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
Expand Down
6 changes: 4 additions & 2 deletions test/test_treetensornetworks/test_solvers/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,11 @@ using Test

ψ0 = random_mps(s; internal_inds_space=10)

function solver(PH, psi0; time_step, kwargs...)
function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...)

Check warning on line 166 in test/test_treetensornetworks/test_solvers/test_tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/test_treetensornetworks/test_solvers/test_tdvp.jl:166:- function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...) test/test_treetensornetworks/test_solvers/test_tdvp.jl:166:+ function solver(psi0; (psi_ref!), (PH_ref!), time_step, kwargs...)
solver_kwargs = (;
ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true
)
PH=PH_ref![]

Check warning on line 170 in test/test_treetensornetworks/test_solvers/test_tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/test_treetensornetworks/test_solvers/test_tdvp.jl:170:- PH=PH_ref![] test/test_treetensornetworks/test_solvers/test_tdvp.jl:170:+ PH = PH_ref![]
psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...)
return psi, (; info=exp_info)
end
Expand Down Expand Up @@ -551,10 +552,11 @@ end

ψ0 = normalize!(random_ttn(s; link_space=10))

function solver(PH, psi0; time_step, kwargs...)
function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...)

Check warning on line 555 in test/test_treetensornetworks/test_solvers/test_tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/test_treetensornetworks/test_solvers/test_tdvp.jl:555:- function solver(psi0; (psi_ref!),(PH_ref!), time_step, kwargs...) test/test_treetensornetworks/test_solvers/test_tdvp.jl:555:+ function solver(psi0; (psi_ref!), (PH_ref!), time_step, kwargs...)
solver_kwargs = (;
ishermitian=true, tol=1e-12, krylovdim=30, maxiter=100, verbosity=0, eager=true
)
PH=PH_ref![]

Check warning on line 559 in test/test_treetensornetworks/test_solvers/test_tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: test/test_treetensornetworks/test_solvers/test_tdvp.jl:559:- PH=PH_ref![] test/test_treetensornetworks/test_solvers/test_tdvp.jl:559:+ PH = PH_ref![]
psi, exp_info = exponentiate(PH, time_step, psi0; solver_kwargs...)
return psi, (; info=exp_info)
end
Expand Down

0 comments on commit ac76123

Please sign in to comment.