Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Current ortho fix #208

Merged
merged 20 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,37 +584,63 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs.
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
function orthogonalize_path(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return orthogonalize_path(tn, [edge]; kwargs...)
end

function orthogonalize_path(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return orthogonalize_path(tn, edgetype(tn)(edge); kwargs...)
end

# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function orthogonalize_path(
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs...
)
# tn = factorize(tn, edge; kwargs...)
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
# return contract(tn, new_vertex => dst(edge))
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
for edge in edges
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
end
return tn
end

function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return _orthogonalize_edge(tn, edge; kwargs...)
end

function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return orthogonalize(tn, edgetype(tn)(edge); kwargs...)
function orthogonalize_path(tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...)
return orthogonalize_path(tn, edgetype(tn).(edges); kwargs...)
end

# Orthogonalize an ITensorNetwork towards a source vertex, treating
# Orthogonalize an ITensorNetwork towards a region, treating
# the network as a tree spanned by a spanning tree.
# TODO: Rename `tree_orthogonalize`.
function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, source_vertex)
spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex)
for e in spanning_tree_edges
ψ = orthogonalize(ψ, e)
end
return ψ
function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, region::Vector)
spanning_tree_edges = post_order_dfs_edges_region(bfs_tree(ψ, first(region)), region)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
return orthogonalize_path(ψ, spanning_tree_edges)
end

function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, region)
return orthogonalize(ψ, [region])
end

function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edges::Vector{<:AbstractEdge})
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
return orthogonalize(ψ, unique(vcat([src(e) for e in edges], [dst(e) for e in edges])))
end

function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edges::Vector{<:Pair})
return orthogonalize(ψ, edgetype(ψ).(edges))
end

function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edge::AbstractEdge)
return orthogonalize(ψ, [edge])
end

function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, edge::Pair)
return orthogonalize(ψ, edgetype(ψ)(edge))
end

# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
Expand Down
5 changes: 5 additions & 0 deletions src/edge_sequences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ end
@traitfn function edge_sequence(::Algorithm"parallel", g::::(!IsDirected))
return [[e] for e in vcat(edges(g), reverse.(edges(g)))]
end

function post_order_dfs_edges_region(g::AbstractGraph, region)
es = post_order_dfs_edges(g, first(region))
return filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), es)
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
49 changes: 3 additions & 46 deletions src/solvers/alternating_update/region_update.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,3 @@
#ToDo: generalize beyond 2-site
#ToDo: remove concept of orthogonality center for generality
function current_ortho(sweep_plan, which_region_update)
regions = first.(sweep_plan)
region = regions[which_region_update]
current_verts = support(region)
if !isa(region, AbstractEdge) && length(region) == 1
return only(current_verts)
end
if which_region_update == length(regions)
# look back by one should be sufficient, but may be brittle?
overlapping_vertex = only(
intersect(current_verts, support(regions[which_region_update - 1]))
)
return overlapping_vertex
else
# look forward
other_regions = filter(
x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end])
)
# find the first region that has overlapping support with current region
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
if isnothing(ind)
# look backward
other_regions = reverse(
filter(
x -> !(issetequal(x, current_verts)),
support.(regions[1:(which_region_update - 1)]),
),
)
ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions)
end
@assert !isnothing(ind)
future_verts = union(support(other_regions[ind]))
# return ortho_ceter as the vertex in current region that does not overlap with following one
overlapping_vertex = intersect(current_verts, future_verts)
nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex))
return nonoverlapping_vertex
end
end

