Skip to content

Commit c67588f

Browse files
committed
Add raw embedding
1 parent 2dd1736 commit c67588f

File tree

2 files changed

+158
-24
lines changed

2 files changed

+158
-24
lines changed

Diff for: test/test_sparse_embedding.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import itertools
2+
import random
3+
import unittest
4+
from itertools import product
5+
6+
import torch
7+
import torch_xla
8+
import torch.nn as nn
9+
import torch.nn.functional as FA
10+
from absl.testing import parameterized
11+
from torch_xla.experimental import sparse as sparse_xla
12+
13+
index_dtypes = {torch.int32, torch.int64}
14+
float_dtypes = {torch.float32, torch.float16, torch.bfloat16}
15+
16+
device = torch_xla.device()
17+
18+
19+
def functional_embedding(impl):
20+
21+
def f(input):
22+
impl()
23+
24+
@parameterized.parameters(*index_dtypes)
25+
@parameterized.parameters(*float_dtypes)
26+
def test_embedding_sparse_basic(self, index_dtype, float_dtype):
27+
torch_embedding = nn.Embedding(10, 20, sparse=True, dtype=float_dtype)
28+
xla_embedding = sparse_xla.Embedding.from_pretrained(
29+
torch_embedding.weight.clone().detach().to(device), sparse=True)
30+
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=index_dtype)
31+
input_xla = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]],
32+
dtype=index_dtype,
33+
device=device)
34+
ref_out = torch_embedding(input)
35+
xla_out = xla_embedding(input_xla)
36+
self.assertTrue(torch.allclose(ref_out, xla_out))
37+
ref_out.sum().backward()
38+
xla_out.sum().backward()
39+
torch_weight_grad = torch_embedding.weight.grad
40+
xla_weight_grad = xla_embedding.weight.grad

Diff for: torch_xla/experimental/sparse/embedding_bag.py

+118-24
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import warnings
44
import torch
55
from torch import Tensor
6-
from torch.nn.modules.sparse import EmbeddingBag as UpstreamEmbeddingBag
76
from torch.nn import functional as F
7+
from torch.nn.modules import sparse
88
from torch.autograd import Function
99
from .coo import SparseCOOTensor
1010

11+
_NativeEmbeddingModule = sparse.Embedding
12+
_NativeEmbeddingBagModule = sparse.EmbeddingBag
13+
1114

1215
class _EmbeddingBagMode(IntEnum):
1316
SUM = 0
@@ -132,6 +135,31 @@ def _apply_bag_size_backward(mode: _EmbeddingBagMode, index_grad: Tensor,
132135
return index_grad
133136

134137

138+
def _sparse_embedding_backward(grad: Tensor, indices: Tensor, num_weights: int,
139+
padding_idx: int):
140+
if padding_idx != -1:
141+
c = indices != padding_idx
142+
indices = indices.index(c)
143+
grad = grad.index(c)
144+
145+
num_features = grad.size(-1)
146+
weight_size = (num_weights, num_features)
147+
dense_options = _tensor_factory_kwargs(grad)
148+
149+
if grad.numel() == 0:
150+
sparse_index = torch.empty((1, 0),
151+
**_tensor_factory_kwargs(indices),
152+
dtype=torch.int64)
153+
sparse_values = torch.empty((0, num_features), **dense_options),
154+
else:
155+
sparse_index = indices.reshape(1, -1)
156+
sparse_values = grad.reshape((-1, num_features))
157+
158+
grad_weight = SparseCOOTensor(sparse_index, sparse_values, weight_size)
159+
160+
return grad_weight
161+
162+
135163
class SparseEmbeddingBag(Function):
136164
"""Alternate embeding bag implementation. Produces sparse gradients using an XLA compatible tensor subclass
137165
"""
@@ -173,7 +201,7 @@ def setup_context(ctx,
173201
bool, Optional[Tensor], bool, Optional[int]],
174202
outputs: Tuple[Tensor, Tensor, Tensor, Tensor]):
175203
weight, indices, offsets, _, _, mode, _, per_sample_weights, _, padding_idx = inputs
176-
output, offset2bag, bag_size, max_indices = outputs
204+
_, offset2bag, bag_size, max_indices = outputs
177205
# for completness, not technically required as integer dtype will make this automatic
178206
ctx.mark_non_differentiable(offset2bag, bag_size, max_indices)
179207
ctx.set_materialize_grads(False)
@@ -200,27 +228,40 @@ def backward(ctx, grad_: Tensor, *args: Tuple[Optional[Tensor], ...]):
200228
torch._check(ctx.mode == _EmbeddingBagMode.SUM)
201229
index_grad.mul_(per_sample_weights.unsqueeze(1))
202230

203-
if ctx.padding_idx != -1:
204-
c = indices != ctx.padding_idx
205-
indices = indices.index(c)
206-
grad_ = grad_.index(c)
231+
grad_weight = _sparse_embedding_backward(index_grad, indices,
232+
ctx.num_weights, ctx.padding_idx)
233+
return grad_weight, grad_indices, grad_offsets, grad_per_sample_weights
207234

208-
num_features = grad_.size(-1)
209-
weight_size = (ctx.num_weights, num_features)
210-
dense_options = _tensor_factory_kwargs(grad_)
211235

