From f50b1065ff4fe0f87aaaa98cc1b4991ee365bbbb Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Mon, 5 Aug 2024 17:33:42 +0200 Subject: [PATCH] chore: clean notebook --- .../lora_finetune/gpt2_finetune_hybrid.ipynb | 184 +++--------------- .../lora_finetune/lora_module.py | 65 +++++++ .../{custom_module.py => remote_module.py} | 0 .../lora_finetune/requirements.txt | 1 - 4 files changed, 97 insertions(+), 153 deletions(-) create mode 100644 use_case_examples/lora_finetune/lora_module.py rename use_case_examples/lora_finetune/{custom_module.py => remote_module.py} (100%) diff --git a/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb b/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb index 0ba0e8bc44..506b3febfe 100644 --- a/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb +++ b/use_case_examples/lora_finetune/gpt2_finetune_hybrid.ipynb @@ -11,8 +11,9 @@ "from pathlib import Path\n", "\n", "import torch\n", - "from custom_module import CustomConv1D\n", + "from lora_module import LoraTraining\n", "from peft import LoraConfig, TaskType, get_peft_model\n", + "from remote_module import CustomConv1D\n", "from tqdm import tqdm\n", "from transformers import (\n", " AutoModelForCausalLM,\n", @@ -37,7 +38,7 @@ "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", - "# FREEZE WEIGHTS\n", + "# Freeze weights\n", "for param in model.parameters():\n", " param.requires_grad = False" ] @@ -138,75 +139,6 @@ "replace_conv1d(peft_model, module_index_to_skip=0);" ] }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "class LoraTraining(torch.nn.Module):\n", - " def __init__(self, inference_model, gradient_accumulation_steps) -> None:\n", - " super().__init__()\n", - "\n", - " self.inference_model = inference_model\n", - "\n", - " self.optimizer = None\n", - " self.lr_scheduler = None\n", - "\n", - " self.gradient_accumulation_steps = gradient_accumulation_steps\n", - " self.max_grad_norm = None\n", - "\n", - " self.calibrate = False\n", - " self.run_optimizer = False\n", - "\n", - " def update_training_parameters(self, optimizer, lr_scheduler, training_args):\n", - " assert self.gradient_accumulation_steps == training_args.gradient_accumulation_steps\n", - "\n", - " self.optimizer = optimizer\n", - " self.lr_scheduler = lr_scheduler\n", - " self.max_grad_norm = training_args.max_grad_norm\n", - "\n", - " def forward(self, inputs):\n", - " # FIXME: handle multi-inputs in hybrid model\n", - " x, y = inputs\n", - "\n", - " # some parts on server side\n", - " outputs = self.inference_model(input_ids=x, labels=y)\n", - "\n", - " loss = outputs.loss\n", - " loss = loss / self.gradient_accumulation_steps\n", - "\n", - " # Update gradients\n", - " loss.backward()\n", - "\n", - " grad_norm = None\n", - " if not self.calibrate and self.run_optimizer:\n", - " assert self.optimizer is not None\n", - " assert self.lr_scheduler is not None\n", - " assert self.max_grad_norm is not None\n", - "\n", - " grad_norm = torch.nn.utils.clip_grad_norm_(\n", - " self.inference_model.parameters(), max_norm=self.max_grad_norm, norm_type=2\n", - " )\n", - "\n", - " self.optimizer.step()\n", - " self.lr_scheduler.step()\n", - "\n", - " self.inference_model.zero_grad()\n", - "\n", - " # Clean gradients after calibration\n", - " elif self.calibrate:\n", - " self.inference_model.zero_grad()\n", - "\n", - " return (loss, grad_norm)\n", - "\n", - " def toggle_calibrate(self, enable: bool = True):\n", - " self.calibrate = enable\n", - "\n", - " def toggle_run_optimizer(self, enable: bool = True):\n", - " self.run_optimizer = enable" - ] - }, { "cell_type": "code", "execution_count": 26, @@ -226,18 +158,12 @@ "source": [ "BLOCK_SIZE = 128\n", "\n", - "\n", - "def load_dataset(file_path, tokenizer):\n", - " dataset = TextDataset(\n", - " tokenizer=tokenizer,\n", - " file_path=file_path,\n", - " block_size=BLOCK_SIZE,\n", - " cache_dir=\"cache_dataset\",\n", - " )\n", - " return dataset\n", - "\n", - "\n", - "train_dataset = load_dataset(\"data_finetune/what_is_fhe.txt\", tokenizer)" + "train_dataset = TextDataset(\n", + " tokenizer=tokenizer,\n", + " file_path=\"data_finetune/what_is_fhe.txt\",\n", + " block_size=BLOCK_SIZE,\n", + " cache_dir=\"cache_dataset\",\n", + ")" ] }, { @@ -246,11 +172,9 @@ "metadata": {}, "outputs": [], "source": [ - "tokenizer.parallelism = False\n", - "\n", "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n", "\n", - "EPOCHS = 2\n", + "EPOCHS = 100\n", "PER_DEVICE_TRAIN_BATCH_SIZE = 4\n", "\n", "training_args = TrainingArguments(\n", @@ -291,21 +215,14 @@ "num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)\n", "max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)\n", "\n", - "trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ + "trainer.create_optimizer_and_scheduler(num_training_steps=max_steps)\n", + "\n", "lora_training.update_training_parameters(trainer.optimizer, trainer.lr_scheduler, training_args)" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -326,8 +243,15 @@ " return remote_names\n", "\n", "\n", - "remote_names = get_remote_names(lora_training)\n", - "\n", + "remote_names = get_remote_names(lora_training)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ "hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)" ] }, @@ -433,56 +357,10 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", - "Training Progress: 50%|█████ | 1/2 [02:53<02:53, 173.99s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/2, Loss: 1.8678, grad norm: 0.20942462980747223, lr: 0.00025\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[35], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(SEED)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mtrain_custom_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhybrid_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfhe\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msimulate\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[34], line 35\u001b[0m, in \u001b[0;36mtrain_custom_model\u001b[0;34m(hybrid_model, train_dataloader, training_args, fhe)\u001b[0m\n\u001b[1;32m 31\u001b[0m run_optimizer \u001b[38;5;241m=\u001b[39m is_last_batch_step \u001b[38;5;129;01mor\u001b[39;00m accumulate_gradients\n\u001b[1;32m 33\u001b[0m hybrid_model\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtoggle_run_optimizer(enable\u001b[38;5;241m=\u001b[39mrun_optimizer)\n\u001b[0;32m---> 35\u001b[0m loss, grad_norm \u001b[38;5;241m=\u001b[39m \u001b[43mhybrid_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlabels\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfhe\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfhe\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m total_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m grad_norm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/ml/torch/hybrid_model.py:419\u001b[0m, in \u001b[0;36mHybridFHEModel.__call__\u001b[0;34m(self, x, fhe)\u001b[0m\n\u001b[1;32m 417\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mremote_modules\u001b[38;5;241m.\u001b[39mvalues():\n\u001b[1;32m 418\u001b[0m module\u001b[38;5;241m.\u001b[39mfhe_local_mode \u001b[38;5;241m=\u001b[39m HybridFHEMode(fhe)\n\u001b[0;32m--> 419\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 420\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "Cell \u001b[0;32mIn[25], line 34\u001b[0m, in \u001b[0;36mLoraTraining.forward\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 31\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Update gradients\u001b[39;00m\n\u001b[0;32m---> 34\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 36\u001b[0m grad_norm \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcalibrate \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_optimizer:\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 480\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 481\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 486\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 192\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/torch/autograd/function.py:267\u001b[0m, in \u001b[0;36mBackwardCFunction.apply\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mImplementing both \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbackward\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvjp\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m for a custom \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFunction is not allowed. You should only implement one \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mof them.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 266\u001b[0m user_fn \u001b[38;5;241m=\u001b[39m vjp_fn \u001b[38;5;28;01mif\u001b[39;00m vjp_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m Function\u001b[38;5;241m.\u001b[39mvjp \u001b[38;5;28;01melse\u001b[39;00m backward_fn\n\u001b[0;32m--> 267\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43muser_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/concrete-ml/use_case_examples/lora_finetune/custom_module.py:36\u001b[0m, in \u001b[0;36mForwardBackwardModule.backward\u001b[0;34m(ctx, grad_output)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbackward\u001b[39m(ctx, grad_output):\n\u001b[1;32m 35\u001b[0m backward_module \u001b[38;5;241m=\u001b[39m ctx\u001b[38;5;241m.\u001b[39mbackward_module\n\u001b[0;32m---> 36\u001b[0m grad_input \u001b[38;5;241m=\u001b[39m \u001b[43mbackward_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgrad_output\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# grad_weight and grad_bias are not needed when computing the backward for lora\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m grad_input, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/ml/torch/hybrid_model.py:253\u001b[0m, in \u001b[0;36mRemoteModule.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfhe_local_mode \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m {\n\u001b[1;32m 245\u001b[0m HybridFHEMode\u001b[38;5;241m.\u001b[39mDISABLE,\n\u001b[1;32m 246\u001b[0m HybridFHEMode\u001b[38;5;241m.\u001b[39mCALIBRATE,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 249\u001b[0m }:\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# Using quantized module\u001b[39;00m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprivate_q_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 252\u001b[0m y \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor(\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprivate_q_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfhe\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfhe_local_mode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m )\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfhe_local_mode \u001b[38;5;241m==\u001b[39m HybridFHEMode\u001b[38;5;241m.\u001b[39mDISABLE:\n\u001b[1;32m 257\u001b[0m \u001b[38;5;66;03m# Calling torch\u001b[39;00m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprivate_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/ml/quantization/quantized_module.py:443\u001b[0m, in \u001b[0;36mQuantizedModule.forward\u001b[0;34m(self, fhe, debug, *x)\u001b[0m\n\u001b[1;32m 440\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdequantize_output(\u001b[38;5;241m*\u001b[39mto_tuple(q_y_pred))\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m y_pred, debug_value_tracker\n\u001b[0;32m--> 443\u001b[0m q_y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mquantized_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mq_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfhe\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfhe\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;66;03m# De-quantize the output predicted values\u001b[39;00m\n\u001b[1;32m 446\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdequantize_output(\u001b[38;5;241m*\u001b[39mto_tuple(q_y_pred))\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/ml/quantization/quantized_module.py:486\u001b[0m, in \u001b[0;36mQuantizedModule.quantized_forward\u001b[0;34m(self, fhe, *q_x)\u001b[0m\n\u001b[1;32m 484\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_clear_forward(\u001b[38;5;241m*\u001b[39mq_x)\n\u001b[1;32m 485\u001b[0m simulate \u001b[38;5;241m=\u001b[39m fhe \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msimulate\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 486\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fhe_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mq_x\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msimulate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msimulate\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/ml/quantization/quantized_module.py:651\u001b[0m, in \u001b[0;36mQuantizedModule._fhe_forward\u001b[0;34m(self, simulate, *q_x)\u001b[0m\n\u001b[1;32m 648\u001b[0m predict_method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfhe_circuit\u001b[38;5;241m.\u001b[39mencrypt_run_decrypt\n\u001b[1;32m 650\u001b[0m \u001b[38;5;66;03m# Execute the forward pass in FHE or with simulation\u001b[39;00m\n\u001b[0;32m--> 651\u001b[0m q_result \u001b[38;5;241m=\u001b[39m to_tuple(\u001b[43mpredict_method\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mq_input\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(q_result) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(q_result_by_output), (\n\u001b[1;32m 654\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNumber of outputs does not match the number of output quantizers.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 655\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(q_result)\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m!=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_quantizers)\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 656\u001b[0m )\n\u001b[1;32m 657\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m elt_index, elt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(q_result):\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/fhe/compilation/circuit.py:168\u001b[0m, in \u001b[0;36mCircuit.simulate\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 165\u001b[0m ordered_validated_args \u001b[38;5;241m=\u001b[39m validate_input_args(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39mclient_specs, \u001b[38;5;241m*\u001b[39margs)\n\u001b[1;32m 167\u001b[0m exporter \u001b[38;5;241m=\u001b[39m SimulatedValueExporter\u001b[38;5;241m.\u001b[39mnew(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39mclient_specs\u001b[38;5;241m.\u001b[39mclient_parameters)\n\u001b[0;32m--> 168\u001b[0m exported \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 169\u001b[0m (\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m arg \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m Value(\n\u001b[1;32m 173\u001b[0m exporter\u001b[38;5;241m.\u001b[39mexport_tensor(position, arg\u001b[38;5;241m.\u001b[39mflatten()\u001b[38;5;241m.\u001b[39mtolist(), \u001b[38;5;28mlist\u001b[39m(arg\u001b[38;5;241m.\u001b[39mshape))\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(arg, np\u001b[38;5;241m.\u001b[39mndarray) \u001b[38;5;129;01mand\u001b[39;00m arg\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m ()\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m exporter\u001b[38;5;241m.\u001b[39mexport_scalar(position, \u001b[38;5;28mint\u001b[39m(arg))\n\u001b[1;32m 176\u001b[0m )\n\u001b[1;32m 177\u001b[0m )\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m position, arg \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(ordered_validated_args)\n\u001b[1;32m 179\u001b[0m ]\n\u001b[1;32m 181\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;241m*\u001b[39mexported)\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(results, \u001b[38;5;28mtuple\u001b[39m):\n", - "File \u001b[0;32m~/Documents/concrete-ml/.venv/lib/python3.10/site-packages/concrete/fhe/compilation/circuit.py:173\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 165\u001b[0m ordered_validated_args \u001b[38;5;241m=\u001b[39m validate_input_args(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39mclient_specs, \u001b[38;5;241m*\u001b[39margs)\n\u001b[1;32m 167\u001b[0m exporter \u001b[38;5;241m=\u001b[39m SimulatedValueExporter\u001b[38;5;241m.\u001b[39mnew(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39mclient_specs\u001b[38;5;241m.\u001b[39mclient_parameters)\n\u001b[1;32m 168\u001b[0m exported \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 169\u001b[0m (\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m arg \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m Value(\n\u001b[0;32m--> 173\u001b[0m \u001b[43mexporter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mposition\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43marg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtolist\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(arg, np\u001b[38;5;241m.\u001b[39mndarray) \u001b[38;5;129;01mand\u001b[39;00m arg\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m ()\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m exporter\u001b[38;5;241m.\u001b[39mexport_scalar(position, \u001b[38;5;28mint\u001b[39m(arg))\n\u001b[1;32m 176\u001b[0m )\n\u001b[1;32m 177\u001b[0m )\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m position, arg \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(ordered_validated_args)\n\u001b[1;32m 179\u001b[0m ]\n\u001b[1;32m 181\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;241m*\u001b[39mexported)\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(results, \u001b[38;5;28mtuple\u001b[39m):\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ - "torch.manual_seed(SEED)\n", - "\n", "train_custom_model(hybrid_model, train_dataloader, training_args, fhe=\"simulate\")" ] }, @@ -494,7 +372,7 @@ "source": [ "fine_tuned_model = hybrid_model.model.inference_model\n", "\n", - "hybrid_model.set_fhe_mode(\"disable\")" + "hybrid_model.set_fhe_mode(\"simulate\")" ] }, { @@ -525,11 +403,13 @@ "metadata": {}, "outputs": [], "source": [ - "with peft_model.disable_adapter_layers():\n", - " # Example usage\n", - " prompt = \"What is FHE ?\"\n", - " generated_text = generate_text(prompt, fine_tuned_model, tokenizer)\n", - " print(generated_text)" + "peft_model.disable_adapter_layers()\n", + "# Example usage\n", + "prompt = \"What is FHE ?\"\n", + "generated_text = generate_text(prompt, fine_tuned_model, tokenizer)\n", + "print(generated_text)\n", + "\n", + "peft_model.enable_adapter_layers()" ] }, { diff --git a/use_case_examples/lora_finetune/lora_module.py b/use_case_examples/lora_finetune/lora_module.py new file mode 100644 index 0000000000..6fde7d0471 --- /dev/null +++ b/use_case_examples/lora_finetune/lora_module.py @@ -0,0 +1,65 @@ +import torch + + +class LoraTraining(torch.nn.Module): + def __init__(self, inference_model, gradient_accumulation_steps) -> None: + super().__init__() + + self.inference_model = inference_model + + self.optimizer = None + self.lr_scheduler = None + + self.gradient_accumulation_steps = gradient_accumulation_steps + self.max_grad_norm = None + + self.calibrate = False + self.run_optimizer = False + + def update_training_parameters(self, optimizer, lr_scheduler, training_args): + assert self.gradient_accumulation_steps == training_args.gradient_accumulation_steps + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.max_grad_norm = training_args.max_grad_norm + + def forward(self, inputs): + # Remove this once hybrid model supports multiple inputs + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4568 + x, y = inputs + + # some parts on server side + outputs = self.inference_model(input_ids=x, labels=y) + + loss = outputs.loss + loss = loss / self.gradient_accumulation_steps + + # Update gradients + loss.backward() + + grad_norm = None + if not self.calibrate and self.run_optimizer: + assert self.optimizer is not None + assert self.lr_scheduler is not None + assert self.max_grad_norm is not None + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.inference_model.parameters(), max_norm=self.max_grad_norm, norm_type=2 + ) + + self.optimizer.step() + self.lr_scheduler.step() + + self.inference_model.zero_grad() + + # Clean gradients after calibration + elif self.calibrate: + self.inference_model.zero_grad() + + return (loss, grad_norm) + + def toggle_calibrate(self, enable: bool = True): + self.calibrate = enable + + def toggle_run_optimizer(self, enable: bool = True): + self.run_optimizer = enable diff --git a/use_case_examples/lora_finetune/custom_module.py b/use_case_examples/lora_finetune/remote_module.py similarity index 100% rename from use_case_examples/lora_finetune/custom_module.py rename to use_case_examples/lora_finetune/remote_module.py diff --git a/use_case_examples/lora_finetune/requirements.txt b/use_case_examples/lora_finetune/requirements.txt index 27e353aed6..ef6bef917d 100644 --- a/use_case_examples/lora_finetune/requirements.txt +++ b/use_case_examples/lora_finetune/requirements.txt @@ -1,4 +1,3 @@ -# FIXME: Only works with source concrete-ml==1.6.1 transformers==4.41.2 peft==0.11.1