-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overloading-AD-Friendly Unflatten #39
base: master
Are you sure you want to change the base?
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Codecov Report
@@ Coverage Diff @@
## master #39 +/- ##
==========================================
+ Coverage 96.49% 96.57% +0.08%
==========================================
Files 4 4
Lines 171 175 +4
==========================================
+ Hits 165 169 +4
Misses 6 6
Continue to review full report at Codecov.
|
return v, unflatten_to_Integer | ||
end | ||
|
||
function flatten(::Type{T}, x::R) where {T<:Real,R<:Real} | ||
v = T[x] | ||
unflatten_to_Real(v::Vector{T}) = convert(R, only(v)) | ||
unflatten_to_Real(v::AbstractVector{<:Real}) = only(v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this line will change the current behavior for unflatten quite drastically:
x = (a = 1., b = [2., 3.], c = [4 5 ; 6 7])
typeof(x.b) #Vector{Float64}
xvec, unflat = flatten(Float16, x)
x2 = unflat(xvec)
typeof(x2.b) #Vector{Float16}
I don't think there is any other way though to facilitate AD while keeping initial parameter types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed -- very breaking. From my perspective in terms of how I tend to use ParameterHandling in practice, the new behaviour is more helpful anyway though 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no way around it if you want to work with AD here. Type changes could be a problem if you define a concrete container to collect samples (e.g. MCMC) of your model parameter.
@@ -217,7 +219,9 @@ value(X::PositiveDefinite) = A_At(vec_to_tril(X.L)) | |||
|
|||
function flatten(::Type{T}, X::PositiveDefinite) where {T<:Real} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not related to the commit, but I think most Statistics packages work with the upper triangular, if I am not mistaken?
using Distributions, LinearAlgebra
Σ1 = UpperTriangular([1. .5; .5 1.])
Σ2 = LowerTriangular([1. .5; .5 1.])
Symmetric(Σ1) # 1 0.5 0.5 1
Symmetric(Σ2) # 1 0.0 0.0 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I agreee with your assertion that most packages use the upper triangle in the Julia ecosystem, but I'm not sure that it's a problem here, because the user should interact with the PositiveDefinite
type via the positive_definite
function, which just requires that the user provide a StridedMatrix
which is positive definite (we should probably widen that to include Symmetric
matrices...). Once inside that functionality, asking for the L
field of a Cholesky
is fine, albeit it may not be totally optimal.
Or maybe I've misunderstood where you're coming from with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. It also should probably be the job of the user to check if his transformations make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
asking for the
L
field of aCholesky
is fine, albeit it may not be totally optimal.
Where "not totally optimal" in this case is "an extra copy() of the entire matrix" (JuliaLang/julia#42920, and you can use workarounds such as PDMats.chol_lower
, though note that the bugfix for JuliaStats/PDMats.jl#143 resulted in yet another AD issue unfortunately)...
Worked for me - I just added a few comments but none of them should influence the merge. Thank you for your work! |
@@ -74,26 +78,26 @@ function flatten(::Type{T}, x::Tuple) where {T<:Real} | |||
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) | |||
lengths = map(length, x_vecs) | |||
sz = _cumsum(lengths) | |||
function unflatten_to_Tuple(v::Vector{T}) | |||
function unflatten_to_Tuple(v::AbstractVector{<:Real}) | |||
map(x_backs, lengths, sz) do x_back, l, s | |||
return x_back(v[(s - l + 1):s]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be amazing if we can somehow find a way to make a @view
here, so we do not generate a new vector for each argument in the tuple. The problem is that if we have to call NamedTuple{names}(v_vec_vec)
instead of typeof(x)(v_vec_vec)
in the NamedTuple dispatch below, we will get back different types for everything bar scalar parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I agree that that would be nice. Would you mind opening a separate issue to discuss further? I'd rather keep it out of scope for this PR.
My 2 cents:
|
Having a |
@willtebbutt : I think I managed to implement a method that allows us to keep initial types and lets us work with AD. I uploaded a version here: https://github.com/paschermayr/Shared-Code/blob/master/parameterhandling.jl I am not sure if this is ideal for ParameterHandling.jl, as it is optimized for the unflatten part (while flattening, buffers for unflatten are created), but I think you can adjust this easily otherwise. I haven't tested it for all the different Parameter types in ParameterHandling.jl. Example: using BenchmarkTools
nt = (a = 1, b = [2, 3], c = Float32(4.), d = 5.)
typeof(nt.c) #Float32
nt_vec, unflat = flatten(Float16, true, nt) #Vector{Float16} with 2 elements, unflatten_to_NamedTuple
nt2 = unflat(nt_vec)
typeof(nt2.c) #Float32
@btime $unflat($nt_vec) #20.942 ns (0 allocations: 0 bytes)
#For AD no type conversion:
nt_vec, unflat = flatten(Float64, false, nt) #Vector{Float64} with 2 elements, unflatten_to_NamedTuple
nt2 = unflat(nt_vec)
typeof(nt2.c) #Float64
|
Note: I've not forgotten about this PR. I'm currently swamped with PhD work, and will return to it when I get some time. @paschermayr how urgent is this for you? Are you happy to work on this branch / with your work-around for now, or do you need a release? |
@willtebbutt Not urgent at all, happy to work with what I have. Thank you in any case! |
Is there a lot left to do? I somehow need this feature 😝 |
To be honest, I'm not entirely sure. It's dropped off my radar, somewhat, and I'm not going to have time to properly look at it until I've submitted. |
This would be very useful if it were merged, as using @btime loss($θ_flat) # 188.774 μs (123 allocations: 79.52 KiB)
@btime ForwardDiff.gradient($loss, $θ_flat) # 2.389 ms (1231 allocations: 5.80 MiB)
@btime ReverseDiff.gradient($loss, $θ_flat) # 8.220 ms (308089 allocations: 13.17 MiB)
@btime Zygote.gradient($loss, $θ_flat) # 36.134 ms (421748 allocations: 15.25 MiB) |
Since this PR started, I have created another package, because my needs were slightly different than the ParameterHandling.jl case, https://github.com/paschermayr/ModelWrappers.jl . I managed to incorporate all possible cases (Flatten/unflatten performant / AD compatible / Taking into account Integers) by using a separate struct as argument in the flatten function that has all kinds of configurations. The exact specifications can be seen here: Note that this package is optimized for the case when unflatten is performed more often (which can often be performed with 0 allocations by creating buffers while flattening), and is quite dependency heavy as I integrated some AD use cases, but maybe a similar solution could be implemented in ParameterHandling to take care of most corner cases without performance loss. |
@paschermayr that looks great! thanks for sharing. I was thinking about how to assign priors to parameters ... glad someone already started working it out: ) |
ModelWrappers.jl looks great, but does it provide the same functionality? E.g. |
A similar functionality, but the focus is different and my package is much more dependency heavy at the moment. I just linked it to show one possible solution so that AutoDiff can be applied both for the flatten and unflatten case (this was the reason I created ModelWrappers in the first place). I would also like this to be merged if possible, ideally so that Autodiff works in both directions, but if it only works for flatten for now, that would be fine too. As for the other question, any Bijector for a Matrixdistribution that satisfies your constraints should work here - I also implemented a CorrelationMatrix and CovarianceMatrix transformer separately.
|
This PR does what I'm looking for. A small example that fails before the PR and works after: using ParameterHandling, Optim
let
f((;x)) = x^2
θ₀ = (;x = 4.0)
flat_θ, unflatten = ParameterHandling.value_flatten(θ₀)
opt = optimize(f∘unflatten, flat_θ, LBFGS(); autodiff=:forward)
@test only(Optim.minimizer(opt)) ≈ 0.0
end |
Addresses #27 . @paschermayr could you confirm that it resolves your problem? I've run your example locally, but want to make sure that it does what you expect.
This is breaking. @rofinn can you see any problem doing this? It specifically changes code that you wrote. The type constraints are now only applied in the
flatten
bit -- it's assumed that inunflatten
you could reasonably want to use numbers that aren't of the exact same type is the thing that was requested inflatten
, e.g. so that you can propagate aVector{<:Dual}
through unflatten when using ForwardDiff.jl.