Skip to content

Commit

Permalink
Merge pull request #406 from ReactiveBayes/dev-ctransition
Browse files Browse the repository at this point in the history
Add MF rules for CTransition
  • Loading branch information
bvdmitri authored Jul 2, 2024
2 parents 824ab2a + 748bb7a commit df215c4
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 14 deletions.
43 changes: 42 additions & 1 deletion src/nodes/predefined/continuous_transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import LazyArrays
import StatsFuns: log2π

@doc raw"""
The ContinuousTransition node transforms an m-dimensional (dx) vector x into an n-dimensional (dy) vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a`.
The functional form of the ContinuousTransition node is given by:
y ~ Normal(K(a) * x, W⁻¹)
This node transforms an m-dimensional vector x into an n-dimensional vector y via a linear (or nonlinear) transformation with a `n×m`-dimensional matrix `A` that is constructed from a vector `a` via a transformation K(a).
ContinuousTransition node is primarily used in two regimes:
# When no structure on A is specified:
Expand Down Expand Up @@ -37,6 +40,20 @@ Interfaces:
4. W - `n×n`-dimensional precision matrix used to soften the transition and perform variational message passing.
Note that you can set W to a fixed value or put a prior on it to control the amount of jitter.
The ContinuousTransition node support two factorizations:
1. Mean-field factorization:
```julia
@constraints begin
q(y, x, a, W) = q(y)q(x)q(a)q(W)
end
```
2. Structured factorization:
```julia
@constraints begin
q(y, x, a, W) = q(y, x)q(a)q(W)
end
```
"""
struct ContinuousTransition end

Expand Down Expand Up @@ -121,3 +138,27 @@ end

return AE
end

@average_energy ContinuousTransition (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my, Vy = mean_cov(q_y)
mx, Vx = mean_cov(q_x)
mW = mean(q_W)

Fs = getjacobians(meta, ma)
dy = length(Fs)

n = div(ndims(q_y), 2)
mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta)

trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma))
xxt = mx * mx'
for (i, j) in Iterators.product(1:dy, 1:dy)
FjVaFi = Fs[j] * Va * Fs[i]'
trWSU += mW[j, i] * tr(FjVaFi)
trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi)
end
AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2

