diff --git a/distributed_embeddings/python/layers/dist_model_parallel_test.py b/distributed_embeddings/python/layers/dist_model_parallel_test.py index 249f3c9..7c68983 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel_test.py +++ b/distributed_embeddings/python/layers/dist_model_parallel_test.py @@ -189,7 +189,7 @@ def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs): for ref_w, test_w in zip(ref_weights, test_weights): # assert close here since order of accumulations(inputs and batch dim) might have changed - self.assertAllClose(tf.convert_to_tensor(ref_w), tf.convert_to_tensor(test_w)) + self.assertAllClose(tf.convert_to_tensor(ref_w), tf.convert_to_tensor(test_w), 1e-05, 1e-05) def test_broadcast(self): tf.keras.utils.set_random_seed(int(time.time()) + self.hvd_rank) diff --git a/tests/dist_model_parallel_test.py b/tests/dist_model_parallel_test.py index 249f3c9..7c68983 100644 --- a/tests/dist_model_parallel_test.py +++ b/tests/dist_model_parallel_test.py @@ -189,7 +189,7 @@ def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs): for ref_w, test_w in zip(ref_weights, test_weights): # assert close here since order of accumulations(inputs and batch dim) might have changed - self.assertAllClose(tf.convert_to_tensor(ref_w), tf.convert_to_tensor(test_w)) + self.assertAllClose(tf.convert_to_tensor(ref_w), tf.convert_to_tensor(test_w), 1e-05, 1e-05) def test_broadcast(self): tf.keras.utils.set_random_seed(int(time.time()) + self.hvd_rank)