-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extracting flatten/unflatten into lightweight package? #43
Comments
How would you feel about depending on ParameterHandling once #42 goes in? |
I think even with #42 it think it would still be too heavy - the idea is to keep ChangesOfVariables super-lightweight, so we can convince packages to depend on it, and the stuff in But now that I think about it, since it looks like we need flatten/unflatten functionality in ChangesOfVariables, how would you like moving that there? It's not much more than I have in the draft of JuliaMath/ChangesOfVariables.jl#2 anyway, and I think it would fit into the theme of the package - it is a change of variables, after all (with a logabsdet-jacobian of 0, of course). @willtebbutt and @devmotion, what do you think? Update: ParameterHandling will likely depend on ChangesOfVariables indirectly in the future due to #42, since we hope to get support for ChangesOfVariables into LogExpFunctions (esp for |
I'm not sure, it's not clear to me that this functionality should be included in JuliaMath/ChangesOfVariables.jl#2. I would prefer if such wrapping/reshaping/flattening etc. is handled by the user. |
Yes, the more I think about it, it's really a bit of a different thing. And we could use |
I still think it's a very fundamental functionality that may deserve a central package independent from variable-domain transformations. I wonder if we should add an interface That would form a nice basis for a package that depends on |
In principle I could get on board with this kind of thing. My main concern is that if we choose to make this functionality more widely available it'll be difficult to get consensus on some semantics. For example, So I would want to generalise what we have here a little bit to allow different semantics depending on your situation. Maybe utilising a trait-based mechanism like the one used for So you would write something like struct Flattener{T} end
struct KeepIntegers end
# These have different behaviour.
with_inverse(Flattener{}(), 1)
with_inverse(Flattener{KeepIntegers}(), 1) = [1.0] in case you want to work around this stuff? |
Yes, thought some more about this today and also started to worry about integers. With the autodiff use case, as an extension to excluding integers, people may also want to exclude other parameters from the gradient calculation (if they are not relevant, to increase performance), so one might want to have a way to mark parameters as active/constant. That could quickly go beyond "non-opinionated central API" indeed ... |
We actually already have a way to mark parameters as being constant (see |
I think it would be nice to explore these options in ParameterHandling (especially since it's going to me more lightweight soon). If we add A central "flatten-only" package could come later then, maybe? As soon as #27 is sorted out, people will already be able to ParameterHandling in |
Neat! One advantage with the |
@devmotion pointed me to https://github.com/JuliaDiff/FiniteDifferences.jl/blob/main/src/to_vec.jl There's probably equivalents of this in several places in the ecosystem, internally. If we had a nice, lightweight central API for a to-real-vec-and-back, people could make their types support it, that would make a lot of AD a bit easier. It's handling integers and constant parameters though that makes it less straightforward than I thought initialy - but I think it could (and maybe should) still be decoupled from actual value transformations. |
Just to add that in the ChainRules ecosystem there is a vague plan for moving away from |
I think there is another copy of this general idea somewhere in ArrayInterface.jl? |
@oschulz : There are 2 issues that come to my mind:
|
Hm, I guess in many cases, AD would just run on the flattened result - it would then (e.g.) be combined with the flat-gradient and then reconstructed into the original type. So with
Yes, integers and values that are supposed to be constant in general. That's the main challenge, I think - AD use-cases will typically want to assume that integers be constant, other use cases may see this differently. ParameterHandling has I don't have ready answers, I have to admit. But the fact that we have this "flatten-to-real" in so many places seems like a good motivation to have a lightweight central package that people would be willing to depend on (separate from value domain transformations). Once we have figured out the answers to those questions, that is. :-) |
For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening: https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl. We opted for the conservative approach and only consider non-integer numeric arrays (due in part to Flux's design constraints). Something like |
This is build on top of Functors.jl mainly, right? |
Yes, Functors.jl is an integral part of it but there's no reason a similar set of functionality couldn't be developed for another parameter handling library :) |
While revamping ForwardDiffPullbacks, I had to add "invent" yet another flatten/unflatten mechanism. To get full performance and type stability (this needs to be really fast and allocation-free with deeply nested structures, flattening to static vectors), I ended up with
An approach like The problem is that with these flatten/unflatten capabilities "hidden" in several packages in different ways, it's near impossible for users to specify flatten/unflatten for types that need special handling without creating a dependency nightmare. I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases? |
I guess you could make methods in case |
I think I should link to the current discussion regarding |
The implementation isn't type stable because of unrelated factors (mostly caching in Functors.jl), but
I've thought about this and it's just genuinely hard because of how broad and nuanced "(un)flattening" is. Some examples to chew on:
|
I agree, it's a tricky, multi-faceted problem. But I still think we have too many competing solutions in the ecosystem right now. Maybe one could somehow factorize this into an API for struct-developers that allows them to "annotate" their structs so we need less heuristics/guesswork, and a set of flatten/unflatten APIs for engine/algorithms-developers that make use of those "annotations"? I don't have a concrete proposal, I just feel that some ChainRulesCore-like (in spirit, not functionality) standard in this area is really missing in the ecosystem (or possibly the language itself) at the moment. |
AFAIK something like https://github.com/rafaqz/FieldMetadata.jl could be that standard, but the note at the top of the README seems to suggest differently :/ |
@rafaqz can we pull you in here as well? |
FieldMetadata.jl is a cool idea, but it sets metadata by defining new functions on an object. So you have method table state to think about if you ever want to change anything during use. Its also a lot of confusing and fragile macros for people to understand, that dont scale to organisation/research group level use very well. ModelParameters.jl is a better solution. You can do most of the same things but state is contained in the object, and has a Tables.jl interface. What do you need from this package that you can't do with ModelParameters? |
Additionally, truly lightweight recursive rebuilding is possible with my PR to Accessors.jl. And thats probably the most generic place to put it. It hasn't been merged because Im too busy and we can't get it type stable because of issues with Base no longer inferring recursive methods. |
If you have a type that needs a bit of special handling, and you want it to be compatible with a wide variety of Julia ML, statistics & friends packages, and you want GPU + GPU support, you currently have to depend on and specialize functionality defined in: Adapt, ConstructionBase, Functors, ParameterHandling, and possibly a few others. To my knowledge, none of those draw on each other for their default implementations. This doesn't compose well and forces packages to either take on a lot of dependencies or not be compatible with parts of the ecosystem. And it can result in a lot of boilerplate code. As a type developer, I'll probably just want to implement a simple, closure-free API like get_raw_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
get_semantic_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
reconstruct_from(::Type{<:MyType}, ::ResultOfGetRawContents)
reconstruct_from(::Type{<:MyType}, ::ResultOfGetSemanticContents) The contract here would be reconstruct_from(typeof(x), get_raw_contents(x) == x
reconstruct_from(typeof(x), get_semantic_contents(x) == x
A package like ConstructionBase would seem a natural place to host such an API.
A package like Adapt would probably want to use |
Yes ConstructionBase.jl is the natural place for this, and sharing these base methods across more packages was the original reason for it to be written. Flatten.jl, Setfield.jl and I think BangBang.jl all needed it. Fixing any inadequacies it has to make it useful in these other packages is surely within scope. In your schema I guess Also, Adapt.jl is already kind of redundant. You can replace it with Flatten.jl or Accessors.jl with that PR. I do this in some GPU based packages like DynamicGrids.jl, but the current problem with types stability in base makes this less of an option than it used to be. |
I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-( |
Yes, that was my idea. If there's room for such a lightweight, closure-free API in ConstructionBase I'd be very happy to pitch in! |
My understanding of ModelParameters is that model struct types must be able to accept Now this isn't necessarily set in stone, but I've yet to see a satisfactory solution to the rewrite-the-world-to-work-with-[param wrapper type(s)] problem.
In defense of Adapt, I think the type stability is more important than the potentially lost flexibility. There's also the deeper question of whether (un)flattening is the right implementation strategy for something like @oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of A concrete motivating example may be found in a NN module system. Getting all parameters of a layer may be handled by |
IMO multiple |
Thanks! Multiple different versions of get_semantic_contents for the same type in what respect? Could you give a quick example? |
I put a relevant example in the comment above, but the gist is that not all traversals we make over a nested object This is why I mentioned StructWalk earlier: by parameterizing the (un)flatten function with the type of traversal being performed ( |
That would be part of the nested walk scheme relevant to the application though, right, and orthogonal to the "get content of this type" API?
Ah, I think I get it @ToucheSir . So you mean instead of having Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them? |
With such a proposed closure-free API in ConstructionBase, would ParameterHandling adopt it? And what would be needed (see discussion at end of ConstructionBase.jl#54). |
Yes. The default could be identical to |
In the spirit of creating lightweight interface-defining packages (see TuringLang/Bijectors.jl#199 which resulted in InverseFunctions.jl and ChangesOfVariables.jl):
While adding an interface-test utility to ChangesOfVariables.jl today, I needed a helper function _to_realvec_and_back - which does exactly what
ParameterHandling.flatten
does. And I needed it for exactly what #27 is requesting. :-)It would be nice not to reinvent this, but
ParameterHandling
would of course be way to heavy a dependency forChangesOfVariables
- and even if it wasn't, we'd end up with a circular dependency sinceBijectors
will useChangesOfVariables
soon, soParameterHandling
will too.I think the flatten/unflatten functionality of
ParameterHandling
does something very fundamental, and is orthogonal to it's variable transformation capabilities. Would you consider splitting it out into a lightweight package (basically the contents offlatten.jl
)?A truly lightweight recursive flatten/unflatten interface package could IMHO find use in many places in the ecosystem.
flatten
may be a bit too generic a name for the function (we probably have severalflatten
s in the ecosystem), but how aboutflatten_and_back
or so?CC @willtebbutt, @paschermayr, @devmotion
Update:
ChangesOfVariables.test_with_logabsdet_jacobian
now has an additional optional argument to pass a transformation, which solves the dependency problem. Users will be able to use ParameterHandling for variable transformations during the test without a direct dependency between the two packages (after #27 is solved).The text was updated successfully, but these errors were encountered: