From 7d4b7bb4e2d95bba58e31a07458d5bd85cc13bf5 Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Mon, 13 Jan 2025 14:50:24 -0800 Subject: [PATCH] Fix flaky `JaxLayer` test. (#20756) The `DTypePolicy` test produces lower precision results. --- keras/src/utils/jax_layer_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index acd28b2ea70..96c74809d13 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -324,11 +324,13 @@ def verify_identical_model(model): model2.export(path, format="tf_saved_model") model4 = tf.saved_model.load(path) output4 = model4.serve(x_test) + # The output difference is greater when using the GPU or bfloat16 + lower_precision = testing.jax_uses_gpu() or "dtype" in layer_init_kwargs self.assertAllClose( output1, output4, - # The output difference might be significant when using the GPU - atol=1e-2 if testing.jax_uses_gpu() else 1e-6, + atol=1e-2 if lower_precision else 1e-6, + rtol=1e-3 if lower_precision else 1e-6, ) # test subclass model building without a build method