From bbb61aa5ebcc9c883254460faeb1c96ca4b13543 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Thu, 27 Mar 2025 23:11:29 +0000 Subject: [PATCH 1/4] Add sparse COO tensor subclass impl and limited ops --- torch_xla/experimental/sparse/__init__.py | 3 + torch_xla/experimental/sparse/_coo_ops.py | 158 ++++++++++++++++++++++ torch_xla/experimental/sparse/coo.py | 93 +++++++++++++ 3 files changed, 254 insertions(+) create mode 100644 torch_xla/experimental/sparse/__init__.py create mode 100644 torch_xla/experimental/sparse/_coo_ops.py create mode 100644 torch_xla/experimental/sparse/coo.py diff --git a/torch_xla/experimental/sparse/__init__.py b/torch_xla/experimental/sparse/__init__.py new file mode 100644 index 00000000000..7074143f22d --- /dev/null +++ b/torch_xla/experimental/sparse/__init__.py @@ -0,0 +1,3 @@ +from .coo import SparseCOOTensor + +__all__ = ["SparseCOOTensor"] diff --git a/torch_xla/experimental/sparse/_coo_ops.py b/torch_xla/experimental/sparse/_coo_ops.py new file mode 100644 index 00000000000..f032bab99b4 --- /dev/null +++ b/torch_xla/experimental/sparse/_coo_ops.py @@ -0,0 +1,158 @@ +from typing import Tuple +import torch + + +def _flatten_indices(inds, shape): + # Flatted N-D indices to 1-D indices + flat_indices = inds.new_zeros(inds.size(1)) + for d, sz in enumerate(shape): + flat_indices.mul_(sz) + flat_indices.add_(inds[d]) + return flat_indices + + +def _check_no_kwargs(kwargs, name): + torch._check( + kwargs is None or len(kwargs) == 0, + f"{name}: expected no kwargs, got {kwargs}") + + +def _check_strided(arg, f_name, arg_name): + torch._check(arg.layout == torch.strided, + f"{f_name}: Expected strided {arg_name}, not {arg.layout}") + + +def _check_sparse(arg, f_name, arg_name): + torch._check(arg.layout == torch.sparse_coo, + f"{f_name}: Expected sparse {arg_name}, not {arg.layout}") + + +def coo_sparse_mask(args=(), kwargs=None): + """ + x.sparse_mask(y) + create a new sparse tensor from strided x and sparse y. + Result holds values from x with sparsity pattern of Y + """ + torch._check( + len(args) == 2, f"sparse_mask: Expected two arguments, got {len(args)}") + self, mask = args + _check_strided(self, "sparse_mask", "self") + _check_sparse(mask, "sparse_mask", "mask") + torch._check( + mask.size() == self.size(), + f"sparse_mask: expected mask and self to have the same shape (self: {self.size()}, mask: {mask.size()})" + ) + _check_no_kwargs(kwargs, "sparse_mask") + mask_indices = mask.indices() + flat_indices = _flatten_indices(mask_indices, mask.size()) + values = self.view(-1)[flat_indices] + # acess the subclass ctor without a circular import! + return mask.__class__( + mask_indices, + values, + size=self.size(), + dtype=values.dtype(), + device=values.device(), + requires_grad=values.requires_grad(), + ) + + +def coo_resize_as_(args, kwargs=None): + # The aten impl supports more cases, but the only one we need is the 0-nnz case where we can freely modify the sparse/dense dims at will. + # We have also not implemented dense dim consideration anywhere as they are not needed. + + torch._check(len(args) == 2, f"resize_as_: expected two args not {len(args)}") + self, other = args + _check_sparse(self, "resize_as_", "self") + _check_sparse(other, "resize_as_", "other") + torch._check( + self.nnz() == 0, + "resize_as_: resizing a sparse tensor with nnz != 0 is not supported") + _check_no_kwargs(kwargs, "resize_as_") + + new_nnz = other.nnz() + values_shape = (new_nnz,) + index_shape = (len(other.shape), new_nnz) + + # attribute access to modify the tensor in-place, accessors return a clone + self._v.resize_(values_shape) + self._i.resize_(index_shape) + return self + + +def coo_coalesce(args, kwargs=None): + _check_no_kwargs(kwargs, "coalesce") + self = args[0] + _check_sparse(self, "coalesce", "self") + if self._is_coalesced: + return self + if self.nnz() < 2: + self._is_coalesced = True + return self + + indices = self._i + values = self._v + sparse_dim = len(self.shape) + nnz = self.nnz() + indices_scalar = _flatten_indices(indices, self.shape) + + new_indices = torch.empty_like(indices) + new_values = torch.empty_like(values) + indices_buffer, indices_perm = indices_scalar.sort(0) + i = 0 + for j in range(nnz): + pos = indices_perm[j] + curr = indices_buffer[j] + if pos != curr: + i += 1 + for d in range(sparse_dim): + new_indices[d][i] = indices[d][pos] + if values.numel() > 0: + new_values[i] += values[pos] + + return self.__class__( + new_indices, + new_values, + self.shape, + dtype=self.dtype, + device=self.device, + requires_grad=self.requires_grad()) + + +def coo_add_(args=(), kwargs=None): + alpha = kwargs.pop('alpha', 1) + torch._check(len(args) == 2, f"add_: expected two operands, got {len(args)}") + self, other = args + _check_sparse(self, "self", "add_") + _check_sparse(other, "other", "add_") + + # todo: I think this is going to satisfy the needs for sparseADAM opt + torch._check( + self._i.equal(other._i).all(), + "add_: Operation is supported when operands have the same sparsity pattern." + ) + self._v += alpha * other._v + return self + + +def coo_indices(args=(), kwargs=None): + torch._check(len(args) == 1, "indices: expected one argument") + self = args[0] + _check_sparse(self, "indices", "self") + torch._check(self._coalesced, "indices: input must be coalesced") + return self._i.clone() + + +def coo_values(args=(), kwargs=None): + torch._check(len(args) == 1, "values: expected one argument") + self = args[0] + _check_sparse(self, "values", "self") + torch._check(self._coalesced, "values: input must be coalesced") + return self._v.clone() + + +def coo_nnz(args=(), kwargs=None): + torch._check(len(args) == 1, "values: expected one argument") + self = args[0] + _check_sparse(self, "nnz", "self") + return self._v.numel() diff --git a/torch_xla/experimental/sparse/coo.py b/torch_xla/experimental/sparse/coo.py new file mode 100644 index 00000000000..8c246b9ce86 --- /dev/null +++ b/torch_xla/experimental/sparse/coo.py @@ -0,0 +1,93 @@ +from typing import Any, Callable, ClassVar, Dict, Iterable, Tuple +import torch + +from torch._ops import OpOverload + +from _coo_ops import ( + coo_sparse_mask, + coo_resize_as_, + coo_coalesce, + coo_add_, + coo_indices, + coo_values, + coo_nnz, +) + +aten = torch.ops.aten + + +class SparseCOOTensor(torch.Tensor): + + DISPATCH_TABLE: ClassVar[Dict[OpOverload, Callable]] = {} + _v: torch.Tensor + _i: torch.Tensor + + def __new__(cls, + indices: torch.Tensor, + values: torch.Tensor, + size: Iterable[int], + *, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + requires_grad: bool | None = None): + device = device if device is not None else values.device + dtype = dtype if dtype is not None else values.dtype + requires_grad = requires_grad if requires_grad is not None else values.requires_grad + assert (device == values.device and device == indices.device and + dtype == values.dtype and requires_grad == values.requires_grad) + + res = torch.Tensor._make_wrapper_subclass( + cls, + size=size, + strides=None, + layout=torch.sparse_coo, + dtype=dtype, + device=device, + requires_grad=requires_grad) + res._v = values + res._i = indices + return res + + @classmethod + def __torch_dispatch__(cls, + func: OpOverload, + types: Tuple, + args: Tuple = (), + kwargs: Dict[str, Any] | None = None) -> torch.Tensor: + impl = cls._match_func_key_for_dispatch(func) + return impl(args, kwargs) + + @classmethod + def _load_sparse_dispatch(cls): + if getattr(cls, "DISPATCH_TABLE", None) is None: + cls.DISPATCH_TABLE = { + aten.values: coo_values, + aten.indices: coo_indices, + aten.coalesce: coo_coalesce, + aten.add_: coo_add_, + aten.sparse_mask: coo_sparse_mask, + aten.resize_as_: coo_resize_as_, + aten.nnz: coo_nnz, + } + + @classmethod + def _match_func_key_for_dispatch(cls, func): + cls.load_sparse_dispatch() + impl = cls.DISPATCH_TABLE.get(func, None) + if impl is None: + impl = cls.DISPATCH_TABLE.get(func._overloadpacket, + cls._make_not_implemented(func)) + return impl + + @classmethod + def _make_not_implemented(cls, func): + + def impl(args=(), kwargs=None): + raise NotImplemented( + f"Sparse support via {cls.__name__} is limited to a very narrow " + f"scope. The {func.__name__} operator is not currently supported. It" + " is likely you are using this outside of its intended purpose " + "if you see this message. Supported operators are: ", + f"{list(k.__name__ for k in cls.DISPATCH_TABLE.keys())}") + + return impl From a524402f764a1dcc29c2c98379bb9de660b57f57 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Thu, 27 Mar 2025 23:13:19 +0000 Subject: [PATCH 2/4] create embeding bag module --- torch_xla/experimental/sparse/embedding_bag.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 torch_xla/experimental/sparse/embedding_bag.py diff --git a/torch_xla/experimental/sparse/embedding_bag.py b/torch_xla/experimental/sparse/embedding_bag.py new file mode 100644 index 00000000000..e69de29bb2d From 2dd1736bea229917d931c6f9914ff25f4dbba176 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Fri, 28 Mar 2025 19:29:06 +0000 Subject: [PATCH 3/4] Add embedding bag implementations --- torch_xla/experimental/sparse/__init__.py | 3 +- torch_xla/experimental/sparse/coo.py | 2 + .../experimental/sparse/embedding_bag.py | 364 ++++++++++++++++++ 3 files changed, 368 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/sparse/__init__.py b/torch_xla/experimental/sparse/__init__.py index 7074143f22d..0237cd09c10 100644 --- a/torch_xla/experimental/sparse/__init__.py +++ b/torch_xla/experimental/sparse/__init__.py @@ -1,3 +1,4 @@ from .coo import SparseCOOTensor +from .embedding_bag import EmbeddingBag, embedding_bag -__all__ = ["SparseCOOTensor"] +__all__ = ["SparseCOOTensor", "EmbeddingBag", "embedding_bag"] diff --git a/torch_xla/experimental/sparse/coo.py b/torch_xla/experimental/sparse/coo.py index 8c246b9ce86..dc0aea99490 100644 --- a/torch_xla/experimental/sparse/coo.py +++ b/torch_xla/experimental/sparse/coo.py @@ -13,6 +13,8 @@ coo_nnz, ) +__all__ = ["SparseCOOTensor"] + aten = torch.ops.aten diff --git a/torch_xla/experimental/sparse/embedding_bag.py b/torch_xla/experimental/sparse/embedding_bag.py index e69de29bb2d..028b34fc924 100644 --- a/torch_xla/experimental/sparse/embedding_bag.py +++ b/torch_xla/experimental/sparse/embedding_bag.py @@ -0,0 +1,364 @@ +from enum import IntEnum +from typing import Optional, Tuple +import warnings +import torch +from torch import Tensor +from torch.nn.modules.sparse import EmbeddingBag as UpstreamEmbeddingBag +from torch.nn import functional as F +from torch.autograd import Function +from .coo import SparseCOOTensor + + +class _EmbeddingBagMode(IntEnum): + SUM = 0 + MEAN = 1 + MAX = 2 + + +# def _promote_index_offset_dtype(indices: Tensor, offsets: Tensor): +# common_dtype = torch.promote_types(indices.dtype, offsets.dtype) +# indices = indices.to(common_dtype) +# offsets = offsets.to(common_dtype) +# return indices, offsets + +# def _check_type(func_name, arg_name, arg, dtypes): +# torch._check( +# arg.dtype in dtypes, +# f"{func_name}: {arg_name} must have dtype from {dtypes}, but found {arg.dtype}" +# ) + +# def _check_embedding_bag_args(weight: Tensor, indices: Tensor, offsets: Tensor, +# mode: _EmbeddingBagMode, +# per_sample_weights: Optional[Tensor], +# include_last_offset: bool): +# _check_type("embedding_bag", "indices", indices, (torch.int32, torch.int64)) +# _check_type("embedding_bag", "offsets", offsets, (torch.int32, torch.int64)) +# _check_type("embedding_bag", "weight", weight, +# (torch.float16, torch.float32, torch.bfloat16, torch.float64)) +# if offsets.size(0) > 0: +# torch.check(offsets[0] == 0, f"offsets[0] must be 0. Got {offsets[0]}") +# torch.check( +# offsets[-1] < indices.size(0), +# f"offsets[-1] can't be greater than input length {indices.size(0)}, but got {offsets[-1]}" +# ) + +# if per_sample_weights is not None: +# torch._check( +# mode == _EmbeddingBagMode.SUM, +# "embeding_bag: per_sample_weights only supported with mode='sum'") +# _check_type("embeding_bag", "per_sample_weights", per_sample_weights, +# (weight.dtype,)) +# torch._check(per_sample_weights.dim() == 1) +# torch._check(per_sample_weights.numel() == indices.numel()) + +# if include_last_offset: +# torch._check( +# offsets.size(0) >= 1, +# "include_last_offset: number of offsets should be at least 1") + +# def _is_fast_path(weight: Tensor, per_sample_weights: Optional[Tensor], +# output: Tensor, padding_idx: int) -> bool: +# is_fast = weight.dtype in (torch.float32, torch.float16, torch.bfloat16) +# is_fast &= weight.stride(1) == 1 +# is_fast &= output.stride(1) == 1 +# is_fast &= padding_idx < 0 +# if is_fast and per_sample_weights is not None: +# is_fast &= per_sample_weights.stride(1) == 1 +# return is_fast + +# def _make_offset2bag(output: Tensor, weight: Tensor, indices: Tensor, +# offsets: Tensor, mode: _EmbeddingBagMode, +# per_sample_weights: Optional[Tensor], +# padding_idx: int) -> Tensor: +# fast_path_sum = _is_fast_path(weight, per_sample_weights, output, padding_idx) +# if (mode == _EmbeddingBagMode.MEAN or mode == _EmbeddingBagMode.MAX or +# not fast_path_sum): +# offsets_size = offsets.size(0) +# offset2bag = torch.zeros(indices.size(0), **_tensor_factory_kwargs(offsets)) +# include_last_offset = output.size(0) == offsets_size - 1 +# if include_last_offset: +# _offsets = offsets.narrow(0, 0, offsets_size - 1) +# else: +# _offsets = offsets +# output.zero_() + +# offset2bag.index_add_(0, _offsets, torch.ones_like(_offsets)) +# offset2bag[0] -= 1 +# offset2bag = offset2bag.cumsum(0, offset2bag.dtype) +# return offset2bag +# else: +# # we don't use this on the fast path so never initalize it +# return torch.empty(0, **_tensor_factory_kwargs(offsets)) + +# def _make_bag_size(offsets: Tensor, indices: Tensor, mode: _EmbeddingBagMode, +# include_last_offset: bool) -> Tensor: +# last_offset_factor = 1 if include_last_offset else 0 +# num_bags = offsets.size(0) - last_offset_factor +# bag_size = torch.empty(num_bags, **_tensor_factory_kwargs(offsets)) +# if num_bags != 1: +# bag_size[0:bag_size.size(0) - +# 1] = offsets[1:num_bags] - offsets[0:num_bags - 1] +# if num_bags > 0: +# bag_size[-1] = indices.size(0) - offsets[num_bags - 1] +# return bag_size + +# def _make_max_indices(weight: Tensor, indices: Tensor, offsets: Tensor, +# bag_size: Tensor, mode: _EmbeddingBagMode, +# include_last_offset: bool) -> Tensor: +# num_bags = offsets.size(0) +# if mode == _EmbeddingBagMode.MAX: +# if include_last_offset: +# torch._check(num_bags >= 1, +# "include_last_offset: num_bags should be at least 1") +# _tensor_factory_kwargs(bag_size) +# pass + + +def _tensor_factory_kwargs(t): + return { + k: getattr(t, k) for k in { + 'dtype', 'device', 'layout', 'requires_grad', 'memory_format', + 'pin_memory' + } + } + + +def _apply_bag_size_backward(mode: _EmbeddingBagMode, index_grad: Tensor, + offset2bag: Tensor, bag_size: Tensor) -> Tensor: + if mode == _EmbeddingBagMode.MEAN: + index_grad *= (1 / bag_size.to( + **_tensor_factory_kwargs(index_grad)).unsqueeze(1).index_select( + 0, offset2bag)) + return index_grad + + +class SparseEmbeddingBag(Function): + """Alternate embeding bag implementation. Produces sparse gradients using an XLA compatible tensor subclass + """ + + @staticmethod + def forward( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + max_norm: Optional[float] = None, + scale_grad_by_freq: bool = False, + mode: _EmbeddingBagMode = _EmbeddingBagMode(0), + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # otherwise we shouldn not be here + assert sparse + assert weight.requires_grad + torch._check( + not scale_grad_by_freq, + "embedding_bag: scale_grad_by_freq is not supported with sparse gradients" + ) + torch._check( + max_norm is None, + "Itermediate renormalization is not supported with sparse=True") + # output, offset2bag, bag_size, max_indices + # note: autograd is disabled while we are inside here, so we won't triger + # any double-backward behavior. We are not modifying the forward in any way + return torch.embedding_bag(weight, indices, offsets, scale_grad_by_freq, + mode.value, sparse, per_sample_weights, + include_last_offset, padding_idx) + + @staticmethod + def setup_context(ctx, + inputs: Tuple[Tensor, Tensor, Optional[Tensor], + Optional[Tensor], bool, _EmbeddingBagMode, + bool, Optional[Tensor], bool, Optional[int]], + outputs: Tuple[Tensor, Tensor, Tensor, Tensor]): + weight, indices, offsets, _, _, mode, _, per_sample_weights, _, padding_idx = inputs + output, offset2bag, bag_size, max_indices = outputs + # for completness, not technically required as integer dtype will make this automatic + ctx.mark_non_differentiable(offset2bag, bag_size, max_indices) + ctx.set_materialize_grads(False) + ctx.save_for_backward(indices, offsets, offset2bag, bag_size, + per_sample_weights) + ctx.num_weights = weight.size(0) + ctx.padding_idx = padding_idx + ctx.mode = mode + # Todo: remove after validating assumptions + ctx.n_out = len(outputs) + + @staticmethod + def backward(ctx, grad_: Tensor, *args: Tuple[Optional[Tensor], ...]): + # All *args will be none, autograd will need those slots but never materialized grads for non-differentable types + indices, offset2bag, bag_size, per_sample_weights = ctx.saved_tensors + # we need a grad slot in the return for every tensor input to forward + grad_weight = grad_indices = grad_offsets = grad_per_sample_weights = None + if grad_ is not None: + # compute grad_weight + index_grad = grad_.index_select(0, offset2bag) + index_grad = _apply_bag_size_backward(ctx.mode, index_grad, offset2bag, + bag_size) + if per_sample_weights is not None: + torch._check(ctx.mode == _EmbeddingBagMode.SUM) + index_grad.mul_(per_sample_weights.unsqueeze(1)) + + if ctx.padding_idx != -1: + c = indices != ctx.padding_idx + indices = indices.index(c) + grad_ = grad_.index(c) + + num_features = grad_.size(-1) + weight_size = (ctx.num_weights, num_features) + dense_options = _tensor_factory_kwargs(grad_) + + if grad_.numel() == 0: + sparse_index = torch.empty((1, 0), + **_tensor_factory_kwargs(indices), + dtype=torch.int64) + sparse_values = torch.empty((0, num_features), **dense_options), + else: + sparse_index = indices.reshape(1, -1) + sparse_values = grad_.reshape((-1, num_features)) + + grad_weight = SparseCOOTensor(sparse_index, sparse_values, weight_size) + + return grad_weight, grad_indices, grad_offsets, grad_per_sample_weights + + +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = 'mean', + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, +) -> Tensor: + f"""This is the equivalant API for F.embedding_bag, but where sparse gradients for :attr:`weight` would be produced a custom path is taken. + {F.embedding_bag.__doc__} + """ + if weight.dtype == torch.long and input.is_floating_point(): + warnings.warn("Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`.") + weight, input = input, weight + + if per_sample_weights is not None and input.size() != per_sample_weights.size( + ): + raise ValueError( + f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, " + f"then it must have the same shape as the input ({input.shape})") + + if not weight.dim() == 2: + raise ValueError( + f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}" + ) + + if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested: + include_last_offset = True + offsets = input.offsets() + input = input.values().reshape(-1) + if per_sample_weights is not None: + if not per_sample_weights.is_nested: + raise ValueError( + "If input is nested, then per_sample_weights must be nested if specified" + ) + per_sample_weights = per_sample_weights.values().reshape(-1) + elif input.dim() == 2: + if offsets is not None: + type_str = "" + # TODO: Remove this once script supports type() calls + if not torch.jit.is_scripting(): + type_str = str(type(offsets)) + raise ValueError("if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + f"offsets of type {type_str}") + offsets = torch.arange( + 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device) + + input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) + elif input.dim() == 1: + if offsets is None: + raise ValueError("offsets has to be a 1D Tensor but got None") + if offsets.dim() != 1: + raise ValueError("offsets has to be a 1D Tensor") + else: + raise ValueError( + f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}" + ) + if mode == "sum": + mode_enum = _EmbeddingBagMode.SU + elif mode == "mean": + mode_enum = _EmbeddingBagMode.MEAN + elif mode == "max": + mode_enum = _EmbeddingBagMode.MAX + + if scale_grad_by_freq: + raise ValueError( + "max mode does not support scaling the gradient by the frequency") + + if sparse: + raise ValueError("max mode does not support sparse weights") + + else: + raise ValueError("mode has to be one of sum, mean or max") + + if max_norm is not None: + with torch.no_grad(): + # modified weight in place! + torch.embedding_renorm_(weight, input, max_norm, norm_type) + + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + f"(got mode='{mode}'). Please open a feature request on GitHub.") + + impl = SparseEmbeddingBag.apply if sparse and weight.requires_grad( + ) else F.embedding_bag + + ret, _, _, _ = impl( + weight, + input, + offsets, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + return ret + + +class EmbeddingBag(UpstreamEmbeddingBag): + f"""Torch-XLA backend compatible Embedding Bag + + This implementation modifies the forward function when :attr:`sparse` is ``True``. + The alternate implementation has a backward which will produce sparse + gradients w.r.t :attr:`weight` as a tensor subclass implemention of the + ``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical + to :class:`~torch.nn.EmbedingBag`. + + {UpstreamEmbeddingBag.__doc__} + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None) -> Tensor: + # setting sparse without requiring weight grads has no effect + if self.weight.requires_grad() and self.sparse: + return embedding_bag(input, self.weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, + self.sparse, per_sample_weights, + self.include_last_offset, self.padding_idx) + return super().forward(input, offsets, per_sample_weights) From c67588f16c1f9c881b367f675d160e5c1c95341b Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Fri, 28 Mar 2025 21:18:40 +0000 Subject: [PATCH 4/4] Add raw embedding --- test/test_sparse_embedding.py | 40 +++++ .../experimental/sparse/embedding_bag.py | 142 +++++++++++++++--- 2 files changed, 158 insertions(+), 24 deletions(-) create mode 100644 test/test_sparse_embedding.py diff --git a/test/test_sparse_embedding.py b/test/test_sparse_embedding.py new file mode 100644 index 00000000000..97ea69e71a9 --- /dev/null +++ b/test/test_sparse_embedding.py @@ -0,0 +1,40 @@ +import itertools +import random +import unittest +from itertools import product + +import torch +import torch_xla +import torch.nn as nn +import torch.nn.functional as FA +from absl.testing import parameterized +from torch_xla.experimental import sparse as sparse_xla + +index_dtypes = {torch.int32, torch.int64} +float_dtypes = {torch.float32, torch.float16, torch.bfloat16} + +device = torch_xla.device() + + +def functional_embedding(impl): + + def f(input): + impl() + + @parameterized.parameters(*index_dtypes) + @parameterized.parameters(*float_dtypes) + def test_embedding_sparse_basic(self, index_dtype, float_dtype): + torch_embedding = nn.Embedding(10, 20, sparse=True, dtype=float_dtype) + xla_embedding = sparse_xla.Embedding.from_pretrained( + torch_embedding.weight.clone().detach().to(device), sparse=True) + input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=index_dtype) + input_xla = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], + dtype=index_dtype, + device=device) + ref_out = torch_embedding(input) + xla_out = xla_embedding(input_xla) + self.assertTrue(torch.allclose(ref_out, xla_out)) + ref_out.sum().backward() + xla_out.sum().backward() + torch_weight_grad = torch_embedding.weight.grad + xla_weight_grad = xla_embedding.weight.grad diff --git a/torch_xla/experimental/sparse/embedding_bag.py b/torch_xla/experimental/sparse/embedding_bag.py index 028b34fc924..7b8f66947b3 100644 --- a/torch_xla/experimental/sparse/embedding_bag.py +++ b/torch_xla/experimental/sparse/embedding_bag.py @@ -3,11 +3,14 @@ import warnings import torch from torch import Tensor -from torch.nn.modules.sparse import EmbeddingBag as UpstreamEmbeddingBag from torch.nn import functional as F +from torch.nn.modules import sparse from torch.autograd import Function from .coo import SparseCOOTensor +_NativeEmbeddingModule = sparse.Embedding +_NativeEmbeddingBagModule = sparse.EmbeddingBag + class _EmbeddingBagMode(IntEnum): SUM = 0 @@ -132,6 +135,31 @@ def _apply_bag_size_backward(mode: _EmbeddingBagMode, index_grad: Tensor, return index_grad +def _sparse_embedding_backward(grad: Tensor, indices: Tensor, num_weights: int, + padding_idx: int): + if padding_idx != -1: + c = indices != padding_idx + indices = indices.index(c) + grad = grad.index(c) + + num_features = grad.size(-1) + weight_size = (num_weights, num_features) + dense_options = _tensor_factory_kwargs(grad) + + if grad.numel() == 0: + sparse_index = torch.empty((1, 0), + **_tensor_factory_kwargs(indices), + dtype=torch.int64) + sparse_values = torch.empty((0, num_features), **dense_options), + else: + sparse_index = indices.reshape(1, -1) + sparse_values = grad.reshape((-1, num_features)) + + grad_weight = SparseCOOTensor(sparse_index, sparse_values, weight_size) + + return grad_weight + + class SparseEmbeddingBag(Function): """Alternate embeding bag implementation. Produces sparse gradients using an XLA compatible tensor subclass """ @@ -173,7 +201,7 @@ def setup_context(ctx, bool, Optional[Tensor], bool, Optional[int]], outputs: Tuple[Tensor, Tensor, Tensor, Tensor]): weight, indices, offsets, _, _, mode, _, per_sample_weights, _, padding_idx = inputs - output, offset2bag, bag_size, max_indices = outputs + _, offset2bag, bag_size, max_indices = outputs # for completness, not technically required as integer dtype will make this automatic ctx.mark_non_differentiable(offset2bag, bag_size, max_indices) ctx.set_materialize_grads(False) @@ -200,27 +228,40 @@ def backward(ctx, grad_: Tensor, *args: Tuple[Optional[Tensor], ...]): torch._check(ctx.mode == _EmbeddingBagMode.SUM) index_grad.mul_(per_sample_weights.unsqueeze(1)) - if ctx.padding_idx != -1: - c = indices != ctx.padding_idx - indices = indices.index(c) - grad_ = grad_.index(c) + grad_weight = _sparse_embedding_backward(index_grad, indices, + ctx.num_weights, ctx.padding_idx) + return grad_weight, grad_indices, grad_offsets, grad_per_sample_weights - num_features = grad_.size(-1) - weight_size = (ctx.num_weights, num_features) - dense_options = _tensor_factory_kwargs(grad_) - if grad_.numel() == 0: - sparse_index = torch.empty((1, 0), - **_tensor_factory_kwargs(indices), - dtype=torch.int64) - sparse_values = torch.empty((0, num_features), **dense_options), - else: - sparse_index = indices.reshape(1, -1) - sparse_values = grad_.reshape((-1, num_features)) +class SparseEmbedding(Function): - grad_weight = SparseCOOTensor(sparse_index, sparse_values, weight_size) + @staticmethod + def forward(input: Tensor, + weight: Tensor, + padding_idx: Optional[int], + scale_grad_by_freq: bool = False, + sparse: bool = False) -> Tensor: + return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, + sparse) - return grad_weight, grad_indices, grad_offsets, grad_per_sample_weights + @staticmethod + def setup_context(ctx, inputs, output): + indices, weight, padding_idx, scale_grad_by_freq, sparse = inputs + ctx.mark_non_differentiable(indices) + ctx.set_materialize_grads(False) + ctx.save_for_backward(indices) + ctx.num_weights = weight.size(0) + ctx.padding_idx = padding_idx + + @staticmethod + def backward(ctx, grad_: Tensor): + grad_input = grad_weight = None + indices = ctx.saved_tensors + if grad_ is not None: + grad_weight = _sparse_embedding_backward(grad_, indices, ctx.num_weights, + ctx.padding_idx) + + return grad_input, grad_weight def embedding_bag( @@ -336,7 +377,44 @@ def embedding_bag( return ret -class EmbeddingBag(UpstreamEmbeddingBag): +def embedding(input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> Tensor: + f"""This is the equivalant API for F.embedding_bag, but where sparse gradients for :attr:`weight` would be produced a custom path is taken. + {F.embedding.__doc__} + """ + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < weight.size( + 0), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert padding_idx >= -weight.size( + 0), "Padding_idx must be within num_embeddings" + padding_idx = weight.size(0) + padding_idx + else: + padding_idx = -1 + if max_norm is not None: + # Note [embedding_renorm contiguous] + # `embedding_renorm_` will call .contiguous() on input anyways, so we + # call it here and take advantage of the improved locality in the + # `embedding` call below too. + input = input.contiguous() + # Note [embedding_renorm set_grad_enabled] + # Note [Modified from F.embedding] + # We do't need to support torch-script so we can just do what they say this is equivalant to, original below + #_no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + with torch.no_grad(): + torch.embedding_renorm_(weight, input, max_norm, norm_type) + + impl = SparseEmbedding.apply if sparse and weight.requires_grad else torch.embedding + return impl(weight, input, padding_idx, scale_grad_by_freq, sparse) + + +class EmbeddingBag(_NativeEmbeddingBagModule): f"""Torch-XLA backend compatible Embedding Bag This implementation modifies the forward function when :attr:`sparse` is ``True``. @@ -345,12 +423,9 @@ class EmbeddingBag(UpstreamEmbeddingBag): ``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical to :class:`~torch.nn.EmbedingBag`. - {UpstreamEmbeddingBag.__doc__} + {_NativeEmbeddingBagModule.__doc__} """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def forward(self, input: Tensor, offsets: Optional[Tensor] = None, @@ -362,3 +437,22 @@ def forward(self, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) return super().forward(input, offsets, per_sample_weights) + + +class Embedding(_NativeEmbeddingModule): + f"""Torch-XLA backend compatible Embedding + + This implementation modifies the forward function when :attr:`sparse` is ``True``. + The alternate implementation has a backward which will produce sparse + gradients w.r.t :attr:`weight` as a tensor subclass implemention of the + ``torch.sparse_coo`` layout. If :attr:`sparse` is ``False`` this is identical + to :class:`~torch.nn.Embeding`. + + {_NativeEmbeddingBagModule.__doc__} + """ + + def forward(self, input: Tensor) -> Tensor: + if self.weight.requires_grad() and self.sparse: + return embedding(input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + return super().forward(input)