Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Remove aria2c dependency from HuggingFace Target #530

Merged
merged 7 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 60 additions & 39 deletions doc/code/orchestrators/use_huggingface_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "d623d73a",
"id": "dbabb573",
"metadata": {
"lines_to_next_cell": 2
},
Expand All @@ -17,7 +17,7 @@
" - This notebook supports the following **instruct models** that follow a structured chat template. These are examples, and more instruct models are available on Hugging Face:\n",
" - `HuggingFaceTB/SmolLM-360M-Instruct`\n",
" - `microsoft/Phi-3-mini-4k-instruct`\n",
" \n",
"\n",
" - `...`\n",
"\n",
"2. **Excluded Models**:\n",
Expand All @@ -37,63 +37,86 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0d61a68",
"metadata": {},
"outputs": [],
"execution_count": 1,
"id": "76376ab9",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-06T01:24:35.784894Z",
"iopub.status.busy": "2024-11-06T01:24:35.784894Z",
"iopub.status.idle": "2024-11-06T01:27:31.096637Z",
"shell.execute_reply": "2024-11-06T01:27:31.096637Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running model: HuggingFaceTB/SmolLM-135M-Instruct\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n",
"\n",
"\u001b[22m\u001b[39mConversation ID: ad151fa1-30d4-4b51-976b-acfc6ddc9064\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n",
"\u001b[22m\u001b[39mConversation ID: cf250f01-af21-463e-bd6f-ddb11bdf0a97\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 4*4 is a special number because it can be expressed as a product of two numbers,\n",
"HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n"
]
}
],
"source": [
"import time\n",
"from pyrit.prompt_target import HuggingFaceChatTarget \n",
"from pyrit.prompt_target import HuggingFaceChatTarget\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"\n",
"# models to test\n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\" \n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\"\n",
"\n",
"# List of prompts to send\n",
"prompt_list = [\n",
" \"What is 3*3? Give me the solution.\",\n",
" \"What is 4*4? Give me the solution.\"\n",
" ]\n",
"prompt_list = [\"What is 3*3? Give me the solution.\", \"What is 4*4? Give me the solution.\"]\n",
"\n",
"# Dictionary to store average response times\n",
"model_times = {}\n",
" \n",
"\n",
"print(f\"Running model: {model_id}\")\n",
" \n",
"\n",
"try:\n",
" # Initialize HuggingFaceChatTarget with the current model\n",
" target = HuggingFaceChatTarget(\n",
" model_id=model_id, \n",
" use_cuda=False, \n",
" tensor_format=\"pt\",\n",
" max_new_tokens=30 \n",
" )\n",
" \n",
" target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n",
"\n",
" # Initialize the orchestrator\n",
" orchestrator = PromptSendingOrchestrator(\n",
" prompt_target=target,\n",
" verbose=False\n",
" )\n",
" \n",
" orchestrator = PromptSendingOrchestrator(prompt_target=target, verbose=False)\n",
"\n",
" # Record start time\n",
" start_time = time.time()\n",
" \n",
"\n",
" # Send prompts asynchronously\n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
" \n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
"\n",
" # Record end time\n",
" end_time = time.time()\n",
" \n",
"\n",
" # Calculate total and average response time\n",
" total_time = end_time - start_time\n",
" avg_time = total_time / len(prompt_list)\n",
" model_times[model_id] = avg_time\n",
" \n",
"\n",
" print(f\"Average response time for {model_id}: {avg_time:.2f} seconds\\n\")\n",
" \n",
"\n",
" # Print the conversations\n",
" await orchestrator.print_conversations() # type: ignore\n",
" \n",
" await orchestrator.print_conversations() # type: ignore\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred with model {model_id}: {e}\\n\")\n",
" model_times[model_id] = None\n",
Expand All @@ -108,14 +131,12 @@
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "pyrit-dev",
"language": "python",
"name": "python3"
"name": "pyrit-dev"
},
"language_info": {
"codemirror_mode": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import logging
import os
import urllib.request
from pathlib import Path
import subprocess
from typing import Optional

from huggingface_hub import HfApi

Expand All @@ -30,47 +29,13 @@ def get_available_files(model_id: str, token: str):
return []


def download_files_with_aria2(urls: list, token: str, download_dir: Optional[Path] = None):
"""Uses aria2 to download files from the given list of URLs."""

# Convert download_dir to string if it's a Path object
download_dir_str = str(download_dir) if isinstance(download_dir, Path) else download_dir

aria2_command = [
"aria2c",
"-d",
download_dir_str,
"-x",
"3", # Number of connections per server for each download.
"-s",
"5", # Number of splits for each file.
"-j",
"4", # Maximum number of parallel downloads.
"--continue=true",
"--enable-http-pipelining=true",
f"--header=Authorization: Bearer {token}",
"-i",
"-", # Use '-' to read input from stdin
]

try:
# Run aria2c with input from stdin
process = subprocess.Popen(aria2_command, stdin=subprocess.PIPE, text=True)
process.communicate("\n".join(urls)) # Pass URLs directly to stdin
if process.returncode == 0:
logger.info(f"\nFiles downloaded successfully to {download_dir}.")
else:
logger.info(f"Error downloading files with aria2, return code: {process.returncode}.")
raise subprocess.CalledProcessError(process.returncode, aria2_command)
except subprocess.CalledProcessError as e:
logger.info(f"Error downloading files with aria2: {e}")
raise


def download_specific_files_with_aria2(model_id: str, file_patterns: list, token: str, cache_dir: Optional[Path]):
def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path) -> list[str]:
"""
Downloads specific files from a Hugging Face model repository using aria2.
Downloads specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.

Returns:
List of URLs for the downloaded files.
"""
os.makedirs(cache_dir, exist_ok=True)

