Skip to content

Commit

Permalink
swap out aria2c for urllib.request
Browse files Browse the repository at this point in the history
  • Loading branch information
nina-msft committed Nov 6, 2024
1 parent 15bc681 commit 2097639
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 103 deletions.
99 changes: 60 additions & 39 deletions doc/code/orchestrators/use_huggingface_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "d623d73a",
"id": "dbabb573",
"metadata": {
"lines_to_next_cell": 2
},
Expand All @@ -17,7 +17,7 @@
" - This notebook supports the following **instruct models** that follow a structured chat template. These are examples, and more instruct models are available on Hugging Face:\n",
" - `HuggingFaceTB/SmolLM-360M-Instruct`\n",
" - `microsoft/Phi-3-mini-4k-instruct`\n",
" \n",
"\n",
" - `...`\n",
"\n",
"2. **Excluded Models**:\n",
Expand All @@ -37,63 +37,86 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0d61a68",
"metadata": {},
"outputs": [],
"execution_count": 1,
"id": "76376ab9",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-06T01:24:35.784894Z",
"iopub.status.busy": "2024-11-06T01:24:35.784894Z",
"iopub.status.idle": "2024-11-06T01:27:31.096637Z",
"shell.execute_reply": "2024-11-06T01:27:31.096637Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running model: HuggingFaceTB/SmolLM-135M-Instruct\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n",
"\n",
"\u001b[22m\u001b[39mConversation ID: ad151fa1-30d4-4b51-976b-acfc6ddc9064\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n",
"\u001b[22m\u001b[39mConversation ID: cf250f01-af21-463e-bd6f-ddb11bdf0a97\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 4*4 is a special number because it can be expressed as a product of two numbers,\n",
"HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n"
]
}
],
"source": [
"import time\n",
"from pyrit.prompt_target import HuggingFaceChatTarget \n",
"from pyrit.prompt_target import HuggingFaceChatTarget\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"\n",
"# models to test\n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\" \n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\"\n",
"\n",
"# List of prompts to send\n",
"prompt_list = [\n",
" \"What is 3*3? Give me the solution.\",\n",
" \"What is 4*4? Give me the solution.\"\n",
" ]\n",
"prompt_list = [\"What is 3*3? Give me the solution.\", \"What is 4*4? Give me the solution.\"]\n",
"\n",
"# Dictionary to store average response times\n",
"model_times = {}\n",
" \n",
"\n",
"print(f\"Running model: {model_id}\")\n",
" \n",
"\n",
"try:\n",
" # Initialize HuggingFaceChatTarget with the current model\n",
" target = HuggingFaceChatTarget(\n",
" model_id=model_id, \n",
" use_cuda=False, \n",
" tensor_format=\"pt\",\n",
" max_new_tokens=30 \n",
" )\n",
" \n",
" target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n",
"\n",
" # Initialize the orchestrator\n",
" orchestrator = PromptSendingOrchestrator(\n",
" prompt_target=target,\n",
" verbose=False\n",
" )\n",
" \n",
" orchestrator = PromptSendingOrchestrator(prompt_target=target, verbose=False)\n",
"\n",
" # Record start time\n",
" start_time = time.time()\n",
" \n",
"\n",
" # Send prompts asynchronously\n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
" \n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
"\n",
" # Record end time\n",
" end_time = time.time()\n",
" \n",
"\n",
" # Calculate total and average response time\n",
" total_time = end_time - start_time\n",
" avg_time = total_time / len(prompt_list)\n",
" model_times[model_id] = avg_time\n",
" \n",
"\n",
" print(f\"Average response time for {model_id}: {avg_time:.2f} seconds\\n\")\n",
" \n",
"\n",
" # Print the conversations\n",
" await orchestrator.print_conversations() # type: ignore\n",
" \n",
" await orchestrator.print_conversations() # type: ignore\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred with model {model_id}: {e}\\n\")\n",
" model_times[model_id] = None\n",
Expand All @@ -108,14 +131,12 @@
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "pyrit-dev",
"language": "python",
"name": "python3"
"name": "pyrit-dev"
},
"language_info": {
"codemirror_mode": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import logging
import os
import urllib.request
from pathlib import Path
import subprocess
from typing import Optional

from huggingface_hub import HfApi

Expand All @@ -30,47 +29,13 @@ def get_available_files(model_id: str, token: str):
return []


def download_files_with_aria2(urls: list, token: str, download_dir: Optional[Path] = None):
"""Uses aria2 to download files from the given list of URLs."""

# Convert download_dir to string if it's a Path object
download_dir_str = str(download_dir) if isinstance(download_dir, Path) else download_dir

aria2_command = [
"aria2c",
"-d",
download_dir_str,
"-x",
"3", # Number of connections per server for each download.
"-s",
"5", # Number of splits for each file.
"-j",
"4", # Maximum number of parallel downloads.
"--continue=true",
"--enable-http-pipelining=true",
f"--header=Authorization: Bearer {token}",
"-i",
"-", # Use '-' to read input from stdin
]

try:
# Run aria2c with input from stdin
process = subprocess.Popen(aria2_command, stdin=subprocess.PIPE, text=True)
process.communicate("\n".join(urls)) # Pass URLs directly to stdin
if process.returncode == 0:
logger.info(f"\nFiles downloaded successfully to {download_dir}.")
else:
logger.info(f"Error downloading files with aria2, return code: {process.returncode}.")
raise subprocess.CalledProcessError(process.returncode, aria2_command)
except subprocess.CalledProcessError as e:
logger.info(f"Error downloading files with aria2: {e}")
raise


def download_specific_files_with_aria2(model_id: str, file_patterns: list, token: str, cache_dir: Optional[Path]):
def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path) -> list[str]:
"""
Downloads specific files from a Hugging Face model repository using aria2.
Downloads specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.
Returns:
List of URLs for the downloaded files.
"""
os.makedirs(cache_dir, exist_ok=True)

