Skip to content

Commit

Permalink
Format tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss committed Jan 21, 2024
1 parent ac1a923 commit 2748927
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 55 deletions.
3 changes: 1 addition & 2 deletions test/test_treetensornetworks/test_solvers/test_dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ end
sweep(; which_sweep, kw...) = which_sweep
sweep_observer! = observer(sweep)

region(; which_region_update, sweep_plan, kw...) =
first(sweep_plan[which_region_update])
region(; which_region_update, sweep_plan, kw...) = first(sweep_plan[which_region_update])
energy(; eigvals, kw...) = eigvals[1]
region_observer! = observer(region, sweep, energy)

Expand Down
12 changes: 4 additions & 8 deletions test/test_treetensornetworks/test_solvers/test_dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ using Test

ψ = mps(s; states=(v -> rand(["", ""])))

dmrg_x_kwargs = (
nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0
)
dmrg_x_kwargs = (nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0)

ϕ = dmrg_x(H, ψ; nsites=2, dmrg_x_kwargs...)

Expand All @@ -34,7 +32,7 @@ using Test
# @test abs(loginner(ϕ̃, ϕ) / n) ≈ 0.0 atol = 1e-6
end

@testset "Tree DMRG-X" for conserve_qns in (false, true)
@testset "Tree DMRG-X" for conserve_qns in (false, true)
tooth_lengths = fill(2, 3)
root_vertex = (3, 2)
c = named_comb_tree(tooth_lengths)
Expand All @@ -51,10 +49,8 @@ end
# TODO: Use `TTN(s; states=v -> rand(["↑", "↓"]))` or
# `ttns(s; states=v -> rand(["↑", "↓"]))`
ψ = normalize!(TTN(s, v -> rand(["", ""])))

dmrg_x_kwargs = (
nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0
)

dmrg_x_kwargs = (nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0)

ϕ = dmrg_x(H, ψ; nsites=2, dmrg_x_kwargs...)

Expand Down
55 changes: 37 additions & 18 deletions test/test_treetensornetworks/test_solvers/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ using Test
state;
cutoff,
normalize=false,
updater_kwargs=(;
tol=1e-12,
maxiter=500,
krylovdim=25,)
updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25),
)
# TODO: What should `expect` output? Right now
# it outputs a dictionary.
Expand All @@ -250,10 +247,7 @@ using Test
psi2;
cutoff,
normalize=false,
updater_kwargs=(;
tol=1e-12,
maxiter=500,
krylovdim=25,),
updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25),
updater=exponentiate_updater,
)
# TODO: What should `expect` output? Right now
Expand Down Expand Up @@ -318,7 +312,14 @@ using Test

nsites = (step <= 3 ? 2 : 1)
phi = tdvp(
H, -tau * im, phi; nsteps=1, cutoff, nsites, normalize=true, updater_kwargs=(;krylovdim=15)
H,
-tau * im,
phi;
nsteps=1,
cutoff,
nsites,
normalize=true,
updater_kwargs=(; krylovdim=15),
)

Sz1[step] = real(expect("Sz", state; vertices=[c])[c])
Expand Down Expand Up @@ -378,7 +379,14 @@ using Test
for (step, t) in enumerate(trange)
nsites = (step <= 10 ? 2 : 1)
state = tdvp(
H, -tau, state; cutoff, nsites, reverse_step, normalize=true, updater_kwargs=(;krylovdim=15)
H,
-tau,
state;
cutoff,
nsites,
reverse_step,
normalize=true,
updater_kwargs=(; krylovdim=15),
)
#Different backend updaters, default updater_backend = "applyexp"
psi2 = tdvp(
Expand All @@ -389,7 +397,7 @@ using Test
nsites,
reverse_step,
normalize=true,
updater_kwargs=(;krylovdim=15),
updater_kwargs=(; krylovdim=15),
updater=ITensorNetworks.exponentiate_updater,
)
end
Expand Down Expand Up @@ -588,7 +596,7 @@ end
# Should rotate back to original state:
@test abs(inner(ψ0, ψ2)) > 0.99
end
=#
=#
@testset "Accuracy Test" begin
tau = 0.1
ttotal = 1.0
Expand Down Expand Up @@ -625,10 +633,7 @@ end
state;
cutoff,
normalize=false,
updater_kwargs=(;
tol=1e-12,
maxiter=500,
krylovdim=25,)
updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25),
)
push!(Sz_tdvp, real(expect("Sz", state; vertices=[c])[c]))
push!(Sz_exact, real(scalar(dag(prime(statex, s[c])) * Szc * statex)))
Expand Down Expand Up @@ -686,7 +691,14 @@ end

nsites = (step <= 3 ? 2 : 1)
phi = tdvp(
H, -tau * im, phi; nsteps=1, cutoff, nsites, normalize=true, updater_kwargs=(;krylovdim=15)
H,
-tau * im,
phi;
nsteps=1,
cutoff,
nsites,
normalize=true,
updater_kwargs=(; krylovdim=15),
)

Sz1[step] = real(expect("Sz", state; vertices=[c])[c])
Expand Down Expand Up @@ -737,7 +749,14 @@ end
for (step, t) in enumerate(trange)
nsites = (step <= 10 ? 2 : 1)
state = tdvp(
H, -tau, state; cutoff, nsites, reverse_step, normalize=true, updater_kwargs=(;krylovdim=15)
H,
-tau,
state;
cutoff,
nsites,
reverse_step,
normalize=true,
updater_kwargs=(; krylovdim=15),
)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ode_kwargs = (; reltol=1e-8, abstol=1e-8)

