Skip to content

Commit

Permalink
Merge pull request #118 from crusaderky/simple_at
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley authored Jan 21, 2025
2 parents 6ee70c0 + 6f0ef5c commit b5bf75c
Showing 1 changed file with 51 additions and 60 deletions.
111 changes: 51 additions & 60 deletions src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,42 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
raise ValueError(msg)
return at(self._x, idx)

def _update_common(
def _op(
self,
at_op: _AtOp,
y: Array,
in_place_op: Callable[[Array, Array | object], Array] | None,
y: Array | object,
/,
copy: bool | None,
xp: ModuleType | None,
) -> tuple[Array, None] | tuple[None, Array]: # numpydoc ignore=PR01
) -> Array:
"""
Perform common prepocessing to all update operations.
Implement all update operations.
Parameters
----------
at_op : _AtOp
Method of JAX's Array.at[].
in_place_op : Callable[[Array, Array | object], Array] | None
In-place operation to apply on mutable backends::
x[idx] = in_place_op(x[idx], y)
If None::
x[idx] = y
y : array or object
Right-hand side of the operation.
copy : bool or None
Whether to copy the input array. See the class docstring for details.
xp : array_namespace or None
The array namespace for the input array.
Returns
-------
tuple
If the operation can be resolved by ``at[]``, ``(return value, None)``
Otherwise, ``(None, preprocessed x)``.
Array
Updated `x`.
"""
x, idx = self._x, self._idx

Expand Down Expand Up @@ -231,7 +251,7 @@ def _update_common(
if is_jax_array(x):
# Use JAX's at[]
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
return func(y), None
return func(y)
# Emulate at[] behaviour for non-JAX arrays
# with a copy followed by an update
if xp is None:
Expand All @@ -249,52 +269,25 @@ def _update_common(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

return None, x
if in_place_op:
x[self._idx] = in_place_op(x[self._idx], y)
else: # set()
x[self._idx] = y
return x

def set(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = y`` and return the update array."""
res, x = self._update_common(_AtOp.SET, y, copy=copy, xp=xp)
if res is not None:
return res
assert x is not None
x[self._idx] = y
return x

def _iop(
self,
at_op: _AtOp,
elwise_op: Callable[[Array, Array], Array],
y: Array,
/,
copy: bool | None,
xp: ModuleType | None,
) -> Array: # numpydoc ignore=PR01,RT01
"""
``x[idx] += y`` or equivalent in-place operation on a subset of x.
which is the same as saying
x[idx] = x[idx] + y
Note that this is not the same as
operator.iadd(x[idx], y)
Consider for example when x is a numpy array and idx is a fancy index, which
triggers a deep copy on __getitem__.
"""
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
if res is not None:
return res
assert x is not None
x[self._idx] = elwise_op(x[self._idx], y)
return x
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp)

def add(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
Expand All @@ -304,70 +297,68 @@ def add(
# Note for this and all other methods based on _iop:
# operator.iadd and operator.add subtly differ in behaviour, as
# only iadd will trigger exceptions when y has an incompatible dtype.
return self._iop(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)

def subtract(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] -= y`` and return the updated array."""
return self._iop(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)

def multiply(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] *= y`` and return the updated array."""
return self._iop(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)

def divide(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] /= y`` and return the updated array."""
return self._iop(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)

def power(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] **= y`` and return the updated array."""
return self._iop(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)

def min(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
if xp is None:
xp = array_namespace(self._x)
xp = array_namespace(self._x) if xp is None else xp
y = xp.asarray(y)
return self._iop(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)

def max(
self,
y: Array,
y: Array | object,
/,
copy: bool | None = None,
xp: ModuleType | None = None,
) -> Array: # numpydoc ignore=PR01,RT01
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
if xp is None:
xp = array_namespace(self._x)
xp = array_namespace(self._x) if xp is None else xp
y = xp.asarray(y)
return self._iop(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)

0 comments on commit b5bf75c

Please sign in to comment.