From 0788a73aab2312a8086df3c13b391f7eacdbc043 Mon Sep 17 00:00:00 2001 From: Nina Chikanov Date: Fri, 8 Nov 2024 17:50:55 -0800 Subject: [PATCH] fix up unit tests --- tests/test_huggingface_chat_target.py | 31 ++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/tests/test_huggingface_chat_target.py b/tests/test_huggingface_chat_target.py index 71822385b..ece9de2a5 100644 --- a/tests/test_huggingface_chat_target.py +++ b/tests/test_huggingface_chat_target.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from asyncio import Task import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock from pyrit.prompt_target import HuggingFaceChatTarget from pyrit.models.prompt_request_response import PromptRequestResponse, PromptRequestPiece @@ -71,6 +72,17 @@ def mock_pretrained_config(): yield +class AwaitableMock(AsyncMock): + def __await__(self): + return iter([]) + + +@pytest.fixture(autouse=True) +def mock_create_task(): + with patch("asyncio.create_task", return_value=AwaitableMock(spec=Task)): + yield + + def test_init_with_no_token_var_raises(monkeypatch): # Ensure the environment variable is unset monkeypatch.delenv("HUGGINGFACE_TOKEN", raising=False) @@ -81,13 +93,15 @@ def test_init_with_no_token_var_raises(monkeypatch): assert "Environment variable HUGGINGFACE_TOKEN is required" in str(excinfo.value) -# TODO: Run through tests, currently hitting RuntimeError: no running event loop -def test_initialization(): +@pytest.mark.asyncio +async def test_initialization(): # Test the initialization without loading the actual models hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) assert hf_chat.model_id == "test_model" assert not hf_chat.use_cuda assert hf_chat.device == "cpu" + + await hf_chat.load_model_and_tokenizer() assert hf_chat.model is not None assert hf_chat.tokenizer is not None @@ -105,8 +119,10 @@ def test_is_model_id_valid_false(): assert not hf_chat.is_model_id_valid() -def test_load_model_and_tokenizer(): +@pytest.mark.asyncio +async def test_load_model_and_tokenizer(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + await hf_chat.load_model_and_tokenizer() assert hf_chat.model is not None assert hf_chat.tokenizer is not None @@ -114,6 +130,7 @@ def test_load_model_and_tokenizer(): @pytest.mark.asyncio async def test_send_prompt_async(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + await hf_chat.load_model_and_tokenizer() request_piece = PromptRequestPiece( role="user", @@ -133,6 +150,7 @@ async def test_send_prompt_async(): @pytest.mark.asyncio async def test_missing_chat_template_error(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + await hf_chat.load_model_and_tokenizer() hf_chat.tokenizer.chat_template = None request_piece = PromptRequestPiece( @@ -168,8 +186,11 @@ def test_invalid_prompt_request_validation(): assert "This target only supports a single prompt request piece." in str(excinfo.value) -def test_load_with_missing_files(): +@pytest.mark.asyncio +async def test_load_with_missing_files(): hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False, necessary_files=["file1", "file2"]) + await hf_chat.load_model_and_tokenizer() + assert hf_chat.model is not None assert hf_chat.tokenizer is not None