Skip to content

Commit

Permalink
Make SimplexBijector actually bijective (#263)
Browse files Browse the repository at this point in the history
* Remove unused proj field

* Update simplex bijector calls

* Update simplex jacobian calls

* Remove proj type entry

* Compute logdetjac from square part of jacobian

* Increment minor version number

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Update test/interface.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed link and invlink for SimplexBijector

* Update src/Bijectors.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* super-hacky fix to size issue of TransformedDistribution

* added fixme comment

* removed redundant constructor for Stacked

* added implementation of output_size for SimplexBijector

* Update src/bijectors/simplex.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed tests

* removed more references to old SimplexBijector code

* fixed more dirichlet tests

* formatting

* possilby fixed weird formatting complaints

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
  • Loading branch information
5 people authored Jun 19, 2023
1 parent 2147089 commit 03bdffb
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 141 deletions.
24 changes: 8 additions & 16 deletions ext/BijectorsDistributionsADExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,19 @@ Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true
Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true
Bijectors.isdirichlet(::TuringDirichlet) = true

function Bijectors.link(
d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return Bijectors.SimplexBijector{proj}()(x)
function Bijectors.link(d::TuringDirichlet, x::AbstractVecOrMat{<:Real})
return Bijectors.SimplexBijector()(x)
end

function Bijectors.link_jacobian(
d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(Bijectors.SimplexBijector{proj}(), x)
function Bijectors.link_jacobian(d::TuringDirichlet, x::AbstractVector{<:Real})
return jacobian(Bijectors.SimplexBijector(), x)
end

function Bijectors.invlink(
d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return inverse(Bijectors.SimplexBijector{proj}())(y)
function Bijectors.invlink(d::TuringDirichlet, y::AbstractVecOrMat{<:Real})
return inverse(Bijectors.SimplexBijector())(y)
end
function Bijectors.invlink_jacobian(
d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y)
function Bijectors.invlink_jacobian(d::TuringDirichlet, y::AbstractVector{<:Real})
return jacobian(inverse(Bijectors.SimplexBijector()), y)
end

Bijectors.ispd(::TuringWishart) = true
Expand Down
23 changes: 5 additions & 18 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,25 +206,12 @@ isdirichlet(::Distribution) = false
# ∑xᵢ = 1 #
###########

function link(d::Dirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)) where {proj}
return SimplexBijector{proj}()(x)
end

function link_jacobian(
d::Dirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(SimplexBijector{proj}(), x)
end
link(d::Dirichlet, x::AbstractVecOrMat{<:Real}) = SimplexBijector()(x)
link_jacobian(d::Dirichlet, x::AbstractVector{<:Real}) = jacobian(SimplexBijector(), x)

function invlink(
d::Dirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return inverse(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(inverse(SimplexBijector{proj}()), y)
invlink(d::Dirichlet, y::AbstractVecOrMat{<:Real}) = inverse(SimplexBijector())(y)
function invlink_jacobian(d::Dirichlet, y::AbstractVector{<:Real})
return jacobian(inverse(SimplexBijector()), y)
end

## Matrix
Expand Down
109 changes: 42 additions & 67 deletions src/bijectors/simplex.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
####################
# Simplex bijector #
####################
struct SimplexBijector{T} <: Bijector end
SimplexBijector() = SimplexBijector{true}()
struct SimplexBijector <: Bijector end

output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,)
output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,)

output_size(::SimplexBijector, sz::Tuple{Int,Int}) = (first(sz) - 1, last(sz))
function output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int})
return (first(sz) + 1, last(sz))
end

with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(b::SimplexBijector, x) = _simplex_bijector(x, b)
transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b)

function _simplex_bijector(x::AbstractArray, b::SimplexBijector)
return _simplex_bijector!(similar(x), x, b)
sz = size(x)
K = size(x, 1)
y = similar(x, Base.setindex(sz, K - 1, 1))
_simplex_bijector!(y, x, b)
return y
end

# Vector implementation.
function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj}
function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector)
K = length(x)
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(x)
Expand All @@ -29,18 +40,11 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where
z = (x[k] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp)
y[k] = LogExpFunctions.logit(z) + log(T(K - k))
end
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
y[K] = zero(T)
else
y[K] = one(T) - sum_tmp - x[K]
end

return y
end

# Matrix implementation.
function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj}
function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector)
K, N = size(X, 1), size(X, 2)
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(X)
Expand All @@ -54,12 +58,6 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where
z = (X[k, n] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp)
Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k))
end
sum_tmp += X[K - 1, n]
if proj
Y[K, n] = zero(T)
else
Y[K, n] = one(T) - sum_tmp - X[K, n]
end
end

return Y
Expand All @@ -75,10 +73,16 @@ function transform!(
return _simplex_inv_bijector!(x, y, ib.orig)
end

_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b)
function _simplex_inv_bijector(y, b)
sz = size(y)
K = sz[1] + 1
x = similar(y, Base.setindex(sz, K, 1))
_simplex_inv_bijector!(x, y, b)
return x
end

function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj}
K = length(y)
function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector)
K = length(y) + 1
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(y)
ϵ = _eps(T)
Expand All @@ -91,17 +95,12 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj})
x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
end
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
x[K] = _clamp(one(T) - sum_tmp, 0, 1)
else
x[K] = _clamp(one(T) - sum_tmp - y[K], 0, 1)
end

x[K] = _clamp(one(T) - sum_tmp, 0, 1)
return x
end

function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj}
K, N = size(Y, 1), size(Y, 2)
function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector)
K, N = size(Y, 1) + 1, size(Y, 2)
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(Y)
ϵ = _eps(T)
Expand All @@ -114,11 +113,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj})
X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
end
sum_tmp += X[K - 1, n]
if proj
X[K, n] = _clamp(one(T) - sum_tmp, 0, 1)
else
X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], 0, 1)
end
X[K, n] = _clamp(one(T) - sum_tmp, 0, 1)
end

