diff --git a/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb b/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb
index 506b3febfe..6e282b17e0 100644
--- a/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb
+++ b/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -10,6 +10,7 @@
"import shutil\n",
"from pathlib import Path\n",
"\n",
+ "import matplotlib.pyplot as plt\n",
"import torch\n",
"from lora_module import LoraTraining\n",
"from peft import LoraConfig, TaskType, get_peft_model\n",
@@ -45,21 +46,25 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
- "def generate_text(prompt, model, tokenizer, max_length=30, fhe=\"disable\"):\n",
+ "def generate_text(prompt, model, tokenizer, max_new_tokens=30, input_size=None):\n",
" # Encode the input prompt\n",
- " inputs = tokenizer.encode_plus(prompt, return_tensors=\"pt\")\n",
- "\n",
- " attention_mask = inputs[\"attention_mask\"]\n",
+ " inputs = tokenizer.encode_plus(\n",
+ " prompt,\n",
+ " return_tensors=\"pt\",\n",
+ " padding=\"max_length\",\n",
+ " max_length=input_size,\n",
+ " truncation=True,\n",
+ " )\n",
"\n",
" # Generate text\n",
" output = model.generate(\n",
" input_ids=inputs[\"input_ids\"],\n",
- " attention_mask=attention_mask,\n",
- " max_length=max_length,\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_new_tokens=max_new_tokens,\n",
" num_return_sequences=1,\n",
" no_repeat_ngram_size=2,\n",
" top_k=50,\n",
@@ -76,14 +81,14 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be\n"
+ "What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be the object's height.\n"
]
}
],
@@ -96,7 +101,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -113,7 +118,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -141,7 +146,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -152,7 +157,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -168,7 +173,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -197,7 +202,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -222,7 +227,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -248,7 +253,7 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -257,7 +262,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -273,7 +278,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -288,9 +293,19 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
+ ]
+ }
+ ],
"source": [
"def train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"disable\"):\n",
" device = \"cpu\"\n",
@@ -303,6 +318,8 @@
" epoch_pbar = tqdm(total=total_epochs, desc=\"Training Progress\", position=0)\n",
"\n",
" total_batched_samples = 0\n",
+ " epoch_losses = [] # List to store the loss for each epoch\n",
+ "\n",
" for epoch in range(total_epochs):\n",
" total_loss = 0\n",
" grad_norms = []\n",
@@ -316,7 +333,7 @@
" # Gradient accumulation\n",
" is_last_batch_step = (\n",
" steps_in_epoch <= training_args.gradient_accumulation_steps\n",
- " and (step + 1) == steps_in_epoch\n",
+ " and (step + 1) == steps_in_epoch # noqa: W503\n",
" )\n",
" accumulate_gradients = (\n",
" total_batched_samples % training_args.gradient_accumulation_steps == 0\n",
@@ -339,6 +356,9 @@
" # Get last grad norm\n",
" current_grad_norm = grad_norms[-1]\n",
"\n",
+ " # Store the total loss for this epoch\n",
+ " epoch_losses.append(total_loss)\n",
+ "\n",
" # Log epoch results\n",
" print(\n",
" f\"Epoch {epoch + 1}/{training_args.num_train_epochs}, \"\n",
@@ -352,48 +372,66 @@
" save_path = f\"{training_args.output_dir}/checkpoint-{epoch + 1}\"\n",
" hybrid_model.model.inference_model.save_pretrained(save_path)\n",
"\n",
- " epoch_pbar.close()"
+ " epoch_pbar.close()\n",
+ "\n",
+ " # Plot the loss evolution\n",
+ " plt.figure(figsize=(10, 6))\n",
+ " plt.plot(range(1, total_epochs + 1), epoch_losses, marker=\"o\")\n",
+ " plt.title(\"Loss Evolution During Training\")\n",
+ " plt.xlabel(\"Epoch\")\n",
+ " plt.ylabel(\"Total Loss\")\n",
+ " plt.grid(True)\n",
+ " plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
+ ]
+ }
+ ],
"source": [
- "train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"simulate\")"
+ "train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"disable\")"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"fine_tuned_model = hybrid_model.model.inference_model\n",
"\n",
- "hybrid_model.set_fhe_mode(\"simulate\")"
+ "# In simulation, we can only generate a single token at a time because of fixed size circuits\n",
+ "# and how `generate` works (only the last token from the previous generation is kept)\n",
+ "hybrid_model.set_fhe_mode(\"disable\")"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "What is FHE?\n",
- "\n",
- "FHE is a cryptographic technique that enables computations on arbitrary data structures. It consists in generating computable FAs\n"
+ "What is FHE?I\n"
]
}
],
"source": [
- "# Example usage\n",
"prompt = \"What is FHE ?\"\n",
- "generated_text = generate_text(prompt, fine_tuned_model, tokenizer)\n",
+ "generated_text = generate_text(prompt, fine_tuned_model, tokenizer, input_size=BLOCK_SIZE)\n",
"print(generated_text)"
]
},
@@ -401,10 +439,20 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
+ ]
+ }
+ ],
"source": [
"peft_model.disable_adapter_layers()\n",
- "# Example usage\n",
+ "\n",
"prompt = \"What is FHE ?\"\n",
"generated_text = generate_text(prompt, fine_tuned_model, tokenizer)\n",
"print(generated_text)\n",
@@ -414,9 +462,19 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
+ ]
+ }
+ ],
"source": [
"def print_weights_and_size(model, print_detail=False):\n",
" total_weights = 0\n",
@@ -438,15 +496,16 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Total number of weights: 124587264\n",
- "Total number of LoRA weights: 147456\n"
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
]
}
],
@@ -456,9 +515,19 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
+ ]
+ }
+ ],
"source": [
"path = Path(\"gpt2_lora_finetuned_hybrid_deployment\")\n",
"\n",
@@ -470,14 +539,16 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Total number of weights: 39569664\n"
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
]
}
],
@@ -487,14 +558,16 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Weights removed: 68.24 %\n"
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n",
+ "\u001b[1;31mClick here for more info."
]
}
],