Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overloading-AD-Friendly Unflatten #39

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

willtebbutt
Copy link
Member

Addresses #27 . @paschermayr could you confirm that it resolves your problem? I've run your example locally, but want to make sure that it does what you expect.

This is breaking. @rofinn can you see any problem doing this? It specifically changes code that you wrote. The type constraints are now only applied in the flatten bit -- it's assumed that in unflatten you could reasonably want to use numbers that aren't of the exact same type is the thing that was requested in flatten, e.g. so that you can propagate a Vector{<:Dual} through unflatten when using ForwardDiff.jl.

src/flatten.jl Outdated Show resolved Hide resolved
src/parameters.jl Outdated Show resolved Hide resolved
src/test_utils.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Sep 10, 2021

Codecov Report

Merging #39 (b923d81) into master (21e6ff7) will increase coverage by 0.08%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #39      +/-   ##
==========================================
+ Coverage   96.49%   96.57%   +0.08%     
==========================================
  Files           4        4              
  Lines         171      175       +4     
==========================================
+ Hits          165      169       +4     
  Misses          6        6              
Impacted Files Coverage Δ
src/flatten.jl 98.14% <100.00%> (+0.07%) ⬆️
src/parameters.jl 97.50% <100.00%> (+0.06%) ⬆️
src/test_utils.jl 92.50% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 21e6ff7...b923d81. Read the comment docs.

return v, unflatten_to_Integer
end

function flatten(::Type{T}, x::R) where {T<:Real,R<:Real}
v = T[x]
unflatten_to_Real(v::Vector{T}) = convert(R, only(v))
unflatten_to_Real(v::AbstractVector{<:Real}) = only(v)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this line will change the current behavior for unflatten quite drastically:

x = (a = 1., b = [2., 3.], c = [4 5 ; 6 7])
typeof(x.b) #Vector{Float64}
xvec, unflat = flatten(Float16, x)
x2 = unflat(xvec)
typeof(x2.b) #Vector{Float16}

I don't think there is any other way though to facilitate AD while keeping initial parameter types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed -- very breaking. From my perspective in terms of how I tend to use ParameterHandling in practice, the new behaviour is more helpful anyway though 🤷

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no way around it if you want to work with AD here. Type changes could be a problem if you define a concrete container to collect samples (e.g. MCMC) of your model parameter.

@@ -217,7 +219,9 @@ value(X::PositiveDefinite) = A_At(vec_to_tril(X.L))

function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to the commit, but I think most Statistics packages work with the upper triangular, if I am not mistaken?

using Distributions, LinearAlgebra
Σ1 = UpperTriangular([1. .5; .5 1.])
Σ2 = LowerTriangular([1. .5; .5 1.])
Symmetric(Σ1) # 1 0.5 0.5 1
Symmetric(Σ2) # 1 0.0 0.0 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I agreee with your assertion that most packages use the upper triangle in the Julia ecosystem, but I'm not sure that it's a problem here, because the user should interact with the PositiveDefinite type via the positive_definite function, which just requires that the user provide a StridedMatrixwhich is positive definite (we should probably widen that to include Symmetric matrices...). Once inside that functionality, asking for the L field of a Cholesky is fine, albeit it may not be totally optimal.

Or maybe I've misunderstood where you're coming from with this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. It also should probably be the job of the user to check if his transformations make sense.

Copy link
Member

@st-- st-- Nov 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asking for the L field of a Cholesky is fine, albeit it may not be totally optimal.

Where "not totally optimal" in this case is "an extra copy() of the entire matrix" (JuliaLang/julia#42920, and you can use workarounds such as PDMats.chol_lower, though note that the bugfix for JuliaStats/PDMats.jl#143 resulted in yet another AD issue unfortunately)...

@paschermayr
Copy link

Worked for me - I just added a few comments but none of them should influence the merge. Thank you for your work!

@@ -74,26 +78,26 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real}
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(length, x_vecs)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v::Vector{T})
function unflatten_to_Tuple(v::AbstractVector{<:Real})
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[(s - l + 1):s])
Copy link

@paschermayr paschermayr Sep 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be amazing if we can somehow find a way to make a @view here, so we do not generate a new vector for each argument in the tuple. The problem is that if we have to call NamedTuple{names}(v_vec_vec) instead of typeof(x)(v_vec_vec) in the NamedTuple dispatch below, we will get back different types for everything bar scalar parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I agree that that would be nice. Would you mind opening a separate issue to discuss further? I'd rather keep it out of scope for this PR.

@rofinn
Copy link
Contributor

rofinn commented Sep 14, 2021

My 2 cents:

  1. I agree that these type constraints were perhaps a bit too restrictive, due to instabilities in Zygote that were hard to resolve.
  2. I haven't looked at this code in almost a year, but the general idea of loosening the types to allow things like duals to work with unflatten seems reasonable.
  3. Would it be possible to introduce a strict=false keyword that certains types of performance code could use to enforce symmetry/stability in both flatten and unflatten? I worry that the performance benefits we saw dependend on that, though I really don't remember anymore. I'd also be fine with it not being the default since reducing the precision is a bit of a niche use-case.
  4. Would you mind including a few benchmarks in the case where you want to use reduced precision?

@paschermayr
Copy link

paschermayr commented Sep 15, 2021

My 2 cents:
3. Would it be possible to introduce a strict=false keyword that certains types of performance code could use to enforce symmetry/stability in both flatten and unflatten? I worry that the performance benefits we saw dependend on that, though I really don't remember anymore. I'd also be fine with it not being the default since reducing the precision is a bit of a niche use-case.

Having a strict=true for unflatten in the non-AD case might be a good idea, we could use views in this case. One would have to adjust map(flatten, x) with a map(x) do flatten(x, strict) end block to have the same performance as before.

