diff --git a/eagerpy/framework.py b/eagerpy/framework.py index 7fc0dca..d036d63 100644 --- a/eagerpy/framework.py +++ b/eagerpy/framework.py @@ -32,7 +32,7 @@ def square(t: TensorType) -> TensorType: return t.square() -def pow(t: TensorType, exponent: float) -> TensorType: +def pow(t: TensorType, exponent: TensorOrScalar) -> TensorType: return t.pow(exponent) diff --git a/eagerpy/tensor/base.py b/eagerpy/tensor/base.py index 249ddef..41a7da3 100644 --- a/eagerpy/tensor/base.py +++ b/eagerpy/tensor/base.py @@ -106,8 +106,8 @@ def __mod__(self: TensorType, other: TensorOrScalar) -> TensorType: return type(self)(self.raw.__mod__(unwrap1(other))) @final - def __pow__(self: TensorType, exponent: float) -> TensorType: - return type(self)(self.raw.__pow__(exponent)) + def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType: + return type(self)(self.raw.__pow__(unwrap1(exponent))) @final @property diff --git a/eagerpy/tensor/tensor.py b/eagerpy/tensor/tensor.py index d4f9112..02e0181 100644 --- a/eagerpy/tensor/tensor.py +++ b/eagerpy/tensor/tensor.py @@ -190,7 +190,7 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: ... @abstractmethod - def __pow__(self: TensorType, exponent: float) -> TensorType: + def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType: ... @abstractmethod @@ -527,7 +527,7 @@ def abs(self: TensorType) -> TensorType: return self.__abs__() @final - def pow(self: TensorType, exponent: float) -> TensorType: + def pow(self: TensorType, exponent: TensorOrScalar) -> TensorType: return self.__pow__(exponent) @final diff --git a/tests/test_main.py b/tests/test_main.py index 9e6a8b8..7035390 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -499,6 +499,16 @@ def test_pow_op(t: Tensor) -> Tensor: return t ** 3 +@compare_allclose +def test_pow_tensor(t: Tensor) -> Tensor: + return ep.pow(t, (t + 0.5)) + + +@compare_allclose +def test_pow_op_tensor(t: Tensor) -> Tensor: + return t ** (t + 0.5) + + @compare_all def test_add(t1: Tensor, t2: Tensor) -> Tensor: return t1 + t2