From 5e8361c06f4652ec3d98d6ab7cb1760e75896cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Fri, 20 Sep 2024 05:26:08 +0000 Subject: [PATCH] adapt it to the current lighteval main --- create_config.py | 108 ++---- launcher.py | 249 ++++++++------ pyproject.toml | 2 +- slurm/run_eval.slurm.jinja | 27 +- slurm/run_eval_s3.slurm.jinja | 27 +- src/nanotron/config/config.py | 49 ++- src/nanotron/config/lighteval_config.py | 61 +--- src/nanotron/config/utils_config.py | 2 +- src/nanotron/lighteval/evaluation_tasks.py | 361 ++++++++++----------- src/nanotron/lighteval/one_job_runner.py | 86 +++-- src/nanotron/lighteval/run_evals.py | 12 +- src/nanotron/trainer.py | 64 ++-- 12 files changed, 520 insertions(+), 528 deletions(-) diff --git a/create_config.py b/create_config.py index 31242c90..f09df36e 100644 --- a/create_config.py +++ b/create_config.py @@ -1,57 +1,42 @@ -import os -from pathlib import Path -import subprocess -from datetime import datetime -import math -import torch - import argparse +import math +from datetime import datetime +from pathlib import Path -from nanotron.models.llama import LlamaConfig - +import torch from nanotron.config import ( + AdamWOptimizerArgs, + CheckpointsArgs, Config, DataArgs, - NanosetDatasetsArgs, - PretrainDatasetsArgs, - S3UploadArgs, - CheckpointsArgs, + DatasetStageArgs, GeneralArgs, - LightEvalConfig, - LightEvalLoggingArgs, - LightEvalTasksArgs, LoggingArgs, LRSchedulerArgs, ModelArgs, OptimizerArgs, - AdamWOptimizerArgs, ParallelismArgs, + PretrainDatasetsArgs, RandomInit, TokenizerArgs, TokensArgs, - DatasetStageArgs, ) +from nanotron.models.llama import LlamaConfig if __name__ == "__main__": ########################################### ## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU) - HF_USER_OR_ORG = "eliebak" + HF_USER_OR_ORG = None TRAIN_STEPS = 100 CHECKPOINT_INTERVAL = 200 - SAVE_NAME="smollm-135M-1gpu-toy" - + SAVE_NAME = "smollm-135M-1gpu-toy" ########################################### parser = argparse.ArgumentParser() parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml") parser.add_argument("--seed", help="seed", type=int, default=8) - parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") - parser.add_argument("--logs-path", help="path to the logs folder", type=str) - parser.add_argument("--run", help="name of the run", type=str) - parser.add_argument("--slurm", help="use slurm", action="store_true") - parser.add_argument("--nodes", help="specify the number of nodes", type=int) args = parser.parse_args() timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -78,7 +63,7 @@ rope_scaling=None, tie_word_embeddings=True, use_cache=True, - vocab_size=49152, + vocab_size=49152, ) # Uncomment to evaluate the model on a set of tasks with lighteval during the training. @@ -100,24 +85,16 @@ # ), # batch_size=16, # logging=LightEvalLoggingArgs( - # local_output_path="lighteval-logs", - # private=True, - # push_details_to_hub=True, - # push_results_to_hub=True, - # push_results_to_tensorboard=True, - # hf_user_or_org=HF_USER_OR_ORG, - # hub_repo_results="lighteval-results", - # hub_repo_details="lighteval-details", - # hub_repo_tensorboard="smollm-evals-visualization", + # output_dir=None, + # push_to_hub=True, + # push_to_tensorboard=True, + # public_run=False, + # results_org=HF_USER_OR_ORG, # tensorboard_metric_prefix="eval", # ), - # temp_dir = "temp_dir", - # slurm_template="slurm/run_eval.slurm.jinja", - # # slurm_template="slurm/run_eval_s3.slurm.jinja", if s3 - # ) - lighteval = None + # lighteval = None checkpoints = CheckpointsArgs( # checkpoints_path="checkpoints", @@ -137,7 +114,7 @@ ) tokens = TokensArgs( - batch_accumulation_per_replica=8, + batch_accumulation_per_replica=1, micro_batch_size=8, sequence_length=2048, train_steps=TRAIN_STEPS, @@ -147,7 +124,7 @@ model = ModelArgs( model_config=model_config, init_method=RandomInit( - std=1/math.sqrt(model_config.hidden_size), + std=1 / math.sqrt(model_config.hidden_size), ), dtype=torch.bfloat16, ) @@ -164,12 +141,11 @@ lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="linear", - lr_decay_steps = 20, - lr_decay_starting_step=80 , + lr_decay_steps=20, + lr_decay_starting_step=80, min_decay_lr=0, ) - optimizer = OptimizerArgs( zero_stage=0, weight_decay=0.01, @@ -197,12 +173,12 @@ # s5cmd_path="PATH_TO_S5CMD", # ) - data_stages=[ + data_stages = [ DatasetStageArgs( data=DataArgs( # 1. Un-tokenized dataset from HuggingFace dataset=PretrainDatasetsArgs( - hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory + hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory hf_dataset_splits="train", hf_dataset_config_name="cosmopedia-v2", text_column_name="text", @@ -250,42 +226,12 @@ lighteval=lighteval, ) - save_path= Path(args.save_path) + save_path = Path(args.save_path) save_path.mkdir(parents=True, exist_ok=True) config_path_yaml = save_path / f"{SAVE_NAME}.yaml" config.save_as_yaml(config_path_yaml) print(f"πŸ’Ύ Configuration saved in: {str(save_path)}") - - if args.launch: - - # Sanity check for logs_path and run - if not args.logs_path: - raise ValueError("--logs_path must be defined. Please provide a path for the logs.") - if not args.run: - raise ValueError("--run must be defined. Please provide a name for the run.") - - launcher_path = Path("launcher.py") - if not launcher_path.exists(): - raise FileNotFoundError(f"Launcher not found at {launcher_path}. Please ensure the file exists or change the launcher path in the create_config.py file.") - launch_command = [ - "python", str(launcher_path), - "--config-path", str(config_path_yaml), - ] - launch_command.extend([ - "--logs-path", args.logs_path, - "--run", args.run - ]) - if args.slurm: - launch_command.append("--slurm") - - if args.nodes: - launch_command.extend(["--nodes", str(args.nodes)]) - - - print(f"πŸ§ͺ Launching configuration with command: {' '.join(launch_command)}") - subprocess.run(launch_command, check=True) - else: - print("To launch this configuration, run:") - print(f"python 'launcher.py' configs/{str(config_path_yaml)}") \ No newline at end of file + print("To launch this configuration, run:") + print(f"python launcher.py --config-path configs/{str(config_path_yaml)}") diff --git a/launcher.py b/launcher.py index 49df6c1f..c8dd4a75 100644 --- a/launcher.py +++ b/launcher.py @@ -1,23 +1,26 @@ +import argparse +import json import os -from pathlib import Path import subprocess import tempfile from datetime import datetime +from pathlib import Path + import torch -import argparse -import json from jinja2 import Template - -from nanotron.logging import human_format - from nanotron.config import ( Config, get_config_from_file, + save_as_yaml, ) +from nanotron.config.lighteval_config import LightEvalConfig +from nanotron.logging import human_format + def count_subdirectories(path): return sum(os.path.isdir(os.path.join(path, item)) for item in os.listdir(path)) + def launch_slurm_job(launch_file_contents, *args): """ Small helper function to save a sbatch script and call it. @@ -33,22 +36,26 @@ def launch_slurm_job(launch_file_contents, *args): f.flush() return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] + def set_nested_attribute(obj, path, value): - parts = path.split('.') + parts = path.split(".") for part in parts[:-1]: if not hasattr(obj, part): - setattr(obj, part, type('', (), {})()) + setattr(obj, part, type("", (), {})()) obj = getattr(obj, part) setattr(obj, parts[-1], value) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None, required=True) - parser.add_argument("--run", help="name of the run", type=str, required=True) + parser.add_argument("--project", help="name of the project", type=str) + parser.add_argument("--run", help="name of the run", type=str) parser.add_argument("--logs-path", help="path to the logs folder", type=str, default="logs") - parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", - help="Override config values. Use dot notation for nested keys.") + parser.add_argument( + "--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys." + ) parser.add_argument("--slurm", action="store_true", help="Launch the job on Slurm") parser.add_argument("--nodes", type=int, help="Number of nodes to use for the job") args = parser.parse_args() @@ -65,27 +72,34 @@ def set_nested_attribute(obj, path, value): if config.general.logs_path is None and args.logs_path is None: raise ValueError("Please provide a logs path") + if config.general.project is None and args.project is None: + raise ValueError("Please provide a project name") + elif args.project is not None: + config.general.project = args.project + + if config.general.run is None and args.run is None: + raise ValueError("Please provide a run name") + elif args.run is not None: + config.general.run = args.run - num_params = human_format( - config.model.model_config.get_llama_param_count() - ).replace(".", ",") + num_params = human_format(config.model.model_config.get_llama_param_count()).replace(".", ",") if args.override: for item in args.override: - if '=' not in item: + if "=" not in item: raise ValueError(f"Invalid override format: {item}. Use KEY=VALUE.") - key, value = item.split('=', 1) + key, value = item.split("=", 1) try: value = eval(value) except: pass - + set_nested_attribute(config, key, value) print("⇄ Applied overrides:") for item in args.override: print(f" {item}") - + # Calculate and print learning rate and global batch size information lr_initial = config.optimizer.learning_rate_scheduler.learning_rate lr_min = config.optimizer.learning_rate_scheduler.min_decay_lr @@ -99,13 +113,14 @@ def set_nested_attribute(obj, path, value): bs_gpu_token = bs_gpu_sample * config.tokens.sequence_length # Sample/Token in one step - gbs_sample = bs_gpu_sample * config.parallelism.dp*config.tokens.batch_accumulation_per_replica + gbs_sample = bs_gpu_sample * config.parallelism.dp * config.tokens.batch_accumulation_per_replica gbs_token = gbs_sample * config.tokens.sequence_length total_tokens = config.tokens.train_steps * gbs_token total_tokens_billions = human_format(total_tokens).replace(".", ",") - print(f""" + print( + f""" πŸ‹οΈ Model Parameters: β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Total Parameters β”‚ {num_params:>22} β”‚ @@ -117,30 +132,36 @@ def set_nested_attribute(obj, path, value): β”‚ Tokenizer β”‚ {config.tokenizer.tokenizer_name_or_path[:22]:>22} β”‚ β”‚ Vocab Size β”‚ {config.model.model_config.vocab_size:>22d} β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -""") +""" + ) num_nodes = args.nodes if args.slurm else 1 - print(f""" + print( + f""" πŸŽ›οΈ Parallelism Configuration: β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Nodes β”‚ {num_nodes:>22d} β”‚ β”‚ Total GPUs β”‚ {config.parallelism.dp*config.parallelism.pp*config.parallelism.tp:>22d} β”‚ β”‚ Data Parallel (DP) β”‚ {config.parallelism.dp:>22d} β”‚ β”‚ Pipeline Parallel (PP)β”‚ {config.parallelism.pp:>22d} β”‚ -β”‚ Tensor Parallel (TP) β”‚ {config.parallelism.tp:>22d} β”‚ +β”‚ Tensor Parallel (TP) β”‚ {config.parallelism.tp:>22d} β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -""") +""" + ) - print(f""" + print( + f""" πŸ“™ Training Configuration: β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Total Tokens β”‚ {total_tokens_billions:>22} β”‚ β”‚ Batch Size (per GPU) β”‚ {bs_gpu_token:>15,d} Tokens β”‚ β”‚ Global Batch Size β”‚ {gbs_token:>15,d} Tokens β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -""") +""" + ) - print(f""" + print( + f""" πŸ“Š Learning Rate Schedule: β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Initial LR β”‚ {lr_initial:>22.2e} β”‚ @@ -151,8 +172,10 @@ def set_nested_attribute(obj, path, value): β”‚ Decay Steps β”‚ {lr_decay_steps:>22d} β”‚ β”‚ Final LR β”‚ {lr_min:>22.2e} β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -""") - print(f""" +""" + ) + print( + f""" πŸ”§ Optimization Configuration: β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ Optimizer β”‚ {config.optimizer.optimizer_factory.__class__.__name__:>22} β”‚ @@ -164,88 +187,109 @@ def set_nested_attribute(obj, path, value): β”‚ ZeRO Stage β”‚ {config.optimizer.zero_stage:>22d} β”‚ β”‚ FP32 Grad Accumulationβ”‚ {str(config.optimizer.accumulate_grad_in_fp32):>22} β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -""") +""" + ) config.general.logs_path = args.logs_path - config.general.run = args.run - path = Path(args.logs_path) / f"{args.run}" + path = Path(args.logs_path) / f"{config.general.run}" path.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - run_number = count_subdirectories(f"{args.logs_path}/{args.run}") + 1 + run_number = count_subdirectories(f"{args.logs_path}/{config.general.run}") + 1 timestamp_with_run = f"run{run_number:03d}_{timestamp}" config.general.timestamp_with_run = timestamp_with_run - config.general.config_logs_path = str(Path(config.general.logs_path) / args.run / timestamp_with_run / "config") + config.general.config_logs_path = str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "config" + ) Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) if config.checkpoints.checkpoints_path is None: - config.checkpoints.checkpoints_path = str(Path(config.general.logs_path) / args.run / timestamp_with_run / "checkpoints") + config.checkpoints.checkpoints_path = str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "checkpoints" + ) Path(config.checkpoints.checkpoints_path).mkdir(parents=True, exist_ok=True) - if args.slurm: - + nodes = args.nodes launch_slurm_config_path = Path("slurm/launch_slurm_config.json") - eval_slurm_config_path = Path("slurm/eval_slurm_config.json") - - with open(launch_slurm_config_path, 'r') as f: + if config.lighteval is not None: + eval_slurm_config_path = Path("slurm/eval_slurm_config.json") + if eval_slurm_config_path.exists(): + config.general.eval_slurm_config = str(eval_slurm_config_path.resolve()) + else: + raise ValueError("Lighteval SLURM configuration is required but not provided.") + if config.general.is_s3_available: + config.general.eval_slurm_template = "slurm/run_eval_s3.slurm.jinja" + else: + config.general.eval_slurm_template = "slurm/run_eval.slurm.jinja" + + with open(launch_slurm_config_path, "r") as f: launch_slurm_config = json.load(f) - - + total_gpus = config.parallelism.dp * config.parallelism.pp * config.parallelism.tp - gpus_per_node = launch_slurm_config.get('gpus_per_node') - required_nodes = (total_gpus + gpus_per_node - 1) // gpus_per_node # Ceiling division + gpus_per_node = launch_slurm_config.get("gpus_per_node") + if total_gpus < gpus_per_node: + required_nodes = 1 + gpus_per_node = total_gpus + print( + "Warning: The total number of GPUs is less than the GPUs per node. You need to adjust to use all available GPUs." + ) + else: + required_nodes = (total_gpus + gpus_per_node - 1) // gpus_per_node # Ceiling division if args.nodes != required_nodes: - raise ValueError(f"Number of nodes in config ({args.nodes}) does not match the required number of nodes ({required_nodes}) based on the parallelism configuration.") + raise ValueError( + f"Number of nodes in config ({args.nodes}) does not match the required number of nodes ({required_nodes}) based on the parallelism configuration." + ) - # Create necessary folders project_log_folder = Path(config.general.logs_path) - log_folder = project_log_folder / f"{args.run}"/ f"{timestamp_with_run}" - subfolders = ['launch-script', 'slurm-logs'] - if hasattr(config, 'lighteval') and config.lighteval is not None: - subfolders.append('evals') + log_folder = project_log_folder / f"{config.general.run}" / f"{timestamp_with_run}" + subfolders = ["launch-script", "slurm-logs"] + if hasattr(config, "lighteval") and config.lighteval is not None: + subfolders.append("evals") for subfolder in subfolders: folder_path = str(log_folder / subfolder) Path(folder_path).mkdir(parents=True, exist_ok=True) - if subfolder == 'launch-script': + if subfolder == "launch-script": config.general.launch_script_path = folder_path - elif subfolder == 'slurm-logs': + elif subfolder == "slurm-logs": config.general.slurm_logs_path = folder_path - elif subfolder == 'evals': + elif subfolder == "evals": config.general.evals_logs_path = folder_path - for evals_subfolder in ['launch-config', 'logs',"lighteval-logs"]: + for evals_subfolder in ["launch-config", "logs", "lighteval-logs"]: if evals_subfolder == "lighteval-logs": - if config.lighteval.logging.local_output_path is None: + if config.lighteval.logging.output_dir is None: evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) - config.lighteval.logging.local_output_path = evals_subfolder_path + config.lighteval.logging.output_dir = evals_subfolder_path else: evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) torchrun_args = "" - if 'torchrun_args' in launch_slurm_config and launch_slurm_config['torchrun_args']: - torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config['torchrun_args'].items()]) - - launch_slurm_config.update({ - "job_name": f"{config.general.project}-{config.general.run}", - "nodes": args.nodes, - "slurm_logs_path": config.general.slurm_logs_path, - "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), - "config_path_yaml": f"{config.general.config_logs_path}/launch.yaml", - "torchrun_args": torchrun_args, - }) + if "torchrun_args" in launch_slurm_config and launch_slurm_config["torchrun_args"]: + torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config["torchrun_args"].items()]) + + launch_slurm_config.update( + { + "job_name": f"{config.general.project}-{config.general.run}", + "nodes": args.nodes, + "slurm_logs_path": config.general.slurm_logs_path, + "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), + "config_path_yaml": f"{config.general.config_logs_path}/launch_config.yaml", + "torchrun_args": torchrun_args, + } + ) # Load Jinja2 template template_path = Path("slurm/launch_training.slurm.jinja") - with open(template_path, 'r') as f: + with open(template_path, "r") as f: template = Template(f.read()) # Render the template @@ -254,15 +298,18 @@ def set_nested_attribute(obj, path, value): config.general.launch_slurm_config = str(launch_slurm_config_path.resolve()) else: config.general.launch_slurm_config = None - if eval_slurm_config_path.exists(): - config.general.eval_slurm_config = str(eval_slurm_config_path.resolve()) - else: - config.general.eval_slurm_config = None - config_path_yaml = str(Path(config.general.config_logs_path) / "launch.yaml") + if config.lighteval is not None: + # Save the lighteval configuration + lighteval_config = config.lighteval + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + config.general.lighteval_config_path = str(Path(config.general.config_logs_path) / "lighteval_config.yaml") + save_as_yaml(lighteval_config, LightEvalConfig, config.general.lighteval_config_path) + + config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) config.save_as_yaml(config_path_yaml) - + # Launch the Slurm job job_id = launch_slurm_job(sbatch_script) print(f"πŸš€ Slurm job launched with id={job_id}") @@ -270,34 +317,36 @@ def set_nested_attribute(obj, path, value): # Save the Slurm script if a path is provided if config.general.launch_script_path: Path(config.general.launch_script_path).mkdir(parents=True, exist_ok=True) - script_filename = f"slurm_launch_script.slurm" + script_filename = "slurm_launch_script.slurm" script_path = str(Path(config.general.launch_script_path) / script_filename) script_path = os.path.join(config.general.launch_script_path, script_filename) - - with open(script_path, 'w') as f: + + with open(script_path, "w") as f: f.write(sbatch_script) - print(f" πŸ€– Slurm Configuration Details:") + print(" πŸ€– Slurm Configuration Details:") - slurm_config_keys = ['qos', 'gpus_per_node', 'cpus_per_task', 'constraint', 'account', 'reservation'] + slurm_config_keys = ["qos", "gpus_per_node", "cpus_per_task", "constraint", "account", "reservation"] for key in slurm_config_keys: if key in launch_slurm_config: if launch_slurm_config[key] is not None: print(f" {key}: {launch_slurm_config[key]}") - + print(" ") print(" πŸ“ Log structure:") print(f" {config.general.logs_path}/{config.general.run}/") print(f" └── {timestamp_with_run}/") - if config.checkpoints.checkpoints_path == str(Path(config.general.logs_path) / args.run / timestamp_with_run / "checkpoints"): + if config.checkpoints.checkpoints_path == str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "checkpoints" + ): print(" β”œβ”€β”€ checkpoints/") print(" β”œβ”€β”€ config/") print(" β”œβ”€β”€ launch-script/") print(" β”œβ”€β”€ slurm-logs/") - if hasattr(config, 'lighteval') and config.lighteval is not None: + if hasattr(config, "lighteval") and config.lighteval is not None: print(" └── evals/") print(" β”œβ”€β”€ launch-config/") print(" └── logs/") - if config.lighteval.logging.local_output_path== str(Path(config.general.evals_logs_path) / "lighteval-logs"): + if config.lighteval.logging.output_dir == str(Path(config.general.evals_logs_path) / "lighteval-logs"): print(" └── lighteval-logs/") else: @@ -312,27 +361,35 @@ def set_nested_attribute(obj, path, value): print("πŸ’» Running on an interactive node with GPUs.") gpu_config = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp if gpu_count < gpu_config: - raise ValueError(f"Error: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " - f"requires {gpu_config} GPUs, but only {gpu_count} are available.") + raise ValueError( + f"Error: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"requires {gpu_config} GPUs, but only {gpu_count} are available." + ) elif gpu_count == gpu_config: - print(f"πŸš€ Running on {gpu_count} GPUs, which matches your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp})") - total_gpus= gpu_count + print( + f"πŸš€ Running on {gpu_count} GPUs, which matches your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp})" + ) + total_gpus = gpu_count elif gpu_count > gpu_config: - total_gpus= gpu_config - print(f"⚠️ Warning: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " - f"uses {total_gpus} GPUs, but {gpu_count} are available. " - f"You are not fully utilizing all available GPUs on this device.") - - config_path_yaml = str(Path(config.general.config_logs_path) / "launch.yaml") + total_gpus = gpu_config + print( + f"⚠️ Warning: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"uses {total_gpus} GPUs, but {gpu_count} are available. " + f"You are not fully utilizing all available GPUs on this device." + ) + + config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") os.makedirs(config.general.config_logs_path, exist_ok=True) config.save_as_yaml(config_path_yaml) trainer_python_file = "run_train.py" - cmd = f"{trainer_python_file} --config-file {args.config_path}" + cmd = f"{trainer_python_file} --config-file {config_path_yaml}" launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {total_gpus} {cmd}" print(f"πŸš€ Launching interactive job with command: {launch_cmd}") - + subprocess.run(launch_cmd, shell=True, check=True) else: - print("❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file + print( + "❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs." + ) diff --git a/pyproject.toml b/pyproject.toml index 802a30ab..dbde7f0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ s3 = [ ] lighteval = [ - "lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git@nanotron-compatible", + "lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git", ] [build-system] requires = [ diff --git a/slurm/run_eval.slurm.jinja b/slurm/run_eval.slurm.jinja index 8cd3ee5a..6444858a 100644 --- a/slurm/run_eval.slurm.jinja +++ b/slurm/run_eval.slurm.jinja @@ -70,17 +70,26 @@ echo go $COUNT_NODE echo $HOSTNAMES -torch_dist_args="--nproc_per_node 8 \ +CMD="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ + --lighteval-config-path {{ lighteval_config_path }} \ + " + +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ --max_restarts 0 \ --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ " -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" diff --git a/slurm/run_eval_s3.slurm.jinja b/slurm/run_eval_s3.slurm.jinja index ee467274..04441638 100644 --- a/slurm/run_eval_s3.slurm.jinja +++ b/slurm/run_eval_s3.slurm.jinja @@ -75,17 +75,26 @@ echo $HOSTNAMES mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER -torch_dist_args="--nproc_per_node 8 \ +CMD="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \ + --lighteval-config-path {{ lighteval_config_path }} \ + " + +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ --max_restarts 0 \ --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ " -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 488ebf96..27105ee8 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -1,18 +1,16 @@ import datetime import os from dataclasses import dataclass, fields -import pathlib from pathlib import Path -from datasets.download.streaming_download_manager import xPath -from typing import List, Optional, Type, Union, Dict +from typing import List, Optional, Type, Union import dacite import torch import yaml from dacite import from_dict -from datasets.download.streaming_download_manager import xPath from yaml.loader import SafeLoader +from datasets.download.streaming_download_manager import xPath from nanotron.config.lighteval_config import LightEvalConfig from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs @@ -22,11 +20,11 @@ cast_str_to_torch_dtype, serialize, ) -from nanotron.s3_checkpoints import check_path_is_local from nanotron.generation.sampler import SamplerType from nanotron.logging import get_logger from nanotron.parallel.pipeline_parallel.engine import PipelineEngine from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.s3_checkpoints import check_path_is_local logger = get_logger(__name__) @@ -93,12 +91,16 @@ def __post_init__(self): self.text_column_name = "text" if self.hf_dataset_splits is None: self.hf_dataset_splits = "train" + + @dataclass class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" remove_after_upload: bool - upload_s3_path: Optional[str] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 + upload_s3_path: Optional[ + str + ] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 s5cmd_numworkers: Optional[int] = None s5cmd_concurrency: Optional[int] = None s5cmd_path: Optional[str] = None @@ -109,6 +111,7 @@ def __post_init__(self): if isinstance(self.s5cmd_path, str): self.s5cmd_path = Path(self.s5cmd_path) + @dataclass class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" @@ -200,17 +203,20 @@ class GeneralArgs: ignore_sanity_checks: Whether to ignore sanity checks """ - project: str + project: Optional[str] = None run: Optional[str] = None - logs_path: Optional[str] = "logs" + logs_path: Optional[str] = None launch_slurm_config: Optional[str] = None eval_slurm_config: Optional[str] = None + eval_slurm_template: Optional[str] = None + lighteval_config_path: Optional[str] = None + is_s3_available: Optional[bool] = None timestamp_with_run: Optional[str] = None launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None evals_logs_path: Optional[str] = None - temp_dir: Optional[str] = "temp_dir" + temp_dir: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None consumed_train_samples: Optional[int] = None @@ -389,15 +395,18 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): - - if hasattr(self, '_post_init_done'): + + if hasattr(self, "_post_init_done"): return self._post_init_done = True self.general.__post_init__() if self.s3_upload is not None: self.s3_upload.__post_init__() + self.general.is_s3_available = True + else: + self.general.is_s3_available = False # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 @@ -430,18 +439,14 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" - - # if lighteval, we need tokenizer to be defined if self.lighteval is not None: assert self.tokenizer.tokenizer_name_or_path is not None - @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp - def save_as_yaml(self, file_path: str): config_dict = serialize(self) @@ -514,11 +519,10 @@ def get_config_from_file( skip_unused_config_keys: whether to skip unused first-nesting-level keys in the config file (for config with additional sections) skip_null_keys: whether to skip keys with value None at first and second nesting level """ - + with open(config_path) as f: config_dict = yaml.load(f, Loader=SafeLoader) - config = get_config_from_dict( config_dict, config_class=config_class, @@ -532,3 +536,14 @@ def get_config_from_file( ) config.model.model_config = model_config_class(**config.model.model_config) return config + + +def save_as_yaml(config, config_class, file_path: str): + + config_dict = serialize(config) + file_path = str(file_path) + with open(file_path, "w") as f: + yaml.dump(config_dict, f) + + # Sanity test config can be reloaded + _ = get_config_from_file(file_path, config_class=config_class) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index fe11437d..3808d60c 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from pathlib import Path from typing import Dict, Optional, Union from nanotron.config.parallelism_config import ParallelismArgs @@ -32,55 +31,28 @@ def __post_init__(self): @dataclass class LightEvalLoggingArgs: """Arguments related to logging for LightEval""" - local_output_path: Optional[Path] = None - private: Optional[bool] = True - push_results_to_hub: Optional[bool] = None - push_details_to_hub: Optional[bool] = None - push_results_to_tensorboard: Optional[bool] = None - hf_user_or_org: Optional[str] = None - hub_repo_results: Optional[str] = None #path is hf_user_or_org/hub_repo_results - hub_repo_details: Optional[str] = None #path is hf_user_or_org/hub_repo_details - hub_repo_tensorboard: Optional[str] = None - tensorboard_metric_prefix: Optional[str] = None - def __post_init__(self): - if isinstance(self.local_output_path, str): - self.local_output_path = Path(self.local_output_path) - if self.push_results_to_hub is not None and self.hf_user_or_org is None: - raise ValueError("hf_user_or_org must be specified if push_results_to_hub is set") - if self.push_details_to_hub is not None and self.hf_user_or_org is None: - raise ValueError("hf_user_or_org must be specified if push_details_to_hub is set") - if self.hf_user_or_org is not None: - if self.push_results_to_hub is not None and self.hub_repo_results is None: - self.hub_repo_results = "evals-results" - if self.push_details_to_hub is not None and self.hub_repo_details is None: - self.hub_repo_details = "evals-details" + output_dir: Optional[str] = None + save_details: bool = True + push_to_hub: bool = False + push_to_tensorboard: bool = False + public_run: bool = False + results_org: str | None = None + tensorboard_metric_prefix: str = "eval" @dataclass class LightEvalTasksArgs: """Arguments related to tasks for LightEval""" - tasks: Optional[str] = None + tasks: str custom_tasks: Optional[str] = None max_samples: Optional[int] = None num_fewshot_seeds: Optional[int] = None - dataset_loading_processes: Optional[int] = 8 + dataset_loading_processes: int = 8 multichoice_continuations_start_space: Optional[bool] = None - no_multichoice_continuations_start_space: Optional[bool] = None - - -@dataclass -class LightEvalWandbLoggerConfig: - """Arguments related to the local Wandb logger""" - - wandb_project: str = "" - wandb_entity: Optional[str] = None - wandb_run_name: Optional[str] = None - - def __post_init__(self): - assert self.wandb_project != "", "Please specify a wandb_project" + pair_wise_tokenization: bool = False @dataclass @@ -91,13 +63,8 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_template: Optional[str] = None - slurm_script_dir: Optional[str] = None - temp_dir: Optional[str] = "temp_dir" - checkpoints_path: Optional[str] = None - parallelism: Optional[ParallelismArgs] = None - batch_size: Optional[int] = None + logging: LightEvalLoggingArgs + tasks: LightEvalTasksArgs + parallelism: ParallelismArgs + batch_size: int = 0 generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None - tasks: Optional[LightEvalTasksArgs] = None - logging: Optional[LightEvalLoggingArgs] = None - wandb: Optional[LightEvalWandbLoggerConfig] = None diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index 124516cd..87d69585 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -1,10 +1,10 @@ from dataclasses import fields from enum import Enum, auto from pathlib import Path -from datasets.download.streaming_download_manager import xPath import torch +from datasets.download.streaming_download_manager import xPath from nanotron.generation.sampler import SamplerType from nanotron.parallel.pipeline_parallel.engine import ( AllForwardAllBackwardPipelineEngine, diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py index a78fe486..2dd9820c 100644 --- a/src/nanotron/lighteval/evaluation_tasks.py +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -8,10 +8,11 @@ from dataclasses import asdict from typing import Dict, List, Tuple -from lighteval.metrics import Metrics +import lighteval.tasks.default_prompts as prompt +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] _TASKS: List[LightevalTaskConfig] = [] @@ -19,130 +20,123 @@ trust_remote_code = True ## COMMON_SENSE_REASONING_TASKS ## + + +def commonsense_qa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[f" {c}" for c in line["choices"]["text"]], + gold_index=LETTER_INDICES.index(line["answerKey"].strip()), + instruction="", + ) + + +def siqa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["context"] + " " + line["question"], + choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], + gold_index=int(line["label"]) - 1, + instruction="", + ) + + COMMON_SENSE_REASONING_TASKS = [ LightevalTaskConfig( name="hellaswag", - prompt_function="hellaswag_prompt", + prompt_function=prompt.hellaswag_harness, # Updated prompt function hf_repo="hellaswag", hf_subset="default", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], - trust_dataset=True, + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="winogrande", - prompt_function="winogrande", + prompt_function=prompt.winogrande, # Updated prompt function hf_repo="winogrande", hf_subset="winogrande_xl", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="piqa", - prompt_function="piqa_harness", + prompt_function=prompt.piqa_harness, # Updated prompt function hf_repo="piqa", hf_subset="plain_text", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="siqa", - prompt_function="siqa_prompt", + prompt_function=siqa_prompt, # Updated prompt function hf_repo="lighteval/siqa", hf_subset="default", hf_avail_splits=["train", "validation"], - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="openbookqa", - prompt_function="openbookqa", + prompt_function=prompt.openbookqa, # Updated prompt function hf_repo="openbookqa", hf_subset="main", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="arc:easy", - prompt_function="arc", + prompt_function=prompt.arc, # Updated prompt function hf_repo="ai2_arc", hf_subset="ARC-Easy", evaluation_splits=["test"], generation_size=1, - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="arc:challenge", - prompt_function="arc", + prompt_function=prompt.arc, # Updated prompt function hf_repo="ai2_arc", hf_subset="ARC-Challenge", evaluation_splits=["test"], generation_size=1, - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="commonsense_qa", - prompt_function="commonsense_qa_prompt", + prompt_function=commonsense_qa_prompt, # Updated prompt function hf_repo="commonsense_qa", hf_subset="default", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), ] -def commonsense_qa_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["question"], - choices=[f" {c}" for c in line["choices"]["text"]], - gold_index=LETTER_INDICES.index(line["answerKey"].strip()), - instruction="", - ) - - -def siqa_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["context"] + " " + line["question"], - choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], - gold_index=int(line["label"]) - 1, - instruction="", - ) +# 0 short for common sense +COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] +_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) +_TASKS += COMMON_SENSE_REASONING_TASKS +## WORLD_KNOWLEDGE_TASKS ## -def hellaswag_prompt(line, task_name: str = None): - def preprocess(text): - """Comes from AiHarness""" - # text = text.strip() - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - text = text.replace(" [title]", ". ") - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - return text - ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} " +def natural_questions_prompt(line, task_name: str = None): return Doc( task_name=task_name, - query=preprocess(line["activity_label"] + ": " + ctx), - choices=[" " + preprocess(ending) for ending in line["endings"]], - gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test - # "metric": "choices_loglikelihood", + query=line["question"] + "?\nAnswer: ", + choices=[line["short_answers"]], + gold_index=0, + instruction="", ) -# 0 short for common sense -COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] -_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) -_TASKS += COMMON_SENSE_REASONING_TASKS - -## WORLD_KNOWLEDGE_TASKS ## - WORLD_KNOWLEDGE_TASKS = [ LightevalTaskConfig( name="trivia_qa", - prompt_function="triviaqa", + prompt_function=prompt.triviaqa, hf_repo="trivia_qa", hf_subset="rc.nocontext", metric=[Metrics.quasi_exact_match], @@ -152,7 +146,7 @@ def preprocess(text): ), LightevalTaskConfig( name="natural_questions", - prompt_function="natural_questions_prompt", + prompt_function=natural_questions_prompt, hf_repo="lighteval/natural_questions_clean", hf_subset="default", metric=[Metrics.quasi_exact_match], @@ -163,35 +157,33 @@ def preprocess(text): ] -def natural_questions_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["question"] + "?\nAnswer: ", - choices=[line["short_answers"]], - gold_index=0, - instruction="", - ) - - WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS] # WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS} _TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING) _TASKS += WORLD_KNOWLEDGE_TASKS ## Reading comprehension ## +def boolq_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", + choices=[" No", " Yes"], # Only gold + gold_index=int(line["label"]), + ) + READING_COMP_TASKS = [ LightevalTaskConfig( name="super_glue:boolq", - prompt_function="boolq_prompt", + prompt_function=boolq_prompt, hf_repo="super_glue", hf_subset="boolq", - metric=["target_perplexity"], + metric=[Metrics.target_perplexity], trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="quac", - prompt_function="quac", + prompt_function=prompt.quac, hf_repo="lighteval/quac_helm", hf_subset="deault", metric=[Metrics.quasi_exact_match], @@ -202,15 +194,6 @@ def natural_questions_prompt(line, task_name: str = None): ] -def boolq_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", - choices=[" No", " Yes"], # Only gold - gold_index=int(line["label"]), - ) - - READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS] _TASKS_STRINGS.extend(READING_COMP_STRING) _TASKS += READING_COMP_TASKS @@ -223,7 +206,7 @@ class CustomMathEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="math", + prompt_function=prompt.math, hf_repo="lighteval/MATH", hf_subset=None, metric=[Metrics.quasi_exact_match_math], @@ -235,7 +218,7 @@ def __init__( generation_size=40, stop_sequence=None, output_regex=None, - frozen=False, + frozen=False, trust_dataset=trust_remote_code, ): super().__init__( @@ -268,7 +251,7 @@ def __init__( ] GSM8K = LightevalTaskConfig( name="gsm8k", - prompt_function="gsm8k", + prompt_function=prompt.gsm8k, hf_repo="gsm8k", hf_subset="main", hf_avail_splits=["train", "test"], @@ -288,20 +271,55 @@ def __init__( ## MMLU ## +def mmlu_harness(line, task_name: str = None): + topic = line["subject"] + prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + prompt += line["question"] + "\n" + prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + prompt += "Answer:" + + gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] + "__few_shots" in line and line["__few_shots"] is True # We are adding few shots + + return Doc( + task_name=task_name, + query=prompt, + choices=[" A", " B", " C", " D"], + target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + ) + + +def mmlu_prompt(line, task_name: str = None): + """MMLU prompt without letters""" + topic = line["subject"] + prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " + prompt += line["question"] + "\nAnswer:" + + return Doc( + task_name=task_name, + query=prompt, + choices=[f" {c}" for c in line["choices"]], + gold_index=line["answer"], + instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", + ) + + class CustomMMLUEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="mmlu_prompt", + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset=None, # metric=[Metrics.loglikelihood_acc_single_token], - metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], hf_avail_splits=None, evaluation_splits=["test"], few_shots_split="dev", few_shots_select=None, - suite=None, + suite=["custom"], generation_size=-1, stop_sequence=None, output_regex=None, @@ -390,41 +408,6 @@ def __init__( ] -def mmlu_harness(line, task_name: str = None): - topic = line["subject"] - prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" - prompt += line["question"] + "\n" - prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) - prompt += "Answer:" - - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - "__few_shots" in line and line["__few_shots"] is True # We are adding few shots - - return Doc( - task_name=task_name, - query=prompt, - choices=[" A", " B", " C", " D"], - target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], - gold_index=gold_ix, - instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - ) - - -def mmlu_prompt(line, task_name: str = None): - """MMLU prompt without letters""" - topic = line["subject"] - prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " - prompt += line["question"] + "\nAnswer:" - - return Doc( - task_name=task_name, - query=prompt, - choices=[f" {c}" for c in line["choices"]], - gold_index=line["answer"], - instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", - ) - - # MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS} MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS] _TASKS_STRINGS.extend(MMLU_STRING) @@ -433,11 +416,20 @@ def mmlu_prompt(line, task_name: str = None): ## BBH ## +def bbh_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["input"] + "\nAnswer: ", + choices=[line["target"]], + gold_index=0, + ) + + class CustomBBHEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="bbh_prompt", + prompt_function=bbh_prompt, hf_repo="lighteval/big_bench_hard", hf_subset=None, metric=[Metrics.exact_match], @@ -445,7 +437,7 @@ def __init__( evaluation_splits=["train"], few_shots_split="train", few_shots_select=None, - suite=None, + suite=["custom"], generation_size=4, stop_sequence=None, output_regex=None, @@ -510,36 +502,80 @@ def __init__( ] -def bbh_prompt(line, task_name: str = None): +# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS} +BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS] +_TASKS_STRINGS.extend(BBH_STRING) +_TASKS += BBH_TASKS + + +## AGI eval ## + + +def agi_eval_math_prompt(line, task_name: str = None): return Doc( task_name=task_name, - query=line["input"] + "\nAnswer: ", - choices=[line["target"]], + query=line["question"], + choices=[line["answer"]], gold_index=0, + instruction="", ) -# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS} -BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS] -_TASKS_STRINGS.extend(BBH_STRING) -_TASKS += BBH_TASKS +def agi_eval_prompt(line, task_name: str = None): + cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]] + prompt = "The following are multiple choice questions (with answers).\n\n" + prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" + prompt += "Answer: " + + choices = LETTER_INDICES[: len(line["options"])] + + output = Doc( + query=prompt, + instruction="The following are multiple choice questions (with answers).\n\n", + choices=None, # updated below + gold_index=None, # updated below + ) + + if line["label"]: + output.choices = choices + output.gold_index = LETTER_INDICES.index(line["label"].strip()) + else: + output.choices = [line["answer"]] + output.gold_index = 0 + + return output + + +def agi_eval_prompt_no_letters(line, task_name: str = None): + cleaned_options = [ + " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "") + for o in line["options"] + ] + + output = Doc( + query=line["question"], + choices=cleaned_options, + gold_index=LETTER_INDICES.index(line["label"].strip()), + instruction="", + ) + + return output -## AGI eval ## class CustomAGIEvalEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="agi_eval_prompt_no_letters", + prompt_function=agi_eval_prompt_no_letters, hf_repo="lighteval/agi_eval_en", hf_subset=None, # metric=[Metrics.loglikelihood_acc_single_token], - metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], hf_avail_splits=["train", "validation"], evaluation_splits=["train"], few_shots_split="validation", few_shots_select=None, - suite=None, + suite=["custom"], generation_size=-1, stop_sequence=None, output_regex=None, @@ -583,57 +619,6 @@ def __init__( ] -def agi_eval_math_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["question"], - choices=[line["answer"]], - gold_index=0, - instruction="", - ) - - -def agi_eval_prompt(line, task_name: str = None): - cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]] - prompt = "The following are multiple choice questions (with answers).\n\n" - prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" - prompt += "Answer: " - - choices = LETTER_INDICES[: len(line["options"])] - - output = Doc( - query=prompt, - instruction="The following are multiple choice questions (with answers).\n\n", - choices=None, # updated below - gold_index=None, # updated below - ) - - if line["label"]: - output.choices = choices - output.gold_index = LETTER_INDICES.index(line["label"].strip()) - else: - output.choices = [line["answer"]] - output.gold_index = 0 - - return output - - -def agi_eval_prompt_no_letters(line, task_name: str = None): - cleaned_options = [ - " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "") - for o in line["options"] - ] - - output = Doc( - query=line["question"], - choices=cleaned_options, - gold_index=LETTER_INDICES.index(line["label"].strip()), - instruction="", - ) - - return output - - # AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS} AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS] _TASKS_STRINGS.extend(AGIEVAL_STRING) @@ -661,7 +646,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) # Convert to dict for lighteval -TASKS_TABLE = [task.as_dict() for task in _TASKS] +TASKS_TABLE = _TASKS # You can have a few pre-organised groups of tasks TASKS_GROUPS = { "all": ",".join(t[1] for t in _TASKS_STRINGS), diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py index b56aafda..3321e7ce 100644 --- a/src/nanotron/lighteval/one_job_runner.py +++ b/src/nanotron/lighteval/one_job_runner.py @@ -1,19 +1,19 @@ """ Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it """ import datetime +import json import os import re import subprocess -from typing import List, Optional, Tuple, Union -import copy -import json +from typing import List, Optional, Tuple + import jinja2 + from nanotron import logging +from nanotron.config import Config, LightEvalConfig from nanotron.logging import log_rank from nanotron.parallel import ParallelContext -from nanotron.config import Config, LightEvalConfig - logger = logging.get_logger(__name__) @@ -35,12 +35,11 @@ def eval_single_checkpoint_no_s3(self, checkpoint_path: str) -> Tuple[str, str]: return None, None slurm_job_id, slurm_log = run_slurm_one_job( - config = self.config, - lighteval_config = self.lighteval_config, - slurm_template=self.lighteval_config.slurm_template, + config=self.config, + lighteval_config=self.lighteval_config, + slurm_template=self.config.general.eval_slurm_template, model_checkpoint_path=checkpoint_path, current_step=self.config.general.step, - s3=False, ) return slurm_job_id, slurm_log @@ -73,12 +72,11 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") slurm_job_id, slurm_log = run_slurm_one_job( - config = self.config, - lighteval_config = self.lighteval_config, - slurm_template=self.lighteval_config.slurm_template, + config=self.config, + lighteval_config=self.lighteval_config, + slurm_template=self.config.general.eval_slurm_template, model_checkpoint_path=checkpoint_path, current_step=self.config.general.step, - s3=True, ) return slurm_job_id, slurm_log @@ -90,7 +88,6 @@ def run_slurm_one_job( model_checkpoint_path: str, slurm_template: str, current_step: int, - s3: bool = True, slurm_name: Optional[str] = "eval", ): """Launch a single job on Slurm with the given mapping @@ -98,11 +95,11 @@ def run_slurm_one_job( slurm_config: Slurm configuration mapping: Mapping to use for the job script (see SLURM_ONE_JOB_MAPPING) """ - + s3 = config.general.is_s3_available eval_launch_script_path = os.path.join(config.general.evals_logs_path, "launch-config", str(current_step)) eval_logs_path = os.path.join(config.general.evals_logs_path, "logs", str(current_step)) - with open(config.general.eval_slurm_config, 'r') as f: + with open(config.general.eval_slurm_config, "r") as f: eval_slurm_config = json.load(f) os.makedirs(eval_launch_script_path, exist_ok=True) @@ -118,28 +115,33 @@ def run_slurm_one_job( # Update the config with additional required parameters # Calculate the number of nodes based on parallelism config and gpus_per_node - total_gpus_needed = lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp - gpus_per_node = eval_slurm_config.get('gpus_per_node') + total_gpus_needed = ( + lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + ) + gpus_per_node = eval_slurm_config.get("gpus_per_node") nodes = (total_gpus_needed + gpus_per_node - 1) // gpus_per_node # Ceiling division - + if s3: - eval_slurm_config.update({ - 'nodes': nodes, # Assuming we want to run on a single node - 'job_name': f"eval-{current_step}", - 'eval_path': eval_logs_path, - 'local_path': config.lighteval.temp_dir, - 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, - "model_checkpoint_path": model_checkpoint_path, - }) + eval_slurm_config.update( + { + "nodes": nodes, # Assuming we want to run on a single node + "job_name": f"eval-{current_step}", + "eval_path": eval_logs_path, + "local_path": f"{config.general.temp_dir}/eval_{config.general.timestamp_with_run}/{current_step}", + "model_checkpoint_path": model_checkpoint_path, + "lighteval_config_path": config.general.lighteval_config_path, + } + ) else: - eval_slurm_config.update({ - 'nodes': nodes, # Assuming we want to run on a single node - 'job_name': f"eval-{current_step}", - 'eval_path': eval_logs_path, - 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, - "model_checkpoint_path": model_checkpoint_path, - }) - + eval_slurm_config.update( + { + "nodes": nodes, # Assuming we want to run on a single node + "job_name": f"eval-{current_step}", + "eval_path": eval_logs_path, + "model_checkpoint_path": model_checkpoint_path, + "lighteval_config_path": config.general.lighteval_config_path, + } + ) launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render(**eval_slurm_config) @@ -164,20 +166,14 @@ def run_slurm_one_job( # Preserve important environment variables env = { - 'PATH': os.environ['PATH'], - 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), - 'HOME': os.path.expanduser("~"), + "PATH": os.environ["PATH"], + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), + "HOME": os.path.expanduser("~"), } try: # Use subprocess.run instead of check_output for better error handling - result = subprocess.run( - ["sbatch", launch_script_path], - env=env, - check=True, - capture_output=True, - text=True - ) + result = subprocess.run(["sbatch", launch_script_path], env=env, check=True, capture_output=True, text=True) output = result.stdout job_ids = output.split()[-1] output_log = ( diff --git a/src/nanotron/lighteval/run_evals.py b/src/nanotron/lighteval/run_evals.py index 5ee36f53..1fd4b178 100644 --- a/src/nanotron/lighteval/run_evals.py +++ b/src/nanotron/lighteval/run_evals.py @@ -15,9 +15,10 @@ def get_parser(): help="Path to the Nanotron checkpoint YAML or python config file, potentially on S3", ) parser.add_argument( - "--lighteval-override", + "--lighteval-config-path", type=str, - help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", + required=True, + help="Path to an optional YAML or python Lighteval config", ) parser.add_argument( "--cache-dir", @@ -25,7 +26,6 @@ def get_parser(): default=None, help="Cache directory", ) - return parser @@ -33,7 +33,7 @@ def get_parser(): parser = get_parser() args, unknowns = parser.parse_known_args() main( - checkpoint_config_path=args.checkpoint_config_path, - lighteval_config_path=args.lighteval_override, + checkpoint_config_path=args.checkpoint_config_path, + lighteval_config_path=args.lighteval_config_path, cache_dir=args.cache_dir, - ) \ No newline at end of file + ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3a4ab60a..76b8fa4a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,14 +19,12 @@ cast, ) -from nanotron.s3_checkpoints import S3Mover, check_path_is_local import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader from nanotron import distributed as dist from nanotron import logging -from nanotron.lighteval import LightEvalRunner from nanotron.config import ( Config, DatasetStageArgs, @@ -48,6 +46,7 @@ log_throughput, lr_scheduler_builder, ) +from nanotron.lighteval import LightEvalRunner from nanotron.logging import ( LoggerWriter, LogItem, @@ -151,14 +150,12 @@ def __init__( data_parallel_size=self.config.parallelism.dp, expert_parallel_size=self.config.parallelism.expert_parallel_size, ) - + self.pre_init() # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) - - # Log benchmark info if os.environ.get("NANOTRON_BENCHMARK", "0") == "1": log_throughput(self.config, self.parallel_context) @@ -288,12 +285,13 @@ def post_init(self): self.post_checkpoint_callback = None else: # Use the no_s3 version of the evaluation function - # TODO: make it one function + make it automatic to switch to the right jinja template + # TODO: make it one function + make it automatic to switch to the right jinja template self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 else: self.post_checkpoint_callback = None else: self.post_checkpoint_callback = None + def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -306,7 +304,7 @@ def pre_training(self, *args, **kwargs): rank=0, ) - current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") + datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: wandb.init( project=self.config.general.project, @@ -316,20 +314,19 @@ def pre_training(self, *args, **kwargs): # Define tokens metric as x-axis for all metrics wandb.define_metric("Tokens") wandb.define_metric("*", step_metric="Tokens") - + # Handle resuming from a previous run - initial_step = getattr(self.config.general, 'step', 0) + initial_step = getattr(self.config.general, "step", 0) if initial_step is None: initial_step = 0 - + initial_tokens = initial_step * self.global_batch_size - + # Log initial tokens to set the starting point wandb.log({"Tokens": initial_tokens}) - + print(f"Initial Tokens: {initial_tokens}") - def post_train_step(self): # Update our background upload/removal of checkpoints if self.s3_mover is not None: @@ -338,12 +335,11 @@ def post_train_step(self): def post_training(self): if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) - + def post_training(self): if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) - def _print_training_plan(self): if hasattr(self.config, "data_stages") and self.config.data_stages is not None: stages_info = "".join( @@ -748,17 +744,21 @@ def _init_model_instance(self) -> NanotronModel: def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model - # Load or initialize model weights + # Load or initialize model weights reloaded_from_checkpoint = False if self.init_checkpoint_path is not None: - # Load from a pre existing checkpoint + # Load from a pre existing checkpoint if check_path_is_local(self.init_checkpoint_path): - # Reload from a training checkpoint - log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + # Reload from a training checkpoint + log_rank( + f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0 + ) self.param_shard_metadata = load_weights( - model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + model=unwrapped_model, + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, ) - reloaded_from_checkpoint=True + reloaded_from_checkpoint = True if not reloaded_from_checkpoint: log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) @@ -894,18 +894,25 @@ def setup_log_writers( def pre_save_checkpoint(self) -> Path: if wandb is not None and dist.get_rank(self.parallel_context.dp_pg) == 0: - if self.config.general.wandb_id is None: + if self.config.general.wandb_id is None: self.config.general.wandb_id = wandb.run.id self.config.general.wandb_project = wandb.run.project - elif self.config.general.wandb_id is not None and self.config.general.wandb_id!= wandb.run.id: - log_rank("Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0) + elif self.config.general.wandb_id is not None and self.config.general.wandb_id != wandb.run.id: + log_rank( + "Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0 + ) self.config.general.wandb_id = wandb.run.id self.config.general.wandb_project = wandb.run.project if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - log_rank(f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", logger=logger, level=logging.INFO, rank=0) + log_rank( + f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", + logger=logger, + level=logging.INFO, + rank=0, + ) def post_save_checkpoint(self): # Upload to S3 @@ -914,14 +921,15 @@ def post_save_checkpoint(self): elif self.post_checkpoint_callback is not None: # If we're not using S3, but we have a post-checkpoint callback for evals - checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.config.general.step}" + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" self.post_checkpoint_callback(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() - checkpoints_path = self.config.checkpoints.checkpoints_path - checkpoint_path = checkpoints_path / f"{self.iteration_step}" + print(f"config: {self.config}") + print(f"checkpoints_path: {checkpoints_path}") + checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: