-
-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a structural loadparams!
#1875
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
cd06023
Add initial implementation
darsnack 99a18ec
Add more tests
darsnack 492c34e
Fix typo in tests
darsnack dee5842
Refactor to allow better support for loading errors with custom models
darsnack b4fe66b
Add documentation for `loadmodel!`
darsnack 0790f24
Spacing in docs
darsnack a155a44
Fix tests
darsnack fbc9faf
Add NEWS entry for `loadmodel!`
darsnack b2a2664
Better docs
darsnack a6cdfdd
Refactor `loadmodel!` to use a custom recursion instead of `fmap`. Ad…
darsnack 29662b2
Add better support for `loadmodel!` w/ tied parameters and address so…
darsnack c831955
Combine `_bool_tie_check` and `_tie_check`.
darsnack 0d55c00
Remove `_parent`
darsnack 3c82471
Apply suggestions from code review
darsnack 9b06730
Clarify docstrings as per review
darsnack cba299b
More clarification
darsnack a59f688
Use extended help for `loadmodel!` docstring
darsnack d08072a
Updated docstring examples for `loadmodel!` to cover more cases
darsnack 6b533b8
Don't do `loadleaf!` docstring
darsnack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
loadleaf!(dst, src, err) = dst | ||
loadleaf!(dst::AbstractArray, src, err) = | ||
error("Tried to copy $src into an array destination; this is not allowed.") | ||
loadleaf!(dst, src::AbstractArray, err) = | ||
error("Tried to copy an array to $dst; this is not allowed.") | ||
function loadleaf!(dst::AbstractArray, src::Bool, err) | ||
if iszero(src) | ||
dst .= src | ||
else | ||
error("Cannot copy boolean parameter == true to non-zero parameter.") | ||
end | ||
return dst | ||
end | ||
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst : | ||
error("Cannot copy non-zero parameter to boolean parameter == true.") | ||
function loadleaf!(dst::AbstractArray, src::AbstractArray, err) | ||
(size(dst) == size(src)) || throw(err) | ||
copyto!(dst, src) | ||
end | ||
|
||
_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) || | ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") | ||
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) || | ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") | ||
_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) || | ||
error("Encountered tied destination parameters with untied and mismatched sources.") | ||
_tie_check(dst, src) = true | ||
|
||
_bool_tie_check(dst, src) = true | ||
|
||
""" | ||
loadmodel!(dst, src) | ||
|
||
Copy all the parameters (trainable and non-trainable) from `src` into `dst`. | ||
|
||
Recursively walks `dst` and `src` together using [`Functors.children`](@ref), | ||
and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch. | ||
Non-array elements (such as activation functions) are not copied and need not match. | ||
Zero bias vectors and `bias=false` are considered equivalent | ||
(see extended help for more details). | ||
|
||
# Examples | ||
```julia | ||
julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0])) | ||
Chain( | ||
Dense(5 => 2, tanh), # 12 parameters | ||
Dense(2 => 1), # 3 parameters | ||
) # Total: 4 arrays, 15 parameters, 316 bytes. | ||
|
||
julia> dst[1].weight ≈ ones(2, 5) # by construction | ||
true | ||
|
||
julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false)); | ||
|
||
julia> Flux.loadmodel!(dst, src); | ||
|
||
julia> dst[1].weight ≈ ones(2, 5) # values changed | ||
false | ||
|
||
julia> iszero(dst[2].bias) | ||
true | ||
``` | ||
|
||
# Extended help | ||
|
||
Throws an error when: | ||
- `dst` and `src` do not share the same fields (at any level) | ||
- the sizes of leaf nodes are mismatched between `dst` and `src` | ||
- copying non-array values to/from an array parameter | ||
(except inactive parameters described below) | ||
- `dst` is a "tied" parameter (i.e. refers to another parameter) and | ||
loaded into multiple times with mismatched source values | ||
|
||
Inactive parameters can be encoded by using the boolean value `false` instead of an array. | ||
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); | ||
however, attempting to copy a non-zero array to an inactive parameter will throw an error. | ||
Likewise, copying a `src` value of `false` to any `dst` array is valid, | ||
but copying a `src` value of `true` will error. | ||
""" | ||
function loadmodel!(dst, src; cache = Base.IdSet()) | ||
ldsts, _ = functor(dst) | ||
lsrcs, _ = functor(src) | ||
(keys(ldsts) == keys(lsrcs)) || | ||
throw(ArgumentError("Tried to load $src into $dst but the structures do not match.")) | ||
|
||
err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") | ||
foreach(ldsts, lsrcs) do ldst, lsrc | ||
if ldst in cache # we already loaded this parameter before | ||
_tie_check(ldst, lsrc) && return ldst | ||
elseif Functors.isleaf(ldst) # our first time loading this leaf | ||
push!(cache, ldst) | ||
loadleaf!(ldst, lsrc, err) | ||
else # this isn't a leaf | ||
loadmodel!(ldst, lsrc; cache = cache) | ||
end | ||
end | ||
|
||
return dst | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if there should be more errors here:
I can imagine that allowing
src
to havenothing
means "don't change the existing weight". Which is what #1875 (comment) would generate. But it may also make truncations of branches not just leaves, which aren't allowed right now, but would I think be easy: