diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 6d1965e29d79..6d671861ef5f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -170,6 +170,7 @@ from _pytest.outcomes import skip from _pytest.pathlib import import_path from pytest import DoctestItem + import pytest else: Module = object DoctestItem = object @@ -187,6 +188,8 @@ # Not critical, only usable on the sandboxed CI instance. TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + + if is_torch_available(): import torch @@ -196,7 +199,6 @@ IS_ROCM_SYSTEM = False IS_CUDA_SYSTEM = False - def parse_flag_from_env(key, default=False): try: value = os.environ[key] @@ -235,6 +237,22 @@ def parse_int_from_env(key, default=None): _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) _run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False) _run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) +_test_with_rocm = parse_flag_from_env("TEST_WITH_ROCM", default=False) + +def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + reason = f"skipIfRocm: {msg}" + + @wraps(fn) + def wrapper(*args, **kwargs): + if _test_with_rocm: + pytest.skip(reason) + else: + return fn(*args, **kwargs) + return wrapper + if func: + return dec_fn(func) + return dec_fn def get_device_count():