Skip to content

Commit

Permalink
Consistent update interface
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 7, 2024
1 parent 1ede385 commit b11a0f6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
66 changes: 36 additions & 30 deletions src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ struct VidalITensorNetwork{V,BTS} <: AbstractITensorNetwork{V}
bond_tensors::BTS
end

site_tensors(ψv::VidalITensorNetwork) = ψv.itensornetwork
bond_tensors(ψv::VidalITensorNetwork) = ψv.bond_tensors
bond_tensor(ψv::VidalITensorNetwork, e) = bond_tensors(ψv)[e]

data_graph_type(::Type{VidalITensorNetwork}) = data_graph_type(site_tensors(ψv))
data_graph(ψv::VidalITensorNetwork) = data_graph(site_tensors(ψv))
function copy(ψv::VidalITensorNetwork)
return VidalITensorNetwork(copy(site_tensors(ψv)), copy(bond_tensors(ψv)))
site_tensors(ψ::VidalITensorNetwork) = ψ.itensornetwork
bond_tensors(ψ::VidalITensorNetwork) = ψ.bond_tensors
bond_tensor(ψ::VidalITensorNetwork, e) = bond_tensors(ψ)[e]

data_graph_type(TN::Type{VidalITensorNetwork}) = data_graph_type(site_tensors(TN))
data_graph(ψ::VidalITensorNetwork) = data_graph(site_tensors(ψ))
function copy(ψ::VidalITensorNetwork)
return VidalITensorNetwork(copy(site_tensors(ψ)), copy(bond_tensors(ψ)))
end

function default_norm_cache::ITensorNetwork)
Expand All @@ -23,29 +23,35 @@ function default_norm_cache(ψ::ITensorNetwork)
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function ITensorNetwork(ψv::VidalITensorNetwork; (cache!)=nothing)
ψ = copy(site_tensors(ψv))
function ITensorNetwork(
ψ_vidal::VidalITensorNetwork; (bp_cache!)=nothing, update_gauge=false, update_kwargs...
)
if update_gauge
ψ_vidal = update(ψ_vidal; update_kwargs...)
end

ψ = copy(site_tensors(ψ_vidal))

for e in edges(ψ)
vsrc, vdst = src(e), dst(e)
root_S = sqrt_diag(bond_tensor(ψv, e))
root_S = sqrt_diag(bond_tensor(ψ_vidal, e))
setindex_preserve_graph!(ψ, noprime(root_S * ψ[vsrc]), vsrc)
setindex_preserve_graph!(ψ, noprime(root_S * ψ[vdst]), vdst)
end

if !isnothing(cache!)
if !isnothing(bp_cache!)
bp_cache = default_norm_cache(ψ)
mts = messages(bp_cache)

for e in edges(ψ)
vsrc, vdst = src(e), dst(e)
pe = partitionedge(bp_cache, (vsrc, 1) => (vdst, 1))
set!(mts, pe, copy(ITensor[dense(bond_tensor(ψv, e))]))
set!(mts, reverse(pe), copy(ITensor[dense(bond_tensor(ψv, e))]))
set!(mts, pe, copy(ITensor[dense(bond_tensor(ψ_vidal, e))]))
set!(mts, reverse(pe), copy(ITensor[dense(bond_tensor(ψ_vidal, e))]))
end

bp_cache = set_messages(bp_cache, mts)
cache![] = bp_cache
bp_cache![] = bp_cache
end

return ψ
Expand All @@ -61,12 +67,12 @@ function vidalitensornetwork_preserve_cache(
edges=NamedGraphs.edges(ψ),
svd_kwargs...,
)
ψv_site_tensors = copy(ψ)
ψ_vidal_site_tensors = copy(ψ)
bond_tensors = bond_tensors_cache(ψ)

for e in edges
vsrc, vdst = src(e), dst(e)
ψvsrc, ψvdst = copy(ψv_site_tensors[vsrc]), copy(ψv_site_tensors[vdst])
ψvsrc, ψvdst = ψ_vidal_site_tensors[vsrc], ψ_vidal_site_tensors[vdst]

