Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround Flux#1027 #4

Merged
merged 21 commits into from
Jul 8, 2021
Merged

Workaround Flux#1027 #4

merged 21 commits into from
Jul 8, 2021

Conversation

ericphanson
Copy link
Member

This example very much does not work:

julia> output
10×1 Matrix{Float32}:
 1.2540959f-6
 5.1405354f-5
 0.00020041084
 3.131654f-5
 5.9973813f-6
 4.5418696f-7
 5.347429f-8
 0.99964094
 3.7541522f-5
 3.067953f-5

julia> output2
10×1 Matrix{Float32}:
 0.086956024
 0.10940457
 0.09150873
 0.098902285
 0.088835925
 0.06954686
 0.06539499
 0.20102207
 0.090939134
 0.0974893

The goal of this PR is to fix things so that we can correctly (de)-serialize this model so the test passes.

@ericphanson ericphanson changed the title add larger example (broken) Workaround Flux#1027 Jul 6, 2021
@ericphanson ericphanson marked this pull request as ready for review July 6, 2021 18:59
@codecov
Copy link

codecov bot commented Jul 6, 2021

Codecov Report

Merging #4 (c9f5b61) into main (002298b) will increase coverage by 3.32%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main       #4      +/-   ##
==========================================
+ Coverage   90.90%   94.23%   +3.32%     
==========================================
  Files           1        2       +1     
  Lines          33       52      +19     
==========================================
+ Hits           30       49      +19     
  Misses          3        3              
Impacted Files Coverage Δ
src/LegolasFlux.jl 91.17% <ø> (+0.26%) ⬆️
src/functors.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 002298b...c9f5b61. Read the comment docs.

@ericphanson
Copy link
Member Author

ericphanson commented Jul 6, 2021

This works around FluxML/Flux.jl#1027 by essentially following the strategy of FluxML/Flux.jl#1509 (comment), where here I've called @ToucheSir's actually_all_params just weights and added loadweights! as an analog to loadparams! that uses these.

This adds a Flux dependency here which is a little unfortunate because downstream code might want to use the LegolasFlux schema without actually touching weights/flux stuff (e.g. looking at a bunch of losses or something), and now such code will pull in the Flux dependency. But when upstream fixes #1027 we can probably remove the flux_workarounds.jl file (and make a breaking release).

Thanks to @hannahilea for some pair debugging on this!

@ericphanson ericphanson requested review from hannahilea and kolia July 6, 2021 19:22
Project.toml Show resolved Hide resolved
examples/digits.jl Show resolved Hide resolved
Copy link

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping! I've left a couple of comments that might be helpful.

src/flux_workarounds.jl Outdated Show resolved Hide resolved
src/flux_workarounds.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@hannahilea hannahilea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Other than the renaming issue below, this seems good to go. :)

src/flux_workarounds.jl Outdated Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved
@ericphanson
Copy link
Member Author

ericphanson commented Jul 7, 2021

Thanks for the reviews!

Thanks especially for the pointers to fcollect and Functors, @ToucheSir! Super helpful. In the last two commits, I switched to using that approach and dropped the Flux dep (for the very light Functors dep). I found that fcollect as-is does not work because of this line: https://github.com/FluxML/Functors.jl/blob/adeb24bc3b2fb3e9959f1157d81f4633a855e207/src/functor.jl#L109. x in arr for arrays arr compares by value, and if we end up with several xs with the same values (say zero'd out arrays from a freshly initialized model or something like that), then we only keep 1 of each, and when it comes time to load the weights later, we don't have enough of them (this isn't hypothetical, it happened with the DigitsModel example here and had me confused for awhile). So instead in fcollect2 I switched to using an IdSet to keep track of the objects we've already added (similarly to Functors.fmap). Since an IdSet is unordered and we definitely care about the ordering, I had to also keep the original Vector{Any} used to collect the weights. Keeping them in two data structures is a little inelegant (I guess I actually want an OrderedIdSet which doesn't exist AFAIK) but should have a negligible perfomance cost in this context (since we just store references to the arrays, not copies of them).

