Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gradient for vector repetitions #1192

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def calculate_reallocate_info(
# where gc
for i in range(idx + 1, len(order)):
if reuse_out is not None:
break # type: ignore
break
for out in order[i].outputs:
if (
getattr(out.type, "ndim", None) == 0
Expand Down
176 changes: 110 additions & 66 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,17 @@

__props__ = ("axis",)

def __init__(self, axis=None):
def __init__(self, axis: int | None = None):
if axis is not None:
if not isinstance(axis, int) or axis < 0:
raise ValueError(

Check warning on line 652 in pytensor/tensor/extra_ops.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/extra_ops.py#L652

Added line #L652 was not covered by tests
f"Repeat only accepts positive integer axis or None, got {axis}"
)
self.axis = axis

def make_node(self, x, repeats):
x = ptb.as_tensor_variable(x)
repeats = ptb.as_tensor_variable(repeats)
repeats = ptb.as_tensor_variable(repeats, dtype="int64")

if repeats.dtype not in integer_dtypes:
raise TypeError("repeats.dtype must be an integer.")
Expand Down Expand Up @@ -687,58 +692,64 @@
out_shape = list(x.type.shape)
out_shape[self.axis] = None

out_type = TensorType(
x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape)
)

out_type = TensorType(x.dtype, shape=out_shape)
return Apply(self, [x, repeats], [out_type()])

def perform(self, node, inputs, output_storage):
x = inputs[0]
repeats = inputs[1]
z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
[x, repeats] = inputs
output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis)

def connection_pattern(self, node):
return [[True], [False]]

def grad(self, inputs, gout):
(x, repeats) = inputs
(gz,) = gout
axis = self.axis
if repeats.ndim == 0:
if self.axis is None:
axis = x.ndim
else:
if self.axis >= 0:
axis = self.axis + 1
else:
axis = self.axis + x.ndim + 1

shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats)
# When axis is a scalar (same number of reps for all elements),
# We can split the repetitions into their own axis with reshape and sum them back
# to the original element location
sum_axis = x.ndim if axis is None else axis + 1
shape = list(x.shape)
shape.insert(sum_axis, repeats)
gx = gz.reshape(shape).sum(axis=sum_axis)

return [
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
DisconnectedType()(),
]
elif repeats.ndim == 1:
# For this implementation, we would need to specify the length
# of repeats in order to split gz in the right way to sum
# the good part.
raise NotImplementedError()
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
axis_size = x.size if axis is None else x.shape[axis]
repeated_eye = repeat(
ptb.eye(axis_size), repeats, axis=0
) # A sparse repeat would be neat

if axis is None:
gx = gz @ repeated_eye
# Undo the ravelling when axis=None
gx = gx.reshape(x.shape)
else:
# Place gradient axis at end for dot product
gx = ptb.moveaxis(gz, axis, -1)
gx = gx @ repeated_eye
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, axis)

else:
raise ValueError()

return [gx, DisconnectedType()()]

def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0]
repeats = node.inputs[1]
out_shape = list(i0_shapes)
axis = self.axis

# uint64 shape are not supported.
dtype = None
if repeats.dtype in ("uint8", "uint16", "uint32"):
dtype = "int64"
if self.axis is None:
if axis is None:
if repeats.ndim == 0:
if len(i0_shapes) == 0:
out_shape = [repeats]
Expand All @@ -751,82 +762,115 @@
out_shape = [pt_sum(repeats, dtype=dtype)]
else:
if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats
out_shape[axis] = out_shape[axis] * repeats
else:
out_shape[self.axis] = pt_sum(repeats, dtype=dtype)
out_shape[axis] = pt_sum(repeats, dtype=dtype)
return [out_shape]


def repeat(x, repeats, axis=None):
"""Repeat elements of an array.
def repeat(
a: TensorLike, repeats: TensorLike, axis: int or None = None
) -> TensorVariable:
"""Repeat elements of a tensor.

It returns an array which has the same shape as `x`, except along the given
`axis`. The `axis` parameter is used to specify the axis along which values
are repeated. By default, a flattened version of `x` is used.
See :func:`numpy.repeat` for more information.

The number of repetitions for each element is `repeats`. `repeats` is
broadcasted to fit the length of the given `axis`.

Parameters
----------
x
Input data, tensor variable.
repeats
int, scalar or tensor variable
a: tensor_like
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a: tensor_like
a: TensorLike

I think the pycharm linter checks these for valid types. There's no type called tensor_like, and it's also not fully human readable; kind of a weird middle ground

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we had a glossary like in PyMC: https://www.pymc.io/projects/docs/en/stable/glossary.html#term-tensor_like

