diff --git a/train.py b/train.py index 70e68b5..8bb95f1 100644 --- a/train.py +++ b/train.py @@ -14,12 +14,14 @@ from importlib import reload import src.main from accelerate import Accelerator +from tqdm import tqdm import wandb import os import logging import hashlib import pickle -from tqdm import tqdm +import GPUtil + from evaluation import evaluate_passkey_retrieval @@ -329,6 +331,12 @@ def train( f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}" ) + # Log GPU memory usage + for gpu in GPUtil.getGPUs(): + gpu_memory_used = gpu.memoryUsed + logger.info(f"GPU {gpu.id} memory use: {gpu_memory_used}MB") + wandb.log({f"GPU_{gpu.id}_memory_used": gpu_memory_used}) + # Save checkpoint accelerator.save_state( {