Skip to content

Commit

Permalink
Fix LSTM performance on TF Backend (#18859)
Browse files Browse the repository at this point in the history
* Wrap only the check on inputs supported in tf.function

* Remove check on inputs supported for lstm / gru

* Remove check on inputs supported for lstm / gru

* Remove check on inputs supported for lstm / gru
  • Loading branch information
sampathweb authored Dec 1, 2023
1 parent cd86342 commit d09e73d
Showing 1 changed file with 2 additions and 23 deletions.
25 changes: 2 additions & 23 deletions keras/backend/tensorflow/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ def _step(time, output_ta_t, *states):
return last_output, outputs, new_states


@tf.function
def gru(
inputs,
initial_state,
Expand All @@ -457,15 +456,14 @@ def gru(
time_major=False,
reset_after=True,
):
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
cudnn_supported = cudnn_ok(
activation,
recurrent_activation,
unroll,
use_bias=bias is not None,
reset_after=reset_after,
)
if not cudnn_supported or not inputs_supported:
if not cudnn_supported or mask is not None:
raise NotImplementedError

from keras.backend.tensorflow import Variable
Expand Down Expand Up @@ -535,21 +533,6 @@ def _do_lstm_arguments_support_cudnn(
)


def _do_rnn_inputs_support_cudnn(mask, time_major):
if tf.sysconfig.get_build_info()["is_rocm_build"]:
if mask is not None:
return tf.reduce_all(mask)
return True
if mask is None:
return True
if time_major:
mask = tf.transpose(mask)
return tf.logical_and(
_is_sequence_right_padded(mask),
tf.logical_not(_has_fully_masked_sequence(mask)),
)


def _is_sequence_right_padded(mask):
"""Check the mask tensor and see if it right padded.
Expand Down Expand Up @@ -655,7 +638,6 @@ def _is_gpu_available():
return bool(tf.config.list_logical_devices("GPU"))


@tf.function(autograph=False)
def _cudnn_gru(
inputs,
initial_state,
Expand Down Expand Up @@ -802,7 +784,6 @@ def cudnn_ok(
return args_supported and _is_gpu_available()


@tf.function
def lstm(
inputs,
initial_state_h,
Expand All @@ -818,11 +799,10 @@ def lstm(
unroll=False,
time_major=False,
):
inputs_supported = _do_rnn_inputs_support_cudnn(mask, time_major)
cudnn_supported = cudnn_ok(
activation, recurrent_activation, unroll, use_bias=bias is not None
)
if not cudnn_supported or not inputs_supported:
if not cudnn_supported or mask is not None:
raise NotImplementedError

from keras.backend.tensorflow import Variable
Expand Down Expand Up @@ -855,7 +835,6 @@ def lstm(
raise NotImplementedError


@tf.function(autograph=False)
def _cudnn_lstm(
inputs,
initial_state_h,
Expand Down

0 comments on commit d09e73d

Please sign in to comment.