Skip to content

Commit

Permalink
Current ortho fix (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Nov 11, 2024
1 parent 33d3006 commit b0d6aa7
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 134 deletions.
53 changes: 33 additions & 20 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Graphs:
add_edge!,
add_vertex!,
bfs_tree,
center,
dst,
edges,
edgetype,
Expand Down Expand Up @@ -40,7 +41,7 @@ using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
using LinearAlgebra: LinearAlgebra, factorize
using MacroTools: @capture
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
using NamedGraphs.GraphsExtensions:
, directed_graph, incident_edges, rename_vertices, vertextype
using NDTensors: NDTensors, dim
Expand Down Expand Up @@ -584,37 +585,49 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs.
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
function orthogonalize_walk(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return orthogonalize_walk(tn, [edge]; kwargs...)
end

function orthogonalize_walk(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return orthogonalize_walk(tn, edgetype(tn)(edge); kwargs...)
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function orthogonalize_walk(
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
)
# tn = factorize(tn, edge; kwargs...)
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
# return contract(tn, new_vertex => dst(edge))
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
for edge in edges
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
end
return tn
end

function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return _orthogonalize_edge(tn, edge; kwargs...)
function orthogonalize_walk(tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...)
return orthogonalize_walk(tn, edgetype(tn).(edges); kwargs...)
end

function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return orthogonalize(tn, edgetype(tn)(edge); kwargs...)
# Orthogonalize an ITensorNetwork towards a region, treating
# the network as a tree spanned by a spanning tree.
function tree_orthogonalize::AbstractITensorNetwork, region::Vector)
region_center =
length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region)
path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center)
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
return orthogonalize_walk(ψ, path)
end

# Orthogonalize an ITensorNetwork towards a source vertex, treating
# the network as a tree spanned by a spanning tree.
# TODO: Rename `tree_orthogonalize`.
function ITensorMPS.orthogonalize::AbstractITensorNetwork, source_vertex)
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex)
for e in spanning_tree_edges
ψ = orthogonalize(ψ, e)
end
return ψ
function tree_orthogonalize::AbstractITensorNetwork, region)
return tree_orthogonalize(ψ, [region])
end

# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
Expand Down
4 changes: 2 additions & 2 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ function ITensors.apply(
v⃗ = neighbor_vertices(ψ, o)
if length(v⃗) == 1
if ortho
ψ = orthogonalize(ψ, v⃗[1])
ψ = tree_orthogonalize(ψ, v⃗[1])
end
oψᵥ = apply(o, ψ[v⃗[1]])
if normalize
Expand All @@ -215,7 +215,7 @@ function ITensors.apply(
error("Vertices where the gates are being applied must be neighbors for now.")
end
if ortho
ψ = orthogonalize(ψ, v⃗[1])
ψ = tree_orthogonalize(ψ, v⃗[1])
end
if variational_optimization_only || !is_product_env
ψᵥ₁, ψᵥ₂ = full_update_bp(
Expand Down
49 changes: 3 additions & 46 deletions src/solvers/alternating_update/region_update.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,3 @@
#ToDo: generalize beyond 2-site
#ToDo: remove concept of orthogonality center for generality
function current_ortho(sweep_plan, which_region_update)
regions = first.(sweep_plan)
region = regions[which_region_update]
current_verts = support(region)
if !isa(region, AbstractEdge) && length(region) == 1
return only(current_verts)
end
if which_region_update == length(regions)
# look back by one should be sufficient, but may be brittle?
overlapping_vertex = only(
intersect(current_verts, support(regions[which_region_update - 1]))
)
return overlapping_vertex
else
# look forward
other_regions = filter(
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
)
# find the first region that has overlapping support with current region
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
if isnothing(ind)
# look backward
other_regions = reverse(
filter(
x -> !(issetequal(x, current_verts)),
support.(regions[1:(which_region_update - 1)]),
),
)
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
end
@assert !isnothing(ind)
future_verts = union(support(other_regions[ind]))
# return ortho_ceter as the vertex in current region that does not overlap with following one
overlapping_vertex = intersect(current_verts, future_verts)
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
return nonoverlapping_vertex
end
end

function region_update(
projected_operator,
state;
Expand All @@ -64,14 +23,13 @@ function region_update(

# ToDo: remove orthogonality center on vertex for generality
# region carries same information
ortho_vertex = current_ortho(sweep_plan, which_region_update)
if !isnothing(transform_operator)
projected_operator = transform_operator(
state, projected_operator; outputlevel, transform_operator_kwargs...
)
end
state, projected_operator, phi = extracter(
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
state, projected_operator, region; extracter_kwargs..., internal_kwargs
)
# create references, in case solver does (out-of-place) modify PH or state
state! = Ref(state)
Expand All @@ -97,9 +55,8 @@ function region_update(
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
# so noiseterm is a solver
#end
state, spec = inserter(
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
)
#if isa(region, AbstractEdge) &&
state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs)
all_kwargs = (;
which_region_update,
sweep_plan,
Expand Down
16 changes: 9 additions & 7 deletions src/solvers/extract/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
# insert_local_tensors takes that tensor and factorizes it back
# apart and puts it back into the network.
#
function default_extracter(state, projected_operator, region, ortho; internal_kwargs)
state = orthogonalize(state, ortho)

function default_extracter(state, projected_operator, region; internal_kwargs)
if isa(region, AbstractEdge)
other_vertex = only(setdiff(support(region), [ortho]))
left_inds = uniqueinds(state[ortho], state[other_vertex])
#ToDo: replace with call to factorize
# TODO: add functionality for orthogonalizing onto a bond so that can be called instead
vsrc, vdst = src(region), dst(region)
state = orthogonalize(state, vsrc)
left_inds = uniqueinds(state[vsrc], state[vdst])
U, S, V = svd(
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
)
state[ortho] = U
state[vsrc] = U
local_tensor = S * V
else
state = orthogonalize(state, region)
local_tensor = prod(state[v] for v in region)
end
projected_operator = position(projected_operator, state, region)
Expand Down
25 changes: 10 additions & 15 deletions src/solvers/insert/insert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
function default_inserter(
state::AbstractTTN,
phi::ITensor,
region,
ortho_vert;
region;
normalize=false,
maxdim=nothing,
mindim=nothing,
Expand All @@ -16,16 +15,14 @@ function default_inserter(
)
state = copy(state)
spec = nothing
other_vertex = setdiff(support(region), [ortho_vert])
if !isempty(other_vertex)
v = only(other_vertex)
e = edgetype(state)(ortho_vert, v)
indsTe = inds(state[ortho_vert])
if length(region) == 2
v = last(region)
e = edgetype(state)(first(region), last(region))
indsTe = inds(state[first(region)])
L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff)
state[ortho_vert] = L

state[first(region)] = L
else
v = ortho_vert
v = only(region)
end
state[v] = phi
state = set_ortho_region(state, [v])
Expand All @@ -36,16 +33,14 @@ end
function default_inserter(
state::AbstractTTN,
phi::ITensor,
region::NamedEdge,
ortho;
region::NamedEdge;
cutoff=nothing,
maxdim=nothing,
mindim=nothing,
normalize=false,
internal_kwargs,
)
v = only(setdiff(support(region), [ortho]))
state[v] *= phi
state = set_ortho_region(state, [v])
state[dst(region)] *= phi
state = set_ortho_region(state, [dst(region)])
return state, nothing
end
30 changes: 15 additions & 15 deletions src/solvers/sweep_plans/sweep_plans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ end

support(r) = r

function reverse_region(edges, which_edge; nsites=1, region_kwargs=(;))
function reverse_region(edges, which_edge; reverse_edge=false, nsites=1, region_kwargs=(;))
current_edge = edges[which_edge]
if nsites == 1
return [(current_edge, region_kwargs)]
!reverse_edge && return [(current_edge, region_kwargs)]
reverse_edge && return [(reverse(current_edge), region_kwargs)]
elseif nsites == 2
if last(edges) == current_edge
return ()
Expand Down Expand Up @@ -62,25 +63,24 @@ function forward_sweep(
dir::Base.ForwardOrdering,
graph::AbstractGraph;
root_vertex=GraphsExtensions.default_root_vertex(graph),
reverse_edges=false,
region_kwargs,
reverse_kwargs=region_kwargs,
reverse_step=false,
kwargs...,
)
edges = post_order_dfs_edges(graph, root_vertex)
regions = collect(
flatten(map(i -> forward_region(edges, i; region_kwargs, kwargs...), eachindex(edges)))
)

regions = map(eachindex(edges)) do i
forward_region(edges, i; region_kwargs, kwargs...)
end
regions = collect(flatten(regions))
if reverse_step
reverse_regions = collect(
flatten(
map(
i -> reverse_region(edges, i; region_kwargs=reverse_kwargs, kwargs...),
eachindex(edges),
),
),
)
reverse_regions = map(eachindex(edges)) do i
reverse_region(
edges, i; reverse_edge=reverse_edges, region_kwargs=reverse_kwargs, kwargs...
)
end
reverse_regions = collect(flatten(reverse_regions))
_check_reverse_sweeps(regions, reverse_regions, graph; kwargs...)
regions = interleave(regions, reverse_regions)
end
Expand All @@ -90,7 +90,7 @@ end

#ToDo: is there a better name for this? unidirectional_sweep? traversal?
function forward_sweep(dir::Base.ReverseOrdering, args...; kwargs...)
return reverse(forward_sweep(Base.Forward, args...; kwargs...))
return reverse(forward_sweep(Base.Forward, args...; reverse_edges=true, kwargs...))
end

function default_sweep_plans(
Expand Down
2 changes: 1 addition & 1 deletion src/tebd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function tebd(
ψ = apply(u⃗, ψ; cutoff, maxdim, normalize=true, ortho, kwargs...)
if ortho
for v in vertices(ψ)
ψ = orthogonalize(ψ, v)
ψ = tree_orthogonalize(ψ, v)
end
end
end
Expand Down
41 changes: 20 additions & 21 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using Graphs: has_vertex
using NamedGraphs.GraphsExtensions:
GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
GraphsExtensions,
edge_path,
leaf_vertices,
post_order_dfs_edges,
post_order_dfs_vertices,
a_star
using NamedGraphs: namedgraph_a_star, steiner_tree
using IsApprox: IsApprox, Approx
using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev
using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize
Expand Down Expand Up @@ -29,30 +35,23 @@ function set_ortho_region(tn::AbstractTTN, new_region)
return error("Not implemented")
end

#
# Orthogonalization
#

function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...)
if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn))
return tn
end
# TODO: Rewrite this in a more general way.
if isone(length(ortho_region(tn)))
edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center)
else
edge_list = post_order_dfs_edges(tn, ortho_center)
end
for e in edge_list
tn = orthogonalize(tn, e)
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
issetequal(region, ortho_region(ttn)) && return ttn
st = steiner_tree(ttn, union(region, ortho_region(ttn)))
path = post_order_dfs_edges(st, first(region))
path = filter(e -> !((src(e) region) && (dst(e) region)), path)
if !isempty(path)
ttn = typeof(ttn)(orthogonalize_walk(ITensorNetwork(ttn), path; kwargs...))
end
return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center]))
return set_ortho_region(ttn, region)
end

# For ambiguity error
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...)
return orthogonalize(ttn, [region]; kwargs...)
end

function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...)
return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...))
function tree_orthogonalize(ttn::AbstractTTN, args...; kwargs...)
return orthogonalize(ttn, args...; kwargs...)
end

#
Expand Down
Loading

0 comments on commit b0d6aa7

Please sign in to comment.