Skip to content

Commit

Permalink
Refactor expect (single site) (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored May 3, 2024
1 parent 286a048 commit adce5ac
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 196 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021 Matthew Fishman <[email protected]> and contributors
Copyright (c) 2021 Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]> and contributors"]
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.10.2"

[deps]
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DocMeta.setdocmeta!(

makedocs(;
modules=[ITensorNetworks],
authors="Matthew Fishman <[email protected]> and contributors",
authors="Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors",
repo="https://github.com/mtfishman/ITensorNetworks.jl/blob/{commit}{path}#{line}",
sitename="ITensorNetworks.jl",
format=Documenter.HTML(;
Expand Down
26 changes: 19 additions & 7 deletions src/environment.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
using ITensors: contract
using NamedGraphs.PartitionedGraphs: PartitionedGraph

default_environment_algorithm() = "exact"

function environment(
ψ::AbstractITensorNetwork,
tn::AbstractITensorNetwork,
vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
return environment(Algorithm(alg), ψ, vertices; kwargs...)
return environment(Algorithm(alg), tn, vertices; kwargs...)
end

function environment(
::Algorithm"exact", ψ::AbstractITensorNetwork, verts::Vector; kwargs...
::Algorithm"exact", tn::AbstractITensorNetwork, verts::Vector; kwargs...
)
return [contract(subgraph(ψ, setdiff(vertices(ψ), verts)); kwargs...)]
return [contract(subgraph(tn, setdiff(vertices(tn), verts)); kwargs...)]
end

function environment(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
ptn::PartitionedGraph,
vertices::Vector;
(cache!)=nothing,
partitioned_vertices=default_partitioned_vertices(ψ),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(ψ, partitioned_vertices))
cache! = Ref(BeliefPropagationCache(ptn))
end

if update_cache
Expand All @@ -34,3 +36,13 @@ function environment(

return environment(cache![], vertices)
end

function environment(
alg::Algorithm"bp",
tn::AbstractITensorNetwork,
vertices::Vector;
partitioned_vertices=default_partitioned_vertices(tn),
kwargs...,
)
return environment(alg, PartitionedGraph(tn, partitioned_vertices), vertices; kwargs...)
end
92 changes: 49 additions & 43 deletions src/expect.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,63 @@
using ITensors.ITensorMPS: ITensorMPS, expect, promote_itensor_eltype, OpSum
using Dictionaries: Dictionary, set!
using ITensors: Op, op, contract, siteinds, which_op
using ITensors.ITensorMPS: ITensorMPS, expect

default_expect_alg() = "bp"

function ITensorMPS.expect(ψIψ::AbstractFormNetwork, op::Op; contract_kwargs=(;), kwargs...)
v = only(op.sites)
ψIψ_v = ψIψ[operator_vertex(ψIψ, v)]
s = commonind(ψIψ[ket_vertex(ψIψ, v)], ψIψ_v)
operator = ITensors.op(op.which_op, s)
∂ψIψ_∂v = environment(ψIψ, operator_vertices(ψIψ, [v]); kwargs...)
numerator = contract(vcat(∂ψIψ_∂v, operator); contract_kwargs...)[]
denominator = contract(vcat(∂ψIψ_∂v, ψIψ_v); contract_kwargs...)[]

return numerator / denominator
end

function ITensorMPS.expect(
op::String,
ψ::AbstractITensorNetwork;
cutoff=nothing,
maxdim=nothing,
ortho=false,
sequence=nothing,
vertices=vertices(ψ),
alg::Algorithm,
ψ::AbstractITensorNetwork,
ops;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
cache_construction_function=tn ->
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
kwargs...,
)
s = siteinds(ψ)
ElT = promote_itensor_eltype(ψ)
# ElT = ishermitian(ITensors.op(op, s[vertices[1]])) ? real(ElT) : ElT
res = Dictionary(vertices, Vector{ElT}(undef, length(vertices)))
if isnothing(sequence)
sequence = contraction_sequence(inner_network(ψ, ψ))
ψIψ = inner_network(ψ, ψ)
if isnothing(cache!)
cache! = Ref(cache_construction_function(ψIψ))
end
normψ² = norm_sqr(ψ; alg="exact", sequence)
for v in vertices
O = ITensor(Op(op, v), s)
= apply(O, ψ; cutoff, maxdim, ortho)
res[v] = inner(ψ, Oψ; alg="exact", sequence) / normψ²

if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end
return res

return map(op -> expect(ψIψ, op; alg, cache!, update_cache=false, kwargs...), ops)
end

function ITensorMPS.expect(alg::Algorithm"exact", ψ::AbstractITensorNetwork, ops; kwargs...)
ψIψ = inner_network(ψ, ψ)
return map(op -> expect(ψIψ, op; alg, kwargs...), ops)
end

function ITensorMPS.expect(
::OpSum,
ψ::AbstractITensorNetwork;
cutoff=nothing,
maxdim=nothing,
ortho=false,
sequence=nothing,
ψ::AbstractITensorNetwork, op::Op; alg=default_expect_alg(), kwargs...
)
s = siteinds(ψ)
# h⃗ = Vector{ITensor}(ℋ, s)
if isnothing(sequence)
sequence = contraction_sequence(inner_network(ψ, ψ))
end
h⃗ψ = [apply(hᵢ, ψ; cutoff, maxdim, ortho) for hᵢ in ITensors.terms(ℋ)]
ψhᵢψ = [inner(ψ, hᵢψ; alg="exact", sequence) for hᵢψ in h⃗ψ]
ψh⃗ψ = sum(ψhᵢψ)
ψψ = norm_sqr(ψ; alg="exact", sequence)
return ψh⃗ψ / ψψ
return expect(Algorithm(alg), ψ, [op]; kwargs...)
end

function ITensorMPS.expect(
ψ::AbstractITensorNetwork, op::String, vertices; alg=default_expect_alg(), kwargs...
)
return expect(Algorithm(alg), ψ, [Op(op, vertex) for vertex in vertices]; kwargs...)
end

function ITensorMPS.expect(
opsum_sum::Sum{<:OpSum},
ψ::AbstractITensorNetwork;
cutoff=nothing,
maxdim=nothing,
ortho=true,
sequence=nothing,
ψ::AbstractITensorNetwork, op::String; alg=default_expect_alg(), kwargs...
)
return expect(sum(Ops.terms(opsum_sum)), ψ; cutoff, maxdim, ortho, sequence)
return expect(ψ, op, vertices(ψ); alg, kwargs...)
end
22 changes: 5 additions & 17 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
function operator_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f))
end

function bra_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == bra_vertex_suffix(f), vertices(f))
end
Expand All @@ -31,6 +32,10 @@ function ket_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f))
end

function operator_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [operator_vertex_map(f)(osv) for osv in original_state_vertices]
end

function bra_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [bra_vertex_map(f)(osv) for osv in original_state_vertices]
end
Expand Down Expand Up @@ -67,23 +72,6 @@ function operator_network(f::AbstractFormNetwork)
)
end

function environment(
f::AbstractFormNetwork,
original_state_vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
form_vertices = state_vertices(f, original_state_vertices)
if alg == "bp"
partitioned_vertices = group(v -> original_state_vertex(f, v), vertices(f))
return environment(
tensornetwork(f), form_vertices; alg, partitioned_vertices, kwargs...
)
else
return environment(tensornetwork(f), form_vertices; alg, kwargs...)
end
end

operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f))
bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f))
ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f))
Expand Down
Loading

0 comments on commit adce5ac

Please sign in to comment.