diff --git a/doc/code/orchestrators/use_huggingface_chat_target.ipynb b/doc/code/orchestrators/use_huggingface_chat_target.ipynb index 56cbb926e..d8c6d5ed7 100644 --- a/doc/code/orchestrators/use_huggingface_chat_target.ipynb +++ b/doc/code/orchestrators/use_huggingface_chat_target.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "dbabb573", + "id": "666a9eef", "metadata": { "lines_to_next_cell": 2 }, @@ -38,13 +38,13 @@ { "cell_type": "code", "execution_count": 1, - "id": "76376ab9", + "id": "7b080eca", "metadata": { "execution": { - "iopub.execute_input": "2024-11-06T01:24:35.784894Z", - "iopub.status.busy": "2024-11-06T01:24:35.784894Z", - "iopub.status.idle": "2024-11-06T01:27:31.096637Z", - "shell.execute_reply": "2024-11-06T01:27:31.096637Z" + "iopub.execute_input": "2024-11-08T02:09:00.673581Z", + "iopub.status.busy": "2024-11-08T02:09:00.672562Z", + "iopub.status.idle": "2024-11-08T02:10:32.738032Z", + "shell.execute_reply": "2024-11-08T02:10:32.738032Z" } }, "outputs": [ @@ -59,19 +59,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n", + "Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 33.10 seconds\n", "\n", - "\u001b[22m\u001b[39mConversation ID: ad151fa1-30d4-4b51-976b-acfc6ddc9064\n", - "\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n", - "\u001b[22m\u001b[33massistant: What a great question!\n", - "\n", - "The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n", - "\u001b[22m\u001b[39mConversation ID: cf250f01-af21-463e-bd6f-ddb11bdf0a97\n", + "\u001b[22m\u001b[39mConversation ID: 47d44512-c619-45e5-8140-69ce3e9fc28b\n", "\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n", "\u001b[22m\u001b[33massistant: What a great question!\n", "\n", "The number 4*4 is a special number because it can be expressed as a product of two numbers,\n", - "HuggingFaceTB/SmolLM-135M-Instruct: 2.61 seconds\n" + "\u001b[22m\u001b[39mConversation ID: e3e87055-e761-4672-98c2-9517be5d6027\n", + "\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n", + "\u001b[22m\u001b[33massistant: What a great question!\n", + "\n", + "The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n", + "HuggingFaceTB/SmolLM-135M-Instruct: 33.10 seconds\n" ] } ], @@ -92,6 +92,10 @@ "print(f\"Running model: {model_id}\")\n", "\n", "try:\n", + " # Enable cache for the target to save time on subsequent calls\n", + " # Target will use cached model and tokenizer if the model_id is the same\n", + " HuggingFaceChatTarget.enable_cache()\n", + "\n", " # Initialize HuggingFaceChatTarget with the current model\n", " target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n", "\n", diff --git a/doc/code/orchestrators/use_huggingface_chat_target.py b/doc/code/orchestrators/use_huggingface_chat_target.py index e7bdccc27..250cee2f4 100644 --- a/doc/code/orchestrators/use_huggingface_chat_target.py +++ b/doc/code/orchestrators/use_huggingface_chat_target.py @@ -45,6 +45,10 @@ print(f"Running model: {model_id}") try: + # Enable cache for the target to save time on subsequent calls + # Target will use cached model and tokenizer if the model_id is the same + HuggingFaceChatTarget.enable_cache() + # Initialize HuggingFaceChatTarget with the current model target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format="pt", max_new_tokens=30) diff --git a/pyrit/common/download_hf_model.py b/pyrit/common/download_hf_model.py index dbd1189a6..bcc7ecc74 100644 --- a/pyrit/common/download_hf_model.py +++ b/pyrit/common/download_hf_model.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio import logging import os -import urllib.request +import httpx from pathlib import Path from huggingface_hub import HfApi @@ -29,7 +30,7 @@ def get_available_files(model_id: str, token: str): return [] -def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path) -> list[str]: +async def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path): """ Downloads specific files from a Hugging Face model repository. If file_patterns is None, downloads all files. @@ -56,15 +57,56 @@ def download_specific_files(model_id: str, file_patterns: list, token: str, cach urls = [base_url + file for file in files_to_download] # Download the files - download_files(urls, token, cache_dir) + await download_files(urls, token, cache_dir) -def download_files(urls: list, token: str, cache_dir: Path): +async def download_chunk(url, headers, start, end, client): + """Download a chunk of the file with a specified byte range.""" + range_header = {"Range": f"bytes={start}-{end}", **headers} + response = await client.get(url, headers=range_header) + response.raise_for_status() + return response.content + + +async def download_file(url, token, download_dir, num_splits): + """Download a file in multiple segments (splits) using byte-range requests.""" headers = {"Authorization": f"Bearer {token}"} - for url in urls: - local_filename = Path(cache_dir, url.split("/")[-1]) - request = urllib.request.Request(url, headers=headers) - with urllib.request.urlopen(request) as response, open(local_filename, "wb") as out_file: - data = response.read() - out_file.write(data) - logger.info(f"Downloaded {local_filename}") + async with httpx.AsyncClient(follow_redirects=True) as client: + # Get the file size to determine chunk size + response = await client.head(url, headers=headers) + response.raise_for_status() + file_size = int(response.headers["Content-Length"]) + chunk_size = file_size // num_splits + + # Prepare tasks for each chunk + tasks = [] + file_name = url.split("/")[-1] + file_path = Path(download_dir, file_name) + + for i in range(num_splits): + start = i * chunk_size + end = start + chunk_size - 1 if i < num_splits - 1 else file_size - 1 + tasks.append(download_chunk(url, headers, start, end, client)) + + # Download all chunks concurrently + chunks = await asyncio.gather(*tasks) + + # Write chunks to the file in order + with open(file_path, "wb") as f: + for chunk in chunks: + f.write(chunk) + logger.info(f"Downloaded {file_name} to {file_path}") + + +async def download_files(urls: list[str], token: str, download_dir: Path, num_splits=3, parallel_downloads=4): + """Download multiple files with parallel downloads and segmented downloading.""" + + # Limit the number of parallel downloads + semaphore = asyncio.Semaphore(parallel_downloads) + + async def download_with_limit(url): + async with semaphore: + await download_file(url, token, download_dir, num_splits) + + # Run downloads concurrently, but limit to parallel_downloads at a time + await asyncio.gather(*(download_with_limit(url) for url in urls)) diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 8088e84ce..4c693d81e 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio import json import logging import os @@ -19,7 +20,6 @@ logger = logging.getLogger(__name__) - class HuggingFaceChatTarget(PromptChatTarget): """The HuggingFaceChatTarget interacts with HuggingFace models, specifically for conducting red teaming activities. Inherits from PromptTarget to comply with the current design standards. @@ -75,9 +75,9 @@ def __init__( if self.use_cuda and not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available.") + + self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer()) - # Load the model and tokenizer using the encapsulated method - self.load_model_and_tokenizer() def is_model_id_valid(self) -> bool: """ @@ -92,7 +92,7 @@ def is_model_id_valid(self) -> bool: logger.error(f"Invalid HuggingFace model ID {self.model_id}: {e}") return False - def load_model_and_tokenizer(self): + async def load_model_and_tokenizer(self): """Loads the model and tokenizer, downloading if necessary. Downloads the model to the HF_MODELS_DIR folder if it does not exist, @@ -117,11 +117,11 @@ def load_model_and_tokenizer(self): if self.necessary_files is None: # Download all files if no specific files are provided logger.info(f"Downloading all files for {self.model_id}...") - download_specific_files(self.model_id, None, self.huggingface_token, cache_dir) + await download_specific_files(self.model_id, None, self.huggingface_token, cache_dir) else: # Download only the necessary files logger.info(f"Downloading specific files for {self.model_id}...") - download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir) + await download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir) # Load the tokenizer and model from the specified directory logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") @@ -152,6 +152,9 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P """ Sends a normalized prompt asynchronously to the HuggingFace model. """ + # Load the model and tokenizer using the encapsulated method + await self.load_model_and_tokenizer_task + self._validate_request(prompt_request=prompt_request) request = prompt_request.request_pieces[0] prompt_template = request.converted_value diff --git a/tests/test_huggingface_chat_target.py b/tests/test_huggingface_chat_target.py index 854ca1a1c..71822385b 100644 --- a/tests/test_huggingface_chat_target.py +++ b/tests/test_huggingface_chat_target.py @@ -81,6 +81,7 @@ def test_init_with_no_token_var_raises(monkeypatch): assert "Environment variable HUGGINGFACE_TOKEN is required" in str(excinfo.value) +# TODO: Run through tests, currently hitting RuntimeError: no running event loop def test_initialization(): # Test the initialization without loading the actual models hf_chat = HuggingFaceChatTarget(model_id="test_model", use_cuda=False)