From 2b073b608daebbf8184fa2ee009b7cae9fa72ed0 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Sun, 29 Dec 2024 05:08:06 +0800 Subject: [PATCH] Fix torch gpu CI (#20696) --- .kokoro/github/ubuntu/gpu/build.sh | 1 - keras/src/export/export_lib_test.py | 6 ++++++ keras/src/models/model_test.py | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index ae1b3b48326..a70f28a062a 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -74,7 +74,6 @@ then # TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH pytest keras --ignore keras/src/applications \ - --ignore keras/src/export/export_lib_test.py \ --cov=keras \ --cov-config=pyproject.toml diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/export_lib_test.py index c0fa09891da..9ee2d6fc512 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/export_lib_test.py @@ -57,6 +57,9 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None): ), ) @pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" +) class ExportSavedModelTest(testing.TestCase): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) @@ -344,6 +347,9 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): reason="Export only currently supports the TF and JAX backends.", ) @pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" +) class ExportArchiveTest(testing.TestCase): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index d1277948160..212fbad5887 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1229,6 +1229,9 @@ def test_functional_deeply_nested_outputs_struct_losses(self): @pytest.mark.skipif( testing.jax_uses_gpu(), reason="Leads to core dumps on CI" ) + @pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" + ) def test_export(self): import tensorflow as tf