diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 49c2aefa09260e..c63caae382a963 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -346,6 +346,17 @@ def tooslow(test_case): return unittest.skip(reason="test is too slow")(test_case) +def skip_if_aqlm_inference_not_fixed(test_case): + """ + Decorator marking tests for inference using aqlm models. + + These tests will be skipped till the issue from aqlm side is resolved + """ + return unittest.skip( + reason="inference doesn't work with quantized aqlm models using torch.Any type with recent torch versions" + )(test_case) + + def skip_if_not_implemented(test_func): @functools.wraps(test_func) def wrapper(*args, **kwargs): diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index b79eae54c0c31e..b0e4e68e1cec28 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -26,6 +26,7 @@ require_aqlm, require_torch_gpu, require_torch_multi_gpu, + skip_if_aqlm_inference_not_fixed, slow, torch_device, ) @@ -142,6 +143,7 @@ def test_quantized_model_conversion(self): self.assertEqual(nb_linears - 1, nb_aqlm_linear) + @skip_if_aqlm_inference_not_fixed def test_quantized_model(self): """ Simple test that checks if the quantized model is working properly @@ -158,6 +160,7 @@ def test_raise_if_non_quantized(self): with self.assertRaises(ValueError): _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + @skip_if_aqlm_inference_not_fixed def test_save_pretrained(self): """ Simple test that checks if the quantized model is working properly after being saved and loaded @@ -171,6 +174,7 @@ def test_save_pretrained(self): output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + @skip_if_aqlm_inference_not_fixed @require_torch_multi_gpu def test_quantized_model_multi_gpu(self): """