In which case we should use tensor_like in the docs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input tensor
repeats: tensor_like
The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
axis : int, optional
The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.

See Also
Returns
-------
repeated_tensor: TensorVariable
Output tensor which as the same shape as a, except along the given axis

Examples
--------
tensor.tile

.. testcode::
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

import pytensor.tensor as pt

a = pt.arange(4).reshape((2, 2))
out = pt.repeat(a, repeats=[2, 3], axis=0)
print(out.eval())

.. testoutput::

[[0 1]
[0 1]
[2 3]
[2 3]
[2 3]]

When axis is None, the array is first flattened and then repeated

.. testcode::

import pytensor.tensor as pt

a = pt.arange(4).reshape((2, 2))
out = pt.repeat(a, repeats=[2, 3, 0, 1], axis=None)
print(out.eval())

.. testoutput::

[0 0 1 1 1 3]


.. versionadded:: 0.6

"""
a = ptb.as_tensor_variable(a)

if axis is not None:
axis = normalize_axis_index(axis, a.ndim)

repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)

if repeats.ndim > 1:
raise ValueError("The dimension of repeats should not exceed 1.")

if repeats.ndim == 1 and not repeats.broadcastable[0]:
return Repeat(axis=axis)(x, repeats)
# We only use the Repeat Op for vector repeats
return Repeat(axis=axis)(a, repeats)
else:
if repeats.ndim == 1:
repeats = repeats[0]

if x.dtype == "uint64":
if a.dtype == "uint64":
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Which is not valid for the `reshape` operation at the end
raise TypeError("repeat doesn't support dtype uint64")

if axis is None:
axis = 0
x = x.flatten()
else:
if axis >= x.ndim:
raise ValueError("Axis should not exceed x.ndim-1.")
if axis < 0:
axis = x.ndim + axis
a = a.flatten()

shape = [x.shape[i] for i in range(x.ndim)]
repeat_shape = list(a.shape)

# shape_ is the shape of the intermediate tensor which has
# alloc_shape is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x
# along that additional dimension.
shape_ = shape[:]
shape_.insert(axis + 1, repeats)
alloc_shape = repeat_shape[:]
alloc_shape.insert(axis + 1, repeats)

# shape is now the shape of output, where shape[axis] becomes
# repeat_shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats.
shape[axis] = shape[axis] * repeats

# dims_ is the dimension of that intermediate tensor.
dims_ = list(np.arange(x.ndim))
dims_.insert(axis + 1, "x")
repeat_shape[axis] = repeat_shape[axis] * repeats

# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape, and
# return the output z.
z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape)
return z
# dimension, we reshape it to the expected output shape
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
repeat_shape
)


class Bartlett(Op):
Expand Down
28 changes: 21 additions & 7 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ def test_basic(self, ndim, dtype):
isinstance(n.op, Repeat) for n in f.maker.fgraph.toposort()
)

@pytest.mark.slow
@pytest.mark.parametrize("ndim", [1, 3])
@pytest.mark.parametrize("dtype", ["int8", "uint8", "uint64"])
def test_infer_shape(self, ndim, dtype):
Expand All @@ -606,6 +605,10 @@ def test_infer_shape(self, ndim, dtype):
a = rng.random(shp).astype(config.floatX)

for axis in self._possible_axis(ndim):
if axis is not None and axis < 0:
# Operator does not support negative axis
continue

r_var = scalar(dtype=dtype)
r = np.asarray(3, dtype=dtype)
if dtype in self.numpy_unsupported_dtypes:
Expand Down Expand Up @@ -635,12 +638,23 @@ def test_infer_shape(self, ndim, dtype):
self.op_class,
)

@pytest.mark.parametrize("ndim", range(3))
def test_grad(self, ndim):
a = np.random.random((10,) * ndim).astype(config.floatX)

for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a])
@pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}")
@pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}")
@pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}")
def test_grad(self, x_ndim, repeats_ndim, axis):
rng = np.random.default_rng(
[653, x_ndim, 2 if axis is None else axis, repeats_ndim]
)
x_test = rng.normal(size=np.arange(3, 3 + x_ndim))
if repeats_ndim == 0:
repeats_size = ()
else:
repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,)
repeats = rng.integers(1, 6, size=repeats_size)
utt.verify_grad(
lambda x: Repeat(axis=axis)(x, repeats),
[x_test],
)

def test_broadcastable(self):
x = TensorType(config.floatX, shape=(None, 1, None))()
Expand Down