212-
if grad_.numel() == 0:
213-
sparse_index = torch.empty((1, 0),
214-
**_tensor_factory_kwargs(indices),
215-
dtype=torch.int64)
216-
sparse_values = torch.empty((0, num_features), **dense_options),
217-
else:
218-
sparse_index = indices.reshape(1, -1)
219-
sparse_values = grad_.reshape((-1, num_features))
236+
class SparseEmbedding(Function):
220237

221-
grad_weight = SparseCOOTensor(sparse_index, sparse_values, weight_size)
238+
@staticmethod
239+
def forward(input: Tensor,
240+
weight: Tensor,
241+
padding_idx: Optional[int],
242+
scale_grad_by_freq: bool = False,
243+
sparse: bool = False) -> Tensor:
244+
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq,
245+
sparse)
222246

223-
return grad_weight, grad_indices, grad_offsets, grad_per_sample_weights
247+
@staticmethod
248+
def setup_context(ctx, inputs, output):
249+
indices, weight, padding_idx, scale_grad_by_freq, sparse = inputs
250+
ctx.mark_non_differentiable(indices)
251+
ctx.set_materialize_grads(False)
252+
ctx.save_for_backward(indices)
253+
ctx.num_weights = weight.size(0)
254+
ctx.padding_idx = padding_idx
255+
256+
@staticmethod
257+
def backward(ctx, grad_: Tensor):
258+
grad_input = grad_weight = None
259+
indices = ctx.saved_tensors
260+
if grad_ is not None:
261+
grad_weight = _sparse_embedding_backward(grad_, indices, ctx.num_weights,
262+
ctx.padding_idx)
263+
264+
return grad_input, grad_weight
224265

225266

226267
def embedding_bag(
@@ -336,7 +377,44 @@ def embedding_bag(
336377
return ret
337378

338379

339-
class EmbeddingBag(UpstreamEmbeddingBag):
380+
def embedding(input: Tensor,
381+
weight: Tensor,
382+
padding_idx: Optional[int] = None,
383+
max_norm: Optional[float] = None,
384+
norm_type: float = 2.0,
385+
scale_grad_by_freq: bool = False,
386+
sparse: bool = False) -> Tensor:
387+
f"""This is the equivalant API for F.embedding_bag, but where sparse gradients for :attr:`weight` would be produced a custom path is taken.
388+
{F.embedding.__doc__}
389+
"""
390+
if padding_idx is not None:
391+
if padding_idx > 0:
392+
assert padding_idx < weight.size(
393+
0), "Padding_idx must be within num_embeddings"
394+
elif padding_idx < 0:
395+
assert padding_idx >= -weight.size(
396+
0), "Padding_idx must be within num_embeddings"
397+
padding_idx = weight.size(0) + padding_idx
398+
else:
399+
padding_idx = -1
400+
if max_norm is not None:
401+
# Note [embedding_renorm contiguous]
402+
# `embedding_renorm_` will call .contiguous() on input anyways, so we
403+
# call it here and take advantage of the improved locality in the
404+
# `embedding` call below too.
405+
input = input.contiguous()
406+
# Note [embedding_renorm set_grad_enabled]
407+
# Note [Modified from F.embedding]
408+
# We do't need to support torch-script so we can just do what they say this is equivalant to, original below
409+
#_no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
410+
with torch.no_grad():
411+
torch.embedding_renorm_(weight, input, max_norm, norm_type)
412+
413+
impl = SparseEmbedding.apply if sparse and weight.requires_grad else torch.embedding
414+
return impl(weight, input, padding_idx, scale_grad_by_freq, sparse)
415+
416+
417+
class EmbeddingBag(_NativeEmbeddingBagModule):
340418
f"""Torch-XLA backend compatible Embedding Bag
341419
342420
This implementation modifies the forward function when :attr:`sparse` is ``True``.
@@ -345,12 +423,9 @@ class EmbeddingBag(UpstreamEmbeddingBag):
345423
``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical
346424
to :class:`~torch.nn.EmbedingBag`.
347425
348-
{UpstreamEmbeddingBag.__doc__}
426+
{_NativeEmbeddingBagModule.__doc__}
349427
"""
350428

351-
def __init__(self, *args, **kwargs):
352-
super().__init__(*args, **kwargs)
353-
354429
def forward(self,
355430
input: Tensor,
356431
offsets: Optional[Tensor] = None,
@@ -362,3 +437,22 @@ def forward(self,
362437
self.sparse, per_sample_weights,
363438
self.include_last_offset, self.padding_idx)
364439
return super().forward(input, offsets, per_sample_weights)
440+
441+
442+
class Embedding(_NativeEmbeddingModule):
443+
f"""Torch-XLA backend compatible Embedding
444+
445+
This implementation modifies the forward function when :attr:`sparse` is ``True``.
446+
The alternate implementation has a backward which will produce sparse
447+
gradients w.r.t :attr:`weight` as a tensor subclass implemention of the
448+
``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical
449+
to :class:`~torch.nn.Embeding`.
450+
451+
{_NativeEmbeddingBagModule.__doc__}
452+
"""
453+
454+
def forward(self, input: Tensor) -> Tensor:
455+
if self.weight.requires_grad() and self.sparse:
456+
return embedding(input, self.weight, self.padding_idx, self.max_norm,
457+
self.norm_type, self.scale_grad_by_freq, self.sparse)
458+
return super().forward(input)

0 commit comments

Comments
 (0)