Skip to content

Commit

Permalink
fix up unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nina-msft committed Nov 9, 2024
1 parent 403683d commit 0788a73
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions tests/test_huggingface_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from asyncio import Task
import pytest
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, AsyncMock

from pyrit.prompt_target import HuggingFaceChatTarget
from pyrit.models.prompt_request_response import PromptRequestResponse, PromptRequestPiece
Expand Down Expand Up @@ -71,6 +72,17 @@ def mock_pretrained_config():
yield


class AwaitableMock(AsyncMock):
def __await__(self):
return iter([])


@pytest.fixture(autouse=True)
def mock_create_task():
with patch("asyncio.create_task", return_value=AwaitableMock(spec=Task)):
yield


def test_init_with_no_token_var_raises(monkeypatch):
# Ensure the environment variable is unset
monkeypatch.delenv("HUGGINGFACE_TOKEN", raising=False)
Expand All @@ -81,13 +93,15 @@ def test_init_with_no_token_var_raises(monkeypatch):
assert "Environment variable HUGGINGFACE_TOKEN is required" in str(excinfo.value)


# TODO: Run through tests, currently hitting RuntimeError: no running event loop
def test_initialization():
@pytest.mark.asyncio
async def test_initialization():
# Test the initialization without loading the actual models
hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
assert hf_chat.model_id == "test_model"
assert not hf_chat.use_cuda
assert hf_chat.device == "cpu"

await hf_chat.load_model_and_tokenizer()
assert hf_chat.model is not None
assert hf_chat.tokenizer is not None

Expand All @@ -105,15 +119,18 @@ def test_is_model_id_valid_false():
assert not hf_chat.is_model_id_valid()


def test_load_model_and_tokenizer():
@pytest.mark.asyncio
async def test_load_model_and_tokenizer():
hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
await hf_chat.load_model_and_tokenizer()
assert hf_chat.model is not None
assert hf_chat.tokenizer is not None


@pytest.mark.asyncio
async def test_send_prompt_async():
hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
await hf_chat.load_model_and_tokenizer()

request_piece = PromptRequestPiece(
role="user",
Expand All @@ -133,6 +150,7 @@ async def test_send_prompt_async():
@pytest.mark.asyncio
async def test_missing_chat_template_error():
hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
await hf_chat.load_model_and_tokenizer()
hf_chat.tokenizer.chat_template = None

request_piece = PromptRequestPiece(
Expand Down Expand Up @@ -168,8 +186,11 @@ def test_invalid_prompt_request_validation():
assert "This target only supports a single prompt request piece." in str(excinfo.value)


def test_load_with_missing_files():
@pytest.mark.asyncio
async def test_load_with_missing_files():
hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False, necessary_files=["file1", "file2"])
await hf_chat.load_model_and_tokenizer()

assert hf_chat.model is not None
assert hf_chat.tokenizer is not None

Expand Down

0 comments on commit 0788a73

Please sign in to comment.