From 188bd3adaa27a35cf05608e4383037d0ad2cb7e2 Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Mon, 23 Sep 2024 10:13:59 +0300 Subject: [PATCH] Added both hpu and gpu specific changes confest --- tests/lora/conftest.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index d3ebd15510284..099158798aa56 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -84,12 +84,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] + if is_hpu(): + backend_type = "hccl" + else: + backend_type = "nccl" init_distributed_environment( world_size=1, rank=0, distributed_init_method=f"file://{temp_file}", local_rank=0, - backend="nccl", + backend=backend_type, ) initialize_model_parallel(1, 1) yield @@ -259,8 +263,13 @@ def get_model_patched(*, model_config, device_config, **kwargs): device_config=device_config, **kwargs) - with patch("vllm.worker.model_runner.get_model", get_model_patched): - engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + if is_hpu(): + with patch("vllm.worker.habana_model_runner.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + else: + with patch("vllm.worker.model_runner.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + yield engine.llm_engine del engine cleanup()