Also, since we now have loadweights! available in the package, we might be able to make the API a little easier to use (e.g. allow a user to pass a model directly instead of making them get the weights out first themselves). But I think that should be followup work (I think it requires a little more thought about the best approach).

@ToucheSir
Copy link

Nice catch for fcollect. IMO this also a bug in the upstream implementation and we should fix it to work like fcollect2. Will have a look at that when I next find time to work on Functors.

@ericphanson
Copy link
Member Author

Ok good to hear this is likely a bug! I've filed an issue, FluxML/Functors.jl#16.

@ericphanson ericphanson requested a review from hannahilea July 7, 2021 17:14
README.md Outdated Show resolved Hide resolved
examples/digits.jl Outdated Show resolved Hide resolved
examples/digits.jl Outdated Show resolved Hide resolved
examples/digits.jl Outdated Show resolved Hide resolved
examples/digits.jl Outdated Show resolved Hide resolved
src/functors.jl Outdated Show resolved Hide resolved
src/functors.jl Outdated Show resolved Hide resolved
src/functors.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
Copy link
Contributor

@hannahilea hannahilea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...so I wrote out a whole big comment and was sure it was attached to my previous review, but now I don't see it anywhere? The gist is that weights has a very specific meaning in the context of ML, so our use of weights (in load_weights(...), weights(...), etc) should probably actually be renamed to something more like parameters or values. (Too bad params is already taken by Flux!)

Basically, while all weights are parameters, not all parameters are weights (some are biases, etc).

@ToucheSir
Copy link

(Too bad params is already taken by Flux!)

Not only that, the name params is misleading because it only returns trainable parameters. Moreover, it conflicts with BenchmarkTools.params as well. I wish we could rename it to something else, but it would be a massive back compat break because params is arguably the most used function in all of Flux (not counting re-exports from Zygote etc.)

@ericphanson
Copy link
Member Author

Ok, we've got some bikeshedding to do! Some ideas...

Getter Setter Column name
Current weights load_weights! weights
state fetch_state load_state! state
learnings fetch_learnings load_learnings! learnings

We're leaning towards state but if anyone has any further ideas or comments let me know! I will pick something and merge tomorrow morning if there aren't other objections.

@hannahilea
Copy link
Contributor

...pytorch uses state:

In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used.
(https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)

...tensorflow uses weights (huh):

A Keras model consists of multiple components:

The architecture, or configuration, which specifies what layers the model contain, and how they're connected.
A set of weights values (the "state of the model"). (https://www.tensorflow.org/guide/keras/save_and_serialize#introduction)

so maybe I sent us down this bikeshed prematurely, and it is okay to stick with weights after all... 🤷

Copy link
Contributor

@hannahilea hannahilea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...in light of the renaming discussion above, I'm going to leave the renaming (or not) to your discretion and let you merge at your leisure.

@ericphanson
Copy link
Member Author

ericphanson commented Jul 8, 2021

Ok cool! Thanks for looking up how others do it, that was a great idea.

I think then since weights seems acceptable I'd like to go with that. We only allow (multi-dimensional) arrays with all the same element type, which is far from arbitrary state. That restriction is on purpose of course, since we are trying to provide a different way to serialize the model rather than include arbitrary state from your julia session, so I think it's good to highlight it a little. Also, it's nice to keep this non-breaking by not adjusting the schema.

But I think I'd like to use fetch_weights instead of weights so that we have fetch_weights, load_weights!, and column name weights. I like that better than reusing weights for both the function and the column name because you can then do

weights = fetch_weights(model)

row = ModelRow(; weights, ...)

I'll add an explanation to the readme to say exactly what we mean by weights here.

@ericphanson ericphanson merged commit cd965c8 into main Jul 8, 2021
@ericphanson ericphanson deleted the eph/debug branch July 8, 2021 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants