Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SimplexBijector actually bijective #263

Merged
merged 26 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
36a6b41
Remove unused proj field
sethaxen Jun 5, 2023
fa503cb
Update simplex bijector calls
sethaxen Jun 5, 2023
4abaae9
Update simplex jacobian calls
sethaxen Jun 5, 2023
579c808
Remove proj type entry
sethaxen Jun 5, 2023
3110270
Compute logdetjac from square part of jacobian
sethaxen Jun 5, 2023
f21d328
Increment minor version number
sethaxen Jun 5, 2023
7ba8948
Merge branch 'master' into fixsimplex
sethaxen Jun 6, 2023
7e2927a
Apply suggestions from code review
sethaxen Jun 6, 2023
ba21df5
Apply suggestions from code review
sethaxen Jun 6, 2023
3ce84bb
Update test/interface.jl
sethaxen Jun 6, 2023
e8ad6cb
Merge branch 'master' into fixsimplex
yebai Jun 12, 2023
6614b15
Merge branch 'master' into fixsimplex
torfjelde Jun 17, 2023
776e4af
fixed link and invlink for SimplexBijector
torfjelde Jun 17, 2023
921f818
Update src/Bijectors.jl
torfjelde Jun 17, 2023
d934cfa
super-hacky fix to size issue of TransformedDistribution
torfjelde Jun 17, 2023
8f39b0d
added fixme comment
torfjelde Jun 17, 2023
79544ba
Merge branch 'master' into fixsimplex
torfjelde Jun 18, 2023
8efd243
removed redundant constructor for Stacked
torfjelde Jun 19, 2023
9c16433
added implementation of output_size for SimplexBijector
torfjelde Jun 19, 2023
78f015e
Update src/bijectors/simplex.jl
torfjelde Jun 19, 2023
144df86
fixed tests
torfjelde Jun 19, 2023
a8e6e21
removed more references to old SimplexBijector code
torfjelde Jun 19, 2023
852c826
fixed more dirichlet tests
torfjelde Jun 19, 2023
97af441
formatting
torfjelde Jun 19, 2023
1f8a0f1
possilby fixed weird formatting complaints
torfjelde Jun 19, 2023
5d394bb
Apply suggestions from code review
torfjelde Jun 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions ext/BijectorsDistributionsADExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,19 @@ Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true
Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true
Bijectors.isdirichlet(::TuringDirichlet) = true

function Bijectors.link(
d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return Bijectors.SimplexBijector{proj}()(x)
function Bijectors.link(d::TuringDirichlet, x::AbstractVecOrMat{<:Real})
return Bijectors.SimplexBijector()(x)
end

function Bijectors.link_jacobian(
d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(Bijectors.SimplexBijector{proj}(), x)
function Bijectors.link_jacobian(d::TuringDirichlet, x::AbstractVector{<:Real})
return jacobian(Bijectors.SimplexBijector(), x)
end

function Bijectors.invlink(
d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return inverse(Bijectors.SimplexBijector{proj}())(y)
function Bijectors.invlink(d::TuringDirichlet, y::AbstractVecOrMat{<:Real})
return inverse(Bijectors.SimplexBijector())(y)
end
function Bijectors.invlink_jacobian(
d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y)
function Bijectors.invlink_jacobian(d::TuringDirichlet, y::AbstractVector{<:Real})
return jacobian(inverse(Bijectors.SimplexBijector()), y)
end

Bijectors.ispd(::TuringWishart) = true
Expand Down
23 changes: 5 additions & 18 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,25 +206,12 @@ isdirichlet(::Distribution) = false
# ∑xᵢ = 1 #
###########

function link(d::Dirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)) where {proj}
return SimplexBijector{proj}()(x)
end

function link_jacobian(
d::Dirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(SimplexBijector{proj}(), x)
end
link(d::Dirichlet, x::AbstractVecOrMat{<:Real}) = SimplexBijector()(x)
link_jacobian(d::Dirichlet, x::AbstractVector{<:Real}) = jacobian(SimplexBijector(), x)

function invlink(
d::Dirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return inverse(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
) where {proj}
return jacobian(inverse(SimplexBijector{proj}()), y)
invlink(d::Dirichlet, y::AbstractVecOrMat{<:Real}) = inverse(SimplexBijector())(y)
function invlink_jacobian(d::Dirichlet, y::AbstractVector{<:Real})
return jacobian(inverse(SimplexBijector()), y)
end

## Matrix
Expand Down
109 changes: 42 additions & 67 deletions src/bijectors/simplex.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
####################
# Simplex bijector #
####################
struct SimplexBijector{T} <: Bijector end
SimplexBijector() = SimplexBijector{true}()
struct SimplexBijector <: Bijector end

output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,)
output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,)

output_size(::SimplexBijector, sz::Tuple{Int,Int}) = (first(sz) - 1, last(sz))
function output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int})
return (first(sz) + 1, last(sz))
end

