Skip to content

Commit

Permalink
Merge pull request #1 from harisorgn/ho/vec-lkj-cholesky
Browse files Browse the repository at this point in the history
[WIP] LKJ and LKJCholesky bijectors
  • Loading branch information
harisorgn authored Apr 6, 2023
2 parents 62ae1ac + 222eb6e commit 7f5d0fc
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 44 deletions.
193 changes: 152 additions & 41 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ struct CorrBijector <: Bijector end

with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(b::CorrBijector, x::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(x).U)) # keep LowerTriangular until here can avoid some computation
function transform(b::CorrBijector, X::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation
r = _link_chol_lkj(w)
return r + zero(x)
return r + zero(X)
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
end
Expand All @@ -78,7 +78,7 @@ function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
return pd_from_upper(w)
end

logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y)
logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y)
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
#=
It may be more efficient if we can use un-contraint value to prevent call of b
Expand Down Expand Up @@ -173,6 +173,23 @@ end

inverse(::typeof(vec_to_triu1)) = triu1_to_vec

function vec_to_triu1_row_index(idx)
# Assumes that vector was saved in a column-major order
# and that vector is one-based indexed.
M = _triu1_dim_from_length(idx - 1)
return idx - (M*(M-1) ÷ 2)
end

abstract type AbstractVecCorrBijector <: Bijector end

with_logabsdet_jacobian(b::AbstractVecCorrBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::AbstractVecCorrBijector, X) = (_link_chol_lkj cholesky_factor)(X)

function logabsdetjac(b::AbstractVecCorrBijector, x)
return -logabsdetjac(inverse(b), b(x))
end

"""
VecCorrBijector <: Bijector
Expand Down Expand Up @@ -205,29 +222,21 @@ julia> y = b(X) # Transform to unconstrained vector representation.
julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse.
true
"""
struct VecCorrBijector <: Bijector end
with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x)
struct VecCorrBijector <: AbstractVecCorrBijector end
transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = (pd_from_upper _inv_link_chol_lkj)(y)

function transform(::VecCorrBijector, X::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(X).U))
r = _link_chol_lkj(w)
logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_corr(y)