return AE
end
19 changes: 19 additions & 0 deletions src/rules/continuous_transition/W.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs)
dy = length(my)
G₁ = (my * my' + Vy)

G₂ = ((my * mx' + Vyx) * mA')
G₃ = transpose(G₂)
Ex_xx = rank1update(Vx, mx)
Expand All @@ -15,6 +16,7 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs)
return G₁ - G₂ - G₃ + G₅ + G₆
end

# VMP: Stuctured
@rule ContinuousTransition(:W, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
Fs = getjacobians(meta, ma)
Expand All @@ -33,3 +35,20 @@ end

return WishartFast(dy + 2, Δ)
end

# VMP: Mean-field
@rule ContinuousTransition(:W, Marginalisation) (q_y::Any, q_x::Any, q_a::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my, Vy = mean_cov(q_y)
mx, Vx = mean_cov(q_x)

Fs = getjacobians(meta, ma)
dy = length(Fs)

epsilon = sqrt.(var(q_a))
mA = ctcompanion_matrix(ma, epsilon, meta)

Δ = compute_delta(my, Vy, mx, Vx, zeros(eltype(ma), dy, length(mx)), mA, Va, ma, Fs)

return WishartFast(dy + 2, Δ)
end
28 changes: 28 additions & 0 deletions src/rules/continuous_transition/a.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Important note: ContinuousTransition node requires q(a) as input to compute the update message for a. This is a particular requirement for the ContinuousTransition node as it might need the expansion point for the transformation. This is not a general requirement for the VMP rules.

# VMP: Stuctured
@rule ContinuousTransition(:a, Marginalisation) (q_y_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin
ma = mean(q_a)
mW = mean(q_W)
Expand All @@ -23,3 +26,28 @@

return MvNormalWeightedMeanPrecision(xi, W)
end

# VMP: Mean-field
@rule ContinuousTransition(:a, Marginalisation) (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin
mx, Vx = mean_cov(q_x)
mW = mean(q_W)
my = mean(q_y)
ma = mean(q_a)

Fs = getjacobians(meta, ma)
dy = length(Fs)

xi, W = zeros(eltype(ma), length(ma)), zeros(eltype(ma), length(ma), length(ma))

mxmy = mx * my'
Vxmx = rank1update(Vx, mx)

for i in 1:dy
xi += Fs[i]' * mxmy * mW[:, i]
for j in 1:dy
W += mW[j, i] * Fs[i]' * Vxmx * Fs[j]
end
end

return MvNormalWeightedMeanPrecision(xi, W)
end
24 changes: 24 additions & 0 deletions src/rules/continuous_transition/x.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# VMP: Stuctured
@rule ContinuousTransition(:x, Marginalisation) (m_y::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my, Wy = mean_precision(m_y)
Expand All @@ -22,3 +23,26 @@

return MvNormalWeightedMeanPrecision(z, Ξ)
end

# VMP: Mean-field
@rule ContinuousTransition(:x, Marginalisation) (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta) = begin
ma, Va = mean_cov(q_a)
my = mean(q_y)
mW = mean(q_W)

Fs = getjacobians(meta, ma)
dy = length(Fs)

epsilon = sqrt.(var(q_a))
mA = ctcompanion_matrix(ma, epsilon, meta)

Ξ = mA' * mW * mA

for (i, j) in Iterators.product(1:dy, 1:dy)
Ξ += mW[j, i] * Fs[j] * Va * Fs[i]'
end

z = mA' * mW * my

return MvNormalWeightedMeanPrecision(z, Ξ)
end
6 changes: 6 additions & 0 deletions src/rules/continuous_transition/y.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# VMP: Stuctured
@rule ContinuousTransition(:y, Marginalisation) (m_x::MultivariateNormalDistributionsFamily, q_a::MultivariateNormalDistributionsFamily, q_W::Any, meta::CTMeta) = begin
ma = mean(q_a)
mx, Vx = mean_cov(m_x)
Expand All @@ -12,3 +13,8 @@

return MvNormalMeanCovariance(my, Vy)
end

# VMP: Mean-field
@rule ContinuousTransition(:y, Marginalisation) (q_x::Any, q_a::Any, q_W::Any, meta::CTMeta) = MvNormalMeanPrecision(
ctcompanion_matrix(mean(q_a), sqrt.(var(q_a)), meta) * mean(q_x), mean(q_W)
)
19 changes: 13 additions & 6 deletions test/nodes/predefined/continuous_transition_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@
using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily

import ReactiveMP: getjacobians, gettransformation, ctcompanion_matrix
# TODO: A more rigorous test suit for the average energy of CTransition needs to be added
dy, dx = 2, 3
meta = CTMeta(a -> reshape(a, dy, dx))

@testset "AverageEnergy" begin
q_y_x = MvNormalMeanCovariance(zeros(5), diageye(5))
q_a = MvNormalMeanCovariance(zeros(6), diageye(6))
q_W = Wishart(3, diageye(2))
q_y = MvNormalMeanCovariance(zeros(dy), diageye(dy))
q_x = MvNormalMeanCovariance(zeros(dx), diageye(dx))

marginals = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing))
q_y_x = MvNormalMeanCovariance([mean(q_y); mean(q_x)], [cov(q_y) zeros(dy, dx); zeros(dx, dy) cov(q_x)])
q_a = MvNormalMeanCovariance(zeros(dx * dy), diageye(dx * dy))
q_W = Wishart(dy + 1, diageye(dy))

@test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals, meta) 13.0 atol = 1e-2
@show getjacobians(meta, mean(q_a))
marginals_st = (Marginal(q_y_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing))
marginals_mf = (Marginal(q_y, false, false, nothing), Marginal(q_x, false, false, nothing), Marginal(q_a, false, false, nothing), Marginal(q_W, false, false, nothing))

# 12,992 is a result of manual calculation
@test score(AverageEnergy(), ContinuousTransition, Val{(:y_x, :a, :W)}(), marginals_st, meta) 12.992 atol = 1e-2
# 12,07336 is a result of manual calculation
@test score(AverageEnergy(), ContinuousTransition, Val{(:y, :x, :a, :W)}(), marginals_mf, meta) 12.07736 atol = 1e-2
end

@testset "ContinuousTransition Functionality" begin
Expand Down
39 changes: 37 additions & 2 deletions test/rules/continuous_transition/W_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@testset "Linear transformation" begin
# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule(q_y_x, mA, ΣA, UA)
function benchmark_rule_structured(q_y_x, mA, ΣA, UA)
myx, Vyx = mean_cov(q_y_x)

dy = size(mA, 1)
Expand Down Expand Up @@ -39,7 +39,7 @@
qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA))

@test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [(
input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule(qyx, mA, ΣA, UA)
input = (q_y_x = qyx, q_a = qa, meta = metal), output = benchmark_rule_structured(qyx, mA, ΣA, UA)
)]
end
end
Expand All @@ -62,4 +62,39 @@
)]
end
end

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule_meanfield(q_y, q_x, mA, ΣA, UA)
my, Vy = mean_cov(q_y)
mx, Vx = mean_cov(q_x)

dy = size(mA, 1)

G = tr(Vx * UA) * ΣA + mA * Vx * mA' + Vy + ΣA * mx' * UA * mx + (mA * mx - my) * (mA * mx - my)'

return WishartFast(dy + 2, G)
end

@testset "Mean-field: (q_y::Any, q_x::Any, q_a::Any, meta::CTMeta)" begin
for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)]
dydx = dy * dx
transformation = (a) -> reshape(a, dy, dx)
mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx)