pe = partitionedge(bp_cache, (vsrc, 1) => (vdst, 1))
edge_ind = commoninds(ψvsrc, ψvdst)
Expand Down Expand Up @@ -99,8 +105,8 @@ function vidalitensornetwork_preserve_cache(
ψvdst = replaceinds(ψvdst, edge_ind, edge_ind_sim)
ψvdst = replaceinds(ψvdst * V, commoninds(V, S), new_edge_ind)

setindex_preserve_graph!(ψv_site_tensors, ψvsrc, vsrc)
setindex_preserve_graph!(ψv_site_tensors, ψvdst, vdst)
setindex_preserve_graph!(ψ_vidal_site_tensors, ψvsrc, vsrc)
setindex_preserve_graph!(ψ_vidal_site_tensors, ψvdst, vdst)

S = replaceinds(
S,
Expand All @@ -110,12 +116,13 @@ function vidalitensornetwork_preserve_cache(
bond_tensors[e] = S
end

return VidalITensorNetwork(ψv_site_tensors, bond_tensors)
return VidalITensorNetwork(ψ_vidal_site_tensors, bond_tensors)
end

function VidalITensorNetwork(
ψ::ITensorNetwork;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
kwargs...,
)
Expand All @@ -126,36 +133,35 @@ function VidalITensorNetwork(
return vidalitensornetwork_preserve_cache(ψ; bp_cache=cache![], kwargs...)
end

function update(ψv::VidalITensorNetwork; kwargs...)
return VidalITensorNetwork(ITensorNetwork(ψv); kwargs...)
function update(ψ::VidalITensorNetwork; kwargs...)
return VidalITensorNetwork(ITensorNetwork(ψ; update_gauge=false); kwargs...)
end

"""Function to measure the 'isometries' of a state in the Vidal Gauge"""
function vidal_gauge_isometries(
ψv::VidalITensorNetwork;
edges=vcat(NamedGraphs.edges(ψv), reverse.(NamedGraphs.edges(ψv))),
ψ::VidalITensorNetwork; edges=vcat(NamedGraphs.edges(ψ), reverse.(NamedGraphs.edges(ψ)))
)
isometries = Dict()

for e in edges
vsrc, vdst = src(e), dst(e)
ψ_vsrc = copy(ψv[vsrc])
for vn in setdiff(neighbors(ψv, vsrc), [vdst])
ψ_vsrc = noprime(ψ_vsrc * bond_tensor(ψv, vn => vsrc))
ψ_vsrc = copy(ψ[vsrc])
for vn in setdiff(neighbors(ψ, vsrc), [vdst])
ψ_vsrc = noprime(ψ_vsrc * bond_tensor(ψ, vn => vsrc))
end

ψ_vsrcdag = dag(ψ_vsrc)
replaceind!(ψ_vsrcdag, commonind(ψ_vsrc, ψv[vdst]), commonind(ψ_vsrc, ψv[vdst])')
replaceind!(ψ_vsrcdag, commonind(ψ_vsrc, ψ[vdst]), commonind(ψ_vsrc, ψ[vdst])')
isometries[e] = ψ_vsrcdag * ψ_vsrc
end

return isometries
end

"""Function to measure the 'distance' of a state from the Vidal Gauge"""
function gauge_error(ψv::VidalITensorNetwork)
function gauge_error(ψ::VidalITensorNetwork)
f = 0
isometries = vidal_gauge_isometries(ψv)
isometries = vidal_gauge_isometries(ψ)
for e in keys(isometries)
lhs = isometries[e]
f += message_diff(ITensor[lhs], ITensor[denseblocks(delta(inds(lhs)))])
Expand Down
2 changes: 1 addition & 1 deletion test/test_gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using SplitApplyCombine

# Move to symmetric gauge
cache_ref = Ref{BeliefPropagationCache}()
ψ_symm = ITensorNetwork(ψ_vidal; (cache!)=cache_ref)
ψ_symm = ITensorNetwork(ψ_vidal; (bp_cache!)=cache_ref)
bp_cache = cache_ref[]

# Test we just did a gauge transform and didn't change the overall network
Expand Down

0 comments on commit b11a0f6

Please sign in to comment.