return X
Expand Down Expand Up @@ -213,13 +208,10 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix)
return g
end

function simplex_link_jacobian(
x::AbstractVector{T}, ::Val{proj}=Val(true)
) where {T<:Real,proj}
function simplex_link_jacobian(x::AbstractVector{T}) where {T<:Real}
K = length(x)
@assert K > 1 "x needs to be of length greater than 1"
dydxt = similar(x, length(x), length(x))
@inbounds dydxt .= 0
dydxt = fill!(similar(x, K, K - 1), 0)
ϵ = _eps(T)
sum_tmp = zero(T)

Expand All @@ -237,16 +229,10 @@ function simplex_link_jacobian(
((one(T) + ϵ) - sum_tmp)^2
end
end
@inbounds sum_tmp += x[K - 1]
@inbounds if !proj
@simd for i in 1:K
dydxt[i, K] = -1
end
end
return UpperTriangular(dydxt)'
return dydxt'
end
function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj,T}
return simplex_link_jacobian(x, Val(proj))
function jacobian(b::SimplexBijector, x::AbstractVector{T}) where {T}
return simplex_link_jacobian(x)
end

#=
Expand Down Expand Up @@ -315,13 +301,10 @@ function add_simplex_link_adjoint!(
end
=#

function simplex_invlink_jacobian(
y::AbstractVector{T}, ::Val{proj}=Val(true)
) where {T<:Real,proj}
K = length(y)
function simplex_invlink_jacobian(y::AbstractVector{T}) where {T<:Real}
K = length(y) + 1
@assert K > 1 "x needs to be of length greater than 1"
dxdy = similar(y, length(y), length(y))
@inbounds dxdy .= 0
dxdy = fill!(similar(y, K, K - 1), 0)

ϵ = _eps(T)
@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
Expand All @@ -346,28 +329,20 @@ function simplex_invlink_jacobian(
end
end
@inbounds sum_tmp += clamped_x
@inbounds if proj
unclamped_x = one(T) - sum_tmp
clamped_x = _clamp(unclamped_x, 0, 1)
else
unclamped_x = one(T) - sum_tmp - y[K]
clamped_x = _clamp(unclamped_x, 0, 1)
if unclamped_x == clamped_x
dxdy[K, K] = -1
end
end
unclamped_x = one(T) - sum_tmp
clamped_x = _clamp(unclamped_x, 0, 1)
@inbounds if unclamped_x == clamped_x
for i in 1:(K - 1)
@simd for j in i:(K - 1)
dxdy[K, i] += -dxdy[j, i]
end
end
end
return LowerTriangular(dxdy)
return dxdy
end
# jacobian
function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj,T}
return simplex_invlink_jacobian(y, Val(proj))
function jacobian(ib::Inverse{<:SimplexBijector}, y::AbstractVector{T}) where {T}
return simplex_invlink_jacobian(y)
end

#=
Expand Down
3 changes: 0 additions & 3 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ end
end
end

# Avoid mixing tuples and arrays.
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)

Functors.@functor Stacked (bs,)

function Base.show(io::IO, b::Stacked)
Expand Down
21 changes: 12 additions & 9 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ end
# verify against AD
# similar to what we do in test/transform.jl for Dirichlet
if dist isa Dirichlet
b = Bijectors.SimplexBijector{false}()
b = Bijectors.SimplexBijector()
# HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]`
# which in turn will lead to differences between `ForwardDiff.jacobian`
# and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`.
Expand All @@ -168,8 +168,9 @@ end
end
y = b(x)
@test b(param(x)) isa TrackedArray
@test log(abs(det(ForwardDiff.jacobian(b, x)))) logabsdetjac(b, x)
@test log(abs(det(ForwardDiff.jacobian(inverse(b), y))))
@test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1]
logabsdetjac(b, x)
@test logabsdet(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), :])[1]
logabsdetjac(inverse(b), y)
else
b = bijector(dist)
Expand Down Expand Up @@ -420,35 +421,37 @@ end
b = SimplexBijector()
ib = inverse(b)

x = ib(randn(10))
d_x = 10
x = ib(randn(d_x - 1))
y = b(x)

@test Bijectors.jacobian(b, x) ForwardDiff.jacobian(b, x)
@test Bijectors.jacobian(ib, y) ForwardDiff.jacobian(ib, y)

# Just some additional computation so we also ensure the pullbacks are the same
weights = randn(10)
weights_x = randn(d_x)
weights_y = randn(d_x - 1)

# Tracker.jl
x_tracked = Tracker.param(x)
z = sum(weights .* b(x_tracked))
z = sum(weights_y .* b(x_tracked))
Tracker.back!(z)
Δ_tracker = Tracker.grad(x_tracked)

# ForwardDiff.jl
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* b(z)), x)
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_y .* b(z)), x)

# Compare
@test Δ_forwarddiff Δ_tracker

# Tracker.jl
y_tracked = Tracker.param(y)
z = sum(weights .* ib(y_tracked))
z = sum(weights_x .* ib(y_tracked))
Tracker.back!(z)
Δ_tracker = Tracker.grad(y_tracked)

# ForwardDiff.jl
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* ib(z)), y)
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_x .* ib(z)), y)

@test Δ_forwarddiff Δ_tracker
end
Expand Down
Loading

3 comments on commit 03bdffb

@torfjelde
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85884

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.0 -m "<description of version>" 03bdffb50ac4b40567b0129774a3f6fe06916215
git push origin v0.13.0

@torfjelde
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm I'll abort this. I forgot #271 will also be breaking.

Please sign in to comment.