diff --git a/keras/backend/tensorflow/rnn.py b/keras/backend/tensorflow/rnn.py index 61b4b2912f5..874c9f8e041 100644 --- a/keras/backend/tensorflow/rnn.py +++ b/keras/backend/tensorflow/rnn.py @@ -441,7 +441,6 @@ def _step(time, output_ta_t, *states): return last_output, outputs, new_states -@tf.function def gru( inputs, initial_state, @@ -457,7 +456,6 @@ def gru( time_major=False, reset_after=True, ): - inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major) cudnn_supported = cudnn_ok( activation, recurrent_activation, @@ -465,7 +463,7 @@ def gru( use_bias=bias is not None, reset_after=reset_after, ) - if not cudnn_supported or not inputs_supported: + if not cudnn_supported or mask is not None: raise NotImplementedError from keras.backend.tensorflow import Variable @@ -535,21 +533,6 @@ def _do_lstm_arguments_support_cudnn( ) -def _do_rnn_inputs_support_cudnn(mask, time_major): - if tf.sysconfig.get_build_info()["is_rocm_build"]: - if mask is not None: - return tf.reduce_all(mask) - return True - if mask is None: - return True - if time_major: - mask = tf.transpose(mask) - return tf.logical_and( - _is_sequence_right_padded(mask), - tf.logical_not(_has_fully_masked_sequence(mask)), - ) - - def _is_sequence_right_padded(mask): """Check the mask tensor and see if it right padded. @@ -655,7 +638,6 @@ def _is_gpu_available(): return bool(tf.config.list_logical_devices("GPU")) -@tf.function(autograph=False) def _cudnn_gru( inputs, initial_state, @@ -802,7 +784,6 @@ def cudnn_ok( return args_supported and _is_gpu_available() -@tf.function def lstm( inputs, initial_state_h, @@ -818,11 +799,10 @@ def lstm( unroll=False, time_major=False, ): - inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major) cudnn_supported = cudnn_ok( activation, recurrent_activation, unroll, use_bias=bias is not None ) - if not cudnn_supported or not inputs_supported: + if not cudnn_supported or mask is not None: raise NotImplementedError from keras.backend.tensorflow import Variable @@ -855,7 +835,6 @@ def lstm( raise NotImplementedError -@tf.function(autograph=False) def _cudnn_lstm( inputs, initial_state_h,