From ac9b7f14571b66acbf0c397c8c9e62708b1cd12b Mon Sep 17 00:00:00 2001 From: Jonas Rauber Date: Fri, 14 Aug 2020 20:26:56 +0200 Subject: [PATCH 1/2] improved tensorflow's getitem handling --- eagerpy/tensor/tensorflow.py | 24 +++++++++++++++++++++--- tests/test_main.py | 16 ++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/eagerpy/tensor/tensorflow.py b/eagerpy/tensor/tensorflow.py index 49387b1..0fef999 100644 --- a/eagerpy/tensor/tensorflow.py +++ b/eagerpy/tensor/tensorflow.py @@ -535,9 +535,27 @@ def __getitem__(self: TensorType, index: Any) -> TensorType: ) if not basic: # workaround for missing support for this in TensorFlow - # TODO: maybe convert each index individually and then stack them instead - index = tf.convert_to_tensor(index) - index = tf.transpose(index) + index = [tf.convert_to_tensor(x) for x in index] + shapes = [tuple(x.shape) for x in index] + shape = tuple(max(x) for x in zip(*shapes)) + int64 = any(x.dtype == tf.int64 for x in index) + for i in range(len(index)): + t = index[i] + if int64: + t = tf.cast(t, tf.int64) + assert t.ndim == len(shape) + tiling = [] + for b, k in zip(shape, t.shape): + if k == 1: + tiling.append(b) + elif k == b: + tiling.append(1) + else: + raise ValueError( + f"{tuple(t.shape)} cannot be broadcasted to {shape}" + ) + index[i] = tf.tile(t, tiling) + index = tf.stack(index, axis=-1) return type(self)(tf.gather_nd(self.raw, index)) elif ( isinstance(index, range) diff --git a/tests/test_main.py b/tests/test_main.py index a2faf87..cabd1e4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -664,6 +664,22 @@ def test_getitem_tuple_tensors(dummy: Tensor) -> Tensor: return t[rows, indices] +@compare_all +def test_getitem_tuple_tensors_full(dummy: Tensor) -> Tensor: + t = ep.arange(dummy, 32).float32().reshape((8, 4)) + rows = ep.arange(t, len(t))[:, np.newaxis].tile((1, t.shape[-1])) + cols = t.argsort(axis=-1) + return t[rows, cols] + + +@compare_all +def test_getitem_tuple_tensors_full_broadcast(dummy: Tensor) -> Tensor: + t = ep.arange(dummy, 32).float32().reshape((8, 4)) + rows = ep.arange(t, len(t))[:, np.newaxis] + cols = t.argsort(axis=-1) + return t[rows, cols] + + @compare_all def test_getitem_tuple_range_tensor(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) From 7365808a58b59af6665a9637728f4ba3902bf6cb Mon Sep 17 00:00:00 2001 From: Jonas Rauber Date: Fri, 14 Aug 2020 20:36:26 +0200 Subject: [PATCH 2/2] fixed coverage --- eagerpy/tensor/tensorflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eagerpy/tensor/tensorflow.py b/eagerpy/tensor/tensorflow.py index 0fef999..1ad4ccd 100644 --- a/eagerpy/tensor/tensorflow.py +++ b/eagerpy/tensor/tensorflow.py @@ -542,7 +542,7 @@ def __getitem__(self: TensorType, index: Any) -> TensorType: for i in range(len(index)): t = index[i] if int64: - t = tf.cast(t, tf.int64) + t = tf.cast(t, tf.int64) # pragma: no cover assert t.ndim == len(shape) tiling = [] for b, k in zip(shape, t.shape): @@ -551,7 +551,7 @@ def __getitem__(self: TensorType, index: Any) -> TensorType: elif k == b: tiling.append(1) else: - raise ValueError( + raise ValueError( # pragma: no cover f"{tuple(t.shape)} cannot be broadcasted to {shape}" ) index[i] = tf.tile(t, tiling)