Skip to content

Commit

Permalink
Merge pull request #250 from biaslab/dev-issue-249
Browse files Browse the repository at this point in the history
Fix issue #249
  • Loading branch information
bvdmitri authored Jan 27, 2023
2 parents f9d83ac + f5f42b3 commit 351dea1
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 48 deletions.
8 changes: 7 additions & 1 deletion docs/src/extra/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ In addition tests can be evaluated by running following command in the ReactiveM
make test
```

### Fixes to external libraries

If a bug has been discovered in an external dependencies of the `ReactiveMP.jl` it is the best to open an issue
directly in the dependency's github repository. You use can use the `fixes.jl` file for hot-fixes before
a new release of the broken dependecy is available.

### Makefile

`ReactiveMP.jl` uses `Makefile` for most common operations:
Expand All @@ -80,4 +86,4 @@ make test
- `make docs`: Compile documentation
- `make benchmark`: Run simple benchmark
- `make lint`: Check codestyle
- `make format`: Check and fix codestyle
- `make format`: Check and fix codestyle
1 change: 1 addition & 0 deletions src/ReactiveMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using TinyHugeNumbers
# Reexport `tiny` and `huge` from the `TinyHugeNumbers`
export tiny, huge

include("fixes.jl")
include("helpers/macrohelpers.jl")
include("helpers/helpers.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanPrecision, x::AbstractVecto
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r)
return dot(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalMeanPrecision{T}) where {T} = T
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mv_normal_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function Distributions.sqmahal!(r, dist::MvNormalWeightedMeanPrecision, x::Abstr
for i in 1:length(r)
@inbounds r[i] = μ[i] - x[i]
end
return dot(r, invcov(dist), r)
return dot(r, invcov(dist), r) # x' * A * x
end

Base.eltype(::MvNormalWeightedMeanPrecision{T}) where {T} = T
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanVariance})
promote_variate_type(::Type{Multivariate}, ::Type{<:NormalMeanPrecision}) = MvNormalMeanPrecision
promote_variate_type(::Type{Multivariate}, ::Type{<:NormalWeightedMeanPrecision}) = MvNormalWeightedMeanPrecision

# Conversion to gaussian distributions from `Distributions.jl`

Base.convert(::Type{Normal}, dist::UnivariateNormalDistributionsFamily) = Normal(mean_std(dist)...)
Base.convert(::Type{MvNormal}, dist::MultivariateNormalDistributionsFamily) = MvNormal(mean_cov(dist)...)

# Conversion to mean - variance parametrisation

function Base.convert(::Type{NormalMeanVariance{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real}
Expand Down
27 changes: 27 additions & 0 deletions src/fixes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file implements various hot-fixes for external dependencies
# This file can be empty, which is fine. It only means that all external dependecies released a new version
# that is now fixed

# Fix for 3-argument `dot` product and `ForwardDiff.hessian`, see
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/551
# https://github.com/JuliaDiff/ForwardDiff.jl/pull/481
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/480
import LinearAlgebra: dot
import ForwardDiff

function dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector{D}) where {D <: ForwardDiff.Dual}
(axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch())
T = typeof(dot(first(x), first(A), first(y)))
s = zero(T)
i₁ = first(eachindex(x))
x₁ = first(x)
@inbounds for j in eachindex(y)
yj = y[j]
temp = zero(adjoint(A[i₁, j]) * x₁)
@simd for i in eachindex(x)
temp += adjoint(A[i, j]) * x[i]
end
s += dot(temp, yj)
end
return s
end
2 changes: 1 addition & 1 deletion src/nodes/autoregressive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ default_meta(::Type{AR}) = error("Autoregressive node requires meta flag explici
my1, Vy1 = first(myx), first(Vyx)
Vy1x = ar_slice(getvform(meta), Vyx, 1, (order + 1):(2order))

# Euivalento to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
# Equivalent to AE = (-mean(log, q_γ) + log2π + mγ*(Vy1+my1^2 - 2*mθ'*(Vy1x + mx*my1) + tr(Vθ*Vx) + mx'*Vθ*mx + mθ'*(Vx + mx*mx')*mθ)) / 2
AE = (-mean(log, q_γ) + log2π +* (Vy1 + my1^2 - 2 *' * (Vy1x + mx * my1) + mul_trace(Vθ, Vx) + dot(mx, Vθ, mx) + dot(mθ, Vx, mθ) + abs2(dot(mθ, mx)))) / 2

# correction
Expand Down
21 changes: 21 additions & 0 deletions test/approximations/test_grad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module ForwardDiffGradTest

using Test
using ReactiveMP
using Random
using LinearAlgebra
using Distributions
using ForwardDiff

import ReactiveMP: convert_eltype

@testset "ForwardDiffGrad" begin
grad = ForwardDiffGrad()

for i in 1:100
@test ReactiveMP.compute_gradient(grad, (x) -> sum(x)^2, [i]) [2 * i]
@test ReactiveMP.compute_hessian(grad, (x) -> sum(x)^2, [i]) [2;;]
end
end

end
109 changes: 65 additions & 44 deletions test/distributions/test_normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,42 @@ using ReactiveMP
using Random
using LinearAlgebra
using Distributions
using ForwardDiff

import ReactiveMP: convert_eltype

@testset "Normal" begin
@testset "Univariate conversions" begin
check_basic_statistics = (left, right) -> begin
@test mean(left) mean(right)
@test median(left) median(right)
@test mode(left) mode(right)
@test weightedmean(left) weightedmean(right)
@test var(left) var(right)
@test std(left) std(right)
@test cov(left) cov(right)
@test invcov(left) invcov(right)
@test precision(left) precision(right)
@test entropy(left) entropy(right)
@test pdf(left, 1.0) pdf(right, 1.0)
@test pdf(left, -1.0) pdf(right, -1.0)
@test pdf(left, 0.0) pdf(right, 0.0)
@test logpdf(left, 1.0) logpdf(right, 1.0)
@test logpdf(left, -1.0) logpdf(right, -1.0)
@test logpdf(left, 0.0) logpdf(right, 0.0)
end
check_basic_statistics =
(left, right; include_extended_methods = true) -> begin
@test mean(left) mean(right)
@test median(left) median(right)
@test mode(left) mode(right)
@test var(left) var(right)
@test std(left) std(right)
@test entropy(left) entropy(right)

for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand())
@test pdf(left, value) pdf(right, value)
@test logpdf(left, value) logpdf(right, value)
@test all(ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value]))
@test all(ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value]))
end

# These methods are not defined for distributions from `Distributions.jl
if include_extended_methods
@test cov(left) cov(right)
@test invcov(left) invcov(right)
@test weightedmean(left) weightedmean(right)
@test precision(left) precision(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
end
end

types = ReactiveMP.union_types(UnivariateNormalDistributionsFamily{Float64})
etypes = ReactiveMP.union_types(UnivariateNormalDistributionsFamily)
Expand All @@ -36,6 +49,7 @@ import ReactiveMP: convert_eltype

for type in types
left = convert(type, rand(rng, Float64), rand(rng, Float64))
check_basic_statistics(left, convert(Normal, left); include_extended_methods = false)
for type in [types..., etypes...]
right = convert(type, left)
check_basic_statistics(left, right)
Expand All @@ -56,32 +70,38 @@ import ReactiveMP: convert_eltype
end

@testset "Multivariate conversions" begin
check_basic_statistics = (left, right, dims) -> begin
@test mean(left) mean(right)
@test mode(left) mode(right)
@test weightedmean(left) weightedmean(right)
@test var(left) var(right)
@test cov(left) cov(right)
@test invcov(left) invcov(right)
@test logdetcov(left) logdetcov(right)
@test precision(left) precision(right)
@test length(left) === length(right)
@test ndims(left) === ndims(right)
@test size(left) === size(right)
@test entropy(left) entropy(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
@test pdf(left, fill(1.0, dims)) pdf(right, fill(1.0, dims))
@test pdf(left, fill(-1.0, dims)) pdf(right, fill(-1.0, dims))
@test pdf(left, fill(0.0, dims)) pdf(right, fill(0.0, dims))
@test logpdf(left, fill(1.0, dims)) logpdf(right, fill(1.0, dims))
@test logpdf(left, fill(-1.0, dims)) logpdf(right, fill(-1.0, dims))
@test logpdf(left, fill(0.0, dims)) logpdf(right, fill(0.0, dims))
end
check_basic_statistics =
(left, right, dims; include_extended_methods = true) -> begin
@test mean(left) mean(right)
@test mode(left) mode(right)
@test var(left) var(right)
@test cov(left) cov(right)
@test logdetcov(left) logdetcov(right)
@test length(left) === length(right)
@test size(left) === size(right)
@test entropy(left) entropy(right)

for value in (fill(1.0, dims), fill(-1.0, dims), fill(0.0, dims), mean(left), mean(right), rand(dims))
@test pdf(left, value) pdf(right, value)
@test logpdf(left, value) logpdf(right, value)
@test all(isapprox.(ForwardDiff.gradient((x) -> logpdf(left, x), value), ForwardDiff.gradient((x) -> logpdf(right, x), value), atol = 1e-14))
@test all(isapprox.(ForwardDiff.hessian((x) -> logpdf(left, x), value), ForwardDiff.hessian((x) -> logpdf(right, x), value), atol = 1e-14))
end

# These methods are not defined for distributions from `Distributions.jl
if include_extended_methods
@test ndims(left) === ndims(right)
@test invcov(left) invcov(right)
@test weightedmean(left) weightedmean(right)
@test precision(left) precision(right)
@test all(mean_cov(left) .≈ mean_cov(right))
@test all(mean_invcov(left) .≈ mean_invcov(right))
@test all(mean_precision(left) .≈ mean_precision(right))
@test all(weightedmean_cov(left) .≈ weightedmean_cov(right))
@test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right))
@test all(weightedmean_precision(left) .≈ weightedmean_precision(right))
end
end

types = ReactiveMP.union_types(MultivariateNormalDistributionsFamily{Float64})
etypes = ReactiveMP.union_types(MultivariateNormalDistributionsFamily)
Expand All @@ -92,6 +112,7 @@ import ReactiveMP: convert_eltype
for dim in dims
for type in types
left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(rand(rng, Float64, dim))))
check_basic_statistics(left, convert(MvNormal, left), dim; include_extended_methods = false)
for type in [types..., etypes...]
right = convert(type, left)
check_basic_statistics(left, right, dim)
Expand Down

0 comments on commit 351dea1

Please sign in to comment.