# Extract only the upper triangle of `r`.
return triu1_to_vec(r)
end
struct VecTriuBijector <: AbstractVecCorrBijector end
transform(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = (Cholesky UpperTriangular _inv_link_chol_lkj)(y)

function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
Y = vec_to_triu1(y)
w = _inv_link_chol_lkj(Y)
return pd_from_upper(w)
end
logabsdetjac(::Inverse{VecTriuBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y)

struct VecTrilBijector <: AbstractVecCorrBijector end
transform(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = (Cholesky LowerTriangular transpose _inv_link_chol_lkj)(y)

logabsdetjac(::Inverse{VecTrilBijector}, y::AbstractVector{<:Real}) = _logabsdetjac_inv_chol(y)

function logabsdetjac(b::VecCorrBijector, x)
return -logabsdetjac(inverse(b), b(x))
end
function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_chol_lkj(vec_to_triu1(y))
end

"""
function _link_chol_lkj(w)
Expand Down Expand Up @@ -261,21 +270,21 @@ and so
which is the above implementation.
"""
function _link_chol_lkj(w)
function _link_chol_lkj(W::AbstractMatrix)
# TODO: Implement adjoint to support reverse-mode AD backends properly.
K = LinearAlgebra.checksquare(w)
K = LinearAlgebra.checksquare(W)

z = similar(w) # z is also UpperTriangular.
z = similar(W) # z is also UpperTriangular.
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

# This block can't be integrated with loop below, because w[1,1] != 0.
# This block can't be integrated with loop below, because W[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j = 2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
z[1, j] = atanh(W[1, j])
tmp = sqrt(1 - W[1, j]^2)
for i in 2:(j-1)
p = w[i, j] / tmp
p = W[i, j] / tmp
tmp *= sqrt(1 - p^2)
z[i, j] = atanh(p)
end
Expand All @@ -285,33 +294,102 @@ function _link_chol_lkj(w)
return z
end

function _link_chol_lkj(W::UpperTriangular)
K = LinearAlgebra.checksquare(W)
N = ((K-1)*K) ÷ 2 # {K \choose 2} free parameters

z = zeros(eltype(W), N)

# This block can't be integrated with loop below, because w[1,1] != 0.
idx = 1
@inbounds for j = 2:K
z[idx] = atanh(W[1, j])
idx += 1
tmp = sqrt(1 - W[1, j]^2)
for i in 2:(j-1)
p = W[i, j] / tmp
tmp *= sqrt(1 - p^2)
z[idx] = atanh(p)
idx += 1
end
end

return z
end

function _link_chol_lkj(W::LowerTriangular)
K = LinearAlgebra.checksquare(W)
N = div((K-1)*K, 2) # {K \choose 2} free parameters

z = zeros(eltype(W), N)

# This block can't be integrated with loop below, because w[1,1] != 0.
idx = 1
@inbounds for i = 2:K
z[idx] = atanh(W[i, 1])
idx += 1
tmp = sqrt(1 - W[i, 1]^2)
for j in 2:(i-1)
p = W[i, j] / tmp
tmp *= sqrt(1 - p^2)
z[idx] = atanh(p)
idx += 1
end
end

return z
end

"""
_inv_link_chol_lkj(y)
Inverse link function for cholesky factor.
"""
function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)
function _inv_link_chol_lkj(Y::AbstractMatrix)
# TODO: Implement adjoint to support reverse-mode AD backends properly.
K = LinearAlgebra.checksquare(Y)

w = similar(y)
W = similar(Y)

@inbounds for j in 1:K
w[1, j] = 1
W[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]
w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
z = tanh(Y[i-1, j])
tmp = W[i-1, j]
W[i-1, j] = z * tmp
W[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
W[i, j] = 0
end
end

return W
end

function _inv_link_chol_lkj(y::AbstractVector)
# TODO: Implement adjoint to support reverse-mode AD backends properly.
K = _triu1_dim_from_length(length(y))

W = similar(y, K, K)
W .= zeros(eltype(y))

idx = 1
@inbounds for j in 1:K
W[1, j] = 1
for i in 2:j
z = tanh(y[idx])
idx += 1
tmp = W[i-1, j]
W[i-1, j] = z * tmp
W[i, j] = tmp * sqrt(1 - z^2)
end
end

return w
return W
end

function _logabsdetjac_chol_lkj(Y::AbstractMatrix)
function _logabsdetjac_inv_corr(Y::AbstractMatrix)
K = LinearAlgebra.checksquare(Y)

result = float(zero(eltype(Y)))
Expand All @@ -323,3 +401,36 @@ function _logabsdetjac_chol_lkj(Y::AbstractMatrix)
end
return result
end

function _logabsdetjac_inv_corr(y::AbstractVector)
K = _triu1_dim_from_length(length(y))

result = float(zero(eltype(y)))
for (i, y_i) in enumerate(y)
abs_y_i = abs(y_i)
row_idx = vec_to_triu1_row_index(i)
result += (K - row_idx + 1) * (
IrrationalConstants.logtwo - (abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i))
)
end
return result
end

function _logabsdetjac_inv_chol(y::AbstractVector)
K = _triu1_dim_from_length(length(y))

result = float(zero(eltype(y)))
idx = 1
@inbounds for j in 2:K
tmp = zero(result)
for _ in 1:(j-1)
z = tanh(y[idx])
logz = log(1 - z^2)
tmp += logz
result += logz + (tmp / 2)
idx += 1
end
end

return result
end
56 changes: 55 additions & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,59 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM
return y, _transform_inverse_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
project_W = ChainRulesCore.ProjectTo(W)

K = LinearAlgebra.checksquare(W)
N = ((K-1)*K) ÷ 2

z = zeros(eltype(W), N)
tmp_vec = similar(z)

idx = 1
@inbounds for j = 2:K
z[idx] = atanh(W[1, j])
tmp = sqrt(1 - W[1, j]^2)
tmp_vec[idx] = tmp
idx += 1
for i in 2:(j-1)
p = W[i, j] / tmp
tmp *= sqrt(1 - p^2)
tmp_vec[idx] = tmp
z[idx] = atanh(p)
idx += 1
end
end

function pullback_link_chol_lkj(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)

@inbounds ΔW[1,1] = zero(eltype(Δz))
@inbounds for j=2:K
idx_up_to_prev_column = ((j-1)*(j-2) ÷ 2)
ΔW[j, j] = zero(eltype(Δz))
Δtmp = zero(eltype(Δz))
for i in (j-1):-1:2
tmp = tmp_vec[idx_up_to_prev_column + i - 1]
p = W[i, j] / tmp
ftmp = sqrt(1 - p^2)
d_ftmp_p = -p / ftmp
d_p_tmp = -W[i,j] / tmp^2

Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp * d_ftmp_p
ΔW[i, j] = Δp / tmp
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
ΔW[1, j] = Δz[1, j] / (1-W[1,j]^2) - Δtmp / sqrt(1 - W[1,j]^2) * W[1,j]
end

return ChainRulesCore.NoTangent(), project_W(ΔW)
end

return z, pullback_link_chol_lkj
end

# Fixes Zygote's issues with `@debug`
ChainRulesCore.@non_differentiable _debug(::Any)
ChainRulesCore.@non_differentiable _debug(::Any)
13 changes: 12 additions & 1 deletion src/transformed_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<
TransformedDistribution(d::UnivariateDistribution, b) = new{typeof(d), typeof(b), Univariate}(d, b)
TransformedDistribution(d::MultivariateDistribution, b) = new{typeof(d), typeof(b), Multivariate}(d, b)
TransformedDistribution(d::MatrixDistribution, b) = new{typeof(d), typeof(b), Matrixvariate}(d, b)
TransformedDistribution(d::Distribution{CholeskyVariate}, b) = new{typeof(d), typeof(b), CholeskyVariate}(d, b)
end

# fields may contain nested numerical parameters
Expand Down Expand Up @@ -77,7 +78,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d)
bijector(d::PDMatDistribution) = PDBijector()
bijector(d::MatrixBeta) = PDBijector()

bijector(d::LKJ) = CorrBijector()
bijector(d::LKJ) = VecCorrBijector()
bijector(d::LKJCholesky) = d.uplo === 'L' ? VecTrilBijector() : VecTriuBijector()

##############################
# Distributions.jl interface #
Expand Down Expand Up @@ -107,6 +109,11 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real})
return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac
end

function logpdf(td::TransformedDistribution{T}, y::AbstractVector{<:Real}) where {T <: Union{LKJ, LKJCholesky}}
x, logjac = with_logabsdet_jacobian(inverse(td.transform), y)
return logpdf(td.dist, x) + logjac
end

function _logpdf(td::MvTransformed, y::AbstractVector{<:Real})
x, logjac = with_logabsdet_jacobian(inverse(td.transform), y)
return logpdf(td.dist, x) + logjac
Expand Down Expand Up @@ -154,6 +161,10 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real})
x .= td.transform(x)
end

function rand(rng::AbstractRNG, td::TransformedDistribution{T}) where {T <: Union{LKJ, LKJCholesky}}
return td.transform(rand(rng, td.dist))
end

# utility stuff
Distributions.params(td::Transformed) = Distributions.params(td.dist)
function Base.maximum(td::UnivariateTransformed)
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))

pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)'
pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X)

cholesky_factor(X::AbstractMatrix) = cholesky(X).UL
cholesky_factor(X::Cholesky) = X.UL
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X
Loading

0 comments on commit 7f5d0fc

Please sign in to comment.