Skip to content

Commit

Permalink
Renaming. Fix ambiguity in apply(..., abstractitensornetwork)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 8, 2024
1 parent 262f0de commit 2aa2313
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ include(joinpath("formnetworks", "quadraticformnetwork.jl"))
include(joinpath("caches", "beliefpropagationcache.jl"))
include("contraction_tree_to_graph.jl")
include("gauging.jl")
include("apply.jl")
include("utils.jl")
include("tensornetworkoperators.jl")
include(joinpath("ITensorsExt", "itensorutils.jl"))
Expand All @@ -129,6 +128,7 @@ include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
include(joinpath("treetensornetworks", "solvers", "contract.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include("apply.jl")

include("exports.jl")

Expand Down
34 changes: 17 additions & 17 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ end

function ITensors.apply(
o::ITensor,
ψ::ITensorNetwork;
ψ::Union{ITensorNetwork,TreeTensorNetwork};
envs=ITensor[],
normalize=false,
ortho=false,
Expand Down Expand Up @@ -297,30 +297,30 @@ end
#In the future we will try to unify this into apply() above but currently leave it mostly as a separate function
"""Apply() function for an ITN in the Vidal Gauge. Hence the bond tensors are required.
Gate does not necessarily need to be passed. Can supply an edge to do an identity update instead. Uses Simple Update procedure assuming gate is two-site"""
function apply(
o::Union{ITensor,NamedEdge}, ψv::VidalITensorNetwork; normalize=false, apply_kwargs...
function ITensors.apply(
o::Union{ITensor,NamedEdge}, ψ::VidalITensorNetwork; normalize=false, apply_kwargs...
)
updated_ψ = copy(site_tensors(ψv))
updated_bond_tensors = copy(bond_tensors(ψv))
v⃗ = _gate_vertices(o, ψv)
updated_ψ = copy(site_tensors(ψ))
updated_bond_tensors = copy(bond_tensors(ψ))
v⃗ = _gate_vertices(o, ψ)
if length(v⃗) == 2
e = NamedEdge(v⃗[1] => v⃗[2])
ψv1, ψv2 = ψv[src(e)], ψv[dst(e)]
ψv1, ψv2 = ψ[src(e)], ψ[dst(e)]
e_ind = commonind(ψv1, ψv2)

for vn in neighbors(ψv, src(e))
for vn in neighbors(ψ, src(e))
if (vn != dst(e))
ψv1 = noprime(ψv1 * bond_tensor(ψv, vn => src(e)))
ψv1 = noprime(ψv1 * bond_tensor(ψ, vn => src(e)))
end
end

for vn in neighbors(ψv, dst(e))
for vn in neighbors(ψ, dst(e))
if (vn != src(e))
ψv2 = noprime(ψv2 * bond_tensor(ψv, vn => dst(e)))
ψv2 = noprime(ψv2 * bond_tensor(ψ, vn => dst(e)))
end
end

Qᵥ₁, Rᵥ₁, Qᵥ₂, Rᵥ₂, theta = _contract_gate(o, ψv1, bond_tensor(ψv, e), ψv2)
Qᵥ₁, Rᵥ₁, Qᵥ₂, Rᵥ₂, theta = _contract_gate(o, ψv1, bond_tensor(ψ, e), ψv2)

U, S, V = ITensors.svd(
theta,
Expand All @@ -337,22 +337,22 @@ function apply(

ψv1, updated_bond_tensors[e], ψv2 = U * Qᵥ₁, S, V * Qᵥ₂

for vn in neighbors(ψv, src(e))
for vn in neighbors(ψ, src(e))
if (vn != dst(e))
ψv1 = noprime(ψv1 * inv_diag(bond_tensor(ψv, vn => src(e))))
ψv1 = noprime(ψv1 * inv_diag(bond_tensor(ψ, vn => src(e))))
end
end

for vn in neighbors(ψv, dst(e))
for vn in neighbors(ψ, dst(e))
if (vn != src(e))
ψv2 = noprime(ψv2 * inv_diag(bond_tensor(ψv, vn => dst(e))))
ψv2 = noprime(ψv2 * inv_diag(bond_tensor(ψ, vn => dst(e))))
end
end

if normalize
ψv1 /= norm(ψv1)
ψv2 /= norm(ψv2)
normalize!(updated_bond_tensors[e])
updated_bond_tensors[e] /= norm(updated_bond_tensors[e])
end

setindex_preserve_graph!(updated_ψ, ψv1, src(e))
Expand Down
22 changes: 11 additions & 11 deletions src/gauging.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function default_bond_tensor_cache::ITensorNetwork)
function default_bond_tensors::ITensorNetwork)
return DataGraph{vertextype(ψ),Nothing,ITensor}(underlying_graph(ψ))
end

Expand Down Expand Up @@ -62,27 +62,27 @@ end
"""Use an ITensorNetwork ψ, its bond tensors and belief propagation cache to put ψ into the vidal gauge, return the bond tensors and updated_ψ."""
function vidalitensornetwork_preserve_cache(
ψ::ITensorNetwork;
bp_cache::BeliefPropagationCache=default_norm_cache(ψ),
bond_tensors_cache=default_bond_tensor_cache,
cache=default_norm_cache(ψ),
bond_tensors=default_bond_tensors,
message_cutoff=10 * eps(real(scalartype(ψ))),
regularization=10 * eps(real(scalartype(ψ))),
edges=NamedGraphs.edges(ψ),
svd_kwargs...,
)
ψ_vidal_site_tensors = copy(ψ)
bond_tensors = bond_tensors_cache(ψ)
ψ_vidal_bond_tensors = bond_tensors(ψ)

for e in edges
vsrc, vdst = src(e), dst(e)
ψvsrc, ψvdst = ψ_vidal_site_tensors[vsrc], ψ_vidal_site_tensors[vdst]

pe = partitionedge(bp_cache, (vsrc, 1) => (vdst, 1))
pe = partitionedge(cache, (vsrc, 1) => (vdst, 1))
edge_ind = commoninds(ψvsrc, ψvdst)
edge_ind_sim = sim(edge_ind)

X_D, X_U = eigen(only(message(bp_cache, pe)); ishermitian=true, cutoff=message_cutoff)
X_D, X_U = eigen(only(message(cache, pe)); ishermitian=true, cutoff=message_cutoff)
Y_D, Y_U = eigen(
only(message(bp_cache, reverse(pe))); ishermitian=true, cutoff=message_cutoff
only(message(cache, reverse(pe))); ishermitian=true, cutoff=message_cutoff
)
X_D, Y_D = map_diag(x -> x + regularization, X_D),
map_diag(x -> x + regularization, Y_D)
Expand Down Expand Up @@ -115,10 +115,10 @@ function vidalitensornetwork_preserve_cache(
[commoninds(S, U)..., commoninds(S, V)...] =>
[new_edge_ind..., prime(new_edge_ind)...],
)
bond_tensors[e] = S
ψ_vidal_bond_tensors[e] = S
end

return VidalITensorNetwork(ψ_vidal_site_tensors, bond_tensors)
return VidalITensorNetwork(ψ_vidal_site_tensors, ψ_vidal_bond_tensors)
end

function VidalITensorNetwork(
Expand All @@ -132,7 +132,7 @@ function VidalITensorNetwork(
cache! = Ref(default_norm_cache(ψ))
end
cache![] = update(cache![]; cache_update_kwargs...)
return vidalitensornetwork_preserve_cache(ψ; bp_cache=cache![], kwargs...)
return vidalitensornetwork_preserve_cache(ψ; cache=cache![], kwargs...)
end

function update::VidalITensorNetwork; kwargs...)
Expand All @@ -149,7 +149,7 @@ function vidal_gauge_isometry(ψ::VidalITensorNetwork, edge)
end

ψ_vsrcdag = dag(ψ_vsrc)
replaceind!(ψ_vsrcdag, commonind(ψ_vsrc, ψ[vdst]), commonind(ψ_vsrc, ψ[vdst])')
ψ_vsrcdag = replaceind(ψ_vsrcdag, commonind(ψ_vsrc, ψ[vdst]), commonind(ψ_vsrc, ψ[vdst])')

return ψ_vsrcdag * ψ_vsrc
end
Expand Down

0 comments on commit 2aa2313

Please sign in to comment.