Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #249 #250

Merged
merged 9 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_mean_covariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanCovariance, x::AbstractVect
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r) # x' * A * x
return xT_A_y(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalMeanCovariance{T}) where {T} = T
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanPrecision, x::AbstractVecto
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r)
return xT_A_y(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalMeanPrecision{T}) where {T} = T
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function Distributions.sqmahal!(r, dist::MvNormalWeightedMeanPrecision, x::Abstr
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r)
return xT_A_y(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalWeightedMeanPrecision{T}) where {T} = T
Expand Down
7 changes: 6 additions & 1 deletion src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanVariance})
promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanPrecision}) = MvNormalMeanPrecision
promote_variate_type(::Type{Multivariate}, ::Type{<:NormalWeightedMeanPrecision}) = MvNormalWeightedMeanPrecision

# Conversion to gaussian distributions from `Distributions.jl`

Base.convert(::Type{Normal}, dist::UnivariateNormalDistributionsFamily) = Normal(mean_std(dist)...)
Base.convert(::Type{MvNormal}, dist::MultivariateNormalDistributionsFamily) = MvNormal(mean_cov(dist)...)

# Conversion to mean - variance parametrisation

function Base.convert(::Type{NormalMeanVariance{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real}
Expand Down Expand Up @@ -312,7 +317,7 @@ function Base.prod(
n = length(left)
v_inv, v_logdet = cholinv_logdet(v)
m = m_left - m_right
return -(v_logdet + n * log2π) / 2 - dot(m, v_inv, m) / 2
return -(v_logdet + n * log2π) / 2 - xT_A_y(m, v_inv, m) / 2
end

## Friendly functions
Expand Down
25 changes: 25 additions & 0 deletions src/helpers/algebra/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ function v_a_vT(v1, a, v2)
return result
end

"""
xT_A_y(x, A, y)

Computes `dot(x, A, y)`. The built-in Julia 3-arg `dot` is not compatible with the auto-differentiation packages,
such as `ForwardDiff`. We use our own implementation in some cases but ultimately fallback to the `dot`.
"""
xT_A_y(x, A, y) = dot(x, A, y)

function xT_A_y(x::AbstractVector, A::AbstractMatrix, y::AbstractVector)
(axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch())
T = typeof(dot(first(x), first(A), first(y)))
s = zero(T)
i₁ = first(eachindex(x))
x₁ = first(x)
@inbounds for j in eachindex(y)
yj = y[j]
temp = zero(adjoint(A[i₁, j]) * x₁)
@simd for i in eachindex(x)
temp += adjoint(A[i, j]) * x[i]
end
s += dot(temp, yj)
end
return s
end

"""
mvbeta(x)

Expand Down
6 changes: 3 additions & 3 deletions src/nodes/autoregressive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici
my1, Vy1 = first(myx), first(Vyx)
Vy1x = ar_slice(getvform(meta), Vyx, 1, (order + 1):(2order))

# Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2
# Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π + mγ * (Vy1 + my1^2 - 2 * mθ' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2

# correction
if is_multivariate(meta)
Expand All @@ -76,7 +76,7 @@ end

my1, Vy1 = first(my), first(Vy)

AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))
AE = -0.5mean(log, q_γ) + 0.5log2π + 0.5 * mγ * (Vy1 + my1^2 - 2 * mθ' * mx * my1 + mul_trace(Vθ, Vx) + xT_A_y(mx, Vθ, mx) + xT_A_y(mθ, Vx, mθ) + abs2(dot(mθ, mx)))

# correction
if is_multivariate(meta)
Expand Down
2 changes: 1 addition & 1 deletion src/rules/dot_product/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ end
@rule typeof(dot)(:out, Marginalisation) (m_in1::PointMass, m_in2::NormalDistributionsFamily, meta::AbstractCorrection) = begin
A = mean(m_in1)
in2_mean, in2_cov = mean_cov(m_in2)
return NormalMeanVariance(dot(A, in2_mean), dot(A, in2_cov, A))
return NormalMeanVariance(dot(A, in2_mean), xT_A_y(A, in2_cov, A))
end
2 changes: 1 addition & 1 deletion src/rules/multiplication/A.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
@rule typeof(*)(:A, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_in::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
A = mean(m_in)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, dot(A, W_out, A))
W = correction!(meta, xT_A_y(A, W_out, A))
return NormalWeightedMeanPrecision(dot(A, ξ_out), W)
end

Expand Down
2 changes: 1 addition & 1 deletion src/rules/multiplication/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
@rule typeof(*)(:in, Marginalisation) (m_out::MultivariateNormalDistributionsFamily, m_A::PointMass{<:AbstractVector}, meta::Union{<:AbstractCorrection, Nothing}) = begin
A = mean(m_A)
ξ_out, W_out = weightedmean_precision(m_out)
W = correction!(meta, dot(A, W_out, A))
W = correction!(meta, xT_A_y(A, W_out, A))
return NormalWeightedMeanPrecision(dot(A, ξ_out), W)
end

Expand Down
12 changes: 12 additions & 0 deletions test/algebra/test_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ using LinearAlgebra
@test ReactiveMP.mul_trace(a, b) ≈ a * b
end
end

@testset "xT_A_y" begin
import ReactiveMP: xT_A_y

rng = MersenneTwister(1234)
for size in 2:5, T1 in (Float32, Float64), T2 in (Float32, Float64), T3 in (Float32, Float64)
x = rand(T1, size)
A = rand(T2, size, size)
y = rand(T3, size)
@test dot(x, A, y) ≈ xT_A_y(x, A, y)
end
end
end

end
109 changes: 65 additions & 44 deletions test/distributions/test_normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,42 @@ using ReactiveMP
using Random
using LinearAlgebra
using Distributions
using ForwardDiff

import ReactiveMP: convert_eltype

@testset "Normal" begin
@testset "Univariate conversions" begin
check_basic_statistics = (left, right) -> begin
@test mean(left) ≈ mean(right)
@test median(left) ≈ median(right)
@test mode(left) ≈ mode(right)
@test weightedmean(left) ≈ weightedmean(right)
@test var(left) ≈ var(right)
@test std(left) ≈ std(right)
@test cov(left) ≈ cov(right)
@test invcov(left) ≈ invcov(right)
@test precision(left) ≈ precision(right)
@test entropy(left) ≈ entropy(right)
@test pdf(left, 1.0) ≈ pdf(right, 1.0)
@test pdf(left, -1.0) ≈ pdf(right, -1.0)
@test pdf(left, 0.0) ≈ pdf(right, 0.0)
@test logpdf(left, 1.0) ≈ logpdf(right, 1.0)
@test logpdf(left, -1.0) ≈ logpdf(right, -1.0)
@test logpdf(left, 0.0) ≈ logpdf(right, 0.0)
end
check_basic_statistics =
(left, right; include_extended_methods = true) -> begin
@test mean(left) ≈ mean(right)
@test median(left) ≈ median(right)
@test mode(left) ≈ mode(right)
@test var(left) ≈ var(right)
@test std(left) ≈ std(right)
@test entropy(left) ≈ entropy(right)

for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand())
@test pdf(left, value) ≈ pdf(right, value)
@test logpdf(left, value) ≈ logpdf(right, value)
@test all(ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value]))
@test all(ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value]))
end

# These methods are not defined for distributions from `Distributions.jl
if include_extended_methods
@test cov(left) ≈ cov(right)
@test invcov(left) ≈ invcov(right)
@test weightedmean(left) ≈ weightedmean(right)
@test precision(left) ≈ precision(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
end
end

types = ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float64})
etypes = ReactiveMP.union_types(UnivariateNormalDistributionsFamily)
Expand All @@ -36,6 +49,7 @@ import ReactiveMP: convert_eltype

for type in types
left = convert(type, rand(rng, Float64), rand(rng, Float64))
check_basic_statistics(left, convert(Normal, left); include_extended_methods = false)
for type in [types..., etypes...]
right = convert(type, left)
check_basic_statistics(left, right)
Expand All @@ -56,32 +70,38 @@ import ReactiveMP: convert_eltype
end

@testset "Multivariate conversions" begin
check_basic_statistics = (left, right, dims) -> begin
@test mean(left) ≈ mean(right)
@test mode(left) ≈ mode(right)
@test weightedmean(left) ≈ weightedmean(right)
@test var(left) ≈ var(right)
@test cov(left) ≈ cov(right)
@test invcov(left) ≈ invcov(right)
@test logdetcov(left) ≈ logdetcov(right)
@test precision(left) ≈ precision(right)
@test length(left) === length(right)
@test ndims(left) === ndims(right)
@test size(left) === size(right)
@test entropy(left) ≈ entropy(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
@test pdf(left, fill(1.0, dims)) ≈ pdf(right, fill(1.0, dims))
@test pdf(left, fill(-1.0, dims)) ≈ pdf(right, fill(-1.0, dims))
@test pdf(left, fill(0.0, dims)) ≈ pdf(right, fill(0.0, dims))
@test logpdf(left, fill(1.0, dims)) ≈ logpdf(right, fill(1.0, dims))
@test logpdf(left, fill(-1.0, dims)) ≈ logpdf(right, fill(-1.0, dims))
@test logpdf(left, fill(0.0, dims)) ≈ logpdf(right, fill(0.0, dims))
end
check_basic_statistics =
(left, right, dims; include_extended_methods = true) -> begin
@test mean(left) ≈ mean(right)
@test mode(left) ≈ mode(right)
@test var(left) ≈ var(right)
@test cov(left) ≈ cov(right)
@test logdetcov(left) ≈ logdetcov(right)
@test length(left) === length(right)
@test size(left) === size(right)
@test entropy(left) ≈ entropy(right)

for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims))
@test pdf(left, value) ≈ pdf(right, value)
@test logpdf(left, value) ≈ logpdf(right, value)
@test all(isapprox.(ForwardDiff.gradient((x) -> logpdf(left, x), value), ForwardDiff.gradient((x) -> logpdf(right, x), value), atol = 1e-14))
@test all(isapprox.(ForwardDiff.hessian((x) -> logpdf(left, x), value), ForwardDiff.hessian((x) -> logpdf(right, x), value), atol = 1e-14))
end

# These methods are not defined for distributions from `Distributions.jl
if include_extended_methods
@test ndims(left) === ndims(right)
@test invcov(left) ≈ invcov(right)
@test weightedmean(left) ≈ weightedmean(right)
@test precision(left) ≈ precision(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
end
end

types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64})
etypes = ReactiveMP.union_types(MultivariateNormalDistributionsFamily)
Expand All @@ -92,6 +112,7 @@ import ReactiveMP: convert_eltype
for dim in dims
for type in types
left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(rand(rng, Float64, dim))))
check_basic_statistics(left, convert(MvNormal, left), dim; include_extended_methods = false)
for type in [types..., etypes...]
right = convert(type, left)
check_basic_statistics(left, right, dim)
Expand Down