Skip to content

[Draft] Add Experimental limited sparse embedding bag #8905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/test_sparse_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import itertools
import random
import unittest
from itertools import product

import torch
import torch_xla
import torch.nn as nn
import torch.nn.functional as FA
from absl.testing import parameterized
from torch_xla.experimental import sparse as sparse_xla

index_dtypes = {torch.int32, torch.int64}
float_dtypes = {torch.float32, torch.float16, torch.bfloat16}

device = torch_xla.device()


def functional_embedding(impl):

def f(input):
impl()

@parameterized.parameters(*index_dtypes)
@parameterized.parameters(*float_dtypes)
def test_embedding_sparse_basic(self, index_dtype, float_dtype):
torch_embedding = nn.Embedding(10, 20, sparse=True, dtype=float_dtype)
xla_embedding = sparse_xla.Embedding.from_pretrained(
torch_embedding.weight.clone().detach().to(device), sparse=True)
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=index_dtype)
input_xla = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]],
dtype=index_dtype,
device=device)
ref_out = torch_embedding(input)
xla_out = xla_embedding(input_xla)
self.assertTrue(torch.allclose(ref_out, xla_out))
ref_out.sum().backward()
xla_out.sum().backward()
torch_weight_grad = torch_embedding.weight.grad
xla_weight_grad = xla_embedding.weight.grad
4 changes: 4 additions & 0 deletions torch_xla/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .coo import SparseCOOTensor
from .embedding_bag import EmbeddingBag, embedding_bag

__all__ = ["SparseCOOTensor", "EmbeddingBag", "embedding_bag"]
158 changes: 158 additions & 0 deletions torch_xla/experimental/sparse/_coo_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Tuple
import torch


def _flatten_indices(inds, shape):
# Flatted N-D indices to 1-D indices
flat_indices = inds.new_zeros(inds.size(1))
for d, sz in enumerate(shape):
flat_indices.mul_(sz)
flat_indices.add_(inds[d])
return flat_indices


def _check_no_kwargs(kwargs, name):
torch._check(
kwargs is None or len(kwargs) == 0,
f"{name}: expected no kwargs, got {kwargs}")


def _check_strided(arg, f_name, arg_name):
torch._check(arg.layout == torch.strided,
f"{f_name}: Expected strided {arg_name}, not {arg.layout}")


def _check_sparse(arg, f_name, arg_name):
torch._check(arg.layout == torch.sparse_coo,
f"{f_name}: Expected sparse {arg_name}, not {arg.layout}")


def coo_sparse_mask(args=(), kwargs=None):
"""
x.sparse_mask(y)
create a new sparse tensor from strided x and sparse y.
Result holds values from x with sparsity pattern of Y
"""
torch._check(
len(args) == 2, f"sparse_mask: Expected two arguments, got {len(args)}")
self, mask = args
_check_strided(self, "sparse_mask", "self")
_check_sparse(mask, "sparse_mask", "mask")
torch._check(
mask.size() == self.size(),
f"sparse_mask: expected mask and self to have the same shape (self: {self.size()}, mask: {mask.size()})"
)
_check_no_kwargs(kwargs, "sparse_mask")
mask_indices = mask.indices()
flat_indices = _flatten_indices(mask_indices, mask.size())
values = self.view(-1)[flat_indices]
# acess the subclass ctor without a circular import!
return mask.__class__(
mask_indices,
values,
size=self.size(),
dtype=values.dtype(),
device=values.device(),
requires_grad=values.requires_grad(),
)


def coo_resize_as_(args, kwargs=None):
# The aten impl supports more cases, but the only one we need is the 0-nnz case where we can freely modify the sparse/dense dims at will.
# We have also not implemented dense dim consideration anywhere as they are not needed.

torch._check(len(args) == 2, f"resize_as_: expected two args not {len(args)}")
self, other = args
_check_sparse(self, "resize_as_", "self")
_check_sparse(other, "resize_as_", "other")
torch._check(
self.nnz() == 0,
"resize_as_: resizing a sparse tensor with nnz != 0 is not supported")
_check_no_kwargs(kwargs, "resize_as_")

new_nnz = other.nnz()
values_shape = (new_nnz,)
index_shape = (len(other.shape), new_nnz)

# attribute access to modify the tensor in-place, accessors return a clone
self._v.resize_(values_shape)
self._i.resize_(index_shape)
return self


def coo_coalesce(args, kwargs=None):
_check_no_kwargs(kwargs, "coalesce")
self = args[0]
_check_sparse(self, "coalesce", "self")
if self._is_coalesced:
return self
if self.nnz() < 2:
self._is_coalesced = True
return self

