Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add Experimental limited sparse embedding bag #8905

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

amjames
Copy link
Collaborator

@amjames amjames commented Mar 28, 2025

Users of torch_xla encounter an issue when using the sparse=True option with the Embedding or EmbeddingBag modules.

The gradient for weight is created as a sparse tensor and there is no dispatch registered for the combination of sparse creation APIs w/ the XLA key, or the Sparse functionality key and the XLA backed key used in conjunction.

This is a workaround that can be removed, ported to C++, or extended later:

  • SparseCOOTensor: a tensor subclass implementing the optimization and semantics of upstream SparseTensor. it is Composabile with the XLA device.
  • drop in replacements for F.embedding F.embedding_bag, nn.Embedding, and nn.EmbeddingBag which forward to a custom implementation of the backward and produce the above tensor subclass rather than a native torch sparse tensor.

The tensor subclass may have component tensors indices and values which have xla device without issue.

fixes #8719

@amjames amjames requested a review from ysiraichi March 28, 2025 22:23
@amjames amjames force-pushed the amjames/sparse_embedding_bag branch from aacfd1b to c67588f Compare March 28, 2025 22:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for enabling sparse gradients in EmbeddingBag
2 participants