with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(b::SimplexBijector, x) = _simplex_bijector(x, b)
transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b)

function _simplex_bijector(x::AbstractArray, b::SimplexBijector)
return _simplex_bijector!(similar(x), x, b)
sz = size(x)
K = size(x, 1)
y = similar(x, Base.setindex(sz, K - 1, 1))
_simplex_bijector!(y, x, b)
return y
end

# Vector implementation.
function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj}
function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector)
K = length(x)
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(x)
Expand All @@ -29,18 +40,11 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where
z = (x[k] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp)
y[k] = LogExpFunctions.logit(z) + log(T(K - k))
end
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
y[K] = zero(T)
else
y[K] = one(T) - sum_tmp - x[K]
end

return y
end

# Matrix implementation.
function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj}
function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector)
K, N = size(X, 1), size(X, 2)
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(X)
Expand All @@ -54,12 +58,6 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where
z = (X[k, n] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp)
Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k))
end
sum_tmp += X[K - 1, n]
if proj
Y[K, n] = zero(T)
else
Y[K, n] = one(T) - sum_tmp - X[K, n]
end
end

return Y
Expand All @@ -75,10 +73,16 @@ function transform!(
return _simplex_inv_bijector!(x, y, ib.orig)
end

_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b)
function _simplex_inv_bijector(y, b)
sz = size(y)
K = sz[1] + 1
x = similar(y, Base.setindex(sz, K, 1))
_simplex_inv_bijector!(x, y, b)
return x
end

function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj}
K = length(y)
function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector)
K = length(y) + 1
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(y)
ϵ = _eps(T)
Expand All @@ -91,17 +95,12 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj})
x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
end
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
x[K] = _clamp(one(T) - sum_tmp, 0, 1)
else
x[K] = _clamp(one(T) - sum_tmp - y[K], 0, 1)
end

x[K] = _clamp(one(T) - sum_tmp, 0, 1)
return x
end

function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj}
K, N = size(Y, 1), size(Y, 2)
function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector)
K, N = size(Y, 1) + 1, size(Y, 2)
@assert K > 1 "x needs to be of length greater than 1"
T = eltype(Y)
ϵ = _eps(T)
Expand All @@ -114,11 +113,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj})
X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
end
sum_tmp += X[K - 1, n]
if proj
X[K, n] = _clamp(one(T) - sum_tmp, 0, 1)
else
X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], 0, 1)
end
X[K, n] = _clamp(one(T) - sum_tmp, 0, 1)
end

return X
Expand Down Expand Up @@ -213,13 +208,10 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix)
return g
end

function simplex_link_jacobian(
x::AbstractVector{T}, ::Val{proj}=Val(true)
) where {T<:Real,proj}
function simplex_link_jacobian(x::AbstractVector{T}) where {T<:Real}
K = length(x)
@assert K > 1 "x needs to be of length greater than 1"
dydxt = similar(x, length(x), length(x))
@inbounds dydxt .= 0
dydxt = fill!(similar(x, K, K - 1), 0)
ϵ = _eps(T)
sum_tmp = zero(T)

Expand All @@ -237,16 +229,10 @@ function simplex_link_jacobian(
((one(T) + ϵ) - sum_tmp)^2
end
end
@inbounds sum_tmp += x[K - 1]
@inbounds if !proj
@simd for i in 1:K
dydxt[i, K] = -1
end
end
return UpperTriangular(dydxt)'
return dydxt'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adjoint operation seems a bit annoying but I guess the algorithm should be updated in separate PRs if desired.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, agreed, but this should be fixed separately.

end
function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj,T}
return simplex_link_jacobian(x, Val(proj))
function jacobian(b::SimplexBijector, x::AbstractVector{T}) where {T}
return simplex_link_jacobian(x)
end

