From efee62c629ebba4a2ebdcce091b8d0ead567b9da Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Tue, 6 Aug 2024 10:53:45 +0200 Subject: [PATCH] chore: add loss plot --- .../lora_finetune/gpt2_finetune_hybrid.ipynb | 185 ++++++++++++------ 1 file changed, 129 insertions(+), 56 deletions(-) diff --git a/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb b/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb index 506b3febfe..6e282b17e0 100644 --- a/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb +++ b/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 20, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -10,6 +10,7 @@ "import shutil\n", "from pathlib import Path\n", "\n", + "import matplotlib.pyplot as plt\n", "import torch\n", "from lora_module import LoraTraining\n", "from peft import LoraConfig, TaskType, get_peft_model\n", @@ -45,21 +46,25 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ - "def generate_text(prompt, model, tokenizer, max_length=30, fhe=\"disable\"):\n", + "def generate_text(prompt, model, tokenizer, max_new_tokens=30, input_size=None):\n", " # Encode the input prompt\n", - " inputs = tokenizer.encode_plus(prompt, return_tensors=\"pt\")\n", - "\n", - " attention_mask = inputs[\"attention_mask\"]\n", + " inputs = tokenizer.encode_plus(\n", + " prompt,\n", + " return_tensors=\"pt\",\n", + " padding=\"max_length\",\n", + " max_length=input_size,\n", + " truncation=True,\n", + " )\n", "\n", " # Generate text\n", " output = model.generate(\n", " input_ids=inputs[\"input_ids\"],\n", - " attention_mask=attention_mask,\n", - " max_length=max_length,\n", + " attention_mask=inputs[\"attention_mask\"],\n", + " max_new_tokens=max_new_tokens,\n", " num_return_sequences=1,\n", " no_repeat_ngram_size=2,\n", " top_k=50,\n", @@ -76,14 +81,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be\n" + "What is FHE? FH: A basic program that is used to calculate the height of an object, and then sets the minimum height to be the object's height.\n" ] } ], @@ -96,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -141,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -152,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -168,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -197,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -222,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -248,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -257,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -273,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -288,9 +293,19 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." + ] + } + ], "source": [ "def train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"disable\"):\n", " device = \"cpu\"\n", @@ -303,6 +318,8 @@ " epoch_pbar = tqdm(total=total_epochs, desc=\"Training Progress\", position=0)\n", "\n", " total_batched_samples = 0\n", + " epoch_losses = [] # List to store the loss for each epoch\n", + "\n", " for epoch in range(total_epochs):\n", " total_loss = 0\n", " grad_norms = []\n", @@ -316,7 +333,7 @@ " # Gradient accumulation\n", " is_last_batch_step = (\n", " steps_in_epoch <= training_args.gradient_accumulation_steps\n", - " and (step + 1) == steps_in_epoch\n", + " and (step + 1) == steps_in_epoch # noqa: W503\n", " )\n", " accumulate_gradients = (\n", " total_batched_samples % training_args.gradient_accumulation_steps == 0\n", @@ -339,6 +356,9 @@ " # Get last grad norm\n", " current_grad_norm = grad_norms[-1]\n", "\n", + " # Store the total loss for this epoch\n", + " epoch_losses.append(total_loss)\n", + "\n", " # Log epoch results\n", " print(\n", " f\"Epoch {epoch + 1}/{training_args.num_train_epochs}, \"\n", @@ -352,48 +372,66 @@ " save_path = f\"{training_args.output_dir}/checkpoint-{epoch + 1}\"\n", " hybrid_model.model.inference_model.save_pretrained(save_path)\n", "\n", - " epoch_pbar.close()" + " epoch_pbar.close()\n", + "\n", + " # Plot the loss evolution\n", + " plt.figure(figsize=(10, 6))\n", + " plt.plot(range(1, total_epochs + 1), epoch_losses, marker=\"o\")\n", + " plt.title(\"Loss Evolution During Training\")\n", + " plt.xlabel(\"Epoch\")\n", + " plt.ylabel(\"Total Loss\")\n", + " plt.grid(True)\n", + " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." + ] + } + ], "source": [ - "train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"simulate\")" + "train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"disable\")" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "fine_tuned_model = hybrid_model.model.inference_model\n", "\n", - "hybrid_model.set_fhe_mode(\"simulate\")" + "# In simulation, we can only generate a single token at a time because of fixed size circuits\n", + "# and how `generate` works (only the last token from the previous generation is kept)\n", + "hybrid_model.set_fhe_mode(\"disable\")" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "What is FHE?\n", - "\n", - "FHE is a cryptographic technique that enables computations on arbitrary data structures. It consists in generating computable FAs\n" + "What is FHE?I\n" ] } ], "source": [ - "# Example usage\n", "prompt = \"What is FHE ?\"\n", - "generated_text = generate_text(prompt, fine_tuned_model, tokenizer)\n", + "generated_text = generate_text(prompt, fine_tuned_model, tokenizer, input_size=BLOCK_SIZE)\n", "print(generated_text)" ] }, @@ -401,10 +439,20 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." + ] + } + ], "source": [ "peft_model.disable_adapter_layers()\n", - "# Example usage\n", + "\n", "prompt = \"What is FHE ?\"\n", "generated_text = generate_text(prompt, fine_tuned_model, tokenizer)\n", "print(generated_text)\n", @@ -414,9 +462,19 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." + ] + } + ], "source": [ "def print_weights_and_size(model, print_detail=False):\n", " total_weights = 0\n", @@ -438,15 +496,16 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Total number of weights: 124587264\n", - "Total number of LoRA weights: 147456\n" + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." ] } ], @@ -456,9 +515,19 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." + ] + } + ], "source": [ "path = Path(\"gpt2_lora_finetuned_hybrid_deployment\")\n", "\n", @@ -470,14 +539,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Total number of weights: 39569664\n" + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." ] } ], @@ -487,14 +558,16 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Weights removed: 68.24 %\n" + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe file '.venv_lora/lib/python3.10/site-packages/typing_extensions.py' seems to be overriding built in modules and interfering with the startup of the kernel. Consider renaming the file and starting the kernel again.\n", + "\u001b[1;31mClick here for more info." ] } ],