diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index a6b16853..fc0edce4 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -7,6 +7,7 @@ using Graphs: add_edge!, add_vertex!, bfs_tree, + center, dst, edges, edgetype, @@ -40,7 +41,7 @@ using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rename_vertices, vertextype using NDTensors: NDTensors, dim @@ -584,37 +585,49 @@ function LinearAlgebra.factorize(tn::AbstractITensorNetwork, edge::Pair; kwargs. end # For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged? -function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...) +function orthogonalize_walk(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...) + return orthogonalize_walk(tn, [edge]; kwargs...) +end + +function orthogonalize_walk(tn::AbstractITensorNetwork, edge::Pair; kwargs...) + return orthogonalize_walk(tn, edgetype(tn)(edge); kwargs...) +end + +# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged? +function orthogonalize_walk( + tn::AbstractITensorNetwork, edges::Vector{<:AbstractEdge}; kwargs... +) # tn = factorize(tn, edge; kwargs...) # # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))` # new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge))) # return contract(tn, new_vertex => dst(edge)) tn = copy(tn) - left_inds = uniqueinds(tn, edge) - ltags = tags(tn, edge) - X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...) - tn[src(edge)] = X - tn[dst(edge)] *= Y + for edge in edges + left_inds = uniqueinds(tn, edge) + ltags = tags(tn, edge) + X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...) + tn[src(edge)] = X + tn[dst(edge)] *= Y + end return tn end -function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...) - return _orthogonalize_edge(tn, edge; kwargs...) +function orthogonalize_walk(tn::AbstractITensorNetwork, edges::Vector{<:Pair}; kwargs...) + return orthogonalize_walk(tn, edgetype(tn).(edges); kwargs...) end -function ITensorMPS.orthogonalize(tn::AbstractITensorNetwork, edge::Pair; kwargs...) - return orthogonalize(tn, edgetype(tn)(edge); kwargs...) +# Orthogonalize an ITensorNetwork towards a region, treating +# the network as a tree spanned by a spanning tree. +function tree_orthogonalize(ψ::AbstractITensorNetwork, region::Vector) + region_center = + length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region) + path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center) + path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) + return orthogonalize_walk(ψ, path) end -# Orthogonalize an ITensorNetwork towards a source vertex, treating -# the network as a tree spanned by a spanning tree. -# TODO: Rename `tree_orthogonalize`. -function ITensorMPS.orthogonalize(ψ::AbstractITensorNetwork, source_vertex) - spanning_tree_edges = post_order_dfs_edges(bfs_tree(ψ, source_vertex), source_vertex) - for e in spanning_tree_edges - ψ = orthogonalize(ψ, e) - end - return ψ +function tree_orthogonalize(ψ::AbstractITensorNetwork, region) + return tree_orthogonalize(ψ, [region]) end # TODO: decide whether to use graph mutating methods when resulting graph is unchanged? diff --git a/src/apply.jl b/src/apply.jl index d38f04f9..6a55f45f 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -200,7 +200,7 @@ function ITensors.apply( v⃗ = neighbor_vertices(ψ, o) if length(v⃗) == 1 if ortho - ψ = orthogonalize(ψ, v⃗[1]) + ψ = tree_orthogonalize(ψ, v⃗[1]) end oψᵥ = apply(o, ψ[v⃗[1]]) if normalize @@ -215,7 +215,7 @@ function ITensors.apply( error("Vertices where the gates are being applied must be neighbors for now.") end if ortho - ψ = orthogonalize(ψ, v⃗[1]) + ψ = tree_orthogonalize(ψ, v⃗[1]) end if variational_optimization_only || !is_product_env ψᵥ₁, ψᵥ₂ = full_update_bp( diff --git a/src/solvers/alternating_update/region_update.jl b/src/solvers/alternating_update/region_update.jl index b92adc8c..c741c82a 100644 --- a/src/solvers/alternating_update/region_update.jl +++ b/src/solvers/alternating_update/region_update.jl @@ -1,44 +1,3 @@ -#ToDo: generalize beyond 2-site -#ToDo: remove concept of orthogonality center for generality -function current_ortho(sweep_plan, which_region_update) - regions = first.(sweep_plan) - region = regions[which_region_update] - current_verts = support(region) - if !isa(region, AbstractEdge) && length(region) == 1 - return only(current_verts) - end - if which_region_update == length(regions) - # look back by one should be sufficient, but may be brittle? - overlapping_vertex = only( - intersect(current_verts, support(regions[which_region_update - 1])) - ) - return overlapping_vertex - else - # look forward - other_regions = filter( - x -> !(issetequal(x, current_verts)), support.(regions[(which_region_update + 1):end]) - ) - # find the first region that has overlapping support with current region - ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - if isnothing(ind) - # look backward - other_regions = reverse( - filter( - x -> !(issetequal(x, current_verts)), - support.(regions[1:(which_region_update - 1)]), - ), - ) - ind = findfirst(x -> !isempty(intersect(support(x), support(region))), other_regions) - end - @assert !isnothing(ind) - future_verts = union(support(other_regions[ind])) - # return ortho_ceter as the vertex in current region that does not overlap with following one - overlapping_vertex = intersect(current_verts, future_verts) - nonoverlapping_vertex = only(setdiff(current_verts, overlapping_vertex)) - return nonoverlapping_vertex - end -end - function region_update( projected_operator, state; @@ -64,14 +23,13 @@ function region_update( # ToDo: remove orthogonality center on vertex for generality # region carries same information - ortho_vertex = current_ortho(sweep_plan, which_region_update) if !isnothing(transform_operator) projected_operator = transform_operator( state, projected_operator; outputlevel, transform_operator_kwargs... ) end state, projected_operator, phi = extracter( - state, projected_operator, region, ortho_vertex; extracter_kwargs..., internal_kwargs + state, projected_operator, region; extracter_kwargs..., internal_kwargs ) # create references, in case solver does (out-of-place) modify PH or state state! = Ref(state) @@ -97,9 +55,8 @@ function region_update( # drho = noise * noiseterm(PH, phi, ortho) # TODO: actually implement this for trees... # so noiseterm is a solver #end - state, spec = inserter( - state, phi, region, ortho_vertex; inserter_kwargs..., internal_kwargs - ) + #if isa(region, AbstractEdge) && + state, spec = inserter(state, phi, region; inserter_kwargs..., internal_kwargs) all_kwargs = (; which_region_update, sweep_plan, diff --git a/src/solvers/extract/extract.jl b/src/solvers/extract/extract.jl index feb57c2f..1013d1bd 100644 --- a/src/solvers/extract/extract.jl +++ b/src/solvers/extract/extract.jl @@ -7,18 +7,20 @@ # insert_local_tensors takes that tensor and factorizes it back # apart and puts it back into the network. # -function default_extracter(state, projected_operator, region, ortho; internal_kwargs) - state = orthogonalize(state, ortho) + +function default_extracter(state, projected_operator, region; internal_kwargs) if isa(region, AbstractEdge) - other_vertex = only(setdiff(support(region), [ortho])) - left_inds = uniqueinds(state[ortho], state[other_vertex]) - #ToDo: replace with call to factorize + # TODO: add functionality for orthogonalizing onto a bond so that can be called instead + vsrc, vdst = src(region), dst(region) + state = orthogonalize(state, vsrc) + left_inds = uniqueinds(state[vsrc], state[vdst]) U, S, V = svd( - state[ortho], left_inds; lefttags=tags(state, region), righttags=tags(state, region) + state[vsrc], left_inds; lefttags=tags(state, region), righttags=tags(state, region) ) - state[ortho] = U + state[vsrc] = U local_tensor = S * V else + state = orthogonalize(state, region) local_tensor = prod(state[v] for v in region) end projected_operator = position(projected_operator, state, region) diff --git a/src/solvers/insert/insert.jl b/src/solvers/insert/insert.jl index 11aed223..01fb35bd 100644 --- a/src/solvers/insert/insert.jl +++ b/src/solvers/insert/insert.jl @@ -6,8 +6,7 @@ function default_inserter( state::AbstractTTN, phi::ITensor, - region, - ortho_vert; + region; normalize=false, maxdim=nothing, mindim=nothing, @@ -16,16 +15,14 @@ function default_inserter( ) state = copy(state) spec = nothing - other_vertex = setdiff(support(region), [ortho_vert]) - if !isempty(other_vertex) - v = only(other_vertex) - e = edgetype(state)(ortho_vert, v) - indsTe = inds(state[ortho_vert]) + if length(region) == 2 + v = last(region) + e = edgetype(state)(first(region), last(region)) + indsTe = inds(state[first(region)]) L, phi, spec = factorize(phi, indsTe; tags=tags(state, e), maxdim, mindim, cutoff) - state[ortho_vert] = L - + state[first(region)] = L else - v = ortho_vert + v = only(region) end state[v] = phi state = set_ortho_region(state, [v]) @@ -36,16 +33,14 @@ end function default_inserter( state::AbstractTTN, phi::ITensor, - region::NamedEdge, - ortho; + region::NamedEdge; cutoff=nothing, maxdim=nothing, mindim=nothing, normalize=false, internal_kwargs, ) - v = only(setdiff(support(region), [ortho])) - state[v] *= phi - state = set_ortho_region(state, [v]) + state[dst(region)] *= phi + state = set_ortho_region(state, [dst(region)]) return state, nothing end diff --git a/src/solvers/sweep_plans/sweep_plans.jl b/src/solvers/sweep_plans/sweep_plans.jl index 69221995..52915e2b 100644 --- a/src/solvers/sweep_plans/sweep_plans.jl +++ b/src/solvers/sweep_plans/sweep_plans.jl @@ -13,10 +13,11 @@ end support(r) = r -function reverse_region(edges, which_edge; nsites=1, region_kwargs=(;)) +function reverse_region(edges, which_edge; reverse_edge=false, nsites=1, region_kwargs=(;)) current_edge = edges[which_edge] if nsites == 1 - return [(current_edge, region_kwargs)] + !reverse_edge && return [(current_edge, region_kwargs)] + reverse_edge && return [(reverse(current_edge), region_kwargs)] elseif nsites == 2 if last(edges) == current_edge return () @@ -62,25 +63,24 @@ function forward_sweep( dir::Base.ForwardOrdering, graph::AbstractGraph; root_vertex=GraphsExtensions.default_root_vertex(graph), + reverse_edges=false, region_kwargs, reverse_kwargs=region_kwargs, reverse_step=false, kwargs..., ) edges = post_order_dfs_edges(graph, root_vertex) - regions = collect( - flatten(map(i -> forward_region(edges, i; region_kwargs, kwargs...), eachindex(edges))) - ) - + regions = map(eachindex(edges)) do i + forward_region(edges, i; region_kwargs, kwargs...) + end + regions = collect(flatten(regions)) if reverse_step - reverse_regions = collect( - flatten( - map( - i -> reverse_region(edges, i; region_kwargs=reverse_kwargs, kwargs...), - eachindex(edges), - ), - ), - ) + reverse_regions = map(eachindex(edges)) do i + reverse_region( + edges, i; reverse_edge=reverse_edges, region_kwargs=reverse_kwargs, kwargs... + ) + end + reverse_regions = collect(flatten(reverse_regions)) _check_reverse_sweeps(regions, reverse_regions, graph; kwargs...) regions = interleave(regions, reverse_regions) end @@ -90,7 +90,7 @@ end #ToDo: is there a better name for this? unidirectional_sweep? traversal? function forward_sweep(dir::Base.ReverseOrdering, args...; kwargs...) - return reverse(forward_sweep(Base.Forward, args...; kwargs...)) + return reverse(forward_sweep(Base.Forward, args...; reverse_edges=true, kwargs...)) end function default_sweep_plans( diff --git a/src/tebd.jl b/src/tebd.jl index edf5a188..d1d96017 100644 --- a/src/tebd.jl +++ b/src/tebd.jl @@ -23,7 +23,7 @@ function tebd( ψ = apply(u⃗, ψ; cutoff, maxdim, normalize=true, ortho, kwargs...) if ortho for v in vertices(ψ) - ψ = orthogonalize(ψ, v) + ψ = tree_orthogonalize(ψ, v) end end end diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index 7a89ceb9..8815b33f 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -1,6 +1,12 @@ using Graphs: has_vertex using NamedGraphs.GraphsExtensions: - GraphsExtensions, edge_path, leaf_vertices, post_order_dfs_edges, post_order_dfs_vertices + GraphsExtensions, + edge_path, + leaf_vertices, + post_order_dfs_edges, + post_order_dfs_vertices, + a_star +using NamedGraphs: namedgraph_a_star, steiner_tree using IsApprox: IsApprox, Approx using ITensors: ITensors, @Algorithm_str, directsum, hasinds, permute, plev using ITensorMPS: ITensorMPS, linkind, loginner, lognorm, orthogonalize @@ -29,30 +35,23 @@ function set_ortho_region(tn::AbstractTTN, new_region) return error("Not implemented") end -# -# Orthogonalization -# - -function ITensorMPS.orthogonalize(tn::AbstractTTN, ortho_center; kwargs...) - if isone(length(ortho_region(tn))) && ortho_center == only(ortho_region(tn)) - return tn - end - # TODO: Rewrite this in a more general way. - if isone(length(ortho_region(tn))) - edge_list = edge_path(tn, only(ortho_region(tn)), ortho_center) - else - edge_list = post_order_dfs_edges(tn, ortho_center) - end - for e in edge_list - tn = orthogonalize(tn, e) +function ITensorMPS.orthogonalize(ttn::AbstractTTN, region::Vector; kwargs...) + issetequal(region, ortho_region(ttn)) && return ttn + st = steiner_tree(ttn, union(region, ortho_region(ttn))) + path = post_order_dfs_edges(st, first(region)) + path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) + if !isempty(path) + ttn = typeof(ttn)(orthogonalize_walk(ITensorNetwork(ttn), path; kwargs...)) end - return set_ortho_region(tn, typeof(ortho_region(tn))([ortho_center])) + return set_ortho_region(ttn, region) end -# For ambiguity error +function ITensorMPS.orthogonalize(ttn::AbstractTTN, region; kwargs...) + return orthogonalize(ttn, [region]; kwargs...) +end -function ITensorMPS.orthogonalize(tn::AbstractTTN, edge::AbstractEdge; kwargs...) - return typeof(tn)(orthogonalize(ITensorNetwork(tn), edge; kwargs...)) +function tree_orthogonalize(ttn::AbstractTTN, args...; kwargs...) + return orthogonalize(ttn, args...; kwargs...) end # diff --git a/test/test_itensornetwork.jl b/test/test_itensornetwork.jl index ba3caa01..53e2928f 100644 --- a/test/test_itensornetwork.jl +++ b/test/test_itensornetwork.jl @@ -51,6 +51,7 @@ using ITensorNetworks: orthogonalize, random_tensornetwork, siteinds, + tree_orthogonalize, ttn using LinearAlgebra: factorize using NamedGraphs: NamedEdge @@ -287,13 +288,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test nv(tn_ortho) == 5 @test nv(tn) == 4 @test Z ≈ Z̃ - tn_ortho = orthogonalize(tn, 4 => 3) + tn_ortho = tree_orthogonalize(tn, [3, 4]) Z̃ = norm_sqr(tn_ortho) @test nv(tn_ortho) == 4 @test nv(tn) == 4 @test Z ≈ Z̃ - tn_ortho = orthogonalize(tn, 1) + tn_ortho = tree_orthogonalize(tn, 1) Z̃ = norm_sqr(tn_ortho) @test Z ≈ Z̃ Z̃ = inner(tn_ortho, tn) diff --git a/test/test_treetensornetworks/test_solvers/test_dmrg.jl b/test/test_treetensornetworks/test_solvers/test_dmrg.jl index cf8a1caf..004ec561 100644 --- a/test/test_treetensornetworks/test_solvers/test_dmrg.jl +++ b/test/test_treetensornetworks/test_solvers/test_dmrg.jl @@ -1,7 +1,7 @@ @eval module $(gensym()) using DataGraphs: edge_data, vertex_data using Dictionaries: Dictionary -using Graphs: nv, vertices +using Graphs: nv, vertices, uniform_tree using ITensorMPS: ITensorMPS using ITensorNetworks: ITensorNetworks, @@ -19,6 +19,7 @@ using ITensorNetworks.ITensorsExtensions: replace_vertices using ITensorNetworks.ModelHamiltonians: ModelHamiltonians using ITensors: ITensors using KrylovKit: eigsolve +using NamedGraphs: NamedGraph, rename_vertices using NamedGraphs.NamedGraphGenerators: named_comb_tree using Observers: observer using StableRNGs: StableRNG @@ -313,11 +314,12 @@ end nsites = 2 nsweeps = 10 - c = named_comb_tree((3, 2)) - s = siteinds("S=1/2", c) - os = ModelHamiltonians.heisenberg(c) - H = ttn(os, s) rng = StableRNG(1234) + g = NamedGraph(uniform_tree(10)) + g = rename_vertices(v -> (v, 1), g) + s = siteinds("S=1/2", g) + os = ModelHamiltonians.heisenberg(g) + H = ttn(os, s) psi = random_ttn(rng, s; link_space=5) e, psi = dmrg(H, psi; nsweeps, maxdim, nsites)