metal = CTMeta(transformation)
Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy)
μx, Σx = rand(rng, dx), Lx * Lx'
μy, Σy = rand(rng, dy), Ly * Ly'

qy = MvNormalMeanCovariance(μy, Σy)
qx = MvNormalMeanCovariance(μx, Σx)

qa = MvNormalMeanCovariance(vec(mA), kron(UA, ΣA))

@test_rules [check_type_promotion = true, atol = 1e-5] ContinuousTransition(:W, Marginalisation) [(
input = (q_y = qy, q_x = qx, q_a = qa, meta = metal), output = benchmark_rule_meanfield(qy, qx, mA, ΣA, UA)
)]
end
end
end
38 changes: 36 additions & 2 deletions test/rules/continuous_transition/a_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule(q_y_x, q_W)
function benchmark_rule_structured(q_y_x, q_W)
myx, Vyx = mean_cov(q_y_x)
dy = size(q_W.S, 1)
Vx = Vyx[(dy + 1):end, (dy + 1):end]
Expand All @@ -36,7 +36,7 @@
qa = MvNormalMeanCovariance(a0, diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))
@test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [(
input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qyx, qW)
input = (q_y_x = qyx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_structured(qyx, qW)
)]
end
end
Expand All @@ -60,4 +60,38 @@
)]
end
end

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
# NOTE: this test rule does not allow q_x as a PointMass as it involves the covariance matrix of q_x
function benchmark_rule_meanfield(q_y, q_x, q_W)
my = mean(q_y)
mx, Vx = mean_cov(q_x)
mW = mean(q_W)
Λ = kron(Vx + mx * mx', mW)
return MvNormalWeightedMeanPrecision* (vec(my * mx' * inv(Vx + mx * mx'))), Λ)
end

@testset "Mean-field: (q_y::Any, q_x::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin
for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)]
dydx = dy * dx
transformation = (a) -> reshape(a, dy, dx)
a0 = rand(Float32, dydx)
metal = CTMeta(transformation)
Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy)
μx, Σx = rand(rng, dx), Lx * Lx'
μy, Σy = rand(rng, dy), Ly * Ly'
qy = MvNormalMeanCovariance(μy, Σy)
qx = MvNormalMeanCovariance(μx, Σx)
qa = MvNormalMeanCovariance(a0, diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))
@test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [(
input = (q_y = qy, q_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(qy, qx, qW)
)]

@test_rules [check_type_promotion = false] ContinuousTransition(:a, Marginalisation) [(
input = (q_y = PointMass(μy), q_x = qx, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(PointMass(μy), qx, qW)
)]
end
end
end
37 changes: 34 additions & 3 deletions test/rules/continuous_transition/x_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@testset "Linear transformation" begin
# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule(q_y, q_W, mA, ΣA, UA)
function benchmark_rule_strucutred(q_y, q_W, mA, ΣA, UA)
my, Vy = mean_cov(q_y)

mW = mean(q_W)
Expand All @@ -27,15 +27,15 @@
mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx)

metal = CTMeta(transformation)
Lx, Ly = rand(rng, dx, dx), rand(rng, dy, dy)
Ly = rand(rng, dy, dy)
μy, Σy = rand(rng, dy), Ly * Ly'

qy = MvNormalMeanCovariance(μy, Σy)
qa = MvNormalMeanCovariance(vec(mA), diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))

@test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [(
input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule(qy, qW, mA, ΣA, UA)
input = (m_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_strucutred(qy, qW, mA, ΣA, UA)
)]
end
end
Expand All @@ -59,4 +59,35 @@
)]
end
end

# the following rule is used for testing purposes only
# It is derived separately by Thijs van de Laar
function benchmark_rule_meanfield(q_y, q_W, mA, ΣA, UA)
mW = mean(q_W)

Λ = mA'mW * mA + tr(mW * ΣA) * UA
ξ = mA' * mW * mean(q_y)
return MvNormalWeightedMeanPrecision(ξ, Λ)
end

@testset "Mean-field: (q_y::Any, q_a::Any, q_W::Any, meta::CTMeta)" begin
for (dy, dx) in [(1, 3), (2, 3), (3, 2), (2, 2)]
dydx = dy * dx
transformation = (a) -> reshape(a, dy, dx)

mA, ΣA, UA = rand(rng, dy, dx), diageye(dy), diageye(dx)

metal = CTMeta(transformation)
Ly = rand(rng, dy, dy)
μy, Σy = rand(rng, dy), Ly * Ly'

qy = MvNormalMeanCovariance(μy, Σy)
qa = MvNormalMeanCovariance(vec(mA), diageye(dydx))
qW = Wishart(dy + 1, diageye(dy))

@test_rules [check_type_promotion = true, atol = 1e-4] ContinuousTransition(:x, Marginalisation) [(
input = (q_y = qy, q_a = qa, q_W = qW, meta = metal), output = benchmark_rule_meanfield(qy, qW, mA, ΣA, UA)
)]
end
end
end
Loading

0 comments on commit df215c4

Please sign in to comment.