Expand All @@ -90,5 +55,16 @@ def download_specific_files_with_aria2(model_id: str, file_patterns: list, token
base_url = f"https://huggingface.co/{model_id}/resolve/main/"
urls = [base_url + file for file in files_to_download]

# Use aria2c to download the files
download_files_with_aria2(urls, token, cache_dir)
# Download the files
download_files(urls, token, cache_dir)


def download_files(urls: list, token: str, cache_dir: Path):
nina-msft marked this conversation as resolved.
Show resolved Hide resolved
headers = {"Authorization": f"Bearer {token}"}
for url in urls:
local_filename = Path(cache_dir, url.split("/")[-1])
request = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(request) as response, open(local_filename, "wb") as out_file:
data = response.read()
out_file.write(data)
logger.info(f"Downloaded {local_filename}")
27 changes: 16 additions & 11 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig

from pyrit.prompt_target import PromptChatTarget
from pyrit.common.download_hf_model_with_aria2 import download_specific_files_with_aria2
from pyrit.common.download_hf_model import download_specific_files
from pyrit.memory import MemoryInterface
from pyrit.models.prompt_request_response import PromptRequestResponse, construct_response_from_request
from pyrit.exceptions import EmptyResponseException, pyrit_target_retry
Expand Down Expand Up @@ -116,14 +116,12 @@ def load_model_and_tokenizer(self):

if self.necessary_files is None:
# Download all files if no specific files are provided
logger.info(f"Downloading all files for {self.model_id} using aria2...")
download_specific_files_with_aria2(self.model_id, None, self.huggingface_token, cache_dir)
logger.info(f"Downloading all files for {self.model_id}...")
download_specific_files(self.model_id, None, self.huggingface_token, cache_dir)
else:
# Download only the necessary files
logger.info(f"Downloading specific files for {self.model_id} using aria2...")
download_specific_files_with_aria2(
self.model_id, self.necessary_files, self.huggingface_token, cache_dir
)
logger.info(f"Downloading specific files for {self.model_id}...")
download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir)

# Load the tokenizer and model from the specified directory
logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...")
Expand Down Expand Up @@ -165,20 +163,23 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P

# Apply chat template via the _apply_chat_template method
tokenized_chat = self._apply_chat_template(messages)
input_ids = tokenized_chat["input_ids"]
attention_mask = tokenized_chat["attention_mask"]

logger.info(f"Tokenized chat: {tokenized_chat}")
logger.info(f"Tokenized chat: {input_ids}")

try:
# Ensure model is on the correct device (should already be the case from `load_model_and_tokenizer`)
self.model.to(self.device)

# Record the length of the input tokens to later extract only the generated tokens
input_length = tokenized_chat.shape[-1]
input_length = input_ids.shape[-1]

# Generate the response
logger.info("Generating response from model...")
generated_ids = self.model.generate(
input_ids=tokenized_chat,
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
top_p=self.top_p,
Expand Down Expand Up @@ -219,7 +220,11 @@ def _apply_chat_template(self, messages):

# Apply the chat template to format and tokenize the messages
tokenized_chat = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors=self.tensor_format
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors=self.tensor_format,
return_dict=True,
).to(self.device)
return tokenized_chat
else:
Expand Down
13 changes: 8 additions & 5 deletions tests/test_hf_model_downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Licensed under the MIT license.

import os
from pathlib import Path
import pytest
from unittest.mock import patch

# Import functions to test from local application files
from pyrit.common.download_hf_model_with_aria2 import download_specific_files_with_aria2
from pyrit.common.download_hf_model import download_specific_files


# Define constants for testing
Expand All @@ -31,8 +32,10 @@ def setup_environment():
yield token


def test_download_specific_files_with_aria2(setup_environment):
"""Test downloading specific files using aria2."""
def test_download_specific_files(setup_environment):
"""Test downloading specific files"""
token = setup_environment # Get the token from the fixture
with pytest.raises(Exception):
download_specific_files_with_aria2(MODEL_ID, FILE_PATTERNS, token)

with patch("os.makedirs"):
with patch("pyrit.common.download_hf_model.download_files"):
download_specific_files(MODEL_ID, FILE_PATTERNS, token, Path(""))
8 changes: 3 additions & 5 deletions tests/test_huggingface_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def mock_get_required_value(request):
yield


# Fixture to mock download_specific_files_with_aria2 globally for all tests
# Fixture to mock download_specific_files globally for all tests
@pytest.fixture(autouse=True)
def mock_download_specific_files_with_aria2():
with patch(
"pyrit.common.download_hf_model_with_aria2.download_specific_files_with_aria2", return_value=None
) as mock_download:
def mock_download_specific_files():
with patch("pyrit.common.download_hf_model.download_specific_files", return_value=None) as mock_download:
yield mock_download


Expand Down
Loading