Skip to content

Commit

Permalink
Add specialized correlation behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Nov 27, 2024
1 parent 794976a commit 1022c73
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions src/nonlocal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ export tensor_probability
Converts a 2x...x2xmx...xm probability array into
- a mx...xm correlation array (no marginals)
- a (m+1)x...x(m+1) correlation array (marginals).
If `behaviour` is `true` do the transformation for behaviours. Does assume normalization.
If `behaviour` is `true` do the transformation for behaviours. Doesn't assume normalization.
Also accepts the arguments of `tensor_probability` (state and measurements) for convenience.
"""
Expand Down Expand Up @@ -318,8 +318,29 @@ function tensor_correlation(p::AbstractArray{T, N2}, behaviour::Bool = false; ma
return FC
end
# accepts directly the arguments of tensor_probability
function tensor_correlation(rho::Hermitian, all_Aax::Vector{<:Measurement}...; marg::Bool = true)
return tensor_correlation(tensor_probability(rho, all_Aax...), true; marg)
# avoids creating the full probability tensor for performance
function tensor_correlation(
rho::Hermitian{T1, Matrix{T1}},
first_Aax::Vector{Measurement{T2}}, # needed so that T2 is not unbounded
other_Aax::Vector{Measurement{T2}}...;
marg::Bool = true,
) where {T1, T2}
T = real(promote_type(T1, T2))
all_Aax = (first_Aax, other_Aax...)
N = length(all_Aax)
m = Tuple(length.(all_Aax)) # numbers of inputs per party
o = Tuple(broadcast(Aax -> maximum(length.(Aax)), all_Aax)) # numbers of outputs per party
@assert all(o .== 2)
@assert all(broadcast(Aax -> minimum(length.(Aax)), all_Aax) .== 2) # sanity check
size_FC = marg ? m .+ 1 : m
FC = zeros(T, size_FC)
cia = CartesianIndices(o)
cix = CartesianIndices(size_FC)
for a in cia, x in cix
obs = [x[n] > marg ? all_Aax[n][x[n] - marg][1] - all_Aax[n][x[n] - marg][2] : one(all_Aax[n][1][1]) for n in 1:N]
FC[x] = real(dot(Hermitian(kron(obs...)), rho))
end
return FC
end
# shorthand syntax for identical measurements on all parties
function tensor_correlation(rho::Hermitian, Aax::Vector{<:Measurement}, N::Integer; marg::Bool = true)
Expand Down

0 comments on commit 1022c73

Please sign in to comment.