Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracting flatten/unflatten into lightweight package? #43

Open
oschulz opened this issue Oct 13, 2021 · 38 comments
Open

Extracting flatten/unflatten into lightweight package? #43

oschulz opened this issue Oct 13, 2021 · 38 comments

Comments

@oschulz
Copy link

oschulz commented Oct 13, 2021

In the spirit of creating lightweight interface-defining packages (see TuringLang/Bijectors.jl#199 which resulted in InverseFunctions.jl and ChangesOfVariables.jl):

While adding an interface-test utility to ChangesOfVariables.jl today, I needed a helper function _to_realvec_and_back - which does exactly what ParameterHandling.flatten does. And I needed it for exactly what #27 is requesting. :-)

It would be nice not to reinvent this, but ParameterHandling would of course be way to heavy a dependency for ChangesOfVariables - and even if it wasn't, we'd end up with a circular dependency since Bijectors will use ChangesOfVariables soon, so ParameterHandling will too.

I think the flatten/unflatten functionality of ParameterHandling does something very fundamental, and is orthogonal to it's variable transformation capabilities. Would you consider splitting it out into a lightweight package (basically the contents of flatten.jl)?

A truly lightweight recursive flatten/unflatten interface package could IMHO find use in many places in the ecosystem. flatten may be a bit too generic a name for the function (we probably have several flattens in the ecosystem), but how about flatten_and_back or so?

CC @willtebbutt, @paschermayr, @devmotion

Update: ChangesOfVariables.test_with_logabsdet_jacobian now has an additional optional argument to pass a transformation, which solves the dependency problem. Users will be able to use ParameterHandling for variable transformations during the test without a direct dependency between the two packages (after #27 is solved).

@willtebbutt
Copy link
Member

How would you feel about depending on ParameterHandling once #42 goes in?

@oschulz
Copy link
Author

oschulz commented Oct 13, 2021

How would you feel about depending on ParameterHandling once #42 goes in?

I think even with #42 it think it would still be too heavy - the idea is to keep ChangesOfVariables super-lightweight, so we can convince packages to depend on it, and the stuff in parameters.jl is really a different thing - here, I think, ParameterHandling would depend on ChangesOfVariables, not the other way round. And ParameterHandling is a more opinionated package than ChangesOfVariables.

But now that I think about it, since it looks like we need flatten/unflatten functionality in ChangesOfVariables, how would you like moving that there? It's not much more than I have in the draft of JuliaMath/ChangesOfVariables.jl#2 anyway, and I think it would fit into the theme of the package - it is a change of variables, after all (with a logabsdet-jacobian of 0, of course).

@willtebbutt and @devmotion, what do you think?

Update: ParameterHandling will likely depend on ChangesOfVariables indirectly in the future due to #42, since we hope to get support for ChangesOfVariables into LogExpFunctions (esp for logistic, logit etc.). So ChangesOfVariables must not depend on ParameterHandling.

@devmotion
Copy link
Member

I'm not sure, it's not clear to me that this functionality should be included in JuliaMath/ChangesOfVariables.jl#2. I would prefer if such wrapping/reshaping/flattening etc. is handled by the user.

@oschulz
Copy link
Author

oschulz commented Oct 13, 2021

Yes, the more I think about it, it's really a bit of a different thing. And we could use test_with_logabsdet_jacobian(f, x, torealvec_and_back, getjacobian), that would keep the whole thing out of the dependencies of ChangesOfVariables.

@oschulz
Copy link
Author

oschulz commented Oct 13, 2021

I still think it's a very fundamental functionality that may deserve a central package independent from variable-domain transformations.

I wonder if we should add an interface InverseFunctions.with_inverse(f, x) with a default implementation of InverseFunctions.with_inverse(f, x) = (f(x), inverse(f)), to support use cases like flatten/unflatten that can't do inverse(f) without a value (they would specialize with_inverse)?

That would form a nice basis for a package that depends on InverseFunctions and defines a function realvec and with_inverse(::typeof(realvec), x). There are definitely use cases where one would only want to forward function and instantiating the back function would be wasteful (mem allocs). This way, we'd have an API than can do flatten and flatten-and-back.

@willtebbutt
Copy link
Member

In principle I could get on board with this kind of thing.

My main concern is that if we choose to make this functionality more widely available it'll be difficult to get consensus on some semantics. For example, flatten(1) returns an empty vector and a closure. The rationale being that this is what you want if you're doing AD and 1 is a primal (as opposed to a tangent vector), because integers aren't differentiable. I can well imagine that someone somewhere would (quite reasonably) find this choice objectionable, and want different behaviour.

So I would want to generalise what we have here a little bit to allow different semantics depending on your situation. Maybe utilising a trait-based mechanism like the one used for RuleConfigs in ChainRulesCore?

So you would write something like

struct Flattener{T} end

struct KeepIntegers end

# These have different behaviour.
with_inverse(Flattener{}(), 1)
with_inverse(Flattener{KeepIntegers}(), 1) = [1.0]

in case you want to work around this stuff?

@oschulz
Copy link
Author

oschulz commented Oct 14, 2021

Yes, thought some more about this today and also started to worry about integers. With the autodiff use case, as an extension to excluding integers, people may also want to exclude other parameters from the gradient calculation (if they are not relevant, to increase performance), so one might want to have a way to mark parameters as active/constant. That could quickly go beyond "non-opinionated central API" indeed ...

@willtebbutt
Copy link
Member

so one might want to have a way to mark parameters as active/constant.

We actually already have a way to mark parameters as being constant (see fixed), although doing things the other way around (marking things as "active") is something that we don't currently have support for -- although we could almost certainly achieve using a trait-based system.

@oschulz
Copy link
Author

oschulz commented Oct 14, 2021

I think it would be nice to explore these options in ParameterHandling (especially since it's going to me more lightweight soon). If we add InverseFunctions.with_inverse (JuliaMath/InverseFunctions.jl#6 - @devmotion?), your Flattener concept could be both powerful and user-friendly, I think.

A central "flatten-only" package could come later then, maybe? As soon as #27 is sorted out, people will already be able to ParameterHandling in ChangesOfVariables.test_with_logabsdet_jacobian (changed now), without ChangesOfVariables depending on ParameterHandling. And if we go for the with_inverse approach, I would simplify the rv_and_back = x -> (x, identity) option to transform = identity - much more elegant.

@oschulz
Copy link
Author

oschulz commented Oct 14, 2021

We actually already have a way to mark parameters as being constant (see fixed)

Neat!

One advantage with the InverseFunctions.with_inverse approach would be that Flattener() itself would just return the value, not a tuple of value and back-function. So it could support ChangesOfVariables.with_logabsdet_jacobian. While Flattener() would have a trivial volume-element, it could be interesting for ValueFlattener() as a successor of value_flatten.

@oschulz
Copy link
Author

oschulz commented Oct 15, 2021

@devmotion pointed me to https://github.com/JuliaDiff/FiniteDifferences.jl/blob/main/src/to_vec.jl

There's probably equivalents of this in several places in the ecosystem, internally. If we had a nice, lightweight central API for a to-real-vec-and-back, people could make their types support it, that would make a lot of AD a bit easier. It's handling integers and constant parameters though that makes it less straightforward than I thought initialy - but I think it could (and maybe should) still be decoupled from actual value transformations.

CC @mzgubic @oxinabox

@mzgubic
Copy link

mzgubic commented Oct 15, 2021

Just to add that in the ChainRules ecosystem there is a vague plan for moving away from to_vec.

@oxinabox
Copy link

I think there is another copy of this general idea somewhere in ArrayInterface.jl?
cc @ChrisRackauckas

@paschermayr
Copy link

@oschulz : There are 2 issues that come to my mind:

  1. How do you handle the unflatten part? If AD is in mind, then the output type is determined by the input type, but for pretty much anything else, the return type should be given by the initial container type that was a argument in flatten. We had a quick discussion here: Overloading-AD-Friendly Unflatten #39 . I believe the main use case for ParameterHandling.jl is the flatten part, so in this case that would be not much of an issue.

  2. How do you handle integers? Many people work with Integers/Vectors of Integers as parameter arguments, but at the moment they are not flattened. I think this is correct, because the unflatten part would cause all sorts of problems (Integers and Floats would probably be flattened to Floats), but I can see arguments against this case as well.

@oschulz
Copy link
Author

oschulz commented Oct 15, 2021

How do you handle the unflatten part? If AD is in mind, then the output type is determined

Hm, I guess in many cases, AD would just run on the flattened result - it would then (e.g.) be combined with the flat-gradient and then reconstructed into the original type. So with ForwardDiff there wouldn't be a problem. Reverse mode though - good question.

How do you handle integers? Many people work with Integers/Vectors of Integers as parameter arguments

Yes, integers and values that are supposed to be constant in general. That's the main challenge, I think - AD use-cases will typically want to assume that integers be constant, other use cases may see this differently. ParameterHandling has fixed for explicit control.

I don't have ready answers, I have to admit. But the fact that we have this "flatten-to-real" in so many places seems like a good motivation to have a lightweight central package that people would be willing to depend on (separate from value domain transformations). Once we have figured out the answers to those questions, that is. :-)

@ToucheSir
Copy link

For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening: https://github.com/FluxML/Optimisers.jl/blob/master/src/destructure.jl. We opted for the conservative approach and only consider non-integer numeric arrays (due in part to Flux's design constraints). Something like fixed would be interesting, but the main challenge with such a wrapper is making it transparent to non-parameter handling code.

@oschulz
Copy link
Author

oschulz commented Mar 8, 2022

For the interested, here's how Optimisers.jl handles reverse mode friendly (un)flattening

This is build on top of Functors.jl mainly, right?

@ToucheSir
Copy link

Yes, Functors.jl is an integral part of it but there's no reason a similar set of functionality couldn't be developed for another parameter handling library :)

@oschulz
Copy link
Author

oschulz commented Apr 4, 2022

While revamping ForwardDiffPullbacks, I had to add "invent" yet another flatten/unflatten mechanism. To get full performance and type stability (this needs to be really fast and allocation-free with deeply nested structures, flattening to static vectors), I ended up with

  • flatten(x)
  • unflatten(x_orig, x_flat)
  • unflatten_tangent(x_orig, dx_flat)

An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise, there was trouble when things were nested more deeply (though maybe I did it wrong, I also tried via Flatten.jl but that also didn't work out). Also the x_flat, re = ... pattern didn't provide tangent-unflatten capability.

The problem is that with these flatten/unflatten capabilities "hidden" in several packages in different ways, it's near impossible for users to specify flatten/unflatten for types that need special handling without creating a dependency nightmare. I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases?

@paschermayr
Copy link

An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise

I guess you could make methods in case reconstruct_function is not needed.

@oschulz
Copy link
Author

oschulz commented Apr 4, 2022

I think I should link to the current discussion regarding ConstructionBase.getfields/ConstructionBase.getproperties (JuliaObjects/ConstructionBase.jl#54) here, I feel these issues are very connected.

@ToucheSir
Copy link

ToucheSir commented Apr 4, 2022

An approach like x_flat, reconstruct_function = flatten(x) didn't work out performance/type-stability wise, there was trouble when things were nested more deeply (though maybe I did it wrong, I also tried via Flatten.jl but that also didn't work out). Also the x_flat, re = ... pattern didn't provide tangent-unflatten capability.

The implementation isn't type stable because of unrelated factors (mostly caching in Functors.jl), but Optimisers.destructure does basically this. Saving some auxiliary state during flattening pays off when # reconstructions > # of flattenings. What does not work well is returning a plain closure, but I'm preaching to the choir here :)

I wonder if we could come up with a generic lightweight thing similar to ChainRulesCore that would satisfy all use cases?

I've thought about this and it's just genuinely hard because of how broad and nuanced "(un)flattening" is. Some examples to chew on:

  • Flux wants to flatten arbitrary structs with non-flattenable fields without forcing them to be generic in those fields like Flatten.jl requests.
  • Accounting for parameters that are "shared" or "tied" is hard. Annotating them with a wrapper in the source parameter collection can work, but messes with dispatch/type constraints in unsuspecting downstream consumers of that collection. Trying to do so automatically turned out to be a rabbit hole we didn't want to mess with in Optimisers.jl.
  • The same point about wrappers and dispatch applies to constrained parameters as exist in ParameterHanding.jl or ModelParameters.jl.
  • It may not be possible to derive the fixed-ness of a given subtree's children via the type alone. Here too we see a range of very distinct solutions, including ParameterHanding.fixed and Optimisers.trainable. The latter brings up another point: for many applications we can not assume that all parameter-visiting traversals of an object tree(/DAG) will care about the same set of fields flattening does. This is the motivation behind StructWalk.jl's WalkStyle.

@oschulz
Copy link
Author

oschulz commented Apr 4, 2022

I agree, it's a tricky, multi-faceted problem. But I still think we have too many competing solutions in the ecosystem right now. Maybe one could somehow factorize this into an API for struct-developers that allows them to "annotate" their structs so we need less heuristics/guesswork, and a set of flatten/unflatten APIs for engine/algorithms-developers that make use of those "annotations"? I don't have a concrete proposal, I just feel that some ChainRulesCore-like (in spirit, not functionality) standard in this area is really missing in the ecosystem (or possibly the language itself) at the moment.

@ToucheSir
Copy link

AFAIK something like https://github.com/rafaqz/FieldMetadata.jl could be that standard, but the note at the top of the README seems to suggest differently :/

@oschulz
Copy link
Author

oschulz commented Apr 4, 2022

@rafaqz can we pull you in here as well?

@rafaqz
Copy link

rafaqz commented Apr 5, 2022

FieldMetadata.jl is a cool idea, but it sets metadata by defining new functions on an object. So you have method table state to think about if you ever want to change anything during use. Its also a lot of confusing and fragile macros for people to understand, that dont scale to organisation/research group level use very well.

ModelParameters.jl is a better solution. You can do most of the same things but state is contained in the object, and has a Tables.jl interface.

What do you need from this package that you can't do with ModelParameters?

@rafaqz
Copy link

rafaqz commented Apr 5, 2022

Additionally, truly lightweight recursive rebuilding is possible with my PR to Accessors.jl. And thats probably the most generic place to put it. It hasn't been merged because Im too busy and we can't get it type stable because of issues with Base no longer inferring recursive methods.

@oschulz
Copy link
Author

oschulz commented Apr 5, 2022

What do you need from this package that you can't do with ModelParameters?

If you have a type that needs a bit of special handling, and you want it to be compatible with a wide variety of Julia ML, statistics & friends packages, and you want GPU + GPU support, you currently have to depend on and specialize functionality defined in: Adapt, ConstructionBase, Functors, ParameterHandling, and possibly a few others. To my knowledge, none of those draw on each other for their default implementations. This doesn't compose well and forces packages to either take on a lot of dependencies or not be compatible with parts of the ecosystem. And it can result in a lot of boilerplate code.

As a type developer, I'll probably just want to implement a simple, closure-free API like

get_raw_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
get_semantic_contents(x::MyType)::Union{Real,Tuple,AbstractArray,NamedTuple}
reconstruct_from(::Type{<:MyType}, ::ResultOfGetRawContents)
reconstruct_from(::Type{<:MyType}, ::ResultOfGetSemanticContents)

The contract here would be

reconstruct_from(typeof(x), get_raw_contents(x) == x
reconstruct_from(typeof(x), get_semantic_contents(x) == x

reconstruct_from should accept contents with a different numeric precision and array types as the original x, of course, if supported by the type.

A package like ConstructionBase would seem a natural place to host such an API.

get_raw_contents and get_semantic_contents could return tuples, NamedTuples, Reals (even all Numbers?) and Arrays as-is, and return a NamedTuple for structs - get_raw_contents could use fieldcount/fieldname/getfieldandget_semantic_contentscould usepropertynames/getproperty`.

A package like Adapt would probably want to use get_raw_contents for it's default implementation, whereas packages like Functors and ParameterHandling would want get_semantic_contents I guess.

@rafaqz
Copy link

rafaqz commented Apr 5, 2022

Yes ConstructionBase.jl is the natural place for this, and sharing these base methods across more packages was the original reason for it to be written. Flatten.jl, Setfield.jl and I think BangBang.jl all needed it. Fixing any inadequacies it has to make it useful in these other packages is surely within scope.

In your schema I guess getrawcontents == ConstuctionsBase.getfields (in the current PR by @jw3126) and get_semantic_contents == ConstructionBase.getproperties ?

Also, Adapt.jl is already kind of redundant. You can replace it with Flatten.jl or Accessors.jl with that PR. I do this in some GPU based packages like DynamicGrids.jl, but the current problem with types stability in base makes this less of an option than it used to be.

@oschulz
Copy link
Author

oschulz commented Apr 5, 2022

Also, Adapt.jl is already kind of redundant.

I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-(

@oschulz
Copy link
Author

oschulz commented Apr 5, 2022

In your schema I guess getrawcontents == ConstuctionsBase.getfields (in the current PR by @jw3126) and get_semantic_contents == ConstructionBase.getproperties ?

Yes, that was my idea. If there's room for such a lightweight, closure-free API in ConstructionBase I'd be very happy to pitch in!

@ToucheSir
Copy link

What do you need from this package that you can't do with ModelParameters?

My understanding of ModelParameters is that model struct types must be able to accept Params as fields. For philosophical and practical reasons, Flux can not require model structs to use framework-defined types for the purpose of parameter tracking. Instead, this information is currently kept out-of-band via functions like functor or trainable.

Now this isn't necessarily set in stone, but I've yet to see a satisfactory solution to the rewrite-the-world-to-work-with-[param wrapper type(s)] problem.


I don't disagree, but many packages use/support it and AFAIK Adapt doesn't use Flatten, Accessor or any other in it's default implementation. :-(

In defense of Adapt, I think the type stability is more important than the potentially lost flexibility. There's also the deeper question of whether (un)flattening is the right implementation strategy for something like adapt_structure. In FluxML/Functors.jl#27, I tried exploring what would happen if we used a more FP-inspired structural map (think Flatten.modify) as the core primitive. This turned out to have its own set of challenges, but the point is that both the problem and solution space are quite broad!


@oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of get_semantic_contents for the same type?

A concrete motivating example may be found in a NN module system. Getting all parameters of a layer may be handled by get_raw_contents, but we'd also want get_trainable_contents for optimization and get_trivially_serializable_contents for saving layer state.

@jw3126
Copy link

jw3126 commented Apr 5, 2022

how would you foresee handling multiple different versions of get_semantic_contents for the same type?

IMO multiple get_semantic_contents on the same type are not the job of ConstructionBase. Analogous to Base not providing multiple variants of getproperty on a single type. Instead I think lenses are a great abstraction for this, so you could use Accessors.jl or Setfield.jl

@oschulz
Copy link
Author

oschulz commented Apr 5, 2022

@oschulz the closure-free API proposal above looks very interesting, one question for now: how would you foresee handling multiple different versions of get_semantic_contents for the same type?

Thanks!

Multiple different versions of get_semantic_contents for the same type in what respect? Could you give a quick example?

@ToucheSir
Copy link

ToucheSir commented Apr 5, 2022

I put a relevant example in the comment above, but the gist is that not all traversals we make over a nested object DAG tree (I'm being cheeky here, but handling "shared" nodes is a big part of the Functors codebase) will want to access the same set of fields. Thus defining a single get_semantic_contents function has limited value for our use cases unless the definition of that function is sufficiently general. But then if it's too general, there doesn't seem to be much of a difference between get_semantic_contents and get_raw_contents, AIUI.

This is why I mentioned StructWalk earlier: by parameterizing the (un)flatten function with the type of traversal being performed (WalkStyle), the traversal code can be decoupled from the code that extracts the contents of each node. Generic fallbacks mean that no convenience is lost for types that return the same contents regardless of the purpose of the traversal. Is it possible to do something similar with lenses?

@oschulz
Copy link
Author

oschulz commented Apr 5, 2022

but handling "shared" nodes is a big part of the Functors codebase

That would be part of the nested walk scheme relevant to the application though, right, and orthogonal to the "get content of this type" API?

by parameterizing the (un)flatten function with the type of traversal being performed

Ah, I think I get it @ToucheSir . So you mean instead of having get_raw_contents(x) and get_semantic_contents(x) we'll need something like get_contents(x, decomposition_style), with a matching reconstruct_from(T, content, decomposition_style)?

Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them?

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

With such a proposed closure-free API in ConstructionBase, would ParameterHandling adopt it? And what would be needed (see discussion at end of ConstructionBase.jl#54).

@ToucheSir
Copy link

Do you think there would be a finite number of such "decomposition styles" and/or a kind of hierarchy between them so a type developer won't need to know (and specialize for) all of them?

Yes. The default could be identical to get_semantic_contents(x). I'll pick up the rest on the ConstructionBase thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants