-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
b-kloss
authored
Mar 18, 2024
1 parent
5dde64c
commit a4f3592
Showing
34 changed files
with
1,310 additions
and
989 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
function alternating_update( | ||
operator, | ||
init_state::AbstractTTN; | ||
nsweeps, # define default for each solver implementation | ||
nsites, # define default for each level of solver implementation | ||
updater, # this specifies the update performed locally | ||
outputlevel=default_outputlevel(), | ||
region_printer=nothing, | ||
sweep_printer=nothing, | ||
(sweep_observer!)=nothing, | ||
(region_observer!)=nothing, | ||
root_vertex=default_root_vertex(init_state), | ||
extracter_kwargs=(;), | ||
extracter=default_extracter(), | ||
updater_kwargs=(;), | ||
inserter_kwargs=(;), | ||
inserter=default_inserter(), | ||
transform_operator_kwargs=(;), | ||
transform_operator=default_transform_operator(), | ||
kwargs..., | ||
) | ||
inserter_kwargs = (; inserter_kwargs..., kwargs...) | ||
sweep_plans = default_sweep_plans( | ||
nsweeps, | ||
init_state; | ||
root_vertex, | ||
extracter, | ||
extracter_kwargs, | ||
updater, | ||
updater_kwargs, | ||
inserter, | ||
inserter_kwargs, | ||
transform_operator, | ||
transform_operator_kwargs, | ||
nsites, | ||
) | ||
return alternating_update( | ||
operator, | ||
init_state, | ||
sweep_plans; | ||
outputlevel, | ||
sweep_observer!, | ||
region_observer!, | ||
sweep_printer, | ||
region_printer, | ||
) | ||
end | ||
|
||
function alternating_update( | ||
projected_operator, | ||
init_state::AbstractTTN, | ||
sweep_plans; | ||
outputlevel=default_outputlevel(), | ||
checkdone=default_checkdone(), # | ||
(sweep_observer!)=nothing, | ||
sweep_printer=default_sweep_printer,#? | ||
(region_observer!)=nothing, | ||
region_printer=nothing, | ||
) | ||
state = copy(init_state) | ||
@assert !isnothing(sweep_plans) | ||
for which_sweep in eachindex(sweep_plans) | ||
sweep_plan = sweep_plans[which_sweep] | ||
|
||
sweep_time = @elapsed begin | ||
for which_region_update in eachindex(sweep_plan) | ||
state, projected_operator = region_update( | ||
projected_operator, | ||
state; | ||
which_sweep, | ||
sweep_plan, | ||
region_printer, | ||
(region_observer!), | ||
which_region_update, | ||
outputlevel, | ||
) | ||
end | ||
end | ||
|
||
update!(sweep_observer!; state, which_sweep, sweep_time, outputlevel, sweep_plans) | ||
!isnothing(sweep_printer) && | ||
sweep_printer(; state, which_sweep, sweep_time, outputlevel, sweep_plans) | ||
checkdone(; | ||
state, | ||
which_sweep, | ||
outputlevel, | ||
sweep_plan, | ||
sweep_plans, | ||
sweep_observer!, | ||
region_observer!, | ||
) && break | ||
end | ||
return state | ||
end | ||
|
||
function alternating_update(operator::AbstractTTN, init_state::AbstractTTN; kwargs...) | ||
check_hascommoninds(siteinds, operator, init_state) | ||
check_hascommoninds(siteinds, operator, init_state') | ||
# Permute the indices to have a better memory layout | ||
# and minimize permutations | ||
operator = ITensors.permute(operator, (linkind, siteinds, linkind)) | ||
projected_operator = ProjTTN(operator) | ||
return alternating_update(projected_operator, init_state; kwargs...) | ||
end | ||
|
||
function alternating_update( | ||
operator::AbstractTTN, init_state::AbstractTTN, sweep_plans; kwargs... | ||
) | ||
check_hascommoninds(siteinds, operator, init_state) | ||
check_hascommoninds(siteinds, operator, init_state') | ||
# Permute the indices to have a better memory layout | ||
# and minimize permutations | ||
operator = ITensors.permute(operator, (linkind, siteinds, linkind)) | ||
projected_operator = ProjTTN(operator) | ||
return alternating_update(projected_operator, init_state, sweep_plans; kwargs...) | ||
end | ||
|
||
#ToDo: Fix docstring. | ||
""" | ||
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...) | ||
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...) | ||
Use the time dependent variational principle (TDVP) algorithm | ||
to compute `exp(t*H)*init_state` using an efficient algorithm based | ||
on alternating optimization of the MPS tensors and local Krylov | ||
exponentiation of H. | ||
This version of `tdvp` accepts a representation of H as a | ||
Vector of MPOs, Hs = [H1,H2,H3,...] such that H is defined | ||
as H = H1+H2+H3+... | ||
Note that this sum of MPOs is not actually computed; rather | ||
the set of MPOs [H1,H2,H3,..] is efficiently looped over at | ||
each step of the algorithm when optimizing the MPS. | ||
Returns: | ||
* `state::MPS` - time-evolved MPS | ||
""" | ||
function alternating_update( | ||
operators::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs... | ||
) | ||
for operator in operators | ||
check_hascommoninds(siteinds, operator, init_state) | ||
check_hascommoninds(siteinds, operator, init_state') | ||
end | ||
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind))) | ||
projected_operators = ProjTTNSum(operators) | ||
return alternating_update(projected_operators, init_state; kwargs...) | ||
end | ||
|
||
function alternating_update( | ||
operators::Vector{<:AbstractTTN}, init_state::AbstractTTN, sweep_plans; kwargs... | ||
) | ||
for operator in operators | ||
check_hascommoninds(siteinds, operator, init_state) | ||
check_hascommoninds(siteinds, operator, init_state') | ||
end | ||
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind))) | ||
projected_operators = ProjTTNSum(operators) | ||
return alternating_update(projected_operators, init_state, sweep_plans; kwargs...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#ToDo: generalize beyond 2-site | ||
#ToDo: remove concept of orthogonality center for generality | ||
function current_ortho(sweep_plan, which_region_update) | ||
regions = first.(sweep_plan) | ||
region = regions[which_region_update] | ||
current_verts = support(region) | ||
if !isa(region, AbstractEdge) && length(region) == 1 | ||
return only(current_verts) | ||
end | ||
if which_region_update == length(regions) | ||
# look back by one should be sufficient, but may be brittle? | ||
overlapping_vertex = only( | ||
intersect(current_verts, support(regions[which_region_update - 1])) | ||
) | ||
return overlapping_vertex | ||
else | ||
# look forward | ||
other_regions = filter( | ||
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) | ||
) | ||
# find the first region that has overlapping support with current region | ||
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) | ||
if isnothing(ind) | ||
# look backward | ||
other_regions = reverse( | ||
filter( | ||
x -> !(issetequal(x, current_verts)), | ||
support.(regions[1:(which_region_update - 1)]), | ||
), | ||
) | ||
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) | ||
end | ||
@assert !isnothing(ind) | ||
future_verts = union(support(other_regions[ind])) | ||
# return ortho_ceter as the vertex in current region that does not overlap with following one | ||
overlapping_vertex = intersect(current_verts, future_verts) | ||
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) | ||
return nonoverlapping_vertex | ||
end | ||
end | ||
|
||
function region_update( | ||
projected_operator, | ||
state; | ||
outputlevel, | ||
which_sweep, | ||
sweep_plan, | ||
which_region_update, | ||
region_printer, | ||
(region_observer!), | ||
) | ||
(region, region_kwargs) = sweep_plan[which_region_update] | ||
(; | ||
extracter, | ||
extracter_kwargs, | ||
updater, | ||
updater_kwargs, | ||
inserter, | ||
inserter_kwargs, | ||
transform_operator, | ||
transform_operator_kwargs, | ||
internal_kwargs, | ||
) = region_kwargs | ||
|
||
# ToDo: remove orthogonality center on vertex for generality | ||
# region carries same information | ||
ortho_vertex = current_ortho(sweep_plan, which_region_update) | ||
if !isnothing(transform_operator) | ||
projected_operator = transform_operator( | ||
state, projected_operator; outputlevel, transform_operator_kwargs... | ||
) | ||
end | ||
state, projected_operator, phi = extracter( | ||
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs | ||
) | ||
# create references, in case solver does (out-of-place) modify PH or state | ||
state! = Ref(state) | ||
projected_operator! = Ref(projected_operator) | ||
# args passed by reference are supposed to be modified out of place | ||
phi, info = updater( | ||
phi; | ||
state!, | ||
projected_operator!, | ||
outputlevel, | ||
which_sweep, | ||
sweep_plan, | ||
which_region_update, | ||
updater_kwargs..., | ||
internal_kwargs, | ||
) | ||
state = state![] | ||
projected_operator = projected_operator![] | ||
if !(phi isa ITensor && info isa NamedTuple) | ||
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))") | ||
error("In alternating_update, solver must return an ITensor and a NamedTuple") | ||
end | ||
# ToDo: implement noise term as updater | ||
#drho = nothing | ||
#ortho = "left" #i guess with respect to ordered vertices that's valid but may be cleaner to use next_region logic | ||
#if noise > 0.0 && isforward(direction) | ||
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees... | ||
# so noiseterm is a solver | ||
#end | ||
state, spec = inserter( | ||
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs | ||
) | ||
|
||
all_kwargs = (; | ||
cutoff, | ||
maxdim, | ||
mindim, | ||
which_region_update, | ||
sweep_plan, | ||
total_sweep_steps=length(sweep_plan), | ||
end_of_sweep=(which_region_update == length(sweep_plan)), | ||
state, | ||
region, | ||
which_sweep, | ||
spec, | ||
outputlevel, | ||
info..., | ||
region_kwargs..., | ||
internal_kwargs..., | ||
) | ||
update!(region_observer!; all_kwargs...) | ||
!(isnothing(region_printer)) && region_printer(; all_kwargs...) | ||
|
||
return state, projected_operator | ||
end |
Oops, something went wrong.