Skip to content

Commit

Permalink
async httpx implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nina-msft committed Nov 8, 2024
1 parent 43bf1c3 commit 403683d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 31 deletions.
32 changes: 18 additions & 14 deletions doc/code/orchestrators/use_huggingface_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "dbabb573",
"id": "666a9eef",
"metadata": {
"lines_to_next_cell": 2
},
Expand Down Expand Up @@ -38,13 +38,13 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "76376ab9",
"id": "7b080eca",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-06T01:24:35.784894Z",
"iopub.status.busy": "2024-11-06T01:24:35.784894Z",
"iopub.status.idle": "2024-11-06T01:27:31.096637Z",
"shell.execute_reply": "2024-11-06T01:27:31.096637Z"
"iopub.execute_input": "2024-11-08T02:09:00.673581Z",
"iopub.status.busy": "2024-11-08T02:09:00.672562Z",
"iopub.status.idle": "2024-11-08T02:10:32.738032Z",
"shell.execute_reply": "2024-11-08T02:10:32.738032Z"
}
},
"outputs": [
Expand All @@ -59,19 +59,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n",
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 33.10 seconds\n",
"\n",
"\u001b[22m\u001b[39mConversation ID: ad151fa1-30d4-4b51-976b-acfc6ddc9064\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n",
"\u001b[22m\u001b[39mConversation ID: cf250f01-af21-463e-bd6f-ddb11bdf0a97\n",
"\u001b[22m\u001b[39mConversation ID: 47d44512-c619-45e5-8140-69ce3e9fc28b\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 4*4 is a special number because it can be expressed as a product of two numbers,\n",
"HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n"
"\u001b[22m\u001b[39mConversation ID: e3e87055-e761-4672-98c2-9517be5d6027\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n",
"HuggingFaceTB/SmolLM-135M-Instruct: 33.10 seconds\n"
]
}
],
Expand All @@ -92,6 +92,10 @@
"print(f\"Running model: {model_id}\")\n",
"\n",
"try:\n",
" # Enable cache for the target to save time on subsequent calls\n",
" # Target will use cached model and tokenizer if the model_id is the same\n",
" HuggingFaceChatTarget.enable_cache()\n",
"\n",
" # Initialize HuggingFaceChatTarget with the current model\n",
" target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n",
"\n",
Expand Down
4 changes: 4 additions & 0 deletions doc/code/orchestrators/use_huggingface_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
print(f"Running model: {model_id}")

try:
# Enable cache for the target to save time on subsequent calls
# Target will use cached model and tokenizer if the model_id is the same
HuggingFaceChatTarget.enable_cache()

# Initialize HuggingFaceChatTarget with the current model
target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format="pt", max_new_tokens=30)

Expand Down
64 changes: 53 additions & 11 deletions pyrit/common/download_hf_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import logging
import os
import urllib.request
import httpx
from pathlib import Path

from huggingface_hub import HfApi
Expand All @@ -29,7 +30,7 @@ def get_available_files(model_id: str, token: str):
return []


