From 95b43838f42428e56b05aeb87f599c2b6ff8ad8b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 10 Dec 2024 17:11:43 -0800 Subject: [PATCH] Improve implementation of TF shuffle and make it XLA compilable --- keras/src/backend/tensorflow/random.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/keras/src/backend/tensorflow/random.py b/keras/src/backend/tensorflow/random.py index 4bd2162fc0d..2039b35c671 100644 --- a/keras/src/backend/tensorflow/random.py +++ b/keras/src/backend/tensorflow/random.py @@ -94,15 +94,10 @@ def dropout(inputs, rate, noise_shape=None, seed=None): def shuffle(x, axis=0, seed=None): - from keras.src.backend.tensorflow.numpy import swapaxes - - seed = _cast_seed(draw_seed(seed)) - if axis == 0: - return tf.random.experimental.stateless_shuffle(x, seed=seed) - x = swapaxes(x, axis1=0, axis2=axis) - x = tf.random.experimental.stateless_shuffle(x, seed=seed) - x = swapaxes(x, axis1=0, axis2=axis) - return x + indices = tf.argsort( + tf.random.stateless_uniform(shape=[tf.shape(x)[axis]], seed=seed) + ) + return tf.gather(x, indices, axis=axis) def gamma(shape, alpha, dtype=None, seed=None):