Skip to content

Commit

Permalink
Make it easier to run example 3 on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jun 2, 2024
1 parent 3cc10c1 commit 1a09309
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
45 changes: 32 additions & 13 deletions examples/03_tdvp_time_dependent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,24 @@ using Random: Random
include("03_models.jl")
include("03_updaters.jl")

function main()
"""
Run the example on CPU:
```julia
main()
```
Run the example on CPU with single precision:
```julia
main(; eltype=Float32)
```
Run the example on GPU:
```julia
using CUDA: cu
main(; eltype=Float32, device=cu)
```
"""
function main(; eltype=Float64, device=identity)
Random.seed!(1234)

# Time dependent Hamiltonian is:
Expand All @@ -24,16 +41,16 @@ function main()
outputlevel = 3

# Frequency of time dependent terms
ω₁ = 0.1
ω₂ = 0.2
ω₁ = one(eltype) / 10
ω₂ = one(eltype) / 5

# Nearest and next-nearest neighbor
# Heisenberg couplings.
J₁ = 1.0
J₂ = 1.0
J₁ = one(eltype)
J₂ = one(eltype)

time_step = 0.1
time_stop = 1.0
time_step = one(eltype) / 10
time_stop = one(eltype)

# nsite-update TDVP
nsite = 2
Expand All @@ -46,9 +63,9 @@ function main()

# TDVP truncation parameters
maxdim = 100
cutoff = 1e-8
cutoff = (eps(eltype))

tol = 1e-15
tol = 10 * eps(eltype)

@show n
@show ω₁, ω₂
Expand All @@ -61,18 +78,20 @@ function main()
f⃗ = map-> (t -> cos* t)), ω⃗)

# H₀ = H(0) = H₁(0) + H₂(0) + …
ℋ₁₀ = heisenberg(n; J=J₁, J2=0.0)
ℋ₂₀ = heisenberg(n; J=0.0, J2=J₂)
ℋ₁₀ = heisenberg(n; J=J₁, J2=zero(eltype))
ℋ₂₀ = heisenberg(n; J=zero(eltype), J2=J₂)
ℋ⃗₀ = (ℋ₁₀, ℋ₂₀)

s = siteinds("S=1/2", n)

H⃗₀ = map(ℋ₀ -> MPO(ℋ₀, s), ℋ⃗₀)
H⃗₀ = map(ℋ₀ -> device(MPO(eltype, ℋ₀, s)), ℋ⃗₀)

# Initial state, ψ₀ = ψ(0)
# Initialize as complex since that is what OrdinaryDiffEq.jl/DifferentialEquations.jl
# expects.
ψ₀ = complex.(random_mps(s, j -> isodd(j) ? "" : ""; linkdims=start_linkdim))
ψ₀ = device(
complex.(random_mps(eltype, s, j -> isodd(j) ? "" : ""; linkdims=start_linkdim))
)

@show norm(ψ₀)

Expand Down
2 changes: 1 addition & 1 deletion examples/03_updaters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function ode_updater(operator, init; internal_kwargs, alg=Tsit5(), kwargs...)
time_span = typeof(time_step).((current_time, current_time + time_step))
init_vec, to_itensor = to_vec(init)
f(init::ITensor, p, t) = operator(t)(init)
f(init_vec::Vector, p, t) = to_vec(f(to_itensor(init_vec), p, t))[1]
f(init_vec::AbstractArray, p, t) = to_vec(f(to_itensor(init_vec), p, t))[1]
prob = ODEProblem(f, init_vec, time_span)
sol = solve(prob, alg; kwargs...)
state_vec = sol.u[end]
Expand Down

0 comments on commit 1a09309

Please sign in to comment.