3
3
import warnings
4
4
import torch
5
5
from torch import Tensor
6
- from torch .nn .modules .sparse import EmbeddingBag as UpstreamEmbeddingBag
7
6
from torch .nn import functional as F
7
+ from torch .nn .modules import sparse
8
8
from torch .autograd import Function
9
9
from .coo import SparseCOOTensor
10
10
11
+ _NativeEmbeddingModule = sparse .Embedding
12
+ _NativeEmbeddingBagModule = sparse .EmbeddingBag
13
+
11
14
12
15
class _EmbeddingBagMode (IntEnum ):
13
16
SUM = 0
@@ -132,6 +135,31 @@ def _apply_bag_size_backward(mode: _EmbeddingBagMode, index_grad: Tensor,
132
135
return index_grad
133
136
134
137
138
+ def _sparse_embedding_backward (grad : Tensor , indices : Tensor , num_weights : int ,
139
+ padding_idx : int ):
140
+ if padding_idx != - 1 :
141
+ c = indices != padding_idx
142
+ indices = indices .index (c )
143
+ grad = grad .index (c )
144
+
145
+ num_features = grad .size (- 1 )
146
+ weight_size = (num_weights , num_features )
147
+ dense_options = _tensor_factory_kwargs (grad )
148
+
149
+ if grad .numel () == 0 :
150
+ sparse_index = torch .empty ((1 , 0 ),
151
+ ** _tensor_factory_kwargs (indices ),
152
+ dtype = torch .int64 )
153
+ sparse_values = torch .empty ((0 , num_features ), ** dense_options ),
154
+ else :
155
+ sparse_index = indices .reshape (1 , - 1 )
156
+ sparse_values = grad .reshape ((- 1 , num_features ))
157
+
158
+ grad_weight = SparseCOOTensor (sparse_index , sparse_values , weight_size )
159
+
160
+ return grad_weight
161
+
162
+
135
163
class SparseEmbeddingBag (Function ):
136
164
"""Alternate embeding bag implementation. Produces sparse gradients using an XLA compatible tensor subclass
137
165
"""
@@ -173,7 +201,7 @@ def setup_context(ctx,
173
201
bool , Optional [Tensor ], bool , Optional [int ]],
174
202
outputs : Tuple [Tensor , Tensor , Tensor , Tensor ]):
175
203
weight , indices , offsets , _ , _ , mode , _ , per_sample_weights , _ , padding_idx = inputs
176
- output , offset2bag , bag_size , max_indices = outputs
204
+ _ , offset2bag , bag_size , max_indices = outputs
177
205
# for completness, not technically required as integer dtype will make this automatic
178
206
ctx .mark_non_differentiable (offset2bag , bag_size , max_indices )
179
207
ctx .set_materialize_grads (False )
@@ -200,27 +228,40 @@ def backward(ctx, grad_: Tensor, *args: Tuple[Optional[Tensor], ...]):
200
228
torch ._check (ctx .mode == _EmbeddingBagMode .SUM )
201
229
index_grad .mul_ (per_sample_weights .unsqueeze (1 ))
202
230
203
- if ctx .padding_idx != - 1 :
204
- c = indices != ctx .padding_idx
205
- indices = indices .index (c )
206
- grad_ = grad_ .index (c )
231
+ grad_weight = _sparse_embedding_backward (index_grad , indices ,
232
+ ctx .num_weights , ctx .padding_idx )
233
+ return grad_weight , grad_indices , grad_offsets , grad_per_sample_weights
207
234
208
- num_features = grad_ .size (- 1 )
209
- weight_size = (ctx .num_weights , num_features )
210
- dense_options = _tensor_factory_kwargs (grad_ )
211
235
212
- if grad_ .numel () == 0 :
213
- sparse_index = torch .empty ((1 , 0 ),
214
- ** _tensor_factory_kwargs (indices ),
215
- dtype = torch .int64 )
216
- sparse_values = torch .empty ((0 , num_features ), ** dense_options ),
217
- else :
218
- sparse_index = indices .reshape (1 , - 1 )
219
- sparse_values = grad_ .reshape ((- 1 , num_features ))
236
+ class SparseEmbedding (Function ):
220
237
221
- grad_weight = SparseCOOTensor (sparse_index , sparse_values , weight_size )
238
+ @staticmethod
239
+ def forward (input : Tensor ,
240
+ weight : Tensor ,
241
+ padding_idx : Optional [int ],
242
+ scale_grad_by_freq : bool = False ,
243
+ sparse : bool = False ) -> Tensor :
244
+ return torch .embedding (weight , input , padding_idx , scale_grad_by_freq ,
245
+ sparse )
222
246
223
- return grad_weight , grad_indices , grad_offsets , grad_per_sample_weights
247
+ @staticmethod
248
+ def setup_context (ctx , inputs , output ):
249
+ indices , weight , padding_idx , scale_grad_by_freq , sparse = inputs
250
+ ctx .mark_non_differentiable (indices )
251
+ ctx .set_materialize_grads (False )
252
+ ctx .save_for_backward (indices )
253
+ ctx .num_weights = weight .size (0 )
254
+ ctx .padding_idx = padding_idx
255
+
256
+ @staticmethod
257
+ def backward (ctx , grad_ : Tensor ):
258
+ grad_input = grad_weight = None
259
+ indices = ctx .saved_tensors
260
+ if grad_ is not None :
261
+ grad_weight = _sparse_embedding_backward (grad_ , indices , ctx .num_weights ,
262
+ ctx .padding_idx )
263
+
264
+ return grad_input , grad_weight
224
265
225
266
226
267
def embedding_bag (
@@ -336,7 +377,44 @@ def embedding_bag(
336
377
return ret
337
378
338
379
339
- class EmbeddingBag (UpstreamEmbeddingBag ):
380
+ def embedding (input : Tensor ,
381
+ weight : Tensor ,
382
+ padding_idx : Optional [int ] = None ,
383
+ max_norm : Optional [float ] = None ,
384
+ norm_type : float = 2.0 ,
385
+ scale_grad_by_freq : bool = False ,
386
+ sparse : bool = False ) -> Tensor :
387
+ f"""This is the equivalant API for F.embedding_bag, but where sparse gradients for :attr:`weight` would be produced a custom path is taken.
388
+ { F .embedding .__doc__ }
389
+ """
390
+ if padding_idx is not None :
391
+ if padding_idx > 0 :
392
+ assert padding_idx < weight .size (
393
+ 0 ), "Padding_idx must be within num_embeddings"
394
+ elif padding_idx < 0 :
395
+ assert padding_idx >= - weight .size (
396
+ 0 ), "Padding_idx must be within num_embeddings"
397
+ padding_idx = weight .size (0 ) + padding_idx
398
+ else :
399
+ padding_idx = - 1
400
+ if max_norm is not None :
401
+ # Note [embedding_renorm contiguous]
402
+ # `embedding_renorm_` will call .contiguous() on input anyways, so we
403
+ # call it here and take advantage of the improved locality in the
404
+ # `embedding` call below too.
405
+ input = input .contiguous ()
406
+ # Note [embedding_renorm set_grad_enabled]
407
+ # Note [Modified from F.embedding]
408
+ # We do't need to support torch-script so we can just do what they say this is equivalant to, original below
409
+ #_no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
410
+ with torch .no_grad ():
411
+ torch .embedding_renorm_ (weight , input , max_norm , norm_type )
412
+
413
+ impl = SparseEmbedding .apply if sparse and weight .requires_grad else torch .embedding
414
+ return impl (weight , input , padding_idx , scale_grad_by_freq , sparse )
415
+
416
+
417
+ class EmbeddingBag (_NativeEmbeddingBagModule ):
340
418
f"""Torch-XLA backend compatible Embedding Bag
341
419
342
420
This implementation modifies the forward function when :attr:`sparse` is ``True``.
@@ -345,12 +423,9 @@ class EmbeddingBag(UpstreamEmbeddingBag):
345
423
``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical
346
424
to :class:`~torch.nn.EmbedingBag`.
347
425
348
- { UpstreamEmbeddingBag .__doc__ }
426
+ { _NativeEmbeddingBagModule .__doc__ }
349
427
"""
350
428
351
- def __init__ (self , * args , ** kwargs ):
352
- super ().__init__ (* args , ** kwargs )
353
-
354
429
def forward (self ,
355
430
input : Tensor ,
356
431
offsets : Optional [Tensor ] = None ,
@@ -362,3 +437,22 @@ def forward(self,
362
437
self .sparse , per_sample_weights ,
363
438
self .include_last_offset , self .padding_idx )
364
439
return super ().forward (input , offsets , per_sample_weights )
440
+
441
+
442
+ class Embedding (_NativeEmbeddingModule ):
443
+ f"""Torch-XLA backend compatible Embedding
444
+
445
+ This implementation modifies the forward function when :attr:`sparse` is ``True``.
446
+ The alternate implementation has a backward which will produce sparse
447
+ gradients w.r.t :attr:`weight` as a tensor subclass implemention of the
448
+ ``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical
449
+ to :class:`~torch.nn.Embeding`.
450
+
451
+ { _NativeEmbeddingBagModule .__doc__ }
452
+ """
453
+
454
+ def forward (self , input : Tensor ) -> Tensor :
455
+ if self .weight .requires_grad () and self .sparse :
456
+ return embedding (input , self .weight , self .padding_idx , self .max_norm ,
457
+ self .norm_type , self .scale_grad_by_freq , self .sparse )
458
+ return super ().forward (input )
0 commit comments