Skip to content

Commit

Permalink
Update notebook training code
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 12, 2024
1 parent 6fd174f commit 3c67e6f
Showing 1 changed file with 199 additions and 0 deletions.
199 changes: 199 additions & 0 deletions notebooks/01_LongRoPE_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,205 @@
" return torch.exp(loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(\n",
" model,\n",
" train_loader,\n",
" val_loader,\n",
" optimizer,\n",
" criterion,\n",
" scheduler,\n",
" tokenizer,\n",
" epochs=10,\n",
" gradient_accumulation_steps=4,\n",
" resume_from_checkpoint=None,\n",
" max_steps=None,\n",
"):\n",
" \"\"\"\n",
" Train the LongRoPE model.\n",
"\n",
" Args:\n",
" model (nn.Module): The LongRoPE model to train.\n",
" train_loader (DataLoader): DataLoader for training data.\n",
" val_loader (DataLoader): DataLoader for validation data.\n",
" optimizer (Optimizer): Optimizer for updating model parameters.\n",
" criterion (nn.Module): Loss function.\n",
" scheduler (LRScheduler): Learning rate scheduler.\n",
" tokenizer: Tokenizer for encoding/decoding text.\n",
" epochs (int): Number of training epochs.\n",
" gradient_accumulation_steps (int): Number of steps to accumulate gradients.\n",
" resume_from_checkpoint (str): Path to a checkpoint to resume training from.\n",
" max_steps (int): Maximum number of steps to train. If None, train for full epochs.\n",
"\n",
" Returns:\n",
" None\n",
" \"\"\"\n",
" # Initialize the gradient scaler for mixed precision training\n",
" scaler = GradScaler()\n",
"\n",
" # Variables for early stopping\n",
" best_val_loss = float(\"inf\")\n",
" patience = 0\n",
" max_patience = 3\n",
" start_epoch = 0\n",
" global_step = 0\n",
"\n",
" # Check if resuming from a checkpoint\n",
" if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):\n",
" checkpoint = accelerator.load_state(resume_from_checkpoint)\n",
" start_epoch = checkpoint.get(\"epoch\", 0) + 1\n",
" global_step = checkpoint.get(\"global_step\", 0)\n",
" best_val_loss = checkpoint.get(\"best_val_loss\", float(\"inf\"))\n",
" logger.info(\n",
" f\"Resumed training from {resume_from_checkpoint} at epoch {start_epoch}, step {global_step}\"\n",
" )\n",
"\n",
" for epoch in range(start_epoch, epochs):\n",
" model.train()\n",
" total_loss = 0\n",
"\n",
" for i, (inputs, targets) in enumerate(train_loader):\n",
" if max_steps and global_step >= max_steps:\n",
" break\n",
"\n",
" # Move data to the appropriate device (CPU or GPU)\n",
" inputs, targets = (\n",
" inputs.to(accelerator.device),\n",
" targets.to(accelerator.device),\n",
" )\n",
"\n",
" # Use mixed precision training\n",
" with autocast():\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.permute(0, 2, 1), targets)\n",
" # Normalize the loss to account for gradient accumulation\n",
" loss = loss / gradient_accumulation_steps\n",
"\n",
" # Backpropagate and accumulate gradients\n",
" scaler.scale(loss).backward()\n",
"\n",
" if (i + 1) % gradient_accumulation_steps == 0:\n",
" # Update weights and reset gradients\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" optimizer.zero_grad()\n",
" global_step += 1\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" if max_steps and global_step >= max_steps:\n",
" break\n",
"\n",
" # Calculate average training loss and perplexity\n",
" avg_train_loss = total_loss / len(train_loader)\n",
" train_perplexity = compute_perplexity(avg_train_loss)\n",
"\n",
" # Validation step\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for inputs, targets in val_loader:\n",
" inputs, targets = (\n",
" inputs.to(accelerator.device),\n",
" targets.to(accelerator.device),\n",
" )\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.permute(0, 2, 1), targets)\n",
" val_loss += loss.item()\n",
"\n",
" # Calculate average validation loss and perplexity\n",
" avg_val_loss = val_loss / len(val_loader)\n",
" val_perplexity = compute_perplexity(avg_val_loss)\n",
"\n",
" # Update learning rate\n",
" scheduler.step()\n",
"\n",
" # Evaluate passkey retrieval at the end of each epoch and log results\n",
" passkey_accuracies = evaluate_passkey_retrieval(model, tokenizer, model.max_len)\n",
" for length, accuracy in passkey_accuracies.items():\n",
" wandb.log({f\"passkey_retrieval_{length}\": accuracy})\n",
" logger.info(\n",
" f\"Passkey retrieval accuracy at {length} tokens: {accuracy:.2f}\"\n",
" )\n",
"\n",
" # Log gradient norm\n",
" total_norm = 0\n",
" for p in model.parameters():\n",
" if p.grad is not None:\n",
" param_norm = p.grad.data.norm(2)\n",
" total_norm += param_norm.item() ** 2\n",
" total_norm = total_norm**0.5\n",
" wandb.log({\"gradient_norm\": total_norm})\n",
" logger.info(f\"Gradient norm: {total_norm:.4f}\")\n",
"\n",
" # Log metrics\n",
" wandb.log(\n",
" {\n",
" \"epoch\": epoch,\n",
" \"global_step\": global_step,\n",
" \"train_loss\": avg_train_loss,\n",
" \"train_perplexity\": train_perplexity,\n",
" \"val_loss\": avg_val_loss,\n",
" \"val_perplexity\": val_perplexity,\n",
" \"learning_rate\": scheduler.get_last_lr()[0],\n",
" }\n",
" )\n",
"\n",
" # Log epoch results\n",
" logger.info(\n",
" f\"Epoch {epoch+1}, Global Step {global_step}, \"\n",
" f\"Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, \"\n",
" f\"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}\"\n",
" )\n",
"\n",
" # Save checkpoint\n",
" accelerator.save_state(\n",
" {\n",
" \"epoch\": epoch,\n",
" \"global_step\": global_step,\n",
" \"best_val_loss\": best_val_loss,\n",
" },\n",
" f\"checkpoint_epoch_{epoch}_step_{global_step}.pt\",\n",
" )\n",
"\n",
" # Save latest checkpoint\n",
" accelerator.save_state(\n",
" {\n",
" \"epoch\": epoch,\n",
" \"global_step\": global_step,\n",
" \"best_val_loss\": best_val_loss,\n",
" },\n",
" \"checkpoint_latest.pt\",\n",
" )\n",
"\n",
" # Early stopping\n",
" if avg_val_loss < best_val_loss:\n",
" best_val_loss = avg_val_loss\n",
" patience = 0\n",
" # Save best model\n",
" accelerator.save_state(\n",
" {\n",
" \"epoch\": epoch,\n",
" \"global_step\": global_step,\n",
" \"best_val_loss\": best_val_loss,\n",
" },\n",
" \"best_model.pt\",\n",
" )\n",
" else:\n",
" patience += 1\n",
" if patience >= max_patience:\n",
" logger.info(\"Early stopping triggered\")\n",
" break\n",
"\n",
" if max_steps and global_step >= max_steps:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 3c67e6f

Please sign in to comment.