@paschermayr
Copy link

@willtebbutt : I think I managed to implement a method that allows us to keep initial types and lets us work with AD. I uploaded a version here: https://github.com/paschermayr/Shared-Code/blob/master/parameterhandling.jl

I am not sure if this is ideal for ParameterHandling.jl, as it is optimized for the unflatten part (while flattening, buffers for unflatten are created), but I think you can adjust this easily otherwise. I haven't tested it for all the different Parameter types in ParameterHandling.jl. Example:

using BenchmarkTools
nt = (a = 1, b = [2, 3], c = Float32(4.), d = 5.)
typeof(nt.c) #Float32
nt_vec, unflat = flatten(Float16, true, nt) #Vector{Float16} with 2 elements, unflatten_to_NamedTuple
nt2 = unflat(nt_vec)
typeof(nt2.c) #Float32
@btime $unflat($nt_vec) #20.942 ns (0 allocations: 0 bytes)

#For AD no type conversion:
nt_vec, unflat = flatten(Float64, false, nt) #Vector{Float64} with 2 elements, unflatten_to_NamedTuple
nt2 = unflat(nt_vec)
typeof(nt2.c) #Float64

@willtebbutt
Copy link
Member Author

Note: I've not forgotten about this PR. I'm currently swamped with PhD work, and will return to it when I get some time. @paschermayr how urgent is this for you? Are you happy to work on this branch / with your work-around for now, or do you need a release?

@paschermayr
Copy link

@willtebbutt Not urgent at all, happy to work with what I have. Thank you in any case!

@theogf
Copy link
Member

theogf commented Jan 21, 2022

Is there a lot left to do? I somehow need this feature 😝

@willtebbutt
Copy link
Member Author

To be honest, I'm not entirely sure. It's dropped off my radar, somewhat, and I'm not going to have time to properly look at it until I've submitted.

@simsurace
Copy link
Member

This would be very useful if it were merged, as using ForwardDiff or ReverseDiff instead of Zygote can lead to a massive improvement in gradient evaluation:
a quick benchmark with a sparse variational GP produced these numbers:

@btime loss($θ_flat) # 188.774 μs (123 allocations: 79.52 KiB)
@btime ForwardDiff.gradient($loss, $θ_flat) # 2.389 ms (1231 allocations: 5.80 MiB)
@btime ReverseDiff.gradient($loss, $θ_flat) # 8.220 ms (308089 allocations: 13.17 MiB)
@btime Zygote.gradient($loss, $θ_flat) # 36.134 ms (421748 allocations: 15.25 MiB)

@paschermayr
Copy link

Since this PR started, I have created another package, because my needs were slightly different than the ParameterHandling.jl case, https://github.com/paschermayr/ModelWrappers.jl . I managed to incorporate all possible cases (Flatten/unflatten performant / AD compatible / Taking into account Integers) by using a separate struct as argument in the flatten function that has all kinds of configurations. The exact specifications can be seen here:
https://github.com/paschermayr/ModelWrappers.jl/blob/main/src/Core/constraints/flatten/flatten.jl

Note that this package is optimized for the case when unflatten is performed more often (which can often be performed with 0 allocations by creating buffers while flattening), and is quite dependency heavy as I integrated some AD use cases, but maybe a similar solution could be implemented in ParameterHandling to take care of most corner cases without performance loss.

@st--
Copy link
Member

st-- commented Mar 29, 2022

@paschermayr that looks great! thanks for sharing. I was thinking about how to assign priors to parameters ... glad someone already started working it out: )

@simsurace
Copy link
Member

ModelWrappers.jl looks great, but does it provide the same functionality? E.g. ParameterHandling.positive_definite is something that I use a lot. IMHO this PR should still be merged. I wonder what is missing, since all tests are passing. Is there an important test case missing?

@paschermayr
Copy link

paschermayr commented Mar 30, 2022

ModelWrappers.jl looks great, but does it provide the same functionality? E.g. ParameterHandling.positive_definite is something that I use a lot. IMHO this PR should still be merged. I wonder what is missing, since all tests are passing. Is there an important test case missing?

A similar functionality, but the focus is different and my package is much more dependency heavy at the moment. I just linked it to show one possible solution so that AutoDiff can be applied both for the flatten and unflatten case (this was the reason I created ModelWrappers in the first place). I would also like this to be merged if possible, ideally so that Autodiff works in both directions, but if it only works for flatten for now, that would be fine too.

As for the other question, any Bijector for a Matrixdistribution that satisfies your constraints should work here - I also implemented a CorrelationMatrix and CovarianceMatrix transformer separately.

using ModelWrappers

mat = [1.0 .2 ; .2 3.0]
constraint = CovarianceMatrix()
model = ModelWrapper((Σ = Param(mat, constraint), ))

mat_flat = flatten(model) #Vector{Float64} with 3 elements 1.00, 0.200, 3.00
mat_unflat = unflatten(model, mat_flat) #(Σ = [1.0 0.2; 0.2 3.0],)

θᵤ = unconstrain_flatten(model) #Vector{Float64} with 3 elements 0.00 0.200 0.543…
unflatten_constrain(model, θᵤ) #(Σ = [1.0 0.2; 0.2 3.0],)

@jariji
Copy link

jariji commented Feb 16, 2024

This PR does what I'm looking for. A small example that fails before the PR and works after:

using ParameterHandling, Optim
let
    f((;x)) = x^2
    θ₀ = (;x = 4.0)
    flat_θ, unflatten = ParameterHandling.value_flatten(θ₀)
    opt = optimize(funflatten, flat_θ, LBFGS(); autodiff=:forward)
    @test only(Optim.minimizer(opt))  0.0
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants