Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SimplexBijector actually bijective #263

Merged
merged 26 commits into from
Jun 19, 2023
Merged

Make SimplexBijector actually bijective #263

merged 26 commits into from
Jun 19, 2023

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Jun 5, 2023

Similar to #228, currently the SimplexBijector makes transformed distributions improper. A demo from slack:

julia> using Turing

julia> @model function foo()
           d = Dirichlet(ones(2))
           x ~ filldist(Flat(), length(d))
           Turing.@addlogprob! logpdf(transformed(d), x)
           y = transform(inverse(bijector(d)), x)
           return (; y)
       end;

julia> chns = sample(foo(), NUTS(500, 0.8), MCMCThreads(), 1_000, 4)
┌ Info: Found initial step size
└   ϵ = 3.6
┌ Info: Found initial step size
└   ϵ = 3.6
┌ Info: Found initial step size
└   ϵ = 3.6
┌ Info: Found initial step size
└   ϵ = 12.8
Sampling (4 threads) 100%|█████████████████████████████████████████████████████████████| Time: 0:00:06
Chains MCMC chain (1000×14×4 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 5.58 seconds
Compute duration  = 20.62 seconds
parameters        = x[1], x[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters                     mean                      std                     mcse    ess_bulk    ess_ta 
      Symbol                  Float64                  Float64                  Float64     Float64     Float 

        x[1]                   0.0359                   1.7472                   0.0299   3450.8940   2517.15 
        x[2]   -1255806938226913.2500   38456907341159336.0000   12941276533029076.0000      9.0031     15.55 
                                                                                              3 columns omitted

Quantiles
  parameters                      2.5%                     25.0%                     50.0%                    
      Symbol                   Float64                   Float64                   Float64                  F 

        x[1]                   -3.4263                   -1.0549                    0.0211                    
        x[2]   -75987690476946768.0000   -23883185857795536.0000   -10795918685407138.0000   2113099877552726 

julia> yvals = permutedims(stack(first.(generated_quantities(foo(), chns))), (2, 3, 1));^C

julia> ess(yvals)
2-element Vector{Float64}:
 3450.8940281885793
 3450.8940281885766

julia> dropdims(mean(yvals; dims=(1, 2)); dims=(1, 2))
2-element Vector{Float64}:
 0.5055473060755683
 0.49445269392443175

This PR changes SimplexBijector to transform a K-vector to a K-1-vector. Since the proj type entry in SimplexBijector only impacted the extra Kth entry of the unconstrained vector, this type entry has been removed. Since the Jacobian is now non-square, triangular return types are no longer used. As a result, the change is marked as breaking.

@sethaxen
Copy link
Member Author

sethaxen commented Jun 5, 2023

Currently tests fails due to these lines, which seem to assume inputs and outputs are the same size

On Slack, @torfjelde confirmed that these should be fixed.

test/interface.jl Outdated Show resolved Hide resolved
test/interface.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
K = length(x)
@assert K > 1 "x needs to be of length greater than 1"
dydxt = similar(x, length(x), length(x))
dydxt = similar(x, K, K - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe merge it with the next line while you're at it to make it clearer that dydx is initialized with zeros (I had missed that initially when reading the code)?

Suggested change
dydxt = similar(x, K, K - 1)
dydxt = fill!(similar(x, K, K - 1), 0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(BTW the similar + AbstractVector argument wrongly suggests that the code will work for arbitrary inputs - but the loops below assume that everything uses 1-based indices)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but this should be fixed in a separate PR.

src/bijectors/simplex.jl Outdated Show resolved Hide resolved
src/bijectors/simplex.jl Outdated Show resolved Hide resolved
end
end
return UpperTriangular(dydxt)'
return dydxt'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adjoint operation seems a bit annoying but I guess the algorithm should be updated in separate PRs if desired.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, agreed, but this should be fixed separately.

test/interface.jl Outdated Show resolved Hide resolved
test/interface.jl Outdated Show resolved Hide resolved
test/interface.jl Outdated Show resolved Hide resolved
sethaxen and others added 2 commits June 6, 2023 23:39
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member

Haven't forgotten about this, but the DPPL integration was set back significantly by some other changes we had made. Should be done soon now 👍

src/Bijectors.jl Outdated Show resolved Hide resolved
torfjelde and others added 3 commits June 17, 2023 21:24
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/bijectors/simplex.jl Outdated Show resolved Hide resolved
torfjelde and others added 3 commits June 19, 2023 02:16
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
ext/BijectorsDistributionsADExt.jl Outdated Show resolved Hide resolved
ext/BijectorsDistributionsADExt.jl Outdated Show resolved Hide resolved
ext/BijectorsDistributionsADExt.jl Outdated Show resolved Hide resolved
ext/BijectorsDistributionsADExt.jl Outdated Show resolved Hide resolved
test/transform.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member

Aaaalrighty! Does someone want to give this a look-over? I think pasts will pass now, an so it would be nice to get this merged.

@sethaxen
Copy link
Member Author

Thanks @torfjelde for the fixes! All LGTM, but someone else needs to review.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just a minor question.

It's a breaking change it seems, so IMO it would be good to include the correct version bump in the PR to avoid accidentally tagging a non-breaking release.

src/bijectors/simplex.jl Outdated Show resolved Hide resolved
src/bijectors/simplex.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member

It's a breaking change it seems, so IMO it would be good to include the correct version bump in the PR to avoid accidentally tagging a non-breaking release.

We haven't released #master yet, which has been bumped accordingly:) I'm defering release of #master until both this and #271 have gone through.

torfjelde referenced this pull request Jun 19, 2023
* added output_length and output_size to compute output, well, leengths
and sizes for transformations

* added tests for size of transformed dist using VcCorrBijector

* use already constructed transfrormation

* TransformedDistribution should now also have correct variate form

* added proper variateform handling for VecCholeskyBijector too

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added output_size impl for Reshape too

* bump minor version

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Update src/interface.jl

* Update src/bijectors/corr.jl

* reverted removal of length as we'll need it now

* updated Stacked to be compat with changing sizes

* forgot to commit deetion

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add testing of sizes to `test_bijector`

* some more tests for stacked

* Update test/bijectors/stacked.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added awful generated function to determine output ranges for Stacked
with tuple because recursive implementation fail

* added slightly more informative comment

* format

* more fixes to that damned Stacked

* Update test/interface.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* specialized constructors for Stacked further

* fixed bug in output_size for CorrVecBijector

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
@torfjelde torfjelde merged commit 03bdffb into master Jun 19, 2023
@delete-merged-branch delete-merged-branch bot deleted the fixsimplex branch June 19, 2023 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants