Skip to content

Commit

Permalink
Rename variables. Extra functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 19, 2024
1 parent e9a6bc5 commit 33f39a2
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
16 changes: 8 additions & 8 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ function derivative(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
verts::Vector;
(bp_cache!)=nothing,
update_bp_cache=isnothing(bp_cache!),
bp_cache_update_kwargs=default_cache_update_kwargs(bp_cache!),
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
if isnothing(bp_cache!)
bp_cache! = Ref(BeliefPropagationCache(ψ))
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(ψ))
end

if update_bp_cache
bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...)
if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end

return incoming_messages(bp_cache![], setdiff(vertices(ψ), verts))
return incoming_messages(cache![], setdiff(vertices(ψ), verts))
end
3 changes: 3 additions & 0 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@ operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f)
bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f))
ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f))
inv_vertex_map(f::AbstractFormNetwork) = v -> first(v)
operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v)
bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v)
ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v)
4 changes: 2 additions & 2 deletions src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function update(
)
blf = copy(blf)
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex(blf, state_vertex))
setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex(blf, state_vertex))
return blf
end
27 changes: 13 additions & 14 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
using ITensors
using Graphs
using Graphs: nv
using NamedGraphs
using ITensorNetworks
using ITensorNetworks:
delta_network,
update,
tensornetwork,
bra_vertex_map,
ket_vertex_map,
bra_vertex,
ket_vertex,
dual_index_map,
bra_network,
ket_network,
Expand Down Expand Up @@ -45,25 +45,24 @@ using SplitApplyCombine
new_tensor = randomITensor(inds(ψket[v]))
qf_updated = update(qf, v, copy(new_tensor))

@test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)]
@test tensornetwork(qf_updated)[bra_vertex(qf_updated, v)]
dual_index_map(qf_updated)(dag(new_tensor))
@test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] new_tensor
@test tensornetwork(qf_updated)[ket_vertex(qf_updated, v)] new_tensor

@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(derivative_state(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex_map(qf)(v)] * qf[bra_vertex_map(qf)(v)])
ITensors.contract(qf)
@test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) contract(qf)

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=false)
∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp)
∂qf_∂v_bp = normalize!(∂qf_∂v_bp)
∂qf_∂v = normalize!(∂qf_∂v)
∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
∂qf_∂v /= norm(∂qf_∂v)
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=true)
∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp)
∂qf_∂v_bp = normalize!(∂qf_∂v_bp)
∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
@test ∂qf_∂v_bp ∂qf_∂v
end

0 comments on commit 33f39a2

Please sign in to comment.