Skip to content

Commit

Permalink
pow exponents can now be tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Feb 22, 2020
1 parent bc7490a commit 91b3424
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion eagerpy/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def square(t: TensorType) -> TensorType:
return t.square()


def pow(t: TensorType, exponent: float) -> TensorType:
def pow(t: TensorType, exponent: TensorOrScalar) -> TensorType:
return t.pow(exponent)


Expand Down
4 changes: 2 additions & 2 deletions eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def __mod__(self: TensorType, other: TensorOrScalar) -> TensorType:
return type(self)(self.raw.__mod__(unwrap1(other)))

@final
def __pow__(self: TensorType, exponent: float) -> TensorType:
return type(self)(self.raw.__pow__(exponent))
def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType:
return type(self)(self.raw.__pow__(unwrap1(exponent)))

@final
@property
Expand Down
4 changes: 2 additions & 2 deletions eagerpy/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType:
...

@abstractmethod
def __pow__(self: TensorType, exponent: float) -> TensorType:
def __pow__(self: TensorType, exponent: TensorOrScalar) -> TensorType:
...

@abstractmethod
Expand Down Expand Up @@ -527,7 +527,7 @@ def abs(self: TensorType) -> TensorType:
return self.__abs__()

@final
def pow(self: TensorType, exponent: float) -> TensorType:
def pow(self: TensorType, exponent: TensorOrScalar) -> TensorType:
return self.__pow__(exponent)

@final
Expand Down
10 changes: 10 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,16 @@ def test_pow_op(t: Tensor) -> Tensor:
return t ** 3


@compare_allclose
def test_pow_tensor(t: Tensor) -> Tensor:
return ep.pow(t, (t + 0.5))


@compare_allclose
def test_pow_op_tensor(t: Tensor) -> Tensor:
return t ** (t + 0.5)


@compare_all
def test_add(t1: Tensor, t2: Tensor) -> Tensor:
return t1 + t2
Expand Down

0 comments on commit 91b3424

Please sign in to comment.