Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added output_length and output_size #270

Merged
merged 27 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6b46e34
added output_length and output_size to compute output, well, leengths
torfjelde Jun 17, 2023
41d0b06
added tests for size of transformed dist using VcCorrBijector
torfjelde Jun 17, 2023
a76f18a
use already constructed transfrormation
torfjelde Jun 17, 2023
6afc77e
TransformedDistribution should now also have correct variate form
torfjelde Jun 17, 2023
a4c5683
added proper variateform handling for VecCholeskyBijector too
torfjelde Jun 17, 2023
ea724ee
Apply suggestions from code review
torfjelde Jun 17, 2023
387ef5a
added output_size impl for Reshape too
torfjelde Jun 17, 2023
acb5e8f
bump minor version
torfjelde Jun 18, 2023
3391735
Apply suggestions from code review
torfjelde Jun 18, 2023
b524ebb
Update src/interface.jl
torfjelde Jun 18, 2023
d6dc906
Update src/bijectors/corr.jl
torfjelde Jun 18, 2023
280708b
reverted removal of length as we'll need it now
torfjelde Jun 18, 2023
2069d69
updated Stacked to be compat with changing sizes
torfjelde Jun 18, 2023
f533a79
forgot to commit deetion
torfjelde Jun 18, 2023
56b8834
Apply suggestions from code review
torfjelde Jun 18, 2023
098a9c0
add testing of sizes to `test_bijector`
torfjelde Jun 18, 2023
4e14bb2
some more tests for stacked
torfjelde Jun 18, 2023
def7c6f
Update test/bijectors/stacked.jl
torfjelde Jun 18, 2023
fe36875
added awful generated function to determine output ranges for Stacked
torfjelde Jun 18, 2023
bbfaf19
added slightly more informative comment
torfjelde Jun 18, 2023
bf68124
format
torfjelde Jun 18, 2023
45a9850
more fixes to that damned Stacked
torfjelde Jun 18, 2023
1f0c374
Update test/interface.jl
torfjelde Jun 18, 2023
a917c2b
specialized constructors for Stacked further
torfjelde Jun 18, 2023
cdd951a
fixed bug in output_size for CorrVecBijector
torfjelde Jun 18, 2023
5dbd829
Apply suggestions from code review
torfjelde Jun 18, 2023
04f6990
Apply suggestions from code review
torfjelde Jun 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_corr(y)
end

function output_size(::VecCorrBijector, sz::NTuple{2})
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@assert sz[1] == sz[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this a proper, more descriptive error?

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
n = sz[1]
return (n * (n - 1)) ÷ 2
end

function output_size(::Inverse{VecCorrBijector}, sz::NTuple{1})
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
n = _triu1_dim_from_length(first(sz))
return (n, n)
end

"""
VecCholeskyBijector <: Bijector

Expand Down Expand Up @@ -317,6 +328,9 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_chol(y)
end

output_size(::VecCholeskyBijector, sz::NTuple{2}) = output_size(VecCorrBijector(), sz)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
output_size(::Inverse{<:VecCholeskyBijector}, sz::NTuple{1}) = output_size(inverse(VecCorrBijector()), sz)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

"""
function _link_chol_lkj(w)

Expand Down
20 changes: 20 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ function logabsdetjac(f::Columnwise, x::AbstractMatrix)
end
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))

"""
output_size(f, sz)

Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz

"""
output_length(f, len::Int)
output_length(f, sz::Tuple)

Returns the output length of `f` given the input length `len` or size `sz`.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
"""
output_length(f, len::Int) = len
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
function output_length(f, len::Tuple)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
sz = output_size(f, len)
@assert length(sz) == 1
return first(sz)
end

######################
# Bijector interface #
######################
Expand Down
28 changes: 13 additions & 15 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
function variateform(d::Distribution, b)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
sz_in = size(d)
sz_out = output_size(b, sz_in)
return ArrayLikeVariate{length(sz_out)}
end

variateform(::MultivariateDistribution, ::Inverse{VecCholeskyBijector}) = CholeskyVariate

# Transformed distributions
struct TransformedDistribution{D,B,V} <:
Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B}
struct TransformedDistribution{D,B,V} <: Distribution{V,Continuous} where {D<:ContinuousDistribution,B}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
dist::D
transform::B

function TransformedDistribution(d::UnivariateDistribution, b)
return new{typeof(d),typeof(b),Univariate}(d, b)
end
function TransformedDistribution(d::MultivariateDistribution, b)
return new{typeof(d),typeof(b),Multivariate}(d, b)
end
function TransformedDistribution(d::MatrixDistribution, b)
return new{typeof(d),typeof(b),Matrixvariate}(d, b)
end
function TransformedDistribution(d::Distribution{CholeskyVariate}, b)
return new{typeof(d),typeof(b),CholeskyVariate}(d, b)
function TransformedDistribution(d::ContinuousDistribution, b)
return new{typeof(d),typeof(b),variateform(d,b)}(d, b)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
end

Expand Down Expand Up @@ -101,8 +99,8 @@ end
##############################

# size
Base.length(td::Transformed) = length(td.dist)
Base.size(td::Transformed) = size(td.dist)
Base.length(td::Transformed) = output_length(td.transform, size(td.dist))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Base.size(td::Transformed) = output_size(td.transform, size(td.dist))

function logpdf(td::UnivariateTransformed, y::Real)
x, logjac = with_logabsdet_jacobian(inverse(td.transform), y)
Expand Down
20 changes: 20 additions & 0 deletions test/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)

test_ad(x -> sum(bvec(bvecinv(x))), yvec)

# Check that output sizes are computed correctly.
tdist = transformed(dist)
@test length(tdist) == length(yvec)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(bvec))
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa MatrixDistribution
end
end

Expand Down Expand Up @@ -60,6 +69,17 @@ end
# test_bijector is commented out for now,
# as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky)
# test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)

# Check that output sizes are computed correctly.
tdist = transformed(dist)
@test length(tdist) == length(y)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(
MvNormal(zeros(length(tdist)), I), inverse(b)
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa Distribution{CholeskyVariate,Continuous}
end
end
end