function region_update(
projected_operator,
state;
Expand All @@ -64,14 +23,13 @@ function region_update(

# ToDo: remove orthogonality center on vertex for generality
# region carries same information
ortho_vertex = current_ortho(sweep_plan, which_region_update)
if !isnothing(transform_operator)
projected_operator = transform_operator(
state, projected_operator; outputlevel, transform_operator_kwargs...
)
end
state, projected_operator, phi = extracter(
state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs
state, projected_operator, region; extracter_kwargs..., internal_kwargs
)
# create references, in case solver does (out-of-place) modify PH or state
state! = Ref(state)
Expand All @@ -97,9 +55,8 @@ function region_update(
# drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees...
# so noiseterm is a solver
#end
state, spec = inserter(
state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs
)
#if isa(region, AbstractEdge) &&
state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs)
all_kwargs = (;
which_region_update,
sweep_plan,
Expand Down
13 changes: 7 additions & 6 deletions src/solvers/extract/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
# insert_local_tensors takes that tensor and factorizes it back
# apart and puts it back into the network.
#
function default_extracter(state, projected_operator, region, ortho; internal_kwargs)
state = orthogonalize(state, ortho)
function default_extracter(state, projected_operator, region; internal_kwargs)
if isa(region, AbstractEdge)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
other_vertex = only(setdiff(support(region), [ortho]))
left_inds = uniqueinds(state[ortho], state[other_vertex])
vsrc, vdst = src(region), dst(region)
state = orthogonalize(state, vsrc)
left_inds = uniqueinds(state[vsrc], state[vdst])
#ToDo: replace with call to factorize
U, S, V = svd(
state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region)
)
state[ortho] = U
state[vsrc] = U
local_tensor = S * V
else
state = orthogonalize(state, region)
local_tensor = prod(state[v] for v in region)
end
projected_operator = position(projected_operator, state, region)
Expand Down
25 changes: 10 additions & 15 deletions src/solvers/insert/insert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
function default_inserter(
state::AbstractTTN,
phi::ITensor,
region,
ortho_vert;
region;
normalize=false,
maxdim=nothing,
mindim=nothing,
Expand All @@ -16,16 +15,14 @@ function default_inserter(
)
state = copy(state)
spec = nothing
other_vertex = setdiff(support(region), [ortho_vert])
if !isempty(other_vertex)
v = only(other_vertex)
e = edgetype(state)(ortho_vert, v)
indsTe = inds(state[ortho_vert])
if length(region) == 2
v = last(region)
e = edgetype(state)(first(region), last(region))
indsTe = inds(state[first(region)])
L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff)
state[ortho_vert] = L

state[first(region)] = L
else
v = ortho_vert
v = only(region)
end
state[v] = phi
state = set_ortho_region(state, [v])
Expand All @@ -36,16 +33,14 @@ end
function default_inserter(
state::AbstractTTN,
phi::ITensor,
region::NamedEdge,
ortho;
region::NamedEdge;
cutoff=nothing,
maxdim=nothing,
mindim=nothing,
normalize=false,
internal_kwargs,
)
v = only(setdiff(support(region), [ortho]))
state[v] *= phi
state = set_ortho_region(state, [v])
state[dst(region)] *= phi
state = set_ortho_region(state, [dst(region)])
return state, nothing
end
13 changes: 8 additions & 5 deletions src/solvers/sweep_plans/sweep_plans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ end

support(r) = r

function reverse_region(edges, which_edge; nsites=1, region_kwargs=(;))
function reverse_region(edges, which_edge; reverse_edge=false, nsites=1, region_kwargs=(;))
current_edge = edges[which_edge]
if nsites == 1
return [(current_edge, region_kwargs)]
!reverse_edge && return [(current_edge, region_kwargs)]
reverse_edge && return [(reverse(current_edge), region_kwargs)]
elseif nsites == 2
if last(edges) == current_edge
return ()
Expand Down Expand Up @@ -62,6 +63,7 @@ function forward_sweep(
dir::Base.ForwardOrdering,
graph::AbstractGraph;
root_vertex=GraphsExtensions.default_root_vertex(graph),
reverse_edges=false,
region_kwargs,
reverse_kwargs=region_kwargs,
reverse_step=false,
Expand All @@ -71,12 +73,13 @@ function forward_sweep(
regions = collect(
flatten(map(i -> forward_region(edges, i; region_kwargs, kwargs...), eachindex(edges)))
)

if reverse_step
reverse_regions = collect(
flatten(
map(
i -> reverse_region(edges, i; region_kwargs=reverse_kwargs, kwargs...),
i -> reverse_region(
edges, i; reverse_edge=reverse_edges, region_kwargs=reverse_kwargs, kwargs...
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
),
eachindex(edges),
),
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
),
Expand All @@ -90,7 +93,7 @@ end

#ToDo: is there a better name for this? unidirectional_sweep? traversal?
function forward_sweep(dir::Base.ReverseOrdering, args...; kwargs...)
return reverse(forward_sweep(Base.Forward, args...; kwargs...))
return reverse(forward_sweep(Base.Forward, args...; reverse_edges=true, kwargs...))
end

function default_sweep_plans(
Expand Down
38 changes: 14 additions & 24 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using Graphs: has_vertex
using NamedGraphs.GraphsExtensions:
GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices
GraphsExtensions,
edge_path,
leaf_vertices,
post_order_dfs_edges,
post_order_dfs_vertices,
a_star
using NamedGraphs: namedgraph_a_star
using IsApprox: IsApprox, Approx
using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev
using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize
Expand Down Expand Up @@ -29,30 +35,14 @@ function set_ortho_region(tn::AbstractTTN, new_region)
return error("Not implemented")
end

#
# Orthogonalization
#

function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...)
if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn))
return tn
end
# TODO: Rewrite this in a more general way.
if isone(length(ortho_region(tn)))
edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center)
else
edge_list = post_order_dfs_edges(tn, ortho_center)
end
for e in edge_list
tn = orthogonalize(tn, e)
function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...)
new_path = post_order_dfs_edges_region(ttn, region)
existing_path = post_order_dfs_edges_region(ttn, ortho_region(ttn))
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
path = setdiff(new_path, existing_path)
if !isempty(path)
ttn = typeof(ttn)(orthogonalize_path(ITensorNetwork(ttn), path; kwargs...))
end
return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center]))
end

# For ambiguity error

function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...)
return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...))
return set_ortho_region(ttn, region)
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
end

#
Expand Down
12 changes: 7 additions & 5 deletions test/test_treetensornetworks/test_solvers/test_dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@eval module $(gensym())
using DataGraphs: edge_data, vertex_data
using Dictionaries: Dictionary
using Graphs: nv, vertices
using Graphs: nv, vertices, uniform_tree
using ITensorMPS: ITensorMPS
using ITensorNetworks:
ITensorNetworks,
Expand All @@ -19,6 +19,7 @@ using ITensorNetworks.ITensorsExtensions: replace_vertices
using ITensorNetworks.ModelHamiltonians: ModelHamiltonians
using ITensors: ITensors
using KrylovKit: eigsolve
using NamedGraphs: NamedGraph, rename_vertices
using NamedGraphs.NamedGraphGenerators: named_comb_tree
using Observers: observer
using StableRNGs: StableRNG
Expand Down Expand Up @@ -313,11 +314,12 @@ end
nsites = 2
nsweeps = 10

c = named_comb_tree((3, 2))
s = siteinds("S=1/2", c)
os = ModelHamiltonians.heisenberg(c)
H = ttn(os, s)
rng = StableRNG(1234)
g = NamedGraph(uniform_tree(10))
g = rename_vertices(v -> (v, 1), g)
s = siteinds("S=1/2", g)
os = ModelHamiltonians.heisenberg(g)
H = ttn(os, s)
psi = random_ttn(rng, s; link_space=5)
e, psi = dmrg(H, psi; nsweeps, maxdim, nsites)

Expand Down
Loading