You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the problem
I have a model comprising almost entirely of LSTM layers. If I load the same weights into a copy of the model instanced to run on CPU and GPU, results are different.
This issue disappears (the GPU results change to match CPU) if I change any of these:
Move from
SLES + NVIDIA A100 + Driver Version: 550.54.14 + CUDA Version: 12.4
to
In all these cases, I'm running the same (official) docker image, in which my only modification has been to install tf-keras==2.16.0 and plotly.
Standalone code to reproduce the issue.
!pip install plotly
!pip install tf-keras==2.16.0
import os
import tensorflow as tf
import numpy as np
USE_TF_KERAS = True
if USE_TF_KERAS:
import tf_keras as keras
from tf_keras import layers
from tf_keras import initializers
from tf_keras import backend as K
else:
import keras
from keras import layers
from keras import initializers
from keras import backend as K
# Setting float64 as default dtype removes the discrepancy between CPU and GPU!
# keras.backend.set_floatx('float64')
from plotly import graph_objects as go
ROOT_DIR = os.getcwd()
n_time_steps = 800
theta = np.linspace(0, 2 * np.pi, n_time_steps).reshape(1, -1)
np.random.seed(42)
tf.random.set_seed(42)
dummy_input_dict = {
"input_a": 800
* np.stack((np.cos(theta), np.sin(theta)), axis=-1).astype(np.float32),
"input_b": np.random.rand(1, n_time_steps, 5).astype(np.float32),
}
def build_model():
input_a = layers.Input(shape=(n_time_steps, 2), name="input_a")
input_b = layers.Input(shape=(n_time_steps, 5), name="input_b")
x = layers.Concatenate()([input_a, input_b])
for idx in range(8):
lstm_layer = layers.LSTM(
1024,
kernel_initializer=initializers.RandomNormal(seed=42 + idx),
recurrent_initializer=initializers.RandomNormal(seed=52 + idx),
return_sequences=True,
)
x = lstm_layer(x)
y = layers.Dense(1)(x)
model = keras.Model(inputs=[input_a, input_b], outputs=y)
return model
def main(device):
with tf.device(device):
model = build_model()
model.load_weights("my_initial_weights.h5")
features = ["input_a", "input_b"]
dummy_input = [dummy_input_dict[k] for k in features]
preds = model.predict(dummy_input)
return preds
# Save one set of weights, so that we can compare the weights of the two models
with tf.device("/device:CPU:0"):
model = build_model()
model.save_weights("my_initial_weights.h5")
tf.config.list_logical_devices()
cpu_preds = main("/device:CPU:0")
gpu_preds = main("/device:GPU:0")
cpu_output = cpu_preds[0, :, 0]
gpu_output = gpu_preds[0, :, 0]
fig = go.Figure()
fig.add_trace(go.Scatter(y=cpu_output, name="CPU"))
fig.add_trace(go.Scatter(y=gpu_output, name="GPU"))
fig.show()
@tilakrayal - the gist shows a very small difference between CPU/GPU predictions, similar to what I see on my V100 host. I wouldn't be surprised if differences that small were in fact expected.
But on my A100 host the difference becomes orders of magnitude larger. Is there a way to replicate my "problematic system" (NVIDIA A100 + Driver Version: 550.54.14 + CUDA Version: 12.4) on Colab, so that hopefully you can also see the entity of the problem, beyond the screenshots I can share?
I've updated the V100 system. It now has the exact same driver + CUDA as the A100 system (Driver Version: 550.54.14 + CUDA Version: 12.4), and still does not replicate the issue. So the issue seems specific to execution on the A100. How can we replicate on Colab? Thanks.
Indeed, if I modify my example script and set tf.config.experimental.enable_tensor_float_32_execution(False), the numerical issues disappear, and the A100 system produces the same output as the V100 and CPUs.
I find it quite concerning that Tensorflow would take such liberties with data types.
In any case, the main question mark I have at this point is why I don't see the same numerical issues with multi-backend keras. Is it actually using float32, rather than the new TF32? Which keras implementation is doing the right thing?
System information
Describe the problem
I have a model comprising almost entirely of LSTM layers. If I load the same weights into a copy of the model instanced to run on CPU and GPU, results are different.
This issue disappears (the GPU results change to match CPU) if I change any of these:
to
In all these cases, I'm running the same (official) docker image, in which my only modification has been to install tf-keras==2.16.0 and plotly.
Standalone code to reproduce the issue.
Resulting plot:
As mentioned at the beginning:
# keras.backend.set_floatx('float64')
USE_TF_KERAS = False
All workaround the issue, and the GPU prediction matches the CPU prediction.
I also re-iterate that all of this has been run in the official
tensorflow/tensorflow:2.16.1-gpu-jupyter
container, on both hosts.The text was updated successfully, but these errors were encountered: