Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Remove aria2c dependency from HuggingFace Target #530

Merged
merged 7 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions doc/code/orchestrators/use_huggingface_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "d623d73a",
"id": "066bb566",
"metadata": {
"lines_to_next_cell": 2
},
Expand All @@ -17,7 +17,7 @@
" - This notebook supports the following **instruct models** that follow a structured chat template. These are examples, and more instruct models are available on Hugging Face:\n",
" - `HuggingFaceTB/SmolLM-360M-Instruct`\n",
" - `microsoft/Phi-3-mini-4k-instruct`\n",
" \n",
"\n",
" - `...`\n",
"\n",
"2. **Excluded Models**:\n",
Expand All @@ -37,63 +37,116 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0d61a68",
"metadata": {},
"outputs": [],
"execution_count": 1,
"id": "940f8d8a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-11T22:41:35.643730Z",
"iopub.status.busy": "2024-11-11T22:41:35.643730Z",
"iopub.status.idle": "2024-11-11T22:43:23.863745Z",
"shell.execute_reply": "2024-11-11T22:43:23.862727Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running model: HuggingFaceTB/SmolLM-135M-Instruct\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 37.12 seconds\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[39mConversation ID: 5223e15e-f21c-4d15-88af-8c02d6558182\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 4*4 is a special number because it can be expressed as a product of two numbers,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[39mConversation ID: b0238d3e-ce2e-48c3-a5e1-eaebf2c58e6f\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"HuggingFaceTB/SmolLM-135M-Instruct: 37.12 seconds\n"
]
}
],
"source": [
"import time\n",
"from pyrit.prompt_target import HuggingFaceChatTarget \n",
"from pyrit.prompt_target import HuggingFaceChatTarget\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"\n",
"# models to test\n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\" \n",
"model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\"\n",
"\n",
"# List of prompts to send\n",
"prompt_list = [\n",
" \"What is 3*3? Give me the solution.\",\n",
" \"What is 4*4? Give me the solution.\"\n",
" ]\n",
"prompt_list = [\"What is 3*3? Give me the solution.\", \"What is 4*4? Give me the solution.\"]\n",
"\n",
"# Dictionary to store average response times\n",
"model_times = {}\n",
" \n",
"\n",
"print(f\"Running model: {model_id}\")\n",
" \n",
"\n",
"try:\n",
" # Initialize HuggingFaceChatTarget with the current model\n",
" target = HuggingFaceChatTarget(\n",
" model_id=model_id, \n",
" use_cuda=False, \n",
" tensor_format=\"pt\",\n",
" max_new_tokens=30 \n",
" )\n",
" \n",
" target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n",
"\n",
" # Initialize the orchestrator\n",
" orchestrator = PromptSendingOrchestrator(\n",
" prompt_target=target,\n",
" verbose=False\n",
" )\n",
" \n",
" orchestrator = PromptSendingOrchestrator(prompt_target=target, verbose=False)\n",
"\n",
" # Record start time\n",
" start_time = time.time()\n",
" \n",
"\n",
" # Send prompts asynchronously\n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
" \n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
"\n",
" # Record end time\n",
" end_time = time.time()\n",
" \n",
"\n",
" # Calculate total and average response time\n",
" total_time = end_time - start_time\n",
" avg_time = total_time / len(prompt_list)\n",
" model_times[model_id] = avg_time\n",
" \n",
"\n",
" print(f\"Average response time for {model_id}: {avg_time:.2f} seconds\\n\")\n",
" \n",
"\n",
" # Print the conversations\n",
" await orchestrator.print_conversations() # type: ignore\n",
" \n",
" await orchestrator.print_conversations() # type: ignore\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred with model {model_id}: {e}\\n\")\n",
" model_times[model_id] = None\n",
Expand All @@ -108,14 +161,12 @@
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "pyrit-dev",
"language": "python",
"name": "python3"
"name": "pyrit-dev"
},
"language_info": {
"codemirror_mode": {
Expand Down
112 changes: 112 additions & 0 deletions pyrit/common/download_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import logging
import os
import httpx
from pathlib import Path

from huggingface_hub import HfApi


logger = logging.getLogger(__name__)


def get_available_files(model_id: str, token: str):
"""Fetches available files for a model from the Hugging Face repository."""
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)
available_files = [file.rfilename for file in model_info.siblings]

# Perform simple validation: raise a ValueError if no files are available
if not len(available_files):
raise ValueError(f"No available files found for the model: {model_id}")

return available_files
except Exception as e:
logger.info(f"Error fetching model files for {model_id}: {e}")
return []


async def download_specific_files(model_id: str, file_patterns: list, token: str, cache_dir: Path):
"""
Downloads specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.

Returns:
List of URLs for the downloaded files.
"""
os.makedirs(cache_dir, exist_ok=True)

available_files = get_available_files(model_id, token)
# If no file patterns are provided, download all available files
if file_patterns is None:
files_to_download = available_files
logger.info(f"Downloading all files for model {model_id}.")
else:
# Filter files based on the patterns provided
files_to_download = [file for file in available_files if any(pattern in file for pattern in file_patterns)]
if not files_to_download:
logger.info(f"No files matched the patterns provided for model {model_id}.")
return

# Generate download URLs directly
base_url = f"https://huggingface.co/{model_id}/resolve/main/"
urls = [base_url + file for file in files_to_download]

# Download the files
await download_files(urls, token, cache_dir)


async def download_chunk(url, headers, start, end, client):
"""Download a chunk of the file with a specified byte range."""
range_header = {"Range": f"bytes={start}-{end}", **headers}
response = await client.get(url, headers=range_header)
response.raise_for_status()
return response.content


async def download_file(url, token, download_dir, num_splits):
"""Download a file in multiple segments (splits) using byte-range requests."""
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(follow_redirects=True) as client:
# Get the file size to determine chunk size
response = await client.head(url, headers=headers)
response.raise_for_status()
file_size = int(response.headers["Content-Length"])
chunk_size = file_size // num_splits

# Prepare tasks for each chunk
tasks = []
file_name = url.split("/")[-1]
file_path = Path(download_dir, file_name)

for i in range(num_splits):
start = i * chunk_size
end = start + chunk_size - 1 if i < num_splits - 1 else file_size - 1
tasks.append(download_chunk(url, headers, start, end, client))

# Download all chunks concurrently
chunks = await asyncio.gather(*tasks)

# Write chunks to the file in order
with open(file_path, "wb") as f:
for chunk in chunks:
f.write(chunk)
logger.info(f"Downloaded {file_name} to {file_path}")


async def download_files(urls: list[str], token: str, download_dir: Path, num_splits=3, parallel_downloads=4):
"""Download multiple files with parallel downloads and segmented downloading."""

# Limit the number of parallel downloads
semaphore = asyncio.Semaphore(parallel_downloads)

async def download_with_limit(url):
async with semaphore:
await download_file(url, token, download_dir, num_splits)

# Run downloads concurrently, but limit to parallel_downloads at a time
await asyncio.gather(*(download_with_limit(url) for url in urls))
94 changes: 0 additions & 94 deletions pyrit/common/download_hf_model_with_aria2.py

This file was deleted.

Loading
Loading