From d09e73d203be274b5a3d5aec351f66e92b9b28e9 Mon Sep 17 00:00:00 2001
From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com>
Date: Thu, 30 Nov 2023 18:58:46 -0600
Subject: [PATCH] Fix LSTM performance on TF Backend (#18859)

* Wrap only the check on inputs supported in tf.function

* Remove check on inputs supported for lstm / gru

* Remove check on inputs supported for lstm / gru

* Remove check on inputs supported for lstm / gru
---
 keras/backend/tensorflow/rnn.py | 25 ++-----------------------
 1 file changed, 2 insertions(+), 23 deletions(-)

diff --git a/keras/backend/tensorflow/rnn.py b/keras/backend/tensorflow/rnn.py
index 61b4b2912f51..874c9f8e041e 100644
--- a/keras/backend/tensorflow/rnn.py
+++ b/keras/backend/tensorflow/rnn.py
@@ -441,7 +441,6 @@ def _step(time, output_ta_t, *states):
     return last_output, outputs, new_states
 
 
-@tf.function
 def gru(
     inputs,
     initial_state,
@@ -457,7 +456,6 @@ def gru(
     time_major=False,
     reset_after=True,
 ):
-    inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
     cudnn_supported = cudnn_ok(
         activation,
         recurrent_activation,
@@ -465,7 +463,7 @@ def gru(
         use_bias=bias is not None,
         reset_after=reset_after,
     )
-    if not cudnn_supported or not inputs_supported:
+    if not cudnn_supported or mask is not None:
         raise NotImplementedError
 
     from keras.backend.tensorflow import Variable
@@ -535,21 +533,6 @@ def _do_lstm_arguments_support_cudnn(
     )
 
 
-def _do_rnn_inputs_support_cudnn(mask, time_major):
-    if tf.sysconfig.get_build_info()["is_rocm_build"]:
-        if mask is not None:
-            return tf.reduce_all(mask)
-        return True
-    if mask is None:
-        return True
-    if time_major:
-        mask = tf.transpose(mask)
-    return tf.logical_and(
-        _is_sequence_right_padded(mask),
-        tf.logical_not(_has_fully_masked_sequence(mask)),
-    )
-
-
 def _is_sequence_right_padded(mask):
     """Check the mask tensor and see if it right padded.
 
@@ -655,7 +638,6 @@ def _is_gpu_available():
     return bool(tf.config.list_logical_devices("GPU"))
 
 
-@tf.function(autograph=False)
 def _cudnn_gru(
     inputs,
     initial_state,
@@ -802,7 +784,6 @@ def cudnn_ok(
     return args_supported and _is_gpu_available()
 
 
-@tf.function
 def lstm(
     inputs,
     initial_state_h,
@@ -818,11 +799,10 @@ def lstm(
     unroll=False,
     time_major=False,
 ):
-    inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
     cudnn_supported = cudnn_ok(
         activation, recurrent_activation, unroll, use_bias=bias is not None
     )
-    if not cudnn_supported or not inputs_supported:
+    if not cudnn_supported or mask is not None:
         raise NotImplementedError
 
     from keras.backend.tensorflow import Variable
@@ -855,7 +835,6 @@ def lstm(
         raise NotImplementedError
 
 
-@tf.function(autograph=False)
 def _cudnn_lstm(
     inputs,
     initial_state_h,