-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CorrBijector makes posterior improper #228
Comments
After further thought, yes, this is detrimental to sampling with NUTS, because it makes it very unlikely that a u-turn will be encountered, so the maximum tree depth of 10 is usually saturated: julia> mean(==(10), chns[:tree_depth])
0.961 So 96% of transitions saturated the tree depth (took 1024 leapfrog steps). Compare with the same model with Stan: julia> using CmdStan, MCMCChains
julia> code = """
data {
int<lower=1> n;
real<lower=0> eta;
}
parameters {
corr_matrix[n] R;
}
model {
R ~ lkj_corr(eta);
}
""";
julia> stanmodel = Stanmodel(name="lkj", model=code, output_format=:mcmcchains);
...
julia> _, chns, _ = stan(stanmodel, Dict("n" => 3, "eta" => 2), "./");
...
Inference for Stan model: lkj_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (0.013, 0.014, 0.014, 0.014) seconds, 0.055 seconds total
Sampling took (0.027, 0.026, 0.029, 0.026) seconds, 0.11 seconds total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -1.7e+00 3.4e-02 1.4e+00 -4.5 -1.3e+00 -0.20 1.7e+03 1.6e+04 1.0e+00
accept_stat__ 0.91 1.8e-03 0.12 0.64 0.95 1.0 4.7e+03 4.4e+04 1.0e+00
stepsize__ 0.80 5.6e-02 0.080 0.67 0.85 0.88 2.0e+00 1.9e+01 2.8e+13
treedepth__ 2.2 9.1e-02 0.53 1.0 2.0 3.0 3.4e+01 3.1e+02 1.0e+00
n_leapfrog__ 4.5 3.5e-01 2.0 3.0 3.0 7.0 3.4e+01 3.1e+02 1.0e+00
divergent__ 0.00 nan 0.00 0.00 0.00 0.00 nan nan nan
energy__ 3.2 5.0e-02 1.9 0.87 2.8 6.8 1.5e+03 1.3e+04 1.0e+00
R[1,1] 1.0e+00 nan 6.7e-16 1.0 1.0e+00 1.0 nan nan nan
R[1,2] 4.0e-03 7.0e-03 4.1e-01 -0.67 -3.2e-03 0.68 3.5e+03 3.2e+04 1.0e+00
R[1,3] 2.9e-03 6.6e-03 4.1e-01 -0.68 2.2e-03 0.67 3.9e+03 3.6e+04 1.0e+00
R[2,1] 4.0e-03 7.0e-03 4.1e-01 -0.67 -3.2e-03 0.68 3.5e+03 3.2e+04 1.0e+00
R[2,2] 1.0e+00 nan 6.7e-16 1.0 1.0e+00 1.0 nan nan nan
R[2,3] -1.7e-02 7.1e-03 4.1e-01 -0.69 -1.5e-02 0.66 3.3e+03 3.1e+04 1.0e+00
R[3,1] 2.9e-03 6.6e-03 4.1e-01 -0.68 2.2e-03 0.67 3.9e+03 3.6e+04 1.0e+00
R[3,2] -1.7e-02 7.1e-03 4.1e-01 -0.69 -1.5e-02 0.66 3.3e+03 3.1e+04 1.0e+00
R[3,3] 1.0e+00 nan 6.7e-16 1.0 1.0e+00 1.0 nan nan nan
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at
convergence, R_hat=1).
julia> mean(==(10), chns[:treedepth__])
0.0
julia> mean(chns[:treedepth__])
2.203 No transitions hit the max tree depth, and on average each transition took only 2 tree doublings (7 leapfrog steps). |
As demonstrated in https://discourse.julialang.org/t/case-study-speeding-up-a-logistic-regression-with-rhs-prior-turing-vs-numpyro-any-tricks-im-missing/87681/34, this impropriety also introduces post-warmup numerical error. A quick-and-dirty hack to get this working for Turing users might be to edit Bijectors.jl/src/bijectors/corr.jl Lines 83 to 95 in b204712
for j in 1:K, i in j:K
result -= y[i, j]^2 / 2
end This puts a standard normal prior on the extra DOFs. The downside is that it technically lies about the logdetjac (but so does the current implementation, which lies about the transform being bijective), and the logdetjac of the inverse function will disagree, but this is a band-aid that will work until Bijectors can support inputs and outputs of different dimensions. |
Okay, so combining #246 and TuringLang/DynamicPPL.jl#462, this now works: julia> using Turing, Random
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]
julia> @model function model(n, η)
R ~ LKJ(n, η)
end;
julia> mod = model(3, 2.0);
julia> Random.seed!(50);
julia> chns = sample(mod, NUTS(0.99), 1_000; save_state=true)
┌ Info: Found initial step size
└ ϵ = 1.6
Sampling 100%|██████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:09
Chains MCMC chain (1000×21×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 11.07 seconds
Compute duration = 11.07 seconds
parameters = R[1,1], R[2,1], R[3,1], R[1,2], R[2,2], R[3,2], R[1,3], R[2,3], R[3,3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
R[1,1] 1.0000 0.0000 0.0000 0.0000 NaN NaN NaN
R[2,1] -0.0274 0.4129 0.0131 0.0153 773.7963 0.9994 69.8877
R[3,1] 0.0037 0.4011 0.0127 0.0147 642.2541 0.9992 58.0071
R[1,2] -0.0274 0.4129 0.0131 0.0153 773.7963 0.9994 69.8877
R[2,2] 1.0000 0.0000 0.0000 0.0000 996.8495 0.9990 90.0334
R[3,2] 0.0077 0.4061 0.0128 0.0157 670.0528 1.0039 60.5178
R[1,3] 0.0037 0.4011 0.0127 0.0147 642.2541 0.9992 58.0071
R[2,3] 0.0077 0.4061 0.0128 0.0157 670.0528 1.0039 60.5178
R[3,3] 1.0000 0.0000 0.0000 0.0000 672.7942 0.9990 60.7654
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
R[1,1] 1.0000 1.0000 1.0000 1.0000 1.0000
R[2,1] -0.7647 -0.3269 -0.0286 0.2606 0.7454
R[3,1] -0.7393 -0.3006 -0.0072 0.3158 0.7480
R[1,2] -0.7647 -0.3269 -0.0286 0.2606 0.7454
R[2,2] 1.0000 1.0000 1.0000 1.0000 1.0000
R[3,2] -0.7260 -0.2919 0.0068 0.2974 0.7714
R[1,3] -0.7393 -0.3006 -0.0072 0.3158 0.7480
R[2,3] -0.7260 -0.2919 0.0068 0.2974 0.7714
R[3,3] 1.0000 1.0000 1.0000 1.0000 1.0000
Though I'm confused as to why |
Awesome! Can you also check the estimated metric?
If they're even a little offset from 1, then a julia> using MCMCDiagnosticTools
julia> x = ones(1_000, 1, 10);
julia> MCMCDiagnosticTools.ess_rhat(x)[1]
10-element Vector{Float64}:
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
julia> MCMCDiagnosticTools.ess_rhat(x .+ eps() .* randn.())[1]
10-element Vector{Float64}:
955.6941303601072
713.5937319411215
1092.1539947505667
903.5876039869821
925.0133494392358
894.2614035388227
793.3188327568754
1058.2749400295972
988.8133813971027
1096.6615583111022 In #246, the final An interesting question is what should |
Here it is: julia> chns.info.samplerstate.hamiltonian.metric.M⁻¹
3-element Vector{Float64}:
0.25406253280310975
0.22661105092259304
0.2953807290308163 The underlying tracked variables are now the subspace btw, if that's what you were wondering.
Yeah figured it had something to do with that, but was then surprised to see one of them not being affected by this 🤷
Tbh it's going to be quite difficult to figure out why variables are "actually" sampled, and which aren't, and hence customizing the resulting behavior in |
Ah, that happens because the first element won't have any numerical error. It's computed as
FWIW I think this is an unrelated issue to what MCMCDiagnosticTools should do when it gets a scalar array. But this is a tangent for this issue. |
Aaah true! Thanks! |
An$n \times n$ correlation matrix has ${n \choose 2} = \frac{n (n-1)}{2}$ degrees of freedom. This is the same as the number of elements in a strict upper triangular $n \times n$ matrix. The
CorrBijector
works by mapping from the correlation matrix first to its unique upper Cholesky factor and then to a strictly upper triangular matrix of unconstrained entries.The trouble is that in the unconstrained space, we now have$n \times n$ parameters, of which ${n+1 \choose 2} = \frac{n(n+1)}{2}$ have no impact on the log density. These extra parameters have an implicit improper uniform prior on the reals, which makes the posterior distribution in unconstrained space improper. Because these parameters have infinite variance, during adaptation, HMC will learn this, and they will explode in value. I don't know if this will have any negative impact on sampling.
In this demo, we're sampling the uniform distribution on the correlation matrices.
Note the number of parameters. We should have
3*(3-1)/2 = 3
DOFs, but instead we have3*3
. And note that3*(3+1)/2=6
of the degrees of freedom have adapted variances of ~1e20.There are several ways to solve this, neither of which seem to be possible in Bijectors right now:
logabsdetjac
to contain this prior term. However, when I tried this, it seemed to have no effect, since I guesslogabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
is being called instead oflogabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
. When mapping fromX
toy
, these extra parameters are all set to 0, so we have no way of setting this prior.The text was updated successfully, but these errors were encountered: