Skip to content

Commit

Permalink
Refactor sweeps interface (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss authored Mar 18, 2024
1 parent 5dde64c commit a4f3592
Show file tree
Hide file tree
Showing 34 changed files with 1,310 additions and 989 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrayKit = "a9a3c162-d163-4c15-8926-b8794fbefed2"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructWalk = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Expand Down
32 changes: 18 additions & 14 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using SplitApplyCombine
using StaticArrays
using Suppressor
using TimerOutputs
using StructWalk: StructWalk, WalkStyle, postwalk

using DataGraphs: IsUnderlyingGraph, edge_data_type, vertex_data_type
using Graphs: AbstractEdge, AbstractGraph, Graph, add_edge!
Expand Down Expand Up @@ -107,27 +108,30 @@ include("tensornetworkoperators.jl")
include(joinpath("ITensorsExt", "itensorutils.jl"))
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("solvers", "local_solvers", "eigsolve.jl"))
include(joinpath("solvers", "local_solvers", "exponentiate.jl"))
include(joinpath("solvers", "local_solvers", "dmrg_x.jl"))
include(joinpath("solvers", "local_solvers", "contract.jl"))
include(joinpath("solvers", "local_solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
include(joinpath("treetensornetworks", "projttns", "abstractprojttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projouterprodttn.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
include(joinpath("treetensornetworks", "solvers", "contract.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include(joinpath("solvers", "solver_utils.jl"))
include(joinpath("solvers", "defaults.jl"))
include(joinpath("solvers", "insert", "insert.jl"))
include(joinpath("solvers", "extract", "extract.jl"))
include(joinpath("solvers", "alternating_update", "alternating_update.jl"))
include(joinpath("solvers", "alternating_update", "region_update.jl"))
include(joinpath("solvers", "tdvp.jl"))
include(joinpath("solvers", "dmrg.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("solvers", "sweep_plans", "sweep_plans.jl"))
include("apply.jl")

include("exports.jl")
Expand Down
160 changes: 160 additions & 0 deletions src/solvers/alternating_update/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
function alternating_update(
operator,
init_state::AbstractTTN;
nsweeps, # define default for each solver implementation
nsites, # define default for each level of solver implementation
updater, # this specifies the update performed locally
outputlevel=default_outputlevel(),
region_printer=nothing,
sweep_printer=nothing,
(sweep_observer!)=nothing,
(region_observer!)=nothing,
root_vertex=default_root_vertex(init_state),
extracter_kwargs=(;),
extracter=default_extracter(),
updater_kwargs=(;),
inserter_kwargs=(;),
inserter=default_inserter(),
transform_operator_kwargs=(;),
transform_operator=default_transform_operator(),
kwargs...,
)
inserter_kwargs = (; inserter_kwargs..., kwargs...)
sweep_plans = default_sweep_plans(
nsweeps,
init_state;
root_vertex,
extracter,
extracter_kwargs,
updater,
updater_kwargs,
inserter,
inserter_kwargs,
transform_operator,
transform_operator_kwargs,
nsites,
)
return alternating_update(
operator,
init_state,
sweep_plans;
outputlevel,
sweep_observer!,
region_observer!,
sweep_printer,
region_printer,
)
end

function alternating_update(
projected_operator,
init_state::AbstractTTN,
sweep_plans;
outputlevel=default_outputlevel(),
checkdone=default_checkdone(), #
(sweep_observer!)=nothing,
sweep_printer=default_sweep_printer,#?
(region_observer!)=nothing,
region_printer=nothing,
)
state = copy(init_state)
@assert !isnothing(sweep_plans)
for which_sweep in eachindex(sweep_plans)
sweep_plan = sweep_plans[which_sweep]

sweep_time = @elapsed begin
for which_region_update in eachindex(sweep_plan)
state, projected_operator = region_update(
projected_operator,
state;
which_sweep,
sweep_plan,
region_printer,
(region_observer!),
which_region_update,
outputlevel,
)
end
end

update!(sweep_observer!; state, which_sweep, sweep_time, outputlevel, sweep_plans)
!isnothing(sweep_printer) &&
sweep_printer(; state, which_sweep, sweep_time, outputlevel, sweep_plans)
checkdone(;
state,
which_sweep,
outputlevel,
sweep_plan,
sweep_plans,
sweep_observer!,
region_observer!,
) && break
end
return state
end

function alternating_update(operator::AbstractTTN, init_state::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
operator = ITensors.permute(operator, (linkind, siteinds, linkind))
projected_operator = ProjTTN(operator)
return alternating_update(projected_operator, init_state; kwargs...)
end

function alternating_update(
operator::AbstractTTN, init_state::AbstractTTN, sweep_plans; kwargs...
)
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
operator = ITensors.permute(operator, (linkind, siteinds, linkind))
projected_operator = ProjTTN(operator)
return alternating_update(projected_operator, init_state, sweep_plans; kwargs...)
end

#ToDo: Fix docstring.
"""
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*init_state` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.
This version of `tdvp` accepts a representation of H as a
Vector of MPOs, Hs = [H1,H2,H3,...] such that H is defined
as H = H1+H2+H3+...
Note that this sum of MPOs is not actually computed; rather
the set of MPOs [H1,H2,H3,..] is efficiently looped over at
each step of the algorithm when optimizing the MPS.
Returns:
* `state::MPS` - time-evolved MPS
"""
function alternating_update(
operators::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs...
)
for operator in operators
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
end
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind)))
projected_operators = ProjTTNSum(operators)
return alternating_update(projected_operators, init_state; kwargs...)
end

function alternating_update(
operators::Vector{<:AbstractTTN}, init_state::AbstractTTN, sweep_plans; kwargs...
)
for operator in operators
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
end
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind)))
projected_operators = ProjTTNSum(operators)
return alternating_update(projected_operators, init_state, sweep_plans; kwargs...)
end
129 changes: 129 additions & 0 deletions src/solvers/alternating_update/region_update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#ToDo: generalize beyond 2-site
#ToDo: remove concept of orthogonality center for generality
function current_ortho(sweep_plan, which_region_update)
regions = first.(sweep_plan)
region = regions[which_region_update]
current_verts = support(region)
if !isa(region, AbstractEdge) && length(region) == 1
return only(current_verts)
end
if which_region_update == length(regions)
# look back by one should be sufficient, but may be brittle?
overlapping_vertex = only(
intersect(current_verts, support(regions[which_region_update - 1]))
)
return overlapping_vertex
else
# look forward
other_regions = filter(
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
)
# find the first region that has overlapping support with current region
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
if isnothing(ind)
# look backward
other_regions = reverse(
filter(
x -> !(issetequal(x, current_verts)),
support.(regions[1:(which_region_update - 1)]),
),
)
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
end
@assert !isnothing(ind)
future_verts = union(support(other_regions[ind]))
# return ortho_ceter as the vertex in current region that does not overlap with following one
overlapping_vertex = intersect(current_verts, future_verts)
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
return nonoverlapping_vertex
end
end

function region_update(
projected_operator,
state;
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_printer,
(region_observer!),
)
(region, region_kwargs) = sweep_plan[which_region_update]
(;
extracter,
extracter_kwargs,
updater,
updater_kwargs,
inserter,
inserter_kwargs,
transform_operator,
transform_operator_kwargs,
internal_kwargs,
) = region_kwargs

# ToDo: remove orthogonality center on vertex for generality
# region carries same information
ortho_vertex = current_ortho(sweep_plan, which_region_update)
if !isnothing(transform_operator)
projected_operator = transform_operator(
state, projected_operator; outputlevel, transform_operator_kwargs...
)
end
state, projected_operator, phi = extracter(
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
)
# create references, in case solver does (out-of-place) modify PH or state
state! = Ref(state)
projected_operator! = Ref(projected_operator)
# args passed by reference are supposed to be modified out of place
phi, info = updater(
phi;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
updater_kwargs...,
internal_kwargs,
)
state = state![]
projected_operator = projected_operator![]
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
end
# ToDo: implement noise term as updater
#drho = nothing
#ortho = "left" #i guess with respect to ordered vertices that's valid but may be cleaner to use next_region logic
#if noise > 0.0 && isforward(direction)
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
# so noiseterm is a solver
#end
state, spec = inserter(
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
)

all_kwargs = (;
cutoff,
maxdim,
mindim,
which_region_update,
sweep_plan,
total_sweep_steps=length(sweep_plan),
end_of_sweep=(which_region_update == length(sweep_plan)),
state,
region,
which_sweep,
spec,
outputlevel,
info...,
region_kwargs...,
internal_kwargs...,
)
update!(region_observer!; all_kwargs...)
!(isnothing(region_printer)) && region_printer(; all_kwargs...)

return state, projected_operator
end
Loading

0 comments on commit a4f3592

Please sign in to comment.