Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sweeps interface #143

Merged
merged 71 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
e2ebc4d
Reorganize extract_local_tensor.
Jan 30, 2024
9f6c7d5
Started reorganizing, currently broken.
Jan 30, 2024
0691786
Merge branch 'main' into refactor-sweeps
Feb 12, 2024
737195d
First working version for tdvp with refactor using region_kwargs.
Feb 16, 2024
bc41f29
Merge branch 'main' into refactor-sweeps
Feb 29, 2024
08edb87
Move unbound kwargs in call to tdvp into inserter_kwargs.
Feb 29, 2024
166645e
Adapt exponentiate solver to new interface
Feb 29, 2024
2dddf87
Add todo in test.
Feb 29, 2024
4de5db7
Format .
Feb 29, 2024
3197845
Add a few ToDos in tree_sweeping.jl .
Feb 29, 2024
4f131c6
Introduce `internal_kwargs`.
Mar 1, 2024
4090dc4
Add `ishermitian=false` for imaginary-time calculations.
Mar 1, 2024
602ff02
Add extend/expand functionality.
Mar 2, 2024
0493807
Rearrange tdvp_sweep_plan, make use of expand/extend, move _compute_n…
Mar 2, 2024
43c59fb
Add StructWalk to package.
Mar 2, 2024
f2084df
Fix various bugs in the previous few commits.
Mar 2, 2024
8955255
Make expand handle empty NamedTuples.
Mar 2, 2024
c2106b4
Format and slight change in test parameters.
Mar 2, 2024
5ea14b4
Remove second definition of default_outputlevel
Mar 3, 2024
87dfa38
Make region constructors independent of edge-direction.
Mar 3, 2024
5d83451
Remove comment.
Mar 3, 2024
9d82a8e
Add missing variable to test.
Mar 3, 2024
85bb21a
Move solvers to local_solvers directory.
Mar 3, 2024
eb4e331
Move treetensornetworks/solvers to solvers
Mar 3, 2024
8b2ea1e
Start reorganizing directory structure.
Mar 3, 2024
6f45275
Reorganize further.
Mar 3, 2024
a84771b
Incorporate into .
Mar 10, 2024
23c7566
Incorporate `sweep_update` into `alternating_update`.
Mar 10, 2024
d74361e
Implement direction-independent logic for `current_ortho`,`inserter` …
Mar 10, 2024
0a927b7
Account for new file structure in module definition.
Mar 10, 2024
a2f9704
Add vector-`interleaving` to `utils.jl`.
Mar 10, 2024
ebfd1c3
Move inserter/extracter to separate files, implement direction indepe…
Mar 10, 2024
fa13de5
Create file for defaults.
Mar 10, 2024
7538d5c
Fix bug when directionality is not reversed in reverse_step.
Mar 10, 2024
f248bb5
Construct reverse step via region_intersections. Currently broken.
Mar 10, 2024
e8f952c
Fix wrong logic for certain patterns at ends of subtrees in 2-site tdvp.
Mar 11, 2024
1db0bdc
Add check for consistency of forward and reverse sweeps.
Mar 11, 2024
7eebcec
Minor change to test to hit previous errors.
Mar 11, 2024
15501cf
Format.
Mar 11, 2024
2006d16
Add todo regarding kwargs for inserters.
Mar 11, 2024
5cfa5f2
Add functionality.
Mar 13, 2024
d956dd9
Make use of being implemented.
Mar 13, 2024
1711418
Add default_sweep_printer to defaults.jl.
Mar 13, 2024
11016f6
Remove `tdvp_sweep_plans`, make `default_sweep_plans` generic, and on…
Mar 13, 2024
865028b
Remove unnecessary NamedTuple processing functionality.
Mar 13, 2024
92d8943
Format.
Mar 13, 2024
1462d7a
Add tests for _compute_nsweeps functionality and fix some bugs.
Mar 13, 2024
5fe701e
Flatten (updater, updater_kwargs) etc, rename some functions, cleanup.
Mar 14, 2024
942fa6d
Adapt test files to new interface. Remove regression test for sweep p…
Mar 14, 2024
73a3c3e
Adapt local_solvers, add default_alternating_updates as template for …
Mar 14, 2024
1df0418
Update solvers. Use default_alternating_updates for everything but tdvp.
Mar 14, 2024
e0638e0
Add t_evolved to tdvp_sweep_plan
Mar 14, 2024
a81aeb2
Throw error when trying to extend a vector to a length shorter than i…
Mar 14, 2024
c17d361
Dispatch alternating update on presence or not of sweep_plans as posi…
Mar 15, 2024
9f84fc5
Adapt solvers to new definition of alternating_update.
Mar 15, 2024
c585291
Fix name of extend(_or_truncate) in sweep_plans.
Mar 15, 2024
9947aba
Add regression test back in.
Mar 15, 2024
7bab94e
Format.
Mar 15, 2024
3929765
Adapt tdvp_time_dependent. Krylov gives better result than ODE solver…
Mar 15, 2024
be9cd62
Adapt linsolve. Untested.
Mar 15, 2024
1e368bf
Add StructWalk to Project.toml.
Mar 15, 2024
a2073e3
Fix cache_to_disk. Still broken due to lack of implementation of disk…
Mar 15, 2024
f52fcb8
Merge branch 'main' into refactor-sweeps
Mar 15, 2024
0acc695
Format.
Mar 15, 2024
8833174
Move apply.jl back to end of includes.
Mar 15, 2024
aa9ad17
Loosen test for imaginary time propagation TDVP.
Mar 15, 2024
f21f721
Remove unnecessary function from eigsolve.
Mar 17, 2024
870f38a
Remove unfinished insert_region_intersections.
Mar 17, 2024
2767d1f
Remove and move update_step .jlto region_update.jl
Mar 17, 2024
e5043fa
Account for changed file name in inclues.
Mar 17, 2024
d2a306c
Format.
Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrayKit = "a9a3c162-d163-4c15-8926-b8794fbefed2"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructWalk = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Expand Down
32 changes: 18 additions & 14 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ using SplitApplyCombine
using StaticArrays
using Suppressor
using TimerOutputs
using StructWalk: StructWalk, WalkStyle, postwalk