indices = self._i
values = self._v
sparse_dim = len(self.shape)
nnz = self.nnz()
indices_scalar = _flatten_indices(indices, self.shape)

new_indices = torch.empty_like(indices)
new_values = torch.empty_like(values)
indices_buffer, indices_perm = indices_scalar.sort(0)
i = 0
for j in range(nnz):
pos = indices_perm[j]
curr = indices_buffer[j]
if pos != curr:
i += 1
for d in range(sparse_dim):
new_indices[d][i] = indices[d][pos]
if values.numel() > 0:
new_values[i] += values[pos]

return self.__class__(
new_indices,
new_values,
self.shape,
dtype=self.dtype,
device=self.device,
requires_grad=self.requires_grad())


def coo_add_(args=(), kwargs=None):
alpha = kwargs.pop('alpha', 1)
torch._check(len(args) == 2, f"add_: expected two operands, got {len(args)}")
self, other = args
_check_sparse(self, "self", "add_")
_check_sparse(other, "other", "add_")

# todo: I think this is going to satisfy the needs for sparseADAM opt
torch._check(
self._i.equal(other._i).all(),
"add_: Operation is supported when operands have the same sparsity pattern."
)
self._v += alpha * other._v
return self


def coo_indices(args=(), kwargs=None):
torch._check(len(args) == 1, "indices: expected one argument")
self = args[0]
_check_sparse(self, "indices", "self")
torch._check(self._coalesced, "indices: input must be coalesced")
return self._i.clone()


def coo_values(args=(), kwargs=None):
torch._check(len(args) == 1, "values: expected one argument")
self = args[0]
_check_sparse(self, "values", "self")
torch._check(self._coalesced, "values: input must be coalesced")
return self._v.clone()


def coo_nnz(args=(), kwargs=None):
torch._check(len(args) == 1, "values: expected one argument")
self = args[0]
_check_sparse(self, "nnz", "self")
return self._v.numel()
95 changes: 95 additions & 0 deletions torch_xla/experimental/sparse/coo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Callable, ClassVar, Dict, Iterable, Tuple
import torch

from torch._ops import OpOverload

from _coo_ops import (
coo_sparse_mask,
coo_resize_as_,
coo_coalesce,
coo_add_,
coo_indices,
coo_values,
coo_nnz,
)

__all__ = ["SparseCOOTensor"]

aten = torch.ops.aten


class SparseCOOTensor(torch.Tensor):

DISPATCH_TABLE: ClassVar[Dict[OpOverload, Callable]] = {}
_v: torch.Tensor
_i: torch.Tensor

def __new__(cls,
indices: torch.Tensor,
values: torch.Tensor,
size: Iterable[int],
*,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
requires_grad: bool | None = None):
device = device if device is not None else values.device
dtype = dtype if dtype is not None else values.dtype
requires_grad = requires_grad if requires_grad is not None else values.requires_grad
assert (device == values.device and device == indices.device and
dtype == values.dtype and requires_grad == values.requires_grad)

res = torch.Tensor._make_wrapper_subclass(
cls,
size=size,
strides=None,
layout=torch.sparse_coo,
dtype=dtype,
device=device,
requires_grad=requires_grad)
res._v = values
res._i = indices
return res

@classmethod
def __torch_dispatch__(cls,
func: OpOverload,
types: Tuple,
args: Tuple = (),
kwargs: Dict[str, Any] | None = None) -> torch.Tensor:
impl = cls._match_func_key_for_dispatch(func)
return impl(args, kwargs)

@classmethod
def _load_sparse_dispatch(cls):
if getattr(cls, "DISPATCH_TABLE", None) is None:
cls.DISPATCH_TABLE = {
aten.values: coo_values,
aten.indices: coo_indices,
aten.coalesce: coo_coalesce,
aten.add_: coo_add_,
aten.sparse_mask: coo_sparse_mask,
aten.resize_as_: coo_resize_as_,
aten.nnz: coo_nnz,
}

@classmethod
def _match_func_key_for_dispatch(cls, func):
cls.load_sparse_dispatch()
impl = cls.DISPATCH_TABLE.get(func, None)
if impl is None:
impl = cls.DISPATCH_TABLE.get(func._overloadpacket,
cls._make_not_implemented(func))
return impl

@classmethod
def _make_not_implemented(cls, func):

def impl(args=(), kwargs=None):
raise NotImplemented(
f"Sparse support via {cls.__name__} is limited to a very narrow "
f"scope. The {func.__name__} operator is not currently supported. It"
" is likely you are using this outside of its intended purpose "
"if you see this message. Supported operators are: ",
f"{list(k.__name__ for k in cls.DISPATCH_TABLE.keys())}")

return impl
Loading
Loading