#=
Expand Down Expand Up @@ -315,13 +301,10 @@ function add_simplex_link_adjoint!(
end
=#

function simplex_invlink_jacobian(
y::AbstractVector{T}, ::Val{proj}=Val(true)
) where {T<:Real,proj}
K = length(y)
function simplex_invlink_jacobian(y::AbstractVector{T}) where {T<:Real}
K = length(y) + 1
@assert K > 1 "x needs to be of length greater than 1"
dxdy = similar(y, length(y), length(y))
@inbounds dxdy .= 0
dxdy = fill!(similar(y, K, K - 1), 0)

ϵ = _eps(T)
@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
Expand All @@ -346,28 +329,20 @@ function simplex_invlink_jacobian(
end
end
@inbounds sum_tmp += clamped_x
@inbounds if proj
unclamped_x = one(T) - sum_tmp
clamped_x = _clamp(unclamped_x, 0, 1)
else
unclamped_x = one(T) - sum_tmp - y[K]
clamped_x = _clamp(unclamped_x, 0, 1)
if unclamped_x == clamped_x
dxdy[K, K] = -1
end
end
unclamped_x = one(T) - sum_tmp
clamped_x = _clamp(unclamped_x, 0, 1)
@inbounds if unclamped_x == clamped_x
for i in 1:(K - 1)
@simd for j in i:(K - 1)
dxdy[K, i] += -dxdy[j, i]
end
end
end
return LowerTriangular(dxdy)
return dxdy
end
# jacobian
function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj,T}
return simplex_invlink_jacobian(y, Val(proj))
function jacobian(ib::Inverse{<:SimplexBijector}, y::AbstractVector{T}) where {T}
return simplex_invlink_jacobian(y)
end

#=
Expand Down
3 changes: 0 additions & 3 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ end
end
end

# Avoid mixing tuples and arrays.
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)

Functors.@functor Stacked (bs,)

function Base.show(io::IO, b::Stacked)
Expand Down
21 changes: 12 additions & 9 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ end
# verify against AD
# similar to what we do in test/transform.jl for Dirichlet
if dist isa Dirichlet
b = Bijectors.SimplexBijector{false}()
b = Bijectors.SimplexBijector()
# HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]`
# which in turn will lead to differences between `ForwardDiff.jacobian`
# and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`.
Expand All @@ -168,8 +168,9 @@ end
end
y = b(x)
@test b(param(x)) isa TrackedArray
@test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x)
@test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈
@test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1] ≈
logabsdetjac(b, x)
@test logabsdet(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), :])[1] ≈
logabsdetjac(inverse(b), y)
else
b = bijector(dist)
Expand Down Expand Up @@ -420,35 +421,37 @@ end
b = SimplexBijector()
ib = inverse(b)

x = ib(randn(10))
d_x = 10
x = ib(randn(d_x - 1))
y = b(x)

@test Bijectors.jacobian(b, x) ≈ ForwardDiff.jacobian(b, x)
@test Bijectors.jacobian(ib, y) ≈ ForwardDiff.jacobian(ib, y)

# Just some additional computation so we also ensure the pullbacks are the same
weights = randn(10)
weights_x = randn(d_x)
weights_y = randn(d_x - 1)

# Tracker.jl
x_tracked = Tracker.param(x)
z = sum(weights .* b(x_tracked))
z = sum(weights_y .* b(x_tracked))
Tracker.back!(z)
Δ_tracker = Tracker.grad(x_tracked)

# ForwardDiff.jl
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* b(z)), x)
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_y .* b(z)), x)

# Compare
@test Δ_forwarddiff ≈ Δ_tracker

# Tracker.jl
y_tracked = Tracker.param(y)
z = sum(weights .* ib(y_tracked))
z = sum(weights_x .* ib(y_tracked))
Tracker.back!(z)
Δ_tracker = Tracker.grad(y_tracked)

# ForwardDiff.jl
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* ib(z)), y)
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_x .* ib(z)), y)

@test Δ_forwarddiff ≈ Δ_tracker
end
Expand Down
Loading