Expand All @@ -90,5 +55,18 @@ def download_specific_files_with_aria2(model_id: str, file_patterns: list, token
base_url = f"https://huggingface.co/{model_id}/resolve/main/"
urls = [base_url + file for file in files_to_download]

# Use aria2c to download the files
download_files_with_aria2(urls, token, cache_dir)
# Download the files
download_files(urls, token, cache_dir)

return urls


def download_files(urls: list, token: str, cache_dir: Path):
headers = {"Authorization": f"Bearer {token}"}
for url in urls:
local_filename = Path(cache_dir, url.split("/")[-1])
request = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(request) as response, open(local_filename, "wb") as out_file:
data = response.read()
out_file.write(data)
logger.info(f"Downloaded {local_filename}")
27 changes: 16 additions & 11 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig

from pyrit.prompt_target import PromptChatTarget
from pyrit.common.download_hf_model_with_aria2 import download_specific_files_with_aria2
from pyrit.common.download_hf_model import download_specific_files
from pyrit.memory import MemoryInterface
from pyrit.models.prompt_request_response import PromptRequestResponse, construct_response_from_request
from pyrit.exceptions import EmptyResponseException, pyrit_target_retry
Expand Down Expand Up @@ -116,14 +116,12 @@ def load_model_and_tokenizer(self):

if self.necessary_files is None:
# Download all files if no specific files are provided
logger.info(f"Downloading all files for {self.model_id} using aria2...")
download_specific_files_with_aria2(self.model_id, None, self.huggingface_token, cache_dir)
logger.info(f"Downloading all files for {self.model_id}...")
download_specific_files(self.model_id, None, self.huggingface_token, cache_dir)
else:
# Download only the necessary files
logger.info(f"Downloading specific files for {self.model_id} using aria2...")
download_specific_files_with_aria2(
self.model_id, self.necessary_files, self.huggingface_token, cache_dir
)
logger.info(f"Downloading specific files for {self.model_id}...")
download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir)

# Load the tokenizer and model from the specified directory
logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...")
Expand Down Expand Up @@ -165,20 +163,23 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P

# Apply chat template via the _apply_chat_template method
tokenized_chat = self._apply_chat_template(messages)
input_ids = tokenized_chat["input_ids"]
attention_mask = tokenized_chat["attention_mask"]

logger.info(f"Tokenized chat: {tokenized_chat}")
logger.info(f"Tokenized chat: {input_ids}")

try:
# Ensure model is on the correct device (should already be the case from `load_model_and_tokenizer`)
self.model.to(self.device)

# Record the length of the input tokens to later extract only the generated tokens
input_length = tokenized_chat.shape[-1]
input_length = input_ids.shape[-1]

# Generate the response
logger.info("Generating response from model...")
generated_ids = self.model.generate(
input_ids=tokenized_chat,
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
top_p=self.top_p,
Expand Down Expand Up @@ -219,7 +220,11 @@ def _apply_chat_template(self, messages):

# Apply the chat template to format and tokenize the messages
tokenized_chat = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors=self.tensor_format
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors=self.tensor_format,
return_dict=True,
).to(self.device)
return tokenized_chat
else:
Expand Down
14 changes: 9 additions & 5 deletions tests/test_hf_model_downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Licensed under the MIT license.

import os
from pathlib import Path
import pytest
from unittest.mock import patch

# Import functions to test from local application files
from pyrit.common.download_hf_model_with_aria2 import download_specific_files_with_aria2
from pyrit.common.download_hf_model import download_specific_files


# Define constants for testing
Expand All @@ -31,8 +32,11 @@ def setup_environment():
yield token


def test_download_specific_files_with_aria2(setup_environment):
"""Test downloading specific files using aria2."""
def test_download_specific_files(setup_environment):
"""Test downloading specific files"""
token = setup_environment # Get the token from the fixture
with pytest.raises(Exception):
download_specific_files_with_aria2(MODEL_ID, FILE_PATTERNS, token)

with patch("os.makedirs"):
with patch("pyrit.common.download_hf_model.download_files"):
urls = download_specific_files(MODEL_ID, FILE_PATTERNS, token, Path(""))
assert urls
8 changes: 3 additions & 5 deletions tests/test_huggingface_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def mock_get_required_value(request):
yield


# Fixture to mock download_specific_files_with_aria2 globally for all tests
# Fixture to mock download_specific_files globally for all tests
@pytest.fixture(autouse=True)
def mock_download_specific_files_with_aria2():
with patch(
"pyrit.common.download_hf_model_with_aria2.download_specific_files_with_aria2", return_value=None
) as mock_download:
def mock_download_specific_files():
with patch("pyrit.common.download_hf_model.download_specific_files", return_value=None) as mock_download:
yield mock_download


Expand Down

0 comments on commit 2097639

Please sign in to comment.