diff --git a/src/derivative.jl b/src/derivative.jl index 731c6663..347f9c5d 100644 --- a/src/derivative.jl +++ b/src/derivative.jl @@ -22,17 +22,17 @@ function derivative( ::Algorithm"bp", ψ::AbstractITensorNetwork, verts::Vector; - (bp_cache!)=nothing, - update_bp_cache=isnothing(bp_cache!), - bp_cache_update_kwargs=default_cache_update_kwargs(bp_cache!), + (cache!)=nothing, + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(cache!), ) - if isnothing(bp_cache!) - bp_cache! = Ref(BeliefPropagationCache(ψ)) + if isnothing(cache!) + cache! = Ref(BeliefPropagationCache(ψ)) end - if update_bp_cache - bp_cache![] = update(bp_cache![]; bp_cache_update_kwargs...) + if update_cache + cache![] = update(cache![]; cache_update_kwargs...) end - return incoming_messages(bp_cache![], setdiff(vertices(ψ), verts)) + return incoming_messages(cache![], setdiff(vertices(ψ), verts)) end diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index ccd4aa67..a8187751 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -72,3 +72,6 @@ operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f) bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f)) ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f)) inv_vertex_map(f::AbstractFormNetwork) = v -> first(v) +operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v) +bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v) +ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v) diff --git a/src/formnetworks/bilinearformnetwork.jl b/src/formnetworks/bilinearformnetwork.jl index 356b0ed1..f93d89ee 100644 --- a/src/formnetworks/bilinearformnetwork.jl +++ b/src/formnetworks/bilinearformnetwork.jl @@ -56,7 +56,7 @@ function update( ) blf = copy(blf) # TODO: Maybe add a check that it really does preserve the graph. - setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex)) - setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex)) + setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex(blf, state_vertex)) + setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex(blf, state_vertex)) return blf end diff --git a/test/test_forms.jl b/test/test_forms.jl index 7f74c8a7..a7652c52 100644 --- a/test/test_forms.jl +++ b/test/test_forms.jl @@ -1,13 +1,13 @@ using ITensors -using Graphs +using Graphs: nv using NamedGraphs using ITensorNetworks using ITensorNetworks: delta_network, update, tensornetwork, - bra_vertex_map, - ket_vertex_map, + bra_vertex, + ket_vertex, dual_index_map, bra_network, ket_network, @@ -45,25 +45,24 @@ using SplitApplyCombine new_tensor = randomITensor(inds(ψket[v])) qf_updated = update(qf, v, copy(new_tensor)) - @test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)] ≈ + @test tensornetwork(qf_updated)[bra_vertex(qf_updated, v)] ≈ dual_index_map(qf_updated)(dag(new_tensor)) - @test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] ≈ new_tensor + @test tensornetwork(qf_updated)[ket_vertex(qf_updated, v)] ≈ new_tensor @test underlying_graph(ket_network(qf)) == underlying_graph(ψket) @test underlying_graph(operator_network(qf)) == underlying_graph(A) ∂qf_∂v = only(derivative_state(qf, [v])) - @test (∂qf_∂v) * (qf[ket_vertex_map(qf)(v)] * qf[bra_vertex_map(qf)(v)]) ≈ - ITensors.contract(qf) + @test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) ≈ contract(qf) - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=false) - ∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp) - ∂qf_∂v_bp = normalize!(∂qf_∂v_bp) - ∂qf_∂v = normalize!(∂qf_∂v) + ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=false) + ∂qf_∂v_bp = contract(∂qf_∂v_bp) + ∂qf_∂v_bp /= norm(∂qf_∂v_bp) + ∂qf_∂v /= norm(∂qf_∂v) @test ∂qf_∂v_bp != ∂qf_∂v - ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_bp_cache=true) - ∂qf_∂v_bp = ITensors.contract(∂qf_∂v_bp) - ∂qf_∂v_bp = normalize!(∂qf_∂v_bp) + ∂qf_∂v_bp = derivative_state(qf, [v]; alg="bp", update_cache=true) + ∂qf_∂v_bp = contract(∂qf_∂v_bp) + ∂qf_∂v_bp /= norm(∂qf_∂v_bp) @test ∂qf_∂v_bp ≈ ∂qf_∂v end