using DataGraphs: IsUnderlyingGraph, edge_data_type, vertex_data_type
using Graphs: AbstractEdge, AbstractGraph, Graph, add_edge!
Expand Down Expand Up @@ -107,27 +108,30 @@ include("tensornetworkoperators.jl")
include(joinpath("ITensorsExt", "itensorutils.jl"))
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("Graphs", "abstractdatagraph.jl"))
include(joinpath("solvers", "eigsolve.jl"))
include(joinpath("solvers", "exponentiate.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("solvers", "local_solvers", "eigsolve.jl"))
include(joinpath("solvers", "local_solvers", "exponentiate.jl"))
include(joinpath("solvers", "local_solvers", "dmrg_x.jl"))
include(joinpath("solvers", "local_solvers", "contract.jl"))
include(joinpath("solvers", "local_solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
include(joinpath("treetensornetworks", "projttns", "abstractprojttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn.jl"))
include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projouterprodttn.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
include(joinpath("treetensornetworks", "solvers", "contract.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include(joinpath("solvers", "solver_utils.jl"))
include(joinpath("solvers", "defaults.jl"))
include(joinpath("solvers", "insert", "insert.jl"))
include(joinpath("solvers", "extract", "extract.jl"))
include(joinpath("solvers", "alternating_update", "alternating_update.jl"))
include(joinpath("solvers", "alternating_update", "region_update.jl"))
include(joinpath("solvers", "tdvp.jl"))
include(joinpath("solvers", "dmrg.jl"))
include(joinpath("solvers", "dmrg_x.jl"))
include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("solvers", "sweep_plans", "sweep_plans.jl"))
include("apply.jl")

include("exports.jl")
Expand Down
160 changes: 160 additions & 0 deletions src/solvers/alternating_update/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
function alternating_update(
operator,
init_state::AbstractTTN;
nsweeps, # define default for each solver implementation
nsites, # define default for each level of solver implementation
updater, # this specifies the update performed locally
outputlevel=default_outputlevel(),
region_printer=nothing,
sweep_printer=nothing,
(sweep_observer!)=nothing,
(region_observer!)=nothing,
root_vertex=default_root_vertex(init_state),
extracter_kwargs=(;),
extracter=default_extracter(),
updater_kwargs=(;),
inserter_kwargs=(;),
inserter=default_inserter(),
transform_operator_kwargs=(;),
transform_operator=default_transform_operator(),
kwargs...,
)
inserter_kwargs = (; inserter_kwargs..., kwargs...)
sweep_plans = default_sweep_plans(
nsweeps,
init_state;
root_vertex,
extracter,
extracter_kwargs,
updater,
updater_kwargs,
inserter,
inserter_kwargs,
transform_operator,
transform_operator_kwargs,
nsites,
)
return alternating_update(
operator,
init_state,
sweep_plans;
outputlevel,
sweep_observer!,
region_observer!,
sweep_printer,
region_printer,
)
end

function alternating_update(
projected_operator,
init_state::AbstractTTN,
sweep_plans;
outputlevel=default_outputlevel(),
checkdone=default_checkdone(), #
(sweep_observer!)=nothing,
sweep_printer=default_sweep_printer,#?
(region_observer!)=nothing,
region_printer=nothing,
)
state = copy(init_state)
@assert !isnothing(sweep_plans)
for which_sweep in eachindex(sweep_plans)
sweep_plan = sweep_plans[which_sweep]

sweep_time = @elapsed begin
for which_region_update in eachindex(sweep_plan)
state, projected_operator = region_update(
projected_operator,
state;
which_sweep,
sweep_plan,
region_printer,
(region_observer!),
which_region_update,
outputlevel,
)
end
end

update!(sweep_observer!; state, which_sweep, sweep_time, outputlevel, sweep_plans)
!isnothing(sweep_printer) &&
sweep_printer(; state, which_sweep, sweep_time, outputlevel, sweep_plans)
checkdone(;
state,
which_sweep,
outputlevel,
sweep_plan,
sweep_plans,
sweep_observer!,
region_observer!,
) && break
end
return state
end

function alternating_update(operator::AbstractTTN, init_state::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
operator = ITensors.permute(operator, (linkind, siteinds, linkind))
projected_operator = ProjTTN(operator)
return alternating_update(projected_operator, init_state; kwargs...)
end

function alternating_update(
operator::AbstractTTN, init_state::AbstractTTN, sweep_plans; kwargs...
)
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
operator = ITensors.permute(operator, (linkind, siteinds, linkind))
projected_operator = ProjTTN(operator)
return alternating_update(projected_operator, init_state, sweep_plans; kwargs...)
end

#ToDo: Fix docstring.
"""
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...)

Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*init_state` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.

This version of `tdvp` accepts a representation of H as a
Vector of MPOs, Hs = [H1,H2,H3,...] such that H is defined
as H = H1+H2+H3+...
Note that this sum of MPOs is not actually computed; rather
the set of MPOs [H1,H2,H3,..] is efficiently looped over at
each step of the algorithm when optimizing the MPS.

Returns:
* `state::MPS` - time-evolved MPS
"""
function alternating_update(
operators::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs...
)
for operator in operators
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
end
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind)))
projected_operators = ProjTTNSum(operators)
return alternating_update(projected_operators, init_state; kwargs...)
end

function alternating_update(
operators::Vector{<:AbstractTTN}, init_state::AbstractTTN, sweep_plans; kwargs...
)
for operator in operators
check_hascommoninds(siteinds, operator, init_state)
check_hascommoninds(siteinds, operator, init_state')
end
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind)))
projected_operators = ProjTTNSum(operators)
return alternating_update(projected_operators, init_state, sweep_plans; kwargs...)
end
129 changes: 129 additions & 0 deletions src/solvers/alternating_update/region_update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#ToDo: generalize beyond 2-site
#ToDo: remove concept of orthogonality center for generality
function current_ortho(sweep_plan, which_region_update)
regions = first.(sweep_plan)
region = regions[which_region_update]
current_verts = support(region)
if !isa(region, AbstractEdge) && length(region) == 1
return only(current_verts)
end
if which_region_update == length(regions)
# look back by one should be sufficient, but may be brittle?
overlapping_vertex = only(
intersect(current_verts, support(regions[which_region_update - 1]))
)
return overlapping_vertex
else
# look forward
other_regions = filter(
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
)
# find the first region that has overlapping support with current region
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
if isnothing(ind)
# look backward
other_regions = reverse(
filter(
x -> !(issetequal(x, current_verts)),
support.(regions[1:(which_region_update - 1)]),
),
)
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
end
@assert !isnothing(ind)
future_verts = union(support(other_regions[ind]))
# return ortho_ceter as the vertex in current region that does not overlap with following one
overlapping_vertex = intersect(current_verts, future_verts)
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
return nonoverlapping_vertex
end
end

function region_update(
projected_operator,
state;
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_printer,
(region_observer!),
)
(region, region_kwargs) = sweep_plan[which_region_update]
(;
extracter,
extracter_kwargs,
updater,
updater_kwargs,
inserter,
inserter_kwargs,
transform_operator,
transform_operator_kwargs,
internal_kwargs,
) = region_kwargs

# ToDo: remove orthogonality center on vertex for generality
# region carries same information
ortho_vertex = current_ortho(sweep_plan, which_region_update)
if !isnothing(transform_operator)
projected_operator = transform_operator(
state, projected_operator; outputlevel, transform_operator_kwargs...
)
end
state, projected_operator, phi = extracter(
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
)
# create references, in case solver does (out-of-place) modify PH or state
state! = Ref(state)
projected_operator! = Ref(projected_operator)
# args passed by reference are supposed to be modified out of place
phi, info = updater(
phi;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
updater_kwargs...,
internal_kwargs,
)
state = state![]
projected_operator = projected_operator![]
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
end
# ToDo: implement noise term as updater
#drho = nothing
#ortho = "left" #i guess with respect to ordered vertices that's valid but may be cleaner to use next_region logic
#if noise > 0.0 && isforward(direction)
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
# so noiseterm is a solver
#end
state, spec = inserter(
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
)

all_kwargs = (;
cutoff,
maxdim,
mindim,
which_region_update,
sweep_plan,
total_sweep_steps=length(sweep_plan),
end_of_sweep=(which_region_update == length(sweep_plan)),
state,
region,
which_sweep,
spec,
outputlevel,
info...,
region_kwargs...,
internal_kwargs...,
)
update!(region_observer!; all_kwargs...)
!(isnothing(region_printer)) && region_printer(; all_kwargs...)

return state, projected_operator
end
Loading
Loading