Skip to content

Commit

Permalink
Start moving ITensorTDVP.jl codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Oct 23, 2024
1 parent 286eed2 commit 98d1237
Show file tree
Hide file tree
Showing 48 changed files with 2,497 additions and 25 deletions.
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.3.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ITensorTDVP = "25707e16-a4db-4a07-99d9-4d67b7af0342"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Expand All @@ -18,10 +17,15 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SerializedElementArrays = "d3ce8812-9567-47e9-a7b5-65a6d70a3065"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
ITensorMPSChainRulesCoreExt = "ChainRulesCore"

[compat]
Adapt = "4.1.0"
Compat = "4.16.0"
ITensorTDVP = "0.4.1"
ITensors = "0.7"
IsApprox = "2.0.0"
KrylovKit = "0.8.1"
Expand Down
3 changes: 0 additions & 3 deletions examples/solvers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,3 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
ITensors = "0.6.7"
2 changes: 1 addition & 1 deletion ext/ITensorMPSChainRulesCoreExt/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Adapt: adapt
using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig, rrule_via_ad
using ITensors:
ITensors, ITensor, dag, hassameinds, inds, itensor, mapprime, replaceprime, swapprime
using ITensors.ITensorMPS: ITensorMPS, MPO, MPS, apply, inner, siteinds
using ITensorMPS: ITensorMPS, MPO, MPS, apply, inner, siteinds
using NDTensors: datatype

