From b11a0f6a4719f7296bda2d176e1266e2e65c4ad5 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 7 Mar 2024 14:27:38 -0600 Subject: [PATCH] Consistent update interface --- src/gauging.jl | 66 ++++++++++++++++++++++++-------------------- test/test_gauging.jl | 2 +- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/gauging.jl b/src/gauging.jl index f65e90e1..50bfb599 100644 --- a/src/gauging.jl +++ b/src/gauging.jl @@ -7,14 +7,14 @@ struct VidalITensorNetwork{V,BTS} <: AbstractITensorNetwork{V} bond_tensors::BTS end -site_tensors(ψv::VidalITensorNetwork) = ψv.itensornetwork -bond_tensors(ψv::VidalITensorNetwork) = ψv.bond_tensors -bond_tensor(ψv::VidalITensorNetwork, e) = bond_tensors(ψv)[e] - -data_graph_type(::Type{VidalITensorNetwork}) = data_graph_type(site_tensors(ψv)) -data_graph(ψv::VidalITensorNetwork) = data_graph(site_tensors(ψv)) -function copy(ψv::VidalITensorNetwork) - return VidalITensorNetwork(copy(site_tensors(ψv)), copy(bond_tensors(ψv))) +site_tensors(ψ::VidalITensorNetwork) = ψ.itensornetwork +bond_tensors(ψ::VidalITensorNetwork) = ψ.bond_tensors +bond_tensor(ψ::VidalITensorNetwork, e) = bond_tensors(ψ)[e] + +data_graph_type(TN::Type{VidalITensorNetwork}) = data_graph_type(site_tensors(TN)) +data_graph(ψ::VidalITensorNetwork) = data_graph(site_tensors(ψ)) +function copy(ψ::VidalITensorNetwork) + return VidalITensorNetwork(copy(site_tensors(ψ)), copy(bond_tensors(ψ))) end function default_norm_cache(ψ::ITensorNetwork) @@ -23,29 +23,35 @@ function default_norm_cache(ψ::ITensorNetwork) end default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5) -function ITensorNetwork(ψv::VidalITensorNetwork; (cache!)=nothing) - ψ = copy(site_tensors(ψv)) +function ITensorNetwork( + ψ_vidal::VidalITensorNetwork; (bp_cache!)=nothing, update_gauge=false, update_kwargs... +) + if update_gauge + ψ_vidal = update(ψ_vidal; update_kwargs...) + end + + ψ = copy(site_tensors(ψ_vidal)) for e in edges(ψ) vsrc, vdst = src(e), dst(e) - root_S = sqrt_diag(bond_tensor(ψv, e)) + root_S = sqrt_diag(bond_tensor(ψ_vidal, e)) setindex_preserve_graph!(ψ, noprime(root_S * ψ[vsrc]), vsrc) setindex_preserve_graph!(ψ, noprime(root_S * ψ[vdst]), vdst) end - if !isnothing(cache!) + if !isnothing(bp_cache!) bp_cache = default_norm_cache(ψ) mts = messages(bp_cache) for e in edges(ψ) vsrc, vdst = src(e), dst(e) pe = partitionedge(bp_cache, (vsrc, 1) => (vdst, 1)) - set!(mts, pe, copy(ITensor[dense(bond_tensor(ψv, e))])) - set!(mts, reverse(pe), copy(ITensor[dense(bond_tensor(ψv, e))])) + set!(mts, pe, copy(ITensor[dense(bond_tensor(ψ_vidal, e))])) + set!(mts, reverse(pe), copy(ITensor[dense(bond_tensor(ψ_vidal, e))])) end bp_cache = set_messages(bp_cache, mts) - cache![] = bp_cache + bp_cache![] = bp_cache end return ψ @@ -61,12 +67,12 @@ function vidalitensornetwork_preserve_cache( edges=NamedGraphs.edges(ψ), svd_kwargs..., ) - ψv_site_tensors = copy(ψ) + ψ_vidal_site_tensors = copy(ψ) bond_tensors = bond_tensors_cache(ψ) for e in edges vsrc, vdst = src(e), dst(e) - ψvsrc, ψvdst = copy(ψv_site_tensors[vsrc]), copy(ψv_site_tensors[vdst]) + ψvsrc, ψvdst = ψ_vidal_site_tensors[vsrc], ψ_vidal_site_tensors[vdst] pe = partitionedge(bp_cache, (vsrc, 1) => (vdst, 1)) edge_ind = commoninds(ψvsrc, ψvdst) @@ -99,8 +105,8 @@ function vidalitensornetwork_preserve_cache( ψvdst = replaceinds(ψvdst, edge_ind, edge_ind_sim) ψvdst = replaceinds(ψvdst * V, commoninds(V, S), new_edge_ind) - setindex_preserve_graph!(ψv_site_tensors, ψvsrc, vsrc) - setindex_preserve_graph!(ψv_site_tensors, ψvdst, vdst) + setindex_preserve_graph!(ψ_vidal_site_tensors, ψvsrc, vsrc) + setindex_preserve_graph!(ψ_vidal_site_tensors, ψvdst, vdst) S = replaceinds( S, @@ -110,12 +116,13 @@ function vidalitensornetwork_preserve_cache( bond_tensors[e] = S end - return VidalITensorNetwork(ψv_site_tensors, bond_tensors) + return VidalITensorNetwork(ψ_vidal_site_tensors, bond_tensors) end function VidalITensorNetwork( ψ::ITensorNetwork; (cache!)=nothing, + update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(cache!), kwargs..., ) @@ -126,26 +133,25 @@ function VidalITensorNetwork( return vidalitensornetwork_preserve_cache(ψ; bp_cache=cache![], kwargs...) end -function update(ψv::VidalITensorNetwork; kwargs...) - return VidalITensorNetwork(ITensorNetwork(ψv); kwargs...) +function update(ψ::VidalITensorNetwork; kwargs...) + return VidalITensorNetwork(ITensorNetwork(ψ; update_gauge=false); kwargs...) end """Function to measure the 'isometries' of a state in the Vidal Gauge""" function vidal_gauge_isometries( - ψv::VidalITensorNetwork; - edges=vcat(NamedGraphs.edges(ψv), reverse.(NamedGraphs.edges(ψv))), + ψ::VidalITensorNetwork; edges=vcat(NamedGraphs.edges(ψ), reverse.(NamedGraphs.edges(ψ))) ) isometries = Dict() for e in edges vsrc, vdst = src(e), dst(e) - ψ_vsrc = copy(ψv[vsrc]) - for vn in setdiff(neighbors(ψv, vsrc), [vdst]) - ψ_vsrc = noprime(ψ_vsrc * bond_tensor(ψv, vn => vsrc)) + ψ_vsrc = copy(ψ[vsrc]) + for vn in setdiff(neighbors(ψ, vsrc), [vdst]) + ψ_vsrc = noprime(ψ_vsrc * bond_tensor(ψ, vn => vsrc)) end ψ_vsrcdag = dag(ψ_vsrc) - replaceind!(ψ_vsrcdag, commonind(ψ_vsrc, ψv[vdst]), commonind(ψ_vsrc, ψv[vdst])') + replaceind!(ψ_vsrcdag, commonind(ψ_vsrc, ψ[vdst]), commonind(ψ_vsrc, ψ[vdst])') isometries[e] = ψ_vsrcdag * ψ_vsrc end @@ -153,9 +159,9 @@ function vidal_gauge_isometries( end """Function to measure the 'distance' of a state from the Vidal Gauge""" -function gauge_error(ψv::VidalITensorNetwork) +function gauge_error(ψ::VidalITensorNetwork) f = 0 - isometries = vidal_gauge_isometries(ψv) + isometries = vidal_gauge_isometries(ψ) for e in keys(isometries) lhs = isometries[e] f += message_diff(ITensor[lhs], ITensor[denseblocks(delta(inds(lhs)))]) diff --git a/test/test_gauging.jl b/test/test_gauging.jl index f0a7d10b..d2b17710 100644 --- a/test/test_gauging.jl +++ b/test/test_gauging.jl @@ -24,7 +24,7 @@ using SplitApplyCombine # Move to symmetric gauge cache_ref = Ref{BeliefPropagationCache}() - ψ_symm = ITensorNetwork(ψ_vidal; (cache!)=cache_ref) + ψ_symm = ITensorNetwork(ψ_vidal; (bp_cache!)=cache_ref) bp_cache = cache_ref[] # Test we just did a gauge transform and didn't change the overall network