Skip to content

Commit

Permalink
Derivative -> Environment
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Mar 20, 2024
1 parent 40d08ad commit 73680b6
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ include(joinpath("treetensornetworks", "solvers", "contract.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include("apply.jl")
include("derivative.jl")
include("environment.jl")

include("exports.jl")

Expand Down
2 changes: 0 additions & 2 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ end
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioning::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
#We could probably do something cleverer here based on graph partitioning algorithms: https://en.wikipedia.org/wiki/Graph_partition.
default_partitioning(f::AbstractFormNetwork) = group(v -> state_vertex(f, v), vertices(f))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
Expand Down
15 changes: 9 additions & 6 deletions src/derivative.jl → src/environment.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
default_derivative_algorithm() = "exact"
default_environment_algorithm() = "exact"

function derivative(
ψ::AbstractITensorNetwork, vertices::Vector; alg=default_derivative_algorithm(), kwargs...
function environment(
ψ::AbstractITensorNetwork,
vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
return derivative(Algorithm(alg), ψ, vertices; kwargs...)
return environment(Algorithm(alg), ψ, vertices; kwargs...)
end

function derivative(
function environment(
::Algorithm"exact",
ψ::AbstractITensorNetwork,
vertices::Vector;
Expand All @@ -18,7 +21,7 @@ function derivative(
return ITensor[contract(ψ_reduced; sequence, kwargs...)]
end

function derivative(
function environment(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
verts::Vector;
Expand Down
12 changes: 6 additions & 6 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ function operator_network(f::AbstractFormNetwork)
)
end

function derivative(
function environment(
f::AbstractFormNetwork,
state_vertices::Vector;
alg=default_derivative_algorithm(),
alg=default_environment_algorithm(),
kwargs...,
)
tn_vertices = derivative_vertices(f, state_vertices)
tn_vertices = environment_vertices(f, state_vertices)
if alg == "bp"
partitions = group(v -> state_vertex(f, v), vertices(f))
return derivative(tensornetwork(f), tn_vertices; alg, partitions, kwargs...)
return environment(tensornetwork(f), tn_vertices; alg, partitions, kwargs...)
else
return derivative(tensornetwork(f), tn_vertices; alg, kwargs...)
return environment(tensornetwork(f), tn_vertices; alg, kwargs...)
end
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
function environment_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
return setdiff(
vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
)
Expand Down
8 changes: 4 additions & 4 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ITensorNetworks:
bra_network,
ket_network,
operator_network,
derivative,
environment,
BeliefPropagationCache
using Test
using Random
Expand Down Expand Up @@ -52,16 +52,16 @@ using SplitApplyCombine
@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(derivative(qf, [v]))
∂qf_∂v = only(environment(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) contract(qf)

∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
∂qf_∂v /= norm(∂qf_∂v)
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = derivative(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
@test ∂qf_∂v_bp ∂qf_∂v
Expand Down

0 comments on commit 73680b6

Please sign in to comment.