Skip to content

Commit

Permalink
add support for local model checkpoints and trust_remote_code in Hugg…
Browse files Browse the repository at this point in the history
…ingFaceChatTarget
  • Loading branch information
KutalVolkan committed Nov 23, 2024
1 parent 94dd4ec commit 8f50df0
Show file tree
Hide file tree
Showing 2 changed files with 427 additions and 29 deletions.
366 changes: 364 additions & 2 deletions doc/code/targets/use_huggingface_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "940f8d8a",
"metadata": {
"execution": {
Expand All @@ -47,7 +47,28 @@
"shell.execute_reply": "2024-11-11T22:43:23.862727Z"
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running model: HuggingFaceTB/SmolLM-135M-Instruct\n",
"Average response time for HuggingFaceTB/SmolLM-135M-Instruct: 23.72 seconds\n",
"\n",
"\u001b[22m\u001b[39mConversation ID: 56985b2b-cdfb-4d88-8255-71566b0c8c4f\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 3*3 is a fascinating number that has been a subject of fascination for mathematicians and computer scientists for\n",
"\u001b[22m\u001b[39mConversation ID: c5bb0d53-4b28-4048-bd03-251e17781285\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: What a great question!\n",
"\n",
"The number 4*4 is a special number because it can be expressed as a product of two numbers,\n",
"HuggingFaceTB/SmolLM-135M-Instruct: 23.72 seconds\n"
]
}
],
"source": [
"import time\n",
"from pyrit.prompt_target import HuggingFaceChatTarget\n",
Expand Down Expand Up @@ -100,6 +121,347 @@
"else:\n",
" print(f\"{model_id}: Error occurred, no average time calculated.\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4c563c2e",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c04744679fa4eff9fc70d12c8bea45c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/3.44k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "09db4560c0f84443847fa0102b853913",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.model: 0%| | 0.00/500k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46c98f49a87643e3a6598f9175772c7e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/1.94M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ee2d4f3fad164a70bb9564a1c6b80218",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"added_tokens.json: 0%| | 0.00/306 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2e8ea987594849b586b5066ed6688462",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/599 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9a96bb463e2e4e0b861d9b26a2147e78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/967 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a75f058be4843d7a4e580ad1ce91300",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"configuration_phi3.py: 0%| | 0.00/11.2k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:\n",
"- configuration_phi3.py\n",
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0e05e6c399204e10a311f356ad96e4ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"modeling_phi3.py: 0%| | 0.00/73.2k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:\n",
"- modeling_phi3.py\n",
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
"`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.\n",
"Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "75ef6f9379a241a1a6baf2ec062dbec2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors.index.json: 0%| | 0.00/16.5k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "22faaefa4687450ab4b227be1fd83ef7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9506aa531b9c4479bf1dd6208efbcf03",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00002.safetensors: 0%| | 0.00/4.97G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2471ffca92ef4b11989d659aac2f0b22",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00002.safetensors: 0%| | 0.00/2.67G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1a7ce1cbad1d4b6084c4f776a4138d37",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ea318ac28a8486f9dc4ce0c395e4d13",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/181 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model and tokenizer saved locally at ./local_phi_model\n"
]
}
],
"source": [
"# Download and save the model locally\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"import os\n",
"\n",
"model_id = \"microsoft/Phi-3-mini-4k-instruct\"\n",
"local_model_path = \"./local_phi_model\"\n",
"\n",
"os.makedirs(local_model_path, exist_ok=True)\n",
"\n",
"# Download and save the tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
"tokenizer.save_pretrained(local_model_path)\n",
"\n",
"# Download and save the model\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)\n",
"model.save_pretrained(local_model_path)\n",
"\n",
"print(f\"Model and tokenizer saved locally at {local_model_path}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ab024b19",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.\n",
"You are not running the flash-attention implementation, expect numerical differences.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running model from local path: ./local_phi_model\n",
"Average response time for model at ./local_phi_model: 18.46 seconds\n",
"\n",
"\u001b[22m\u001b[39mConversation ID: 9fd7db42-9c5f-4622-8762-76af5797d0ba\n",
"\u001b[1m\u001b[34muser: What is 3*3? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: The solution to 3*3 is 9.\n",
"\u001b[22m\u001b[39mConversation ID: a60a5bd9-a27d-488f-84ba-f1cb3a742932\n",
"\u001b[1m\u001b[34muser: What is 4*4? Give me the solution.\n",
"\u001b[22m\u001b[33massistant: The solution to 4*4 is 16.\n",
"Model at ./local_phi_model: 18.46 seconds\n"
]
}
],
"source": [
"import time\n",
"from pyrit.prompt_target import HuggingFaceChatTarget\n",
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
"\n",
"# Path to the local phi model\n",
"model_path = \"./local_phi_model\"\n",
"\n",
"# List of prompts to send\n",
"prompt_list = [\"What is 3*3? Give me the solution.\", \"What is 4*4? Give me the solution.\"]\n",
"\n",
"# Dictionary to store average response times\n",
"model_times = {}\n",
"\n",
"print(f\"Running model from local path: {model_path}\")\n",
"\n",
"try:\n",
" # Initialize HuggingFaceChatTarget with the local model path\n",
" target = HuggingFaceChatTarget(\n",
" model_path=model_path,\n",
" use_cuda=False,\n",
" tensor_format=\"pt\",\n",
" max_new_tokens=30,\n",
" trust_remote_code=True # Necessary for this model\n",
" )\n",
"\n",
" # Initialize the orchestrator\n",
" orchestrator = PromptSendingOrchestrator(prompt_target=target, verbose=False)\n",
"\n",
" # Record start time\n",
" start_time = time.time()\n",
"\n",
" # Send prompts asynchronously\n",
" responses = await orchestrator.send_prompts_async(prompt_list=prompt_list) # type: ignore\n",
"\n",
" # Record end time\n",
" end_time = time.time()\n",
"\n",
" # Calculate total and average response time\n",
" total_time = end_time - start_time\n",
" avg_time = total_time / len(prompt_list)\n",
" model_times[model_path] = avg_time\n",
"\n",
" print(f\"Average response time for model at {model_path}: {avg_time:.2f} seconds\\n\")\n",
"\n",
" # Print the conversations\n",
" await orchestrator.print_conversations() # type: ignore\n",
"\n",
"except Exception as e:\n",
" print(f\"An error occurred with model at {model_path}: {e}\\n\")\n",
" model_times[model_path] = None\n",
"\n",
"# Print the model average time\n",
"if model_times[model_path] is not None:\n",
" print(f\"Model at {model_path}: {model_times[model_path]:.2f} seconds\")\n",
"else:\n",
" print(f\"Model at {model_path}: Error occurred, no average time calculated.\")"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 8f50df0

Please sign in to comment.