From 6f0ef5c3e74fdec16159563ce568e864f697f904 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 21 Jan 2025 16:40:59 +0000 Subject: [PATCH] MAINT: Simplify `at` implementation --- src/array_api_extra/_lib/_at.py | 111 +++++++++++++++----------------- 1 file changed, 51 insertions(+), 60 deletions(-) diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index f2dd721..927a730 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -185,22 +185,42 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01 raise ValueError(msg) return at(self._x, idx) - def _update_common( + def _op( self, at_op: _AtOp, - y: Array, + in_place_op: Callable[[Array, Array | object], Array] | None, + y: Array | object, /, copy: bool | None, xp: ModuleType | None, - ) -> tuple[Array, None] | tuple[None, Array]: # numpydoc ignore=PR01 + ) -> Array: """ - Perform common prepocessing to all update operations. + Implement all update operations. + + Parameters + ---------- + at_op : _AtOp + Method of JAX's Array.at[]. + in_place_op : Callable[[Array, Array | object], Array] | None + In-place operation to apply on mutable backends:: + + x[idx] = in_place_op(x[idx], y) + + If None:: + + x[idx] = y + + y : array or object + Right-hand side of the operation. + copy : bool or None + Whether to copy the input array. See the class docstring for details. + xp : array_namespace or None + The array namespace for the input array. Returns ------- - tuple - If the operation can be resolved by ``at[]``, ``(return value, None)`` - Otherwise, ``(None, preprocessed x)``. + Array + Updated `x`. """ x, idx = self._x, self._idx @@ -231,7 +251,7 @@ def _update_common( if is_jax_array(x): # Use JAX's at[] func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value)) - return func(y), None + return func(y) # Emulate at[] behaviour for non-JAX arrays # with a copy followed by an update if xp is None: @@ -249,52 +269,25 @@ def _update_common( msg = f"Can't update read-only array {x}" raise ValueError(msg) - return None, x + if in_place_op: + x[self._idx] = in_place_op(x[self._idx], y) + else: # set() + x[self._idx] = y + return x def set( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] = y`` and return the update array.""" - res, x = self._update_common(_AtOp.SET, y, copy=copy, xp=xp) - if res is not None: - return res - assert x is not None - x[self._idx] = y - return x - - def _iop( - self, - at_op: _AtOp, - elwise_op: Callable[[Array, Array], Array], - y: Array, - /, - copy: bool | None, - xp: ModuleType | None, - ) -> Array: # numpydoc ignore=PR01,RT01 - """ - ``x[idx] += y`` or equivalent in-place operation on a subset of x. - - which is the same as saying - x[idx] = x[idx] + y - Note that this is not the same as - operator.iadd(x[idx], y) - Consider for example when x is a numpy array and idx is a fancy index, which - triggers a deep copy on __getitem__. - """ - res, x = self._update_common(at_op, y, copy=copy, xp=xp) - if res is not None: - return res - assert x is not None - x[self._idx] = elwise_op(x[self._idx], y) - return x + return self._op(_AtOp.SET, None, y, copy=copy, xp=xp) def add( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, @@ -304,70 +297,68 @@ def add( # Note for this and all other methods based on _iop: # operator.iadd and operator.add subtly differ in behaviour, as # only iadd will trigger exceptions when y has an incompatible dtype. - return self._iop(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp) + return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp) def subtract( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] -= y`` and return the updated array.""" - return self._iop(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp) + return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp) def multiply( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] *= y`` and return the updated array.""" - return self._iop(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp) + return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp) def divide( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] /= y`` and return the updated array.""" - return self._iop(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp) + return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp) def power( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] **= y`` and return the updated array.""" - return self._iop(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp) + return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp) def min( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array.""" - if xp is None: - xp = array_namespace(self._x) + xp = array_namespace(self._x) if xp is None else xp y = xp.asarray(y) - return self._iop(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp) + return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp) def max( self, - y: Array, + y: Array | object, /, copy: bool | None = None, xp: ModuleType | None = None, ) -> Array: # numpydoc ignore=PR01,RT01 """Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array.""" - if xp is None: - xp = array_namespace(self._x) + xp = array_namespace(self._x) if xp is None else xp y = xp.asarray(y) - return self._iop(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp) + return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)