ω⃗ = [ω₁, ω₂]
f⃗ = [t -> cos* t) for ω in ω⃗]
ode_updater_kwargs=(;f=f⃗,solver_alg=ode_alg,ode_kwargs)
ode_updater_kwargs = (; f=f⃗, solver_alg=ode_alg, ode_kwargs)

function ode_updater(
init;
Expand All @@ -36,12 +36,12 @@ function ode_updater(
region_kwargs,
updater_kwargs,
)
time_step=region_kwargs.time_step
f⃗=updater_kwargs.f
ode_kwargs=updater_kwargs.ode_kwargs
solver_alg=updater_kwargs.solver_alg
H⃗₀=projected_operator![]
result, info = ode_solver(
time_step = region_kwargs.time_step
f⃗ = updater_kwargs.f
ode_kwargs = updater_kwargs.ode_kwargs
solver_alg = updater_kwargs.solver_alg
H⃗₀ = projected_operator![]
result, info = ode_solver(
-im * TimeDependentSum(f⃗, H⃗₀), time_step, init; solver_alg, ode_kwargs...
)
return result, (; info)
Expand All @@ -54,9 +54,8 @@ function tdvp_ode_solver(H⃗₀, ψ₀; time_step, kwargs...)
return psi_t, (; info)
end


krylov_kwargs = (; tol=1e-8, eager=true)
krylov_updater_kwargs=(;f=f⃗,krylov_kwargs)
krylov_updater_kwargs = (; f=f⃗, krylov_kwargs)

function krylov_solver(H⃗₀, ψ₀; time_step, ishermitian=false, issymmetric=false, kwargs...)
psi_t, info = krylov_solver(
Expand All @@ -81,20 +80,17 @@ function krylov_updater(
region_kwargs,
updater_kwargs,
)
default_updater_kwargs = (;
ishermitian=false,
issymmetric=false,
)
default_updater_kwargs = (; ishermitian=false, issymmetric=false)

updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedenc
time_step=region_kwargs.time_step
f⃗=updater_kwargs.f
krylov_kwargs=updater_kwargs.krylov_kwargs
ishermitian=updater_kwargs.ishermitian
issymmetric=updater_kwargs.issymmetric
H⃗₀=projected_operator![]

result, info = krylov_solver(
time_step = region_kwargs.time_step
f⃗ = updater_kwargs.f
krylov_kwargs = updater_kwargs.krylov_kwargs
ishermitian = updater_kwargs.ishermitian
issymmetric = updater_kwargs.issymmetric
H⃗₀ = projected_operator![]

result, info = krylov_solver(
-im * TimeDependentSum(f⃗, H⃗₀),
time_step,
init;
Expand All @@ -105,7 +101,6 @@ function krylov_updater(
return result, (; info)
end


@testset "MPS: Time dependent Hamiltonian" begin
n = 4
J₁ = 1.0
Expand All @@ -126,9 +121,28 @@ end

ψ₀ = complex(mps(s; states=(j -> isodd(j) ? "" : "")))

ψₜ_ode = tdvp(ode_updater, H⃗₀, time_total, ψ₀; time_step, maxdim, cutoff, nsites, updater_kwargs=ode_updater_kwargs)
ψₜ_ode = tdvp(
ode_updater,
H⃗₀,
time_total,
ψ₀;
time_step,
maxdim,
cutoff,
nsites,
updater_kwargs=ode_updater_kwargs,
)

ψₜ_krylov = tdvp(krylov_updater, H⃗₀, time_total, ψ₀; time_step, cutoff, nsites, updater_kwargs=krylov_updater_kwargs)
ψₜ_krylov = tdvp(
krylov_updater,
H⃗₀,
time_total,
ψ₀;
time_step,
cutoff,
nsites,
updater_kwargs=krylov_updater_kwargs,
)

ψₜ_full, _ = tdvp_ode_solver(contract.(H⃗₀), contract(ψ₀); time_step=time_total)

Expand All @@ -145,7 +159,6 @@ end
@test krylov_err < 1e-2
end


@testset "TTN: Time dependent Hamiltonian" begin
tooth_lengths = fill(2, 3)
root_vertex = (3, 2)
Expand All @@ -170,9 +183,28 @@ end

ψ₀ = TTN(ComplexF64, s, v -> iseven(sum(isodd.(v))) ? "" : "")

ψₜ_ode = tdvp(ode_updater, H⃗₀, time_total, ψ₀; time_step, maxdim, cutoff, nsites, updater_kwargs=ode_updater_kwargs)
ψₜ_ode = tdvp(
ode_updater,
H⃗₀,
time_total,
ψ₀;
time_step,
maxdim,
cutoff,
nsites,
updater_kwargs=ode_updater_kwargs,
)

ψₜ_krylov = tdvp(krylov_updater, H⃗₀, time_total, ψ₀; time_step, cutoff, nsites, updater_kwargs=krylov_updater_kwargs)
ψₜ_krylov = tdvp(
krylov_updater,
H⃗₀,
time_total,
ψ₀;
time_step,
cutoff,
nsites,
updater_kwargs=krylov_updater_kwargs,
)

ψₜ_full, _ = tdvp_ode_solver(contract.(H⃗₀), contract(ψ₀); time_step=time_total)

Expand Down

0 comments on commit 2748927

Please sign in to comment.