-
Notifications
You must be signed in to change notification settings - Fork 19.6k
DRAFT: Add custom variable updater. #21225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
6fb5fb0
to
cb79194
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21225 +/- ##
=======================================
Coverage 82.61% 82.62%
=======================================
Files 564 564
Lines 54476 54514 +38
Branches 8470 8475 +5
=======================================
+ Hits 45005 45040 +35
- Misses 7395 7397 +2
- Partials 2076 2077 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Potential replacement for #21196. With this approach, layering is mostly preserved (although in our case the custom updater would still actually contain an optimizer - but at least it doesn't need to in the general case). This change can also generalizes and allows us to remove the |
4826e7e
to
6d1a9a4
Compare
Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization. Can replace the existing `overwrite_with_gradient` attribute on variables, which currently is very application-specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach looks good to me!
@@ -23,5 +24,90 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer): | |||
pass | |||
|
|||
|
|||
@keras_export("keras.optimizers.VariableUpdater") | |||
class VariableUpdater: | |||
"""Allows special handling of variable updates.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is intended to be public in the API then we should provide an explanation and a usage example in the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Yes, I'd like it to be public so I can inherit from it without needing to import internal keras APIs.
Is this location (within optimizer.py
) okay, or would you prefer it somewhere else, (e.g. in its own file)?
Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization.
Can replace the existing
overwrite_with_gradient
attribute on variables, which currently is very application-specific.Eliminates creation of optimizer variables that have custom updaters (including
overwrite_with_gradient
), since those variables are never used and may be wasteful.This is an alternative to #21196. It would allow us to add special-handling for large embedding tables, where we do not want to pass around large gradients for tables that might span multiple devices. Instead the tables are updated in-place using a custom update rule.