Skip to content

DRAFT: Add custom variable updater. #21225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

cantonios
Copy link
Contributor

Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization.

Can replace the existing overwrite_with_gradient attribute on variables, which currently is very application-specific.

Eliminates creation of optimizer variables that have custom updaters (including overwrite_with_gradient), since those variables are never used and may be wasteful.

This is an alternative to #21196. It would allow us to add special-handling for large embedding tables, where we do not want to pass around large gradients for tables that might span multiple devices. Instead the tables are updated in-place using a custom update rule.

@cantonios cantonios force-pushed the updater branch 2 times, most recently from 6fb5fb0 to cb79194 Compare April 29, 2025 17:35
@codecov-commenter
Copy link

codecov-commenter commented Apr 29, 2025

Codecov Report

Attention: Patch coverage is 94.44444% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.62%. Comparing base (37eacb0) to head (a53a152).

Files with missing lines Patch % Lines
keras/src/backend/common/variables.py 83.33% 1 Missing and 1 partial ⚠️
keras/src/optimizers/optimizer.py 96.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21225   +/-   ##
=======================================
  Coverage   82.61%   82.62%           
=======================================
  Files         564      564           
  Lines       54476    54514   +38     
  Branches     8470     8475    +5     
=======================================
+ Hits        45005    45040   +35     
- Misses       7395     7397    +2     
- Partials     2076     2077    +1     
Flag Coverage Δ
keras 82.43% <94.44%> (+<0.01%) ⬆️
keras-jax 63.73% <94.44%> (+0.02%) ⬆️
keras-numpy 58.85% <85.18%> (+0.01%) ⬆️
keras-openvino 32.99% <35.18%> (+<0.01%) ⬆️
keras-tensorflow 64.13% <94.44%> (+0.01%) ⬆️
keras-torch 63.81% <94.44%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cantonios
Copy link
Contributor Author

cantonios commented Apr 29, 2025

@hertschuh @fchollet

Potential replacement for #21196. With this approach, layering is mostly preserved (although in our case the custom updater would still actually contain an optimizer - but at least it doesn't need to in the general case).

This change can also generalizes and allows us to remove the overwrite_with_gradient attribute entirely, instead using a custom updater. The existing overwrite_with_variable attribute is very specific to scale factors in fp8 quantization, using a max when accumulating gradients.

@cantonios cantonios force-pushed the updater branch 5 times, most recently from 4826e7e to 6d1a9a4 Compare April 30, 2025 18:21
Allows customization for how variables are updated by the optimizer.
The base optimizer simply defers to the update handler to do the
update, allowing full customization.

Can replace the existing `overwrite_with_gradient` attribute on
variables, which currently is very application-specific.
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach looks good to me!

@@ -23,5 +24,90 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
pass


@keras_export("keras.optimizers.VariableUpdater")
class VariableUpdater:
"""Allows special handling of variable updates."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is intended to be public in the API then we should provide an explanation and a usage example in the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Yes, I'd like it to be public so I can inherit from it without needing to import internal keras APIs.

Is this location (within optimizer.py) okay, or would you prefer it somewhere else, (e.g. in its own file)?

@gbaned gbaned added this to PR Queue May 2, 2025
@github-project-automation github-project-automation bot moved this to Assigned Reviewer in PR Queue May 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

4 participants