def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path) -> list[str]:
async def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path):
"""
Downloads specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.
Expand All @@ -56,15 +57,56 @@ def download_specific_files(model_id: str, file_patterns: list, token: str, cach
urls = [base_url + file for file in files_to_download]

# Download the files
download_files(urls, token, cache_dir)
await download_files(urls, token, cache_dir)


def download_files(urls: list, token: str, cache_dir: Path):
async def download_chunk(url, headers, start, end, client):
"""Download a chunk of the file with a specified byte range."""
range_header = {"Range": f"bytes={start}-{end}", **headers}
response = await client.get(url, headers=range_header)
response.raise_for_status()
return response.content


async def download_file(url, token, download_dir, num_splits):
"""Download a file in multiple segments (splits) using byte-range requests."""
headers = {"Authorization": f"Bearer {token}"}
for url in urls:
local_filename = Path(cache_dir, url.split("/")[-1])
request = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(request) as response, open(local_filename, "wb") as out_file:
data = response.read()
out_file.write(data)
logger.info(f"Downloaded {local_filename}")
async with httpx.AsyncClient(follow_redirects=True) as client:
# Get the file size to determine chunk size
response = await client.head(url, headers=headers)
response.raise_for_status()
file_size = int(response.headers["Content-Length"])
chunk_size = file_size // num_splits

# Prepare tasks for each chunk
tasks = []
file_name = url.split("/")[-1]
file_path = Path(download_dir, file_name)

for i in range(num_splits):
start = i * chunk_size
end = start + chunk_size - 1 if i < num_splits - 1 else file_size - 1
tasks.append(download_chunk(url, headers, start, end, client))

# Download all chunks concurrently
chunks = await asyncio.gather(*tasks)

# Write chunks to the file in order
with open(file_path, "wb") as f:
for chunk in chunks:
f.write(chunk)
logger.info(f"Downloaded {file_name} to {file_path}")


async def download_files(urls: list[str], token: str, download_dir: Path, num_splits=3, parallel_downloads=4):
"""Download multiple files with parallel downloads and segmented downloading."""

# Limit the number of parallel downloads
semaphore = asyncio.Semaphore(parallel_downloads)

async def download_with_limit(url):
async with semaphore:
await download_file(url, token, download_dir, num_splits)

# Run downloads concurrently, but limit to parallel_downloads at a time
await asyncio.gather(*(download_with_limit(url) for url in urls))
15 changes: 9 additions & 6 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import json
import logging
import os
Expand All @@ -19,7 +20,6 @@

logger = logging.getLogger(__name__)


class HuggingFaceChatTarget(PromptChatTarget):
"""The HuggingFaceChatTarget interacts with HuggingFace models, specifically for conducting red teaming activities.
Inherits from PromptTarget to comply with the current design standards.
Expand Down Expand Up @@ -75,9 +75,9 @@ def __init__(

if self.use_cuda and not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available.")

self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer())

# Load the model and tokenizer using the encapsulated method
self.load_model_and_tokenizer()

def is_model_id_valid(self) -> bool:
"""
Expand All @@ -92,7 +92,7 @@ def is_model_id_valid(self) -> bool:
logger.error(f"Invalid HuggingFace model ID {self.model_id}: {e}")
return False

def load_model_and_tokenizer(self):
async def load_model_and_tokenizer(self):
"""Loads the model and tokenizer, downloading if necessary.
Downloads the model to the HF_MODELS_DIR folder if it does not exist,
Expand All @@ -117,11 +117,11 @@ def load_model_and_tokenizer(self):
if self.necessary_files is None:
# Download all files if no specific files are provided
logger.info(f"Downloading all files for {self.model_id}...")
download_specific_files(self.model_id, None, self.huggingface_token, cache_dir)
await download_specific_files(self.model_id, None, self.huggingface_token, cache_dir)
else:
# Download only the necessary files
logger.info(f"Downloading specific files for {self.model_id}...")
download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir)
await download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir)

# Load the tokenizer and model from the specified directory
logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...")
Expand Down Expand Up @@ -152,6 +152,9 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
"""
Sends a normalized prompt asynchronously to the HuggingFace model.
"""
# Load the model and tokenizer using the encapsulated method
await self.load_model_and_tokenizer_task

self._validate_request(prompt_request=prompt_request)
request = prompt_request.request_pieces[0]
prompt_template = request.converted_value
Expand Down
1 change: 1 addition & 0 deletions tests/test_huggingface_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_init_with_no_token_var_raises(monkeypatch):
assert "Environment variable HUGGINGFACE_TOKEN is required" in str(excinfo.value)


# TODO: Run through tests, currently hitting RuntimeError: no running event loop
def test_initialization():
# Test the initialization without loading the actual models
hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)
Expand Down

0 comments on commit 403683d

Please sign in to comment.