function ChainRulesCore.rrule(
Expand Down
2 changes: 1 addition & 1 deletion ext/ITensorMPSChainRulesCoreExt/indexset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using ITensors:
replacetags,
setprime,
settags
using ITensors.ITensorMPS: MPO, MPS
using ITensorMPS: MPO, MPS

for fname in (
:prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags
Expand Down
2 changes: 1 addition & 1 deletion ext/ITensorMPSChainRulesCoreExt/mpo.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ChainRulesCore: ChainRulesCore, NoTangent
using ITensors: Algorithm, contract, hassameinds, inner, mapprime
using ITensors.ITensorMPS: MPO, MPS, firstsiteinds, siteinds
using ITensorMPS: MPO, MPS, firstsiteinds, siteinds
using LinearAlgebra: tr

function ChainRulesCore.rrule(
Expand Down
2 changes: 1 addition & 1 deletion ext/ITensorMPSChainRulesCoreExt/mps.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ChainRulesCore: @non_differentiable
using ITensors: Index
using ITensors.ITensorMPS: MPS
using ITensorMPS: MPS
@non_differentiable MPS(::Type{<:Number}, sites::Vector{<:Index}, states_)
9 changes: 9 additions & 0 deletions ext/ITensorMPSObserversExt/ITensorMPSObserversExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module ITensorMPSObserversExt
using Observers: Observers
using Observers.DataFrames: AbstractDataFrame
using ITensorMPS: ITensorMPS

function ITensorMPS.update_observer!(observer::AbstractDataFrame; kwargs...)
return Observers.update!(observer; kwargs...)
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module ITensorsPackageCompilerExt
include("compile.jl")
end
26 changes: 26 additions & 0 deletions ext/ITensorMPSPackageCompilerExt/compile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using NDTensors: @Algorithm_str
using ITensors: ITensors
using PackageCompiler: PackageCompiler

function ITensors.compile(
::Algorithm"PackageCompiler";
dir::AbstractString=ITensors.default_compile_dir(),
filename::AbstractString=ITensors.default_compile_filename(),
)
if !isdir(dir)
println("""The directory "$dir" doesn't exist yet, creating it now.""")
println()
mkdir(dir)
end
path = joinpath(dir, filename)
println(
"""Creating the system image "$path" containing the compiled version of ITensorMPS. This may take a few minutes.""",
)
PackageCompiler.create_sysimage(
:ITensorMPS;
sysimage_path=path,
precompile_execution_file=joinpath(@__DIR__, "precompile_itensormps.jl"),
)
println(ITensors.compile_note(; dir, filename))
return path
end
28 changes: 28 additions & 0 deletions ext/ITensorMPSPackageCompilerExt/precompile_itensormps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using ITensorMPS: MPO, OpSum, dmrg, random_mps, siteinds

# TODO: This uses all of the tests to make
# precompile statements, but takes a long time
# (e.g. 700 seconds).
# Try again with later versions of PackageCompiler
#
# include(joinpath(joinpath(dirname(dirname(@__DIR__)),
# test"),
# "runtests.jl"))

function main(; N, dmrg_kwargs)
opsum = OpSum()
for j in 1:(N - 1)
opsum += 0.5, "S+", j, "S-", j + 1
opsum += 0.5, "S-", j, "S+", j + 1
opsum += "Sz", j, "Sz", j + 1
end
for conserve_qns in (false, true)
sites = siteinds("S=1", N; conserve_qns)
H = MPO(opsum, sites)
ψ0 = random_mps(sites, j -> isodd(j) ? "" : ""; linkdims=2)
dmrg(H, ψ0; outputlevel=0, dmrg_kwargs...)
end
return nothing
end

main(; N=6, dmrg_kwargs=(; nsweeps=3, maxdim=10, cutoff=1e-13))
20 changes: 20 additions & 0 deletions src/ITensorMPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,24 @@ include("defaults.jl")
include("update_observer.jl")
include("lattices/lattices.jl")
export Lattice, LatticeBond, square_lattice, triangular_lattice
export TimeDependentSum, dmrg_x, expand, linsolve, tdvp, to_vec
include("solvers/ITensorsExtensions.jl")
using .ITensorsExtensions: to_vec
include("solvers/applyexp.jl")
include("solvers/defaults.jl")
include("solvers/update_observer.jl")
include("solvers/timedependentsum.jl")
include("solvers/tdvporder.jl")
include("solvers/sweep_update.jl")
include("solvers/alternating_update.jl")
include("solvers/tdvp.jl")
include("solvers/dmrg.jl")
include("solvers/dmrg_x.jl")
include("solvers/reducedcontractproblem.jl")
include("solvers/contract.jl")
include("solvers/reducedconstantterm.jl")
include("solvers/reducedlinearproblem.jl")
include("solvers/linsolve.jl")
include("solvers/expand.jl")
include("lib/Experimental/src/Experimental.jl")
end
2 changes: 1 addition & 1 deletion src/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,7 @@ function setindex!(
# into MPS tensors
firstsite = first(r)
lastsite = last(r)
@assert firstsite ITensors.orthocenter(ψ) lastsite
@assert firstsite ITensorMPS.orthocenter(ψ) lastsite
@assert firstsite leftlim(ψ) + 1
@assert rightlim(ψ) - 1 lastsite

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module Experimental
using ITensorTDVP: dmrg
include("dmrg.jl")
end
17 changes: 17 additions & 0 deletions src/lib/Experimental/src/dmrg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using ..ITensorMPS:
MPS,
alternating_update,
compose_observers,
default_observer,
eigsolve_updater,
values_observer

function dmrg(
operator, init::MPS; updater=eigsolve_updater, (observer!)=default_observer(), kwargs...
)
info_ref! = Ref{Any}()
info_observer! = values_observer(; info=info_ref!)
observer! = compose_observers(observer!, info_observer!)
state = alternating_update(operator, init; updater, observer!, kwargs...)
return info_ref![].eigval, state
end
9 changes: 9 additions & 0 deletions src/solvers/ITensorsExtensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module ITensorsExtensions
using ITensors: ITensor, array, inds, itensor
function to_vec(x::ITensor)
function to_itensor(x_vec)
return itensor(x_vec, inds(x))
end
return vec(array(x)), to_itensor
end
end
117 changes: 117 additions & 0 deletions src/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
using ITensors: ITensors, permute

function _extend_sweeps_param(param, nsweeps)
if param isa Number
eparam = fill(param, nsweeps)
else
length(param) == nsweeps && return param
eparam = Vector(undef, nsweeps)
eparam[1:length(param)] = param
eparam[(length(param) + 1):end] .= param[end]
end
return eparam
end

function process_sweeps(; nsweeps, maxdim, mindim, cutoff, noise)
maxdim = _extend_sweeps_param(maxdim, nsweeps)
mindim = _extend_sweeps_param(mindim, nsweeps)
cutoff = _extend_sweeps_param(cutoff, nsweeps)
noise = _extend_sweeps_param(noise, nsweeps)
return (; maxdim, mindim, cutoff, noise)
end

function alternating_update(
operator,
init::MPS;
updater,
updater_kwargs=(;),
nsweeps=default_nsweeps(),
checkdone=default_checkdone(),
write_when_maxdim_exceeds=default_write_when_maxdim_exceeds(),
nsite=default_nsite(),
reverse_step=default_reverse_step(),
time_start=default_time_start(),
time_step=default_time_step(),
order=default_order(),
(observer!)=default_observer(),
(sweep_observer!)=default_sweep_observer(),
outputlevel=default_outputlevel(),
normalize=default_normalize(),
maxdim=default_maxdim(),
mindim=default_mindim(),
cutoff=default_cutoff(ITensors.scalartype(init)),
noise=default_noise(),
)
reduced_operator = ITensorMPS.reduced_operator(operator)
if isnothing(nsweeps)
return error("Must specify `nsweeps`.")
end
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, maxdim, mindim, cutoff, noise)
forward_order = TDVPOrder(order, Base.Forward)
state = copy(init)
# Keep track of the start of the current time step.
# Helpful for tracking the total time, for example
# when using time-dependent updaters.
# This will be passed as a keyword argument to the
# `updater`.
current_time = time_start
info = nothing
for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim(sweeps, sw) = $(maxdim(sweeps, sweep)), writing environment tensors to disk",
)
end
reduced_operator = disk(reduced_operator)
end
sweep_elapsed_time = @elapsed begin
state, reduced_operator, info = sweep_update(
forward_order,
reduced_operator,
state;
updater,
updater_kwargs,
nsite,
current_time,
time_step,
reverse_step,
sweep,
observer!,
normalize,
outputlevel,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
)
end
if !isnothing(time_step)
current_time += time_step
end
update_observer!(
sweep_observer!; state, reduced_operator, sweep, outputlevel, current_time
)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(state))
@printf(" maxerr=%.2E", info.maxtruncerr)
if !isnothing(current_time)
print(" current_time=", round(current_time; digits=3))
end
print(" time=", round(sweep_elapsed_time; digits=3))
println()
flush(stdout)
end
isdone = checkdone(;
state, sweep, outputlevel, observer=observer!, sweep_observer=sweep_observer!
)
isdone && break
end
return state
end

# Assume it is already in a reduced basis.
reduced_operator(operator) = operator
reduced_operator(operators::Vector{MPO}) = ProjMPOSum(operators)
reduced_operator(operator::MPO) = ProjMPO(operator)
Loading

0 comments on commit 98d1237

Please sign in to comment.