Skip to content

Commit

Permalink
Fix ForwardDiff test
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 8, 2024
1 parent c43d2e8 commit 4ff6ddc
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,46 @@ end
end

@testset "LKJCholesky" begin
# Convert Cholesky factor to its free parameters, i.e. its off-diagonal elements
function chol_3by3_to_free_params(x::Cholesky)
if x.uplo == :U
return [x.U[1, 2], x.U[1, 3], x.U[2, 3]]
else
return [x.L[2, 1], x.L[3, 1], x.L[3, 2]]
end
# TODO: Generalise to arbitrary dimension using this code:
# inds = [
# LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
# (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
# ]
end

# Reconstruct Cholesky factor from its free parameters
# Note that x[i, i] is always positive so we don't need to worry about the sign
function free_params_to_chol_3by3(free_params::AbstractVector, uplo::Symbol)
x = UpperTriangular(zeros(eltype(free_params), 3, 3))
x[1, 1] = 1
x[1, 2] = free_params[1]
x[1, 3] = free_params[2]
x[2, 2] = sqrt(1 - free_params[1]^2)
x[2, 3] = free_params[3]
x[3, 3] = sqrt(1 - free_params[2]^2 - free_params[3]^2)
if uplo == :U
return Cholesky(x)
else
return Cholesky(transpose(x))
end
end

@testset "uplo: $uplo" for uplo in [:L, :U]
dist = LKJCholesky(3, 1, uplo)
single_sample_tests(dist)

x = rand(dist)

inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
# Here, we need to pass ForwardDiff only the free parameters of the
# Cholesky factor so that we get a square Jacobian matrix
free_params = chol_3by3_to_free_params(x)
J = ForwardDiff.jacobian(z -> link(dist, free_params_to_chol_3by3(z, uplo)), free_params)
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
end
Expand Down

0 comments on commit 4ff6ddc

Please sign in to comment.