diff --git a/README.md b/README.md index bb6a749..c5f651e 100755 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # LLM-baselines -A modular codebase to experiment with transformers, inspired by NanoGPT. +A modular codebase to experiment with transformers, inspired by nanoGPT. ## Quickstart @@ -36,44 +36,104 @@ parser.add_argument('--batch_size', default=32, type=int) parser.add_argument('--acc_steps', default=4, type=int) parser.add_argument('--seed', default=0, type=int) # random seed for the parameters parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering +parser.add_argument('--eval_interval', default=200, type=int) +parser.add_argument('--full_eval_at', nargs="+", type=int) +parser.add_argument('--eval_batches', default=32, type=int) parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations -parser.add_argument('--lr', default=1e-3, type=float) -parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup steps is iterations * warmup_percent +parser.add_argument('--warmup_steps', default=300, type=int) +parser.add_argument('--lr', default=1e-3, type=float) +parser.add_argument('--wsd_final_lr_scale', default=0.0, type=float) # wsd scheduler +parser.add_argument('--wsd_fract_decay', default=0.1, type=float) # wsd scheduler +parser.add_argument('--decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) +parser.add_argument('--dd_second_decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) +parser.add_argument('--dd_first_lr_factor', default=1e-2, type=float) parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter -parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd']) +parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none', 'dd']) +parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta', 'mars', 'adafactor', 'lamb', 'normalized-sgd']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved -parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT +parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in nanoGPT +parser.add_argument('--momentum', default=0.9, type=float) +parser.add_argument('--shampoo_beta', default=-1.0, type=float) +parser.add_argument('--precondition_frequency', default=10, type=int) #for SOAP and Sophia +parser.add_argument('--max_precond_dim', default=10000, type=int) +parser.add_argument('--merge_dims', default=False, type=bool) # merge dimensions till the product of the dimensions is less than or equal to max_precond_dim +parser.add_argument('--precondition_1d', default=False, type=bool) +parser.add_argument('--normalize_grads', default=False, type=bool) +parser.add_argument('--soap_data_format', default='channels_first', type=str) +parser.add_argument('--correct_bias', default=True, type=bool) +parser.add_argument('--nesterov', default=False, type=bool) # whether to use Nesterov-style momentum +parser.add_argument('--muon_ns_steps', default=5, type=int) # the number of steps to use in the newton schulz, if it is iterative +parser.add_argument('--muon_lr_factor', default=0.02, type=float) # a factor by which to reduce the lr for muon +parser.add_argmunet('--adema_beta3', default=0.9, type=float) # beta3 in AdEMAMix +parser.add_argument('--adema_alpha', default=2.0, type=float) # alpha in AdEMAMix +parser.add_argument('--adema_beta3_warmup', default=None, type=int) # AdEMAMix hyperparameter +parser.add_argument('--adema_alpha_warmup', default=None, type=int) # AdEMAMix hyperparameter +parser.add_argument('--schedulefree_r', defalut=0.0, type=float) # schedulefree hyperparameter +parser.add_argument('--weight_lr_power', default=2.0, type=float) # schedulefree hyperparameter +parser.add_argument('--model_sharding', default=None, type=bool) # Adam-mini +parser.add_argument('--adam_mini_verbose', default=False, type=bool) # print all the logs if true +parser.add_argument('--log_interval', default=50, type=int) +parser.add_argument('--dampening', default=0.0, type=float) +parser.add_argument('--prodigy_beta3', default=None, type=float) # coefficients for computing the Prodidy stepsize using running averages +parser.add_argument('--prodigy_decouple', default=True, type=bool) # Use AdamW style decoupled weight decay +parser.add_argument('--prodigy_use_bias_correction', default=False, type=bool) +parser.add_argument('--prodigy_safeguard_warmup', default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. +parser.add_argument('--prodigy_fsdp_in_use', default=False, type=bool) +parser.add_argument('--sophia_rho', default=0.04, type=float) +parser.add_argument('--clipping_type', default='no', choices=['no', 'local', 'elementwise']) # for methods with clipping +parser.add_argument('--clipping_eta', default=1.0, type=float) +parser.add_argument('--mars_type', default='mars-adamw', choices=['mars-adamw', 'mars-lion', 'mars-shampoo'],) +parser.add_argument('--mars_vr_gamma', default=0.025, type=float) +parser.add_argument('--mars_is_approx', default=True, type=float) +parser.add_argument('--mars_lr', default=3e-3, type=float) +parser.add_argument('--mars_beta1', default=0.95, type=float) +parser.add_argument('--mars_beta2', default=0.99, type=float) +parser.add_argument('--adafactor_decay_rate', default=-0.8, type=float) +parser.add_argument('--lamb_use_bias_correction', default=False, type=bool) # Dataset params -parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) +parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', 'arxiv2000', 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) +parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this # Model params -parser.add_argument('--model', default='base', choices=['base', 'llama2']) -parser.add_argument('--use_pretrained', default="none", type=str) # 'none', 'gpt-2' or a path to the pretraind model +parser.add_argument('--model', default='base', choices=['base', 'llama', 'test']) +parser.add_argument('--parallel_block', action='store_true') +parser.add_argument('--use_pretrained', default='none', type=str) # 'none', 'gpt2' or a path to the pretraind model +parser.add_argument('--from_dense', action='store_true') +parser.add_argument('--init_std', default=0.02, type=float) parser.add_argument('--dropout', default=0.0, type=float) # keep to 0 unless in low data regime (e.g. wikitext) parser.add_argument('--n_head', default=12, type=int) parser.add_argument('--n_layer', default=12, type=int) # depth in (att + ff) blocks parser.add_argument('--n_embd', default=768, type=int) # hidden size ... parser.add_argument('--sequence_length', default=512, type=int) -parser.add_argument('--dtype', default=torch.bfloat16, type=torch.dtype) +parser.add_argument('--dtype', default='bfloat16', type=str, choices=['float32', 'float16', 'bfloat16'],) parser.add_argument('--bias', default=False, type=bool) parser.add_argument('--compile', action='store_true') # if true then model is compiled parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2 +parser.add_argument('--n_kv_head', default=None, type=int) # for Adam-mini +# Checkpointing +parser.add_argument('--results_base_folder', default='./exps', type=str) +parser.add_argument('--permanent_ckpt_interval', default=0, type=int) +parser.add_argument('--latest_ckpt_interval', default=0, type=int) +parser.add_argument('--resume_from', default=None, type=str) +parser.add_argument('--resume_from_swa', default=None, type=str) +parser.add_argument('--auto_resume', default=True) # logging params (WandB) parser.add_argument('--wandb', action='store_true') # whether to use wandb or not -parser.add_argument('--wandb_project', default="my-project", type=str) -parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name +parser.add_argument('--wandb_project', default='my-project', type=str) +parser.add_argument('--wandb_entity', default=None, type=none_or_str) # for the team projects +parser.add_argument('--wandb_run_prefix', default='none', type=str) # is added before the autogenerated experiment name parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences +parser.add_argument('--log_dynamics', action='store_true') # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type (e.g. nccl) -parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) ``` ## Using WandB @@ -111,12 +171,15 @@ src/ optim/ utils.py # contains eval and get_batch functions base.py # training function for the base and llama models + ... distributed/ # code to enable simple distributed training ``` Given the above structure, to add your own model, you can just fork the `./src/models/base.py` file, do your modifications, then if necessary fork the `./src/optim/base.py` in case you need some custom training loop or evaluation. You also need to fork the `./src/config/base.py` file to add your own parameters, which imply adding your new config to the mapping `CONFIG_FORMAT_TO_MODULE_MAP` in `./src/config/__init__.py`. To add a new dataset, create a new file in the `data` folder, check `wikitext.py` for the expected format. +**Note:** we use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests. Before committing your code, simply run ```black . && isort .``` and you will be fine. + ## Multi-GPU training Given a multi-GPU machine with e.g. 4 GPUs, one can distribute the training using data-parallelism: diff --git a/requirements.txt b/requirements.txt index 0153e35..c7ee426 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ tiktoken --find-links https://download.pytorch.org/whl/torch_stable.html -torch==2.0.0+cu118 -torchaudio==2.0.0+cu118 -torchvision==0.15.0+cu118 -tqdm==4.65.0 +torch +torchaudio +torchvision +tqdm transformers wandb datasets -zstandard \ No newline at end of file +zstandard +numpy==1.22.4 \ No newline at end of file diff --git a/src/config/base.py b/src/config/base.py index e48c277..4c80165 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -1,95 +1,235 @@ -import torch - import distributed + def none_or_str(value): - if value == 'None': + if value == "None": return None return value + def parse_args(base_parser, args, namespace): parser = base_parser + # General training params - parser.add_argument('--batch_size', default=32, type=int) - parser.add_argument('--acc_steps', default=4, type=int) - parser.add_argument('--seed', default=0, type=int) - parser.add_argument('--data_seed', default=1337, type=int) - parser.add_argument('--device', default='cuda:0', type=str) - parser.add_argument('--iterations', default=25000, type=int) - parser.add_argument('--lr', default=1e-3, type=float) - parser.add_argument('--warmup_percent', default=0.05, type=float) - parser.add_argument('--weight_decay', default=0.1, type=float) - parser.add_argument('--beta1', default=0.9, type=float) - parser.add_argument('--beta2', default=0.95, type=float) - parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) - parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd']) - parser.add_argument('--eval_freq', default=200, type=int) # in iterations - parser.add_argument('--results_base_folder', default="./exps", type=str) - parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT + parser.add_argument("--run_prefix", default=None, type=str) + parser.add_argument("--experiment_name", default=None, type=str) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--data_seed", default=1337, type=int) + parser.add_argument("--eval_interval", default=200, type=int) + parser.add_argument("--full_eval_at", nargs="+", type=int) + parser.add_argument("--eval_batches", default=32, type=int) + parser.add_argument("--device", default="cuda:0", type=str) + parser.add_argument( + "--distributed_backend", + default=None, + type=str, + required=False, + choices=distributed.registered_backends(), + ) + parser.add_argument("--log_interval", default=50, type=int) + + # Checkpointing + parser.add_argument("--results_base_folder", default="./exps", type=str) + parser.add_argument("--permanent_ckpt_interval", default=0, type=int) + parser.add_argument("--latest_ckpt_interval", default=0, type=int) + parser.add_argument("--resume_from", default=None, type=str) + parser.add_argument("--resume_from_swa", default=None, type=str) + + parser.add_argument("--auto_resume", default=True) + + # logging params (WandB) + parser.add_argument("--wandb", action="store_true") # whether to use wandb or not + parser.add_argument("--wandb_project", default="my-project", type=str) + parser.add_argument( + "--wandb_run_prefix", default="none", type=str + ) # is added before the autogenerated experiment name + parser.add_argument( + "--eval_seq_prefix", default="none", type=str + ) # prefix used to generate sequences + parser.add_argument("--log_dynamics", action="store_true") + # parser.add_argument( + # "--dynamics_logger_cfg", default="./src/logger/rotational_logger.yaml", type=str + # ) + parser.add_argument("--wandb_entity", default=None, type=none_or_str) + + # Schedule + parser.add_argument( + "--scheduler", + default="cos", + choices=["linear", "cos", "wsd", "none", "cos_inf", "cos_wsd", "dd"], + ) + parser.add_argument("--cos_inf_steps", default=0, type=int) + # parser.add_argument("--cos-final-lr", default=1e-6, type=float) + parser.add_argument("--iterations", default=15000, type=int) + parser.add_argument("--warmup_steps", default=3000, type=int) + parser.add_argument("--lr", default=1e-3, type=float) + # wsd + parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float) + parser.add_argument("--wsd_fract_decay", default=0.1, type=float) + # parser.add_argument("--wsd-exponential-decay", action="store_true") + parser.add_argument( + "--decay_type", + default="linear", + choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], + ) + parser.add_argument( + "--dd_second_decay_type", + default="linear", + choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], + ) + parser.add_argument("--dd_first_lr_factor", default=1e-2, type=float) + + # Optimization + parser.add_argument( + "--opt", + default="adamw", + choices=[ + "adamw", + "sgd", + "muon", + "soap", + "ademamix", + "ademamix2", + "lion", + "sf-adamw", + "sf-sgd", + "adam-mini", + "signsgd", + "signum", + "sgdf", + "prodigy", + "sophiag", + "shampoo", + "adopt", + "clip-adagrad", + "clip-adagrad-delay-eta", + "clip-adam", + "clip-adam-delay-eta", + "mars", + "adafactor", + "lamb", + "normalized-sgd", + ], + ) + parser.add_argument("--batch_size", default=50, type=int) + parser.add_argument("--acc_steps", default=1, type=int) + parser.add_argument("--weight_decay", default=1e-1, type=float) + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.95, type=float) + parser.add_argument( + "--grad_clip", default=1.0, type=float + ) # default value is 1.0 in NanoGPT + parser.add_argument("--momentum", default=0.9, type=float) + parser.add_argument("--shampoo_beta", default=-1.0, type=float) + parser.add_argument("--precondition_frequency", default=10, type=int) + parser.add_argument("--max_precond_dim", default=10000, type=int) + parser.add_argument("--merge_dims", default=False, type=bool) + parser.add_argument("--precondition_1d", default=False, type=bool) + parser.add_argument("--normalize_grads", default=False, type=bool) + parser.add_argument("--soap_data_format", default="channels_first", type=str) + parser.add_argument("--correct_bias", default=True, type=bool) + parser.add_argument("--nesterov", default=False, type=bool) + parser.add_argument("--muon_ns_steps", default=5, type=int) + parser.add_argument("--muon_lr_factor", default=1.0, type=float) + parser.add_argument("--adema_beta3", default=0.9, type=float) + parser.add_argument("--adema_alpha", default=2.0, type=float) + parser.add_argument("--adema_beta3_warmup", default=None, type=int) + parser.add_argument("--adema_alpha_warmup", default=None, type=int) + parser.add_argument("--schedulefree_r", default=0.0, type=float) + parser.add_argument("--weight_lr_power", default=2.0, type=float) + parser.add_argument("--model_sharding", default=None, type=bool) + parser.add_argument("--adam_mini_verbose", default=False, type=bool) + parser.add_argument("--dampening", default=0.0, type=float) + parser.add_argument("--prodigy_beta3", default=None, type=float) + parser.add_argument("--prodigy_decouple", default=True, type=bool) + parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) + parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) + parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) + parser.add_argument("--sophia_rho", default=0.04, type=float) + parser.add_argument("--sophia_bs", default=480, type=int) + parser.add_argument( + "--clipping_type", default="no", choices=["no", "local", "elementwise"] + ) + parser.add_argument("--clip_eta", default=1.0, type=float) + parser.add_argument( + "--mars_type", + default="mars-adamw", + choices=["mars-adamw", "mars-lion", "mars-shampoo"], + ) + parser.add_argument("--mars_vr_gamma", default=0.025, type=float) + parser.add_argument("--mars_is_approx", default=True, type=float) + parser.add_argument("--mars_lr", default=3e-3, type=float) + parser.add_argument("--mars_beta1", default=0.95, type=float) + parser.add_argument("--mars_beta2", default=0.99, type=float) + parser.add_argument("--adafactor_decay_rate", default=-0.8, type=float) + parser.add_argument("--lamb_use_bias_correction", default=False, type=bool) + # Dataset params - parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) - parser.add_argument('--vocab_size', default=50304, type=int) - parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, mostly useless except for openwebtext2 + parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") + parser.add_argument( + "--dataset", + default="slimpajama", + choices=[ + "wikitext", + "shakespeare-char", + "arxiv", + "arxiv2000", + "arxiv+wiki", + "openwebtext2", + "redpajama", + "slimpajama", + "slimpajama_chunk1", + "redpajamav2", + "fineweb", + "finewebedu", + ], + ) + parser.add_argument( + "--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"] + ) + parser.add_argument("--vocab_size", default=50304, type=int) + parser.add_argument( + "--data_in_ram", action="store_true" + ) # force the data to RAM, mostly useless except for openwebtext2 + # Model params - parser.add_argument('--model', default='base', choices=['base', 'llama2']) - parser.add_argument('--use_pretrained', default="auto", type=none_or_str) # 'none', 'gpt-2' or a path to the pretraind model - parser.add_argument('--dropout', default=0.0, type=float) - parser.add_argument('--n_head', default=12, type=int) - parser.add_argument('--n_layer', default=12, type=int) # depths in att + ff blocks - parser.add_argument('--n_embd', default=768, type=int) # embedding size / hidden size ... - parser.add_argument('--sequence_length', default=512, type=int) - parser.add_argument('--dtype', default=torch.bfloat16, type=torch.dtype) - parser.add_argument('--bias', default=False, type=bool) - parser.add_argument('--compile', action='store_true') # if true then model is compiled - parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) + parser.add_argument( + "--model", + default="llama", + choices=[ + "base", + "llama", + "test", + ], + ) + parser.add_argument("--parallel_block", action="store_true") + parser.add_argument( + "--use_pretrained", default="none", type=str + ) # 'none', 'gpt-2' or a path to the pretraind model + parser.add_argument("--from_dense", action="store_true") + parser.add_argument("--init_std", default=0.02, type=float) + parser.add_argument("--dropout", default=0.0, type=float) + parser.add_argument("--n_head", default=12, type=int) + parser.add_argument("--n_layer", default=24, type=int) # depths in att + ff blocks + parser.add_argument("--sequence_length", default=512, type=int) + parser.add_argument( + "--n_embd", default=768, type=int # embedding size / hidden size ... + ) parser.add_argument( "--multiple_of", # make SwiGLU hidden layer size multiple of large power of 2 default=256, type=int, ) - parser.add_argument('--run_prefix', default=None, type=str, required=False) # is added before the autogenerated experiment name - parser.add_argument('--exp_name', default=None, type=str, required=False) - # logging params (WandB) - parser.add_argument('--wandb', action='store_true') # whether to use wandb or not - parser.add_argument('--wandb_project', default="my-project", type=str) - parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name - parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences - # Distributed args - parser.add_argument('--distributed_backend', default=None, type=str, required=False, - choices=distributed.registered_backends()) # distributed backend type - parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) - - args = parser.parse_args(args, namespace) - - if args.exp_name is None: - special_name_handle_fields = {"model", "lr", "batch_size", - "acc_steps", "seed", "exp_name", - "wandb", "wandb_project", "eval_seq_prefix", - "run_prefix", "distributed_backend", "config_format", - "sequence_length"} - overriden_values = [] - for key in vars(args): - if key in special_name_handle_fields: - continue - if getattr(args, key) != parser.get_default(key): - overriden_values.append((key, getattr(args, key))) - chunk_len = 10 - overriden_values_str_parts = [] - for chunk_id in range(0, len(overriden_values), chunk_len): - overriden_values_str = "_".join(["{}={}".format(key, value) for key, value in overriden_values[chunk_id:chunk_id+chunk_len]]) - overriden_values_str_parts.append(overriden_values_str) - overriden_values_str = "/".join(overriden_values_str_parts) - exp_name = "" - if args.run_prefix is not None: - exp_name += f"{args.run_prefix}_" - exp_name += f"{args.model}_lr{args.lr}_bs{args.batch_size}x{args.acc_steps}_seqlen{args.sequence_length}/{overriden_values_str}_seed={args.seed}" - args.exp_name = exp_name - - if args.dtype == "torch.bfloat16": - args.dtype = torch.bfloat16 - elif args.dtype == "torch.float16": - args.dtype = torch.float16 - elif args.dtype == "torch.float32": - args.dtype = torch.float32 + parser.add_argument("--n_kv_head", default=None, type=int) # for Adam-mini + parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) + parser.add_argument( + "--dtype", + default="bfloat16", + type=str, + choices=["float32", "float16", "bfloat16"], + ) + parser.add_argument("--bias", default=False, type=bool) + parser.add_argument("--compile", action="store_true") + parser.add_argument("--mlp_dim_exp_factor", default=1.0, type=float) - return args + return parser.parse_args(args, namespace) diff --git a/src/data/arxiv.py b/src/data/arxiv.py index bd146f1..71a7fe0 100644 --- a/src/data/arxiv.py +++ b/src/data/arxiv.py @@ -1,35 +1,35 @@ +import logging import os import tarfile -import logging -from pathlib import Path -from typing import Optional from multiprocessing import Pool +from pathlib import Path +from subprocess import PIPE, Popen, TimeoutExpired from tempfile import NamedTemporaryFile -from subprocess import Popen, TimeoutExpired, PIPE -from typing import Tuple, List +from typing import List, Optional, Tuple import numpy as np import requests -from tqdm.auto import tqdm import tiktoken +from tqdm.auto import tqdm def convert_to_markdown(args: Tuple[Path, Path]): texfile, mdroot = args - mdfile = mdroot/f"{texfile.name}.md" - with Popen(["pandoc", "--wrap=none", "--from", "latex", texfile, - "--output", mdfile], stderr=PIPE) as proc: + mdfile = mdroot / f"{texfile.name}.md" + with Popen( + ["pandoc", "--wrap=none", "--from", "latex", texfile, "--output", mdfile], + stderr=PIPE, + ) as proc: try: proc.communicate(timeout=1) except TimeoutExpired: proc.kill() - def fetch_arxiv(root: Path, year: int): # download latex url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz" - texroot = root/"tex" + texroot = root / "tex" print("Downloading Arxiv year", year) req = requests.get(url, timeout=60) with NamedTemporaryFile(suffix=".tar.gz") as f: @@ -40,13 +40,16 @@ def fetch_arxiv(root: Path, year: int): tar.extractall(texroot) # convert to markdown - mdroot = root/"md"/str(year) + mdroot = root / "md" / str(year) mdroot.mkdir(parents=True) - files = list((texroot/str(year)).iterdir()) + files = list((texroot / str(year)).iterdir()) with Pool(os.cpu_count()) as p: args = [(texfile, mdroot) for texfile in files] - for _ in tqdm(p.imap_unordered(convert_to_markdown, args), - desc="Converting to markdown", total=len(files)): + for _ in tqdm( + p.imap_unordered(convert_to_markdown, args), + desc="Converting to markdown", + total=len(files), + ): pass @@ -55,7 +58,7 @@ def tokenize_arxiv(root: Path, year: int): tokens = [] tokens_val = [] tokens_test = [] - mds = root/"md"/str(year) + mds = root / "md" / str(year) # tokenize desc = f"Tokenizing {year}" @@ -70,46 +73,44 @@ def tokenize_arxiv(root: Path, year: int): tokens_test += tokenizer.encode(text) # save to dir - tpath = root/str(year) + tpath = root / str(year) tpath.mkdir(parents=True) - for x, name in zip([tokens, tokens_val, tokens_test], - ["train", "val", "test"]): - mem = np.memmap(tpath/f"{name}.npy", dtype=np.uint16, mode="w+", - shape=len(x)) + for x, name in zip([tokens, tokens_val, tokens_test], ["train", "val", "test"]): + mem = np.memmap(tpath / f"{name}.npy", dtype=np.uint16, mode="w+", shape=len(x)) for i, v in enumerate(x): mem[i] = v def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): - all_years = list(range(1992, 2004)) + all_years = list(range(1993, 2004)) # 1992 seems to give some problem if years is None: years = all_years assert set(years) <= set(all_years) - root = cachedir/"arxiv" + root = cachedir / "arxiv" root.mkdir(exist_ok=True, parents=True) # download all years requested that are not present for year in years: - if not (root/"md"/str(year)).exists(): + if not (root / "md" / str(year)).exists(): fetch_arxiv(root, year) # tokenize all years not previously tokenized for year in years: - if not (root/str(year)).exists(): + if not (root / str(year)).exists(): tokenize_arxiv(root, year) # load meta ret = {} for split in ["train", "val"]: - paths = [root/str(year)/f"{split}.npy" for year in years] + paths = [root / str(year) / f"{split}.npy" for year in years] x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths] ret[split] = np.concatenate(x) return ret -def get_arxiv_2000(): - return load_arxiv(Path(os.path.dirname(__file__))/"datasets", [2000]) +def get_arxiv_2000(datasets_base_dir): + return load_arxiv(Path(datasets_base_dir), [2000]) -def get_arxiv_full(): - return load_arxiv(Path(os.path.dirname(__file__))/"datasets") +def get_arxiv_full(datasets_base_dir): + return load_arxiv(Path(datasets_base_dir)) diff --git a/src/data/fineweb.py b/src/data/fineweb.py new file mode 100644 index 0000000..3588295 --- /dev/null +++ b/src/data/fineweb.py @@ -0,0 +1,76 @@ +import os + +import numpy as np +import tiktoken +from datasets import load_dataset +from tqdm import tqdm + +tknzr = tiktoken.get_encoding("gpt2") + + +def get_fineweb_data(datasets_dir, num_proc=40): + """To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code.""" + FWEB_DATA_PATH = os.path.join(datasets_dir, "fineweb-100BT/") + if not os.path.exists(os.path.join(FWEB_DATA_PATH, "train.bin")): + os.makedirs(FWEB_DATA_PATH, exist_ok=True) + + dataset = load_dataset( + "HuggingFaceFW/fineweb", + name="sample-100BT", + split="train", + streaming=False, + verification_mode="no_checks", + ) + + split_dataset = dataset.train_test_split( + test_size=0.0001, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(FWEB_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + return { + "train": os.path.join(FWEB_DATA_PATH, "train.bin"), + "val": os.path.join(FWEB_DATA_PATH, "val.bin"), + } + + +if __name__ == "__main__": + get_fineweb_data("./datasets/") diff --git a/src/data/fineweb_edu.py b/src/data/fineweb_edu.py new file mode 100644 index 0000000..a0ce49e --- /dev/null +++ b/src/data/fineweb_edu.py @@ -0,0 +1,76 @@ +import os + +import numpy as np +import tiktoken +from datasets import load_dataset +from tqdm import tqdm + +tknzr = tiktoken.get_encoding("gpt2") + + +def get_fineweb_edu_data(datasets_dir, num_proc=40): + """To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code.""" + FWEB_DATA_PATH = os.path.join(datasets_dir, "fineweb-edu-100BT/") + if not os.path.exists(os.path.join(FWEB_DATA_PATH, "train.bin")): + os.makedirs(FWEB_DATA_PATH, exist_ok=True) + + dataset = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-100BT", + split="train", + streaming=False, + verification_mode="no_checks", + ) + + split_dataset = dataset.train_test_split( + test_size=0.0001, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(FWEB_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + return { + "train": os.path.join(FWEB_DATA_PATH, "train.bin"), + "val": os.path.join(FWEB_DATA_PATH, "val.bin"), + } + + +if __name__ == "__main__": + get_fineweb_edu_data("./datasets/") diff --git a/src/data/openwebtext2.py b/src/data/openwebtext2.py index eef9d50..01fb581 100644 --- a/src/data/openwebtext2.py +++ b/src/data/openwebtext2.py @@ -1,58 +1,65 @@ import os -from tqdm import tqdm + import numpy as np import tiktoken -from datasets import load_dataset - +from datasets import load_dataset +from tqdm import tqdm -OWT2_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/openwebtext2/") tknzr = tiktoken.get_encoding("gpt2") -def get_openwebtext2_data(num_proc=40): - """ https://openwebtext2.readthedocs.io/en/latest/ - """ - if not os.path.exists(os.path.join(OWT2_DATA_PATH, 'train.bin')): + +def get_openwebtext2_data(datasets_base_dir, num_proc=40): + """https://openwebtext2.readthedocs.io/en/latest/""" + OWT2_DATA_PATH = os.path.join(datasets_base_dir, "openwebtext2/") + if not os.path.exists(os.path.join(OWT2_DATA_PATH, "train.bin")): os.makedirs(OWT2_DATA_PATH, exist_ok=True) dataset = load_dataset("the_pile_openwebtext2") - split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) - split_dataset['val'] = split_dataset.pop('test') - + split_dataset = dataset["train"].train_test_split( + test_size=0.0005, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + def process(example): - ids = tknzr.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens - ids.append(tknzr.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe # note: I think eot should be prepended not appended... hmm. it's called "eot" though... - out = {'ids': ids, 'len': len(ids)} + out = {"ids": ids, "len": len(ids)} return out # tokenize the dataset tokenized = split_dataset.map( process, - remove_columns=['text'], + remove_columns=["text"], desc="tokenizing the splits", num_proc=num_proc, ) # concatenate all the ids in each dataset into one large file we can use for training for split, dset in tokenized.items(): - arr_len = np.sum(dset['len']) - filename = os.path.join(OWT2_DATA_PATH, f'{split}.bin') - dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) - arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) + arr_len = np.sum(dset["len"]) + filename = os.path.join(OWT2_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) total_batches = 1024 idx = 0 - for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): # Batch together samples for faster write - batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') - arr_batch = np.concatenate(batch['ids']) + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) # Write into mmap arr[idx : idx + len(arr_batch)] = arr_batch idx += len(arr_batch) arr.flush() - train_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') - val_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') - - return {'train': train_data, 'val': val_data} - + return { + "train": os.path.join(OWT2_DATA_PATH, "train.bin"), + "val": os.path.join(OWT2_DATA_PATH, "val.bin"), + } diff --git a/src/data/redpajama.py b/src/data/redpajama.py new file mode 100644 index 0000000..0901719 --- /dev/null +++ b/src/data/redpajama.py @@ -0,0 +1,124 @@ +import os + +import numpy as np +import tiktoken +from datasets import load_dataset +from tqdm import tqdm + +tknzr = tiktoken.get_encoding("gpt2") + + +def get_redpajama_data(datasets_dir, num_proc=40): + RPJ_DATA_PATH = os.path.join(datasets_dir, "redpajama1Tsample/") + if not os.path.exists(os.path.join(RPJ_DATA_PATH, "train.bin")): + os.makedirs(RPJ_DATA_PATH, exist_ok=True) + dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample") + + split_dataset = dataset["train"].train_test_split( + test_size=0.0005, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(RPJ_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + train_data = np.memmap( + os.path.join(RPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" + ) + val_data = np.memmap( + os.path.join(RPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" + ) + + return {"train": train_data, "val": val_data} + + +def get_redpajamav2_data(datasets_dir, num_proc=40): + """https://openwebtext2.readthedocs.io/en/latest/""" + RPJ_V2_DATA_PATH = os.path.join(datasets_dir, "redpajamaV2sample/") + if not os.path.exists(os.path.join(RPJ_V2_DATA_PATH, "train.bin")): + os.makedirs(RPJ_V2_DATA_PATH, exist_ok=True) + dataset = load_dataset("togethercomputer/RedPajama-Data-V2", name="sample") + + split_dataset = dataset["train"].train_test_split( + test_size=0.0005, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["raw_content"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["raw_content"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(RPJ_V2_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + return { + "train": os.path.join(RPJ_V2_DATA_PATH, "train.bin"), + "val": os.path.join(RPJ_V2_DATA_PATH, "val.bin"), + } diff --git a/src/data/shakespeare.py b/src/data/shakespeare.py index ab6e022..cb13f94 100644 --- a/src/data/shakespeare.py +++ b/src/data/shakespeare.py @@ -4,8 +4,9 @@ import numpy as np import requests - -_char_decode = dict(enumerate(sorted(set(ascii_letters + digits + punctuation + " \n")))) +_char_decode = dict( + enumerate(sorted(set(ascii_letters + digits + punctuation + " \n"))) +) _char_encode = {char: i for i, char in _char_decode.items()} @@ -13,10 +14,9 @@ def char_tknzr(txt: str): return [_char_encode[char] for char in txt if char in _char_encode] -DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets", "shakespeare") - -def get_shakespeare_data(): +def get_shakespeare_data(datasets_dir): """Inspired from https://github.com/karpathy/nanoGPT/""" + DATA_PATH = os.path.join(datasets_dir, "shakespeare") raw_path = os.path.join(DATA_PATH, "raw.txt") train_path = os.path.join(DATA_PATH, f"train.npy") test_path = os.path.join(DATA_PATH, f"test.npy") @@ -36,7 +36,7 @@ def get_shakespeare_data(): # load text with open(raw_path, encoding="utf8") as f: text = "".join(f.readlines()) - i = int(0.8*len(text)) + i = int(0.8 * len(text)) # encode text x = np.array(char_tknzr(text[:i]), dtype=np.uint16) x_test = np.array(char_tknzr(text[i:]), dtype=np.uint16) @@ -46,6 +46,7 @@ def get_shakespeare_data(): mem = np.memmap(test_path, dtype=np.uint16, mode="w+", shape=x_test.shape) mem[:] = x_test - # at this point we know that the binfile was properly created so we load it - return {"train": np.memmap(train_path, dtype=np.uint16, mode="r"), - "val": np.memmap(test_path, dtype=np.uint16, mode="r")} + return { + "train": train_path, + "val": test_path, + } diff --git a/src/data/slimpajama.py b/src/data/slimpajama.py index c3960d7..3811930 100644 --- a/src/data/slimpajama.py +++ b/src/data/slimpajama.py @@ -1,18 +1,15 @@ -from tqdm import tqdm +import os + import numpy as np import tiktoken from datasets import load_dataset -import os - - -SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/") -SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") - +from tqdm import tqdm tknzr = tiktoken.get_encoding("gpt2") -def get_slimpajama_data(num_proc=40): +def get_slimpajama_data(datasets_dir, num_proc=40): + SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/") if not os.path.exists(os.path.join(SPJ_DATA_PATH, "train.bin")): os.makedirs(SPJ_DATA_PATH, exist_ok=True) dataset = load_dataset("DKYoon/SlimPajama-6B") @@ -60,17 +57,15 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) + return { + "train": os.path.join(SPJ_DATA_PATH, "train.bin"), + "val": os.path.join(SPJ_DATA_PATH, "val.bin"), + } - return {"train": train_data, "val": val_data} - -def get_slimpajama_chunk1(num_proc=40): +def get_slimpajama_chunk1(datasets_dir, num_proc=40): + SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/") + SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")): os.makedirs(SPJ_DATA_PATH, exist_ok=True) dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1") @@ -118,11 +113,7 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) - - return {"train": train_data, "val": val_data} + return { + "train": os.path.join(SPJ_DATA_PATH, "train.bin"), + "val": os.path.join(SPJ_DATA_PATH, "val.bin"), + } diff --git a/src/data/utils.py b/src/data/utils.py index b5c7bb8..5f3ee1d 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -1,88 +1,184 @@ -import numpy as np +from pathlib import Path from typing import Dict + +import numpy as np import torch +import torch.distributed as dist -from .shakespeare import get_shakespeare_data -from .wikitext import get_wikitext_data from .arxiv import get_arxiv_2000, get_arxiv_full +from .fineweb import get_fineweb_data +from .fineweb_edu import get_fineweb_edu_data from .openwebtext2 import get_openwebtext2_data +from .redpajama import get_redpajama_data, get_redpajamav2_data +from .shakespeare import get_shakespeare_data from .slimpajama import get_slimpajama_data +from .wikitext import get_wikitext_data def get_dataset(args) -> Dict[str, np.ndarray]: - """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is - contained in its own python file. The expected format at the moment is a dictionary of np.memmap - containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ - if args.dataset == 'wikitext': - return get_wikitext_data() + """Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is + contained in its own python file. The expected format at the moment is a dictionary of np.memmap + containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. + """ + if args.dataset == "wikitext": + return get_wikitext_data(args.datasets_dir) if args.dataset == "shakespeare-char": - return get_shakespeare_data() + return get_shakespeare_data(args.datasets_dir) if args.dataset == "arxiv2000": - return get_arxiv_2000() + return get_arxiv_2000(args.datasets_dir) if args.dataset == "arxiv": - return get_arxiv_full() + return get_arxiv_full(args.datasets_dir) if args.dataset == "arxiv+wiki": - arxiv_data = get_arxiv_full() - wiki_data = get_wikitext_data() - train_data = np.concatenate((arxiv_data['train'], wiki_data['train'])) - val_data = np.concatenate((arxiv_data['val'], wiki_data['val'])) - return {'train': train_data, 'val': val_data} - if args.dataset == 'openwebtext2': - return get_openwebtext2_data() + arxiv_data = get_arxiv_full(args.datasets_dir) + wiki_data = get_wikitext_data(args.datasets_dir) + train_data = np.concatenate((arxiv_data["train"], wiki_data["train"])) + val_data = np.concatenate((arxiv_data["val"], wiki_data["val"])) + return {"train": train_data, "val": val_data} + if args.dataset == "openwebtext2": + return get_openwebtext2_data(args.datasets_dir) + if args.dataset == "redpajama": + return get_redpajama_data(args.datasets_dir) + if args.dataset == "redpajamav2": + return get_redpajamav2_data(args.datasets_dir) if args.dataset == "slimpajama": - return get_slimpajama_data() + return get_slimpajama_data(args.datasets_dir) + if args.dataset == "fineweb": + return get_fineweb_data(args.datasets_dir) + if args.dataset == "finewebedu": + return get_fineweb_edu_data(args.datasets_dir) else: raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") -class Dataset(torch.utils.data.Dataset): - def __init__(self, data, sequence_length): - super().__init__() - self.data = data + +class DataReader: + def __init__( + self, + data_src, + batch_size, + sequence_length, + seed=1337, + with_replacement=False, + auto_shard=True, + keep_in_ram=False, + ): + if isinstance(data_src, (str, Path)): + self.data_path = Path(data_src) + self.keep_in_ram = keep_in_ram + if keep_in_ram: + self.data = np.array( + np.memmap(self.data_path, dtype=np.uint16, mode="r") + ) + else: + self.data = None + elif isinstance(data_src, (np.ndarray, np.memmap)): + self.data_path = None + self.data = data_src + self.keep_in_ram = True + + self.batch_size = batch_size self.sequence_length = sequence_length + self.seed = seed + self.with_replacement = with_replacement + + self.num_tokens = len(self._get_data()) + + if auto_shard and dist.is_initialized(): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + print( + f"Distributed DataReader Initialized for Worker {self.rank}/{self.world_size}" + ) + else: + self.world_size = 1 + self.rank = 0 + + # Sampling without replacement + self.last_epoch = None + self.order = None + self.epoch_offset = None + self.step = 0 + self.num_batches_of_seqlen = 0 + if not with_replacement: + self._shuffle_epoch(0) def __len__(self): - total_length = len(self.data) - # chunk the data into sequences of length `sequence_length` - # NOTE: we discard the last remainding sequence if it's not of length `sequence_length` - return (total_length - 1) // self.sequence_length + # Length in valid start indices for a sequence + # Extra -1 to have a valid next token for the final token of the last idx + return self.num_tokens - self.sequence_length - 1 - def __getitem__(self, idx): - seq_length = self.sequence_length - idx = idx * seq_length - x = torch.from_numpy((self.data[idx : idx + seq_length]).astype(np.int64)) + def _get_data(self): + if self.data is not None: + return self.data + else: + # Construct the memmap each time to avoid a memory leak per NanoGPT + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + return np.memmap(self.data_path, dtype=np.uint16, mode="r") + def __getitem__(self, idx): + # Return the underlying datapoint, no random sampling, no worker sharding + assert 0 <= idx < len(self) + data = self._get_data() + x = torch.from_numpy(data[idx : idx + self.sequence_length].astype(np.int64)) y = torch.from_numpy( - (self.data[idx + 1 : idx + 1 + seq_length]).astype(np.int64) + data[idx + 1 : idx + self.sequence_length + 1].astype(torch.int64) ) return x, y + def set_step(self, step): + self.step = step -def get_dataloader(data, sequence_length, batch_size, seed=0, distributed_backend=None): - """Create a DataLoader for the given data. If distributed_backend is provided and is truly - distributed (world size > 1), the DataLoader will be created with a DistributedSampler that - splits the data across the processes (in conjunction with DDP). - Otherwise, use a RandomSampler with the specified seed. + def sample_batch(self): + data = self._get_data() - Returns both the dataloader and the sampler. - """ - dataset = Dataset(data, sequence_length=sequence_length) - if distributed_backend and distributed_backend.get_world_size() > 1: - sampler = torch.utils.data.DistributedSampler( - dataset, - shuffle=True, - seed=seed, - ) - else: - g = torch.Generator() - g.manual_seed(seed) - sampler = torch.utils.data.RandomSampler( - dataset, replacement=False, generator=g + if self.with_replacement: + idxs = self._sample_with_replacement(self.step) + else: + idxs = self._sample_without_replacement(self.step) + self.step += 1 + + xy = np.stack([data[i : i + self.sequence_length + 1] for i in idxs]).astype( + np.int64 ) + x = torch.from_numpy(xy[:, :-1]).contiguous() + y = torch.from_numpy(xy[:, 1:]).contiguous() + return x, y + + def _sample_with_replacement(self, idx): + # Return an array of token indices of length self.batch_size + # Sampled with replacement, can get repeats at any time + seed = self.seed + idx * self.world_size + self.rank + rng = np.random.default_rng(seed) + return rng.integers(len(self), self.batch_size) + + def _shuffle_epoch(self, epoch): + seed = self.seed + epoch + rng = np.random.default_rng(seed) + # Drop one sequence to allow different offsets per epoch: + self.order = rng.permutation((len(self)) // self.sequence_length - 1) + # Shift all sequences in this epoch by this amount: + self.epoch_offset = rng.integers(self.sequence_length) + self.last_epoch = epoch + self.num_batches_of_seqlen = ( + len(self.order) // self.batch_size + ) # Drops remainder batch + + def _sample_without_replacement(self, step): + # Return an array of token indices of length self.batch_size + # Sampled without replacement, cycle all sequences before potential repeats + # Sequences are randomly offset in every epoch as well + batch_idx = self.world_size * step + self.rank + epoch_length = self.num_batches_of_seqlen + + epoch = batch_idx // epoch_length + if epoch != self.last_epoch: + self._shuffle_epoch(epoch) + epoch_idx = batch_idx % epoch_length + + start = epoch_idx * self.batch_size + end = start + self.batch_size + return self.order[start:end] * self.sequence_length + self.epoch_offset - loader = torch.utils.data.DataLoader( - dataset, - sampler=sampler, - batch_size=batch_size, - num_workers=4, - ) - return loader, sampler + def num_batches(self): + if self.with_replacement: + return self.num_tokens // self.batch_size + return self.num_batches_of_seqlen diff --git a/src/data/wikitext.py b/src/data/wikitext.py index 646f636..0cd5ea5 100755 --- a/src/data/wikitext.py +++ b/src/data/wikitext.py @@ -1,42 +1,49 @@ import os -import zipfile import urllib +import zipfile + import numpy as np import tiktoken -WIKITEXT_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/wikitext/") - - -def get_wikitext_data(): - """ Inspired from https://github.com/tysam-code/hlb-gpt """ +def get_wikitext_data(datasets_base_dir): + """Inspired from https://github.com/tysam-code/hlb-gpt""" + WIKITEXT_DATA_PATH = os.path.join(datasets_base_dir, "wikitext/") if not os.path.exists(WIKITEXT_DATA_PATH): os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True) print("downloading data and tokenizing (1-2 min)") - raw_data_source = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip' - urllib.request.urlretrieve(raw_data_source, os.path.join(WIKITEXT_DATA_PATH,'data.zip')) - - with zipfile.ZipFile(os.path.join(WIKITEXT_DATA_PATH, "data.zip"), 'r') as zip_ref: + raw_data_source = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" + urllib.request.urlretrieve( + raw_data_source, os.path.join(WIKITEXT_DATA_PATH, "data.zip") + ) + + with zipfile.ZipFile( + os.path.join(WIKITEXT_DATA_PATH, "data.zip"), "r" + ) as zip_ref: zip_ref.extractall(WIKITEXT_DATA_PATH) - with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), 'r') as data_file: + with open( + os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), "r" + ) as data_file: raw_train_data = data_file.read() - with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), 'r') as data_file: + with open( + os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), "r" + ) as data_file: raw_eval_data = data_file.read() tokenizer = tiktoken.get_encoding("gpt2") raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data) raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data) - train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) + train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16) - train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'train.bin')) - eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'val.bin')) + train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "train.bin")) + eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "val.bin")) print("completed the tokenization process!") - train_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') - val_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') - - return {'train': train_data, 'val': val_data} + return { + "train": os.path.join(WIKITEXT_DATA_PATH, "train.bin"), + "val": os.path.join(WIKITEXT_DATA_PATH, "val.bin"), + } diff --git a/src/distributed/__init__.py b/src/distributed/__init__.py index 160eebd..74a4ae7 100644 --- a/src/distributed/__init__.py +++ b/src/distributed/__init__.py @@ -1,6 +1,4 @@ - -from . import ddp -from . import single +from . import ddp, single BACKEND_TYPE_TO_MODULE_MAP = { "nccl": ddp.DataParallelDistributedBackend, diff --git a/src/distributed/backend.py b/src/distributed/backend.py index 3df0fda..5faa6c2 100644 --- a/src/distributed/backend.py +++ b/src/distributed/backend.py @@ -1,16 +1,16 @@ - from typing import List class DistributedBackend(object): - def __init__(self, args): pass def transform_model(self, model): raise NotImplementedError - def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): + def get_context_for_microstep_forward( + self, model, microstep_idx, gradient_accumulation_steps + ): raise NotImplementedError def is_master_process(self) -> bool: @@ -30,6 +30,3 @@ def get_world_size(self): def finalize(self): pass - - def sync(self): - pass diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index d4470b6..e66b5f6 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -1,47 +1,49 @@ -import os import math +import os from contextlib import contextmanager +from torch.distributed import (destroy_process_group, get_world_size, + init_process_group) from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed import init_process_group, destroy_process_group, get_world_size, barrier from .backend import DistributedBackend class DataParallelDistributedBackend(DistributedBackend): - def __init__(self, args): - self.rank = int(os.environ.get('RANK', -1)) + self.rank = int(os.environ.get("RANK", -1)) assert self.rank != -1, "DDP backend can not be used without rank" assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices" init_process_group(backend=args.distributed_backend) - self.local_rank = int(os.environ['LOCAL_RANK']) + self.local_rank = int(os.environ["LOCAL_RANK"]) def get_adjusted_args_for_process(self, args): effective_batch_size = args.batch_size * args.acc_steps world_size = self.get_world_size() - if args.acc_steps % world_size != 0: - raise ValueError(f"Number of accumulation steps " - "{args.acc_steps} is not divisible " - "by the world size {world_size}.") if effective_batch_size % world_size != 0: - raise ValueError(f"Effective batch size " - "{effective_batch_size} is not divisible " - "by the world size {world_size}.") + raise ValueError( + f"Effective batch size " + "{effective_batch_size} is not divisible " + "by the world size {world_size}." + ) acc_steps_div = math.gcd(args.acc_steps, world_size) args.acc_steps = args.acc_steps // acc_steps_div args.batch_size = args.batch_size // (world_size // acc_steps_div) - args.device = f'cuda:{self.local_rank}' + args.device = f"cuda:{self.local_rank}" args.seed = args.seed + self.local_rank + args.data_seed = args.data_seed return args def transform_model(self, model): return DDP(model, device_ids=[self.local_rank]) @contextmanager - def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): + def get_context_for_microstep_forward( + self, model, microstep_idx, gradient_accumulation_steps + ): model.require_backward_grad_sync = ( - microstep_idx == gradient_accumulation_steps - 1) + microstep_idx == gradient_accumulation_steps - 1 + ) yield def is_master_process(self) -> bool: @@ -51,13 +53,10 @@ def get_raw_model(self, model): return model.module def translate_model_parameter_name_for_node(self, parameter_name): - return [f'module.{parameter_name}'] + return [f"module.{parameter_name}"] def get_world_size(self): return get_world_size() def finalize(self): destroy_process_group() - - def sync(self): - barrier() diff --git a/src/distributed/single.py b/src/distributed/single.py index b852988..5f8adb2 100644 --- a/src/distributed/single.py +++ b/src/distributed/single.py @@ -4,6 +4,9 @@ class SinlgeNodeBackend(DistributedBackend): + def __init__(self, args): + super().__init__(args) + self.rank = 0 def transform_model(self, model): return model diff --git a/src/main.py b/src/main.py index 92ed664..b963956 100755 --- a/src/main.py +++ b/src/main.py @@ -1,161 +1,594 @@ +import argparse +import copy +import inspect +import json import os +import random import sys +from pathlib import Path + import numpy as np import torch -import inspect -import json -import copy -import argparse -import random -import wandb import config -from models.utils import get_model -from data.utils import get_dataset -from optim.base import train_base import distributed +import wandb +from data.utils import DataReader, get_dataset +from models.utils import get_model +from optim.adafactor import Adafactor +from optim.adammini import Adam_mini +from optim.ademamix import AdEMAMix +from optim.ademamix2 import AdEMAMix2 +from optim.adopt import ADOPT +from optim.base import train +from optim.clipped import (AdagradClip, AdaGradClipDelayedEta, AdamClip, + AdamClipDelayedEta) +from optim.lamb import Lamb +from optim.lion import Lion +from optim.mars import MARS +from optim.muon import CombinedScheduler, Muon +from optim.normalized import NormalizedSGD +from optim.prodigy import Prodigy +from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, + dd_schedule, wsd_schedule) +from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree +from optim.sgdf import SGDF +from optim.shampoo import DistributedShampoo +from optim.sign import Signum +from optim.soap import SOAP +from optim.sophia import SophiaG def get_args(): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument('--config_format', default='base', choices=config.registered_formats()) + parser.add_argument( + "--config_format", default="base", choices=config.registered_formats() + ) args, rem_args = parser.parse_known_args() - return config.parse_args_with_format(format=args.config_format, base_parser=parser, args=rem_args, namespace=args) + final_args = config.parse_args_with_format( + format=args.config_format, base_parser=parser, args=rem_args, namespace=args + ) + return final_args, parser -def main(args): - - torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training - torch.backends.cudnn.allow_tf32 = True +def main(args, parser): distributed_backend = distributed.make_backend_from_args(args) args = distributed_backend.get_adjusted_args_for_process(args) + args.world_size = distributed_backend.get_world_size() - args.device = torch.device(args.device) - device_type = "cuda" if "cuda" in str(args.device) else "cpu" - if device_type == "cuda": - torch.cuda.set_device(args.device) + if args.full_eval_at is None: + args.full_eval_at = [] + # NOTE args.seed is offset per worker in get_adjusted_args_for_process + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) - - print(f"Loading dataset '{args.dataset}'") - - data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} - if args.data_in_ram: - data = {'train': np.array(data['train']), 'val': np.array(data['val'])} - - print(f"Num training tokens: {len(data['train'])}") - print(f"Num validation tokens: {len(data['val'])}") - - model = get_model(args).to(args.device) # todo: take care of initializing the model if args.use_pretrained != 'none' + if "cuda" in args.device: + torch.cuda.set_device(torch.device(args.device)) + # torch.use_deterministic_algorithms(True) # CUBLAS_WORKSPACE_CONFIG=:4096:8 + + exp_name = get_exp_name(args, parser, distributed_backend) + exp_dir = Path(args.results_base_folder) / exp_name + if distributed_backend.is_master_process() and args.wandb: + wandb.init( + project=args.wandb_project, + name=exp_name, + config=vars(args), + entity=args.wandb_entity, + ) + wandb.define_metric("iter") + wandb.define_metric("train/*", step_metric="iter") + wandb.define_metric("val/*", step_metric="iter") + wandb.define_metric("lr", step_metric="iter") + + print(f"Starting Experiment: {exp_name}") + print(f"Experiment Directory: {exp_dir}") + print(f"Config:\n{vars(args)}\n") + + print(f"Loading dataset: '{args.dataset}'") + datareaders = get_data_readers(args) + + model = get_model(args).to( + args.device + ) # todo: take care of initializing the model if args.use_pretrained != 'none' + print(f"\nModel:\n{model}") model = distributed_backend.transform_model(model) - + group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs() param_name_mapping = {p_name: p for p_name, p in model.named_parameters()} optimized_params_cnt = 0 for g in group_specs: params = [] for p_name in g["params"]: - translated_p_names = distributed_backend.translate_model_parameter_name_for_node(p_name) + translated_p_names = ( + distributed_backend.translate_model_parameter_name_for_node(p_name) + ) params += [param_name_mapping[p_name] for p_name in translated_p_names] g["params"] = params optimized_params_cnt += sum([p.numel() for p in g["params"]]) - print("number of optimized parameters: %.2fM" % (optimized_params_cnt/1e6,)) - if args.opt == 'adamw': - use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) + params_cnt = distributed_backend.get_raw_model(model).get_num_params() + print("number of parameters: %.2fM" % (params_cnt / 1e6,)) + print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,)) + if args.wandb and distributed_backend.is_master_process(): + wandb.log( + {"parameters": params_cnt, "optimized_parameters": optimized_params_cnt} + ) + + args.world_size = distributed_backend.get_world_size() + + if args.opt == "adamw": + device_type = "cuda" if "cuda" in args.device else "cpu" + use_fused = (device_type == "cuda") and ( + "fused" in inspect.signature(torch.optim.AdamW).parameters + ) print(f"using fused AdamW: {use_fused}") extra_args = dict(fused=True) if use_fused else dict() - opt = torch.optim.AdamW(group_specs, lr=args.lr, betas=(args.beta1, args.beta2), - weight_decay=args.weight_decay, **extra_args) + opt = torch.optim.AdamW( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + **extra_args, + ) + elif args.opt == "soap": + opt = SOAP( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + shampoo_beta=args.shampoo_beta, + weight_decay=args.weight_decay, + precondition_frequency=args.precondition_frequency, + max_precond_dim=args.max_precond_dim, + merge_dims=args.merge_dims, + precondition_1d=args.precondition_1d, + normalize_grads=args.normalize_grads, + data_format=args.soap_data_format, + correct_bias=args.correct_bias, + ) + elif args.opt == "muon": + param_list = ( + list(model.parameters()) + if args.distributed_backend is None + else list(model.module.parameters()) + ) + assert ( + sum(p.numel() for p in param_list) == params_cnt + ), "number of parameters must be the same" + opt = Muon( + muon_params=param_list, + lr=args.muon_lr_factor, + momentum=args.momentum, + nesterov=args.nesterov, + ns_steps=args.muon_ns_steps, + adamw_params=None, + adamw_lr=args.lr, + adamw_betas=(args.beta1, args.beta2), + adamw_eps=1e-8, + adamw_wd=args.weight_decay, + ) + elif args.opt == "ademamix": + opt = AdEMAMix( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2, args.adema_beta3), + alpha=args.adema_alpha, + beta3_warmup=args.adema_beta3_warmup, + alpha_warmup=args.adema_alpha_warmup, + weight_decay=args.weight_decay, + ) + elif args.opt == "ademamix2": + opt = AdEMAMix2( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2, args.adema_beta3), + alpha=args.adema_alpha, + beta3_warmup=args.adema_beta3_warmup, + alpha_warmup=args.adema_alpha_warmup, + weight_decay=args.weight_decay, + ) + elif args.opt == "lion": + opt = Lion( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + ) + elif args.opt == "sf-adamw": + opt = AdamWScheduleFree( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + warmup_steps=args.warmup_steps, + r=args.schedulefree_r, + weight_lr_power=args.weight_lr_power, + ) # without foreach argument + elif args.opt == "sf-sgd": + opt = SGDScheduleFree( + group_specs, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + warmup_steps=args.warmup_steps, + r=args.schedulefree_r, + weight_lr_power=args.weight_lr_power, + ) # without foreach argument + elif args.opt == "adam-mini": + opt = Adam_mini( + device=args.device, + world_size=args.world_size, + named_parameters=model.named_parameters(), # check + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + model_sharding=args.model_sharding, + dim=args.n_embd, + n_heads=args.n_head, + n_kv_heads=args.n_kv_head, + verbose=args.adam_mini_verbose, + ) + elif args.opt == "signsgd": + opt = Signum( + group_specs, + lr=args.lr, + momentum=0.0, # always use zero momentum because its signSGD + dampening=args.dampening, + weight_decay=args.weight_decay, + nesterov=args.nesterov, + sign_update=True, + ) + elif args.opt == "signum": + opt = Signum( + group_specs, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + dampening=args.dampening, + nesterov=args.nesterov, + sign_update=True, + ) + elif args.opt == "sgdf": + opt = SGDF( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + ) + elif args.opt == "prodigy": + opt = Prodigy( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + beta3=args.prodigy_beta3, + weight_decay=args.weight_decay, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + fsdp_in_use=args.prodigy_fsdp_in_use, + ) + elif args.opt == "sophiag": + opt = SophiaG( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + rho=args.sophia_rho, + ) + elif args.opt == "shampoo": + opt = DistributedShampoo( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + precondition_frequency=args.precondition_frequency, + weight_decay=args.weight_decay, + use_decoupled_weight_decay=True, + # grafting_config=AdamGraftingConfig( + # beta2=args.beta2, # oroginally, the default value is 0.999 + # epsilon=1e-8, + # ), + ) + elif args.opt == "adopt": + opt = ADOPT( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=1e-6, + weight_decay=args.weight_decay, + ) + elif args.opt in [ + "clip-adagrad", + "clip-adagrad-delay-eta", + "clip-adam", + "clip-adam-delay-eta", + ]: + clipped_adagrad_cfg = { + "lr": args.lr, + "eps": 1e-8, + "weight_decay": args.weight_decay, + "clipping": args.clipping_type, + "max_grad_norm": 1.0, + } + if args.opt == "clip-adagrad": + opt = AdagradClip(**clipped_adagrad_cfg) + clipped_adagrad_delay_eta_cfg = { + **clipped_adagrad_cfg, + "exp_avg_sq_value": 0.0001, + "etta": args.clipping_eta, + } + if args.opt == "clip-adagrad-delay-eta": + opt = AdaGradClipDelayedEta(**clipped_adagrad_delay_eta_cfg) + clipped_adam_cfg = { + **clipped_adagrad_cfg, + "betas": (args.beta1, args.beta2), + "correct_bias": args.correct_bias, + } + if args.opt == "clip-adam": + opt = AdamClip(**clipped_adam_cfg) + clipped_adam_delay_eta_cfg = { + **clipped_adam_cfg, + "exp_avg_sq_value": 0.00001, + "etta": args.clipping_eta, + } + if args.opt == "clip-adam-delay-eta": + opt = AdamClipDelayedEta(**clipped_adam_delay_eta_cfg) + elif args.opt == "mars": + opt = MARS( + group_specs, + lr=args.mars_lr, + betas=(args.mars_beta1, args.mars_beta2), + weight_decay=args.weight_decay, + amsgrad=False, + gamma=args.mars_vr_gamma, + is_approx=args.mars_is_approx, + mars_type=args.mars_type, + optimize_1d=False, # we set in order to optimize 1D parameters with AdamW + lr_1d=args.lr, # AdamW's lr when optimize_1d=False + betas_1d=(args.beta1, args.beta2), # AdamW's betas when optimize_1d=False + weight_decay_1d=0.1, # AdamW's weight decay + ) + elif args.opt == "adafactor": + opt = Adafactor( + group_specs, + lr=args.lr, + decay_rate=args.adafactor_decay_rate, + beta1=args.beta1, + clip_threshold=1.0, + weight_decay=args.weight_decay, + ) + elif args.opt == "lamb": + opt = Lamb( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + adam=False, + bias_correction=args.lamb_use_bias_correction, + ) + elif args.opt == "normalized-sgd": + opt = NormalizedSGD( + group_specs, + lr=args.lr, + momentum=args.momentum, + dampening=args.dampening, + weight_decay=args.weight_decay, + nesterov=args.nesterov, + sign_update=False, + ) else: - opt = torch.optim.SGD(group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) - - if args.scheduler != 'none': - if args.scheduler in ['cos', 'linear']: - scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=args.lr, total_steps=args.iterations, - pct_start=args.warmup_percent, anneal_strategy=args.scheduler, - cycle_momentum=False, div_factor=1e2, final_div_factor=.1) + opt = torch.optim.SGD( + group_specs, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + nesterov=args.nesterov, + ) + print(f"\nOptimizer:\n{opt}") + + if args.scheduler != "none": + assert ( + args.warmup_steps < args.iterations + ), "Warmup steps must be < iterations." # from schedules-and-scaling + if args.scheduler in ["cos", "linear"]: + # initial lr is args.lr / div_factor + # final lr is initial_lr/final_div_factor = args.lr / div_factor / final_div_factor + scheduler = ( + torch.optim.lr_scheduler.OneCycleLR( + optimizer=opt, + max_lr=[ + group.get("lr", args.lr) for group in group_specs + ], # it was args.lr + total_steps=args.iterations, + pct_start=args.warmup_steps + / args.iterations, # it was args.warmup_percent + anneal_strategy=args.scheduler, + cycle_momentum=False, + div_factor=1e2, + final_div_factor=1, + ) + if args.opt != "muon" + else CombinedScheduler(opt, args) + ) + elif args.scheduler == "cos_inf": + lambda_schedule = cos_inf_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + n_inf=args.cos_inf_steps, + div_factor=1e2, + final_div_factor=0.1, + ) + scheduler = ( + torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + if args.opt != "muon" + else CombinedScheduler(opt, args) + ) + elif args.scheduler == "wsd": + lambda_schedule = wsd_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + fract_decay=args.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=args.wsd_final_lr_scale, # should be 0 here + decay_type=args.decay_type, + ) + scheduler = ( + torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + if args.opt != "muon" + else CombinedScheduler(opt, args) + ) + elif args.scheduler == "cos_wsd": + lambda_schedule = cosine_wsd_decay_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + anneal_end_factor=0.15, # 0.2 + fract_decay=args.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=0.1, # should be 0 here + decay_type=args.decay_type, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + elif args.scheduler == "dd": + lambda_schedule = dd_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + fract_fisrt_decay=args.wsd_fract_decay, # this will be responsible for the first decay phase + max_lr=args.lr, # [group.get("lr", args.lr) for group in group_specs], + first_final_lr_factor=args.dd_first_lr_factor, + second_final_lr_factor=0.0, # stop with zero lr + div_factor=1e2, + first_decay_type=args.decay_type, + second_decay_type=args.dd_second_decay_type, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) else: raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") else: scheduler = None - args.world_size = distributed_backend.get_world_size() - exp_name = args.exp_name - if distributed_backend.is_master_process() and args.wandb: - params_copy = copy.deepcopy(vars(args)) - del params_copy['device'] - wandb.init(project=args.wandb_project, name=exp_name, config=params_copy) - - ckpt_path = os.path.join(args.results_base_folder, args.dataset, args.model, exp_name) - if not os.path.exists(ckpt_path): - if distributed_backend.is_master_process(): - os.makedirs(ckpt_path) - distributed_backend.sync() - elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed - print(f"Already found experiment '{ckpt_path}'.\nSkipping.") - sys.exit(0) - - itr = 0 - rng_state_dict = None - if args.use_pretrained == "auto": - checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] - if checkpoints: - args.use_pretrained = sorted(checkpoints)[-1] + if (exp_dir / "ckpts" / "latest" / "main.pt").exists(): + if not args.auto_resume: + raise ValueError( + f"The experiment dir {exp_dir} already exists. " + + "To resume training, set auto_resume=True. " + + "Otherwise, specify a different experiment name. " + ) else: - args.use_pretrained = None - - if args.use_pretrained is not None: - last_ckpt_path = args.use_pretrained - print(f"Resuming from {last_ckpt_path}") - checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path)) - model_state_dict = {distributed_backend.translate_model_parameter_name_for_node(k.replace("_orig_mod.", ""))[0]:v for k,v in checkpoint['model'].items()} - # FIXME checkpoints from compiled model have _orig_mod keyword - - optimizer_state_dict = checkpoint['optimizer'] - rng_state_dict = { - module: checkpoint[module] for module in [ - "cpu_rng_state", - "gpu_rng_state", - "numpy_rng_state", - "py_rng_state", - "train_sampler_state" - ] - } + # Auto resume overwrites resume_from + args.resume_from = str(exp_dir / "ckpts" / "latest") + elif distributed_backend.is_master_process(): + exp_dir.mkdir(parents=True, exist_ok=True) - model.load_state_dict(model_state_dict) - opt.load_state_dict(optimizer_state_dict) - itr = checkpoint['itr'] - if scheduler is not None: - scheduler_state_dict = checkpoint['scheduler'] - scheduler.load_state_dict(scheduler_state_dict) + stats = train( + model=model, + opt=opt, + datareaders=datareaders, + scheduler=scheduler, + exp_dir=exp_dir, + distributed_backend=distributed_backend, + cfg=args, + ) - if args.model in ['base', 'llama2']: # all train functions have the same interface - train = train_base - else: - raise NotImplementedError(f"No training method implemented for model type '{args.model}'.") - - print(f"\nTraining model={args.model} \n{vars(args)}\n") - - stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, - eval_freq=args.eval_freq, - distributed_backend=distributed_backend, - ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args) - - args.device = None - args.dtype = None - stats['args'] = vars(args) + stats["args"] = vars(args) if distributed_backend.is_master_process(): - with open(f"{ckpt_path}/summary.json", "w") as fs: + with open(exp_dir / "summary.json", "w") as fs: json.dump(stats, fs) distributed_backend.finalize() +def get_data_readers(args, verbose=True): + data_srcs = get_dataset(args) + train_reader = DataReader( + data_src=data_srcs["train"], + batch_size=args.batch_size, + sequence_length=args.sequence_length, + seed=args.data_seed, + with_replacement=False, + auto_shard=True, + keep_in_ram=args.data_in_ram, + ) + val_reader = DataReader( + data_src=data_srcs["val"], + batch_size=args.batch_size, + sequence_length=args.sequence_length, + seed=args.data_seed, + with_replacement=False, + auto_shard=False, # NOTE Identical Per Rank + keep_in_ram=args.data_in_ram, + ) + + if verbose: + print(f"Num training tokens: {train_reader.num_tokens}") + print(f"Num validation tokens: {val_reader.num_tokens}") + + return { + "train": train_reader, + "val": val_reader, + } + + +def get_exp_name( + args, + parser, + distributed_backend, + key_args=["model", "dataset", "opt"], + ignore_args=[ + "eval_interval", + "full_eval_at", + "distributed_backend", + "latest_ckpt_interval", + "wandb", + "wandb_project", + "wandb_entity", + "batch_size", + "acc_steps", + "results_base_folder", + "run_prefix", + "wandb_run_prefix", + ], +): + # Get the default values + defaults = vars(parser.parse_args([])) + + rank = distributed_backend.rank + + # Generate the prefix with key arguments + prefix_parts = [] + for key in key_args: + if hasattr(args, key): + value = getattr(args, key) + prefix_parts.append(f"{key}-{value}") + + prefix = "_".join(prefix_parts) + prefix = f"{args.batch_size}x{args.acc_steps}(rank={rank})_" + prefix + + # Generate the rest of the string with non-default arguments + non_default_parts = [] + for key, value in vars(args).items(): + if key in ignore_args: + continue + if key not in defaults: + print(f"Warning: {key} not in defaults") + continue + if key not in key_args and value != defaults[key]: + non_default_parts.append(f"{key}-{value}") + + non_default_string = "_".join(non_default_parts) + + if args.run_prefix is not None: + prefix = args.run_prefix + "_" + prefix + + # Combine prefix and non-default string + if non_default_string: + return f"{prefix}__{non_default_string}" + else: + return prefix + + if __name__ == "__main__": - args = get_args() - main(args) + args, parser = get_args() + main(args, parser) diff --git a/src/models/base.py b/src/models/base.py index a844592..3f2dba3 100755 --- a/src/models/base.py +++ b/src/models/base.py @@ -8,7 +8,6 @@ """ import math -import inspect import tiktoken import torch @@ -17,7 +16,7 @@ class LayerNorm(nn.Module): - """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" def __init__(self, ndim, bias): super().__init__() @@ -29,7 +28,6 @@ def forward(self, input): class CausalSelfAttention(nn.Module): - def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 @@ -44,34 +42,52 @@ def __init__(self, config): self.n_embd = config.n_embd self.dropout = config.dropout # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 - self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") if not self.flash: - print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + print( + "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" + ) # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.sequence_length, config.sequence_length)) - .view(1, 1, config.sequence_length, config.sequence_length)) + self.register_buffer( + "bias", + torch.tril( + torch.ones(config.sequence_length, config.sequence_length) + ).view(1, 1, config.sequence_length, config.sequence_length), + ) def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + # batch size, sequence length, embedding dimensionality (n_embd) + ( + B, + T, + C, + ) = x.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + # (B, T, nh, hs) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + + # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True + ) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -79,11 +95,16 @@ def forward(self, x): class MLP(nn.Module): - - def __init__(self, config): + def __init__(self, config, exp_factor=1.0): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dim_exp_factor = exp_factor * 4 + + self.c_fc = nn.Linear( + config.n_embd, int(self.dim_exp_factor * config.n_embd), bias=config.bias + ) + self.c_proj = nn.Linear( + int(self.dim_exp_factor * config.n_embd), config.n_embd, bias=config.bias + ) self.dropout = nn.Dropout(config.dropout) self.activation = nn.GELU() @@ -96,22 +117,30 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, config): super().__init__() self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.parallel = config.parallel_block + if not self.parallel: + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) - def forward(self, x): - x = x + self.attn(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) + def forward(self, x, *args, **kwargs): + if self.parallel: + # from GPT-J 6B https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L299 + x_ln = self.ln_1(x, *args, **kwargs) + x_attn = self.attn(x_ln) + x_ffn = self.mlp(x_ln) + x = x + x_attn + x_ffn + else: + x = x + self.attn(self.ln_1(x, *args, **kwargs)) + x_ = self.mlp(self.ln_2(x, *args, **kwargs)) + x = x + x_ return x - -class GPTBase(nn.Module): +class GPTBase(nn.Module): def __init__(self, config): super().__init__() assert config.vocab_size is not None @@ -119,30 +148,35 @@ def __init__(self, config): self.config = config self.tokenizer = tiktoken.get_encoding("gpt2") - self.transformer = nn.ModuleDict(dict( - wte = nn.Embedding(config.vocab_size, config.n_embd), - wpe = nn.Embedding(config.sequence_length, config.n_embd), - drop = nn.Dropout(config.dropout), - h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = LayerNorm(config.n_embd, bias=config.bias), - )) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.sequence_length, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. # This behavior is deprecated and will be an error in future versions" # not 100% sure what this is, so far seems to be harmless. TODO investigate - self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + self.transformer.wte.weight = ( + self.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): - if pn.endswith('c_proj.weight'): - torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) - - # report number of parameters - print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, + mean=0.0, + std=self.config.init_std / math.sqrt(2 * config.n_layer), + ) def get_num_params(self, non_embedding=True): """ @@ -158,22 +192,29 @@ def get_num_params(self, non_embedding=True): def _init_weights(self, module): if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) def forward(self, idx, targets=None, get_logits=False): device = idx.device b, t = idx.size() - assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + assert ( + t <= self.config.sequence_length + ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" + # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # forward the GPT model itself - tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) + + # forward pass through all the transformer blocks for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) @@ -181,13 +222,21 @@ def forward(self, idx, targets=None, get_logits=False): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) + else: # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + logits = self.lm_head( + x[:, [-1], :] + ) # note: using list [-1] to preserve the time dim loss = None logits = logits if get_logits else None - return {'logits': logits, 'loss': loss} + return { + "logits": logits, + "loss": loss, + } def crop_sequence_length(self, sequence_length): # model surgery to decrease the block size if necessary @@ -195,14 +244,30 @@ def crop_sequence_length(self, sequence_length): # but want to use a smaller block size for some smaller, simpler model assert sequence_length <= self.config.sequence_length self.config.sequence_length = sequence_length - self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:sequence_length]) + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:sequence_length] + ) for block in self.transformer.h: - block.attn.bias = block.attn.bias[:,:,:sequence_length,:sequence_length] - - @classmethod - def from_pretrained(cls, model_type, override_args=None): - # TODO - pass + block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] + + def from_pretrained( + self, + model_path, + ): + paths = model_path.split(",") + if len(paths) == 1: + # TODO: with distributed? + loaded_state = torch.load( + str(model_path + "/ckpt.pt"), + map_location=torch.device(self.config.device), + ) + state_to_load = loaded_state["model"] + + # load the sparse model + state_to_load = { + ".".join(k.split(".")[1:]): v # drop _orig_mod from keys + for k, v in state_to_load.items() + } def get_parameter_group_specs(self): """ @@ -262,7 +327,6 @@ def get_parameter_group_specs(self): {"params": sorted(list(no_decay)), "weight_decay": 0.0}, ] - @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ @@ -272,15 +336,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at sequence_length - idx_cond = idx if idx.size(1) <= self.config.sequence_length else idx[:, -self.config.sequence_length:] + idx_cond = ( + idx + if idx.size(1) <= self.config.sequence_length + else idx[:, -self.config.sequence_length :] + ) # forward the model to get the logits for the index in the sequence - logits = self(idx_cond, get_logits=True)['logits'] + logits = self(idx_cond, get_logits=True)["logits"] # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') + logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution @@ -289,9 +357,20 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx - + @torch.no_grad() def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None): - idx = torch.tensor(self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})).view(1,-1).to(self.lm_head.weight.device) - out_idx = self.generate(idx, max_new_tokens, temperature, top_k).view(-1).to('cpu').numpy() + idx = ( + torch.tensor( + self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"}) + ) + .view(1, -1) + .to(self.lm_head.weight.device) + ) + out_idx = ( + self.generate(idx, max_new_tokens, temperature, top_k) + .view(-1) + .to("cpu") + .numpy() + ) return self.tokenizer.decode(out_idx) diff --git a/src/models/llama.py b/src/models/llama.py index 1604ca7..e6aaec6 100644 --- a/src/models/llama.py +++ b/src/models/llama.py @@ -1,17 +1,6 @@ """ -Llama style Language Model. -References: -1) Llama inference code: -https://github.com/facebookresearch/llama/blob/main/llama/model.py -2) Mistral one file ref: -https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py -3) Llama paper: -https://arxiv.org/pdf/2302.13971.pdf - -Main differences from GPT2: -* Uses RMSNorm instead of LayerNorm -* Uses a slightly different MLP (SwiGLU) -* rotary embeddings (RoPE) +Llama style Language Model that is +compilable (avoids torch complex) """ import math @@ -20,6 +9,7 @@ import torch import torch.nn as nn from torch.nn import functional as F + from models.base import CausalSelfAttention, GPTBase @@ -27,7 +17,10 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore - return torch.polar(torch.ones_like(freqs), freqs) # complex64 + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + # Stack the cos and sin parts in the last dimension to simulate complex numbers + return torch.stack((cos_freqs, sin_freqs), dim=-1) def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: @@ -37,11 +30,11 @@ def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Te """ ndim = x.ndim assert 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( - freqs_cis.shape, - (x.shape[1], x.shape[-1]), - ) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + assert freqs_cis.shape[:-1] == (x.shape[1], x.shape[-2]) + # New shape for broadcasting + shape = [ + 1 if i != 1 and i != ndim - 2 else d for i, d in enumerate(x.shape[:-1]) + ] + [2] return freqs_cis.view(*shape) @@ -49,12 +42,22 @@ def apply_rotary_emb(q, k, freqs_cis): # q, k: (B, T, nh, hs) # freq_cis: (T, hs) # return: (B, T, nh, hs), (B, T, nh, hs) - q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) - k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) - freqs_cis = _reshape_for_broadcast(freqs_cis, q_) - xq_out = torch.view_as_real(q_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(k_ * freqs_cis).flatten(3) - return xq_out.type_as(q), xk_out.type_as(k) + q = q.float().reshape(*q.shape[:-1], -1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 2) + + freqs_cis = _reshape_for_broadcast(freqs_cis, q) + + # Perform manual "complex" multiplication + q_cos = q[..., 0] * freqs_cis[..., 0] - q[..., 1] * freqs_cis[..., 1] + q_sin = q[..., 0] * freqs_cis[..., 1] + q[..., 1] * freqs_cis[..., 0] + k_cos = k[..., 0] * freqs_cis[..., 0] - k[..., 1] * freqs_cis[..., 1] + k_sin = k[..., 0] * freqs_cis[..., 1] + k[..., 1] * freqs_cis[..., 0] + + # Combine the results back into the interleaved format expected by q and k + q_out = torch.stack((q_cos, q_sin), dim=-1).reshape(q.shape).flatten(3) + k_out = torch.stack((k_cos, k_sin), dim=-1).reshape(k.shape).flatten(3) + + return q_out, k_out class RMSNorm(nn.Module): @@ -142,7 +145,8 @@ def __init__(self, config): def forward(self, x, freqs_cis): x = x + self.attn(self.ln_1(x), freqs_cis) - x = x + self.mlp(self.ln_2(x)) + x_ = self.mlp(self.ln_2(x)) + x = x + x_ return x @@ -185,7 +189,6 @@ def __init__(self, config): p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) ) - def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. @@ -211,7 +214,7 @@ def forward(self, idx, targets=None, get_logits=False): x = self.transformer.drop(tok_emb) freqs_cis = self.freqs_cis.to(x.device)[pos] - for block_idx, block in enumerate(self.transformer.h): + for block in self.transformer.h: x = block(x, freqs_cis=freqs_cis) x = self.transformer.ln_f(x) diff --git a/src/models/test.py b/src/models/test.py new file mode 100644 index 0000000..1deb696 --- /dev/null +++ b/src/models/test.py @@ -0,0 +1,232 @@ +""" +Llama style Language Model that is +compilable (avoids torch complex) +""" + +import math + +import tiktoken +import torch +import torch.nn as nn +from torch.nn import functional as F + +from models.base import CausalSelfAttention, GPTBase + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + # Stack the cos and sin parts in the last dimension to simulate complex numbers + return torch.stack((cos_freqs, sin_freqs), dim=-1) + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape[:-1] == (x.shape[1], x.shape[-2]) + # New shape for broadcasting + shape = [ + 1 if i != 1 and i != ndim - 2 else d for i, d in enumerate(x.shape[:-1]) + ] + [2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(q, k, freqs_cis): + # q, k: (B, T, nh, hs) + # freq_cis: (T, hs) + # return: (B, T, nh, hs), (B, T, nh, hs) + q = q.float().reshape(*q.shape[:-1], -1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 2) + + freqs_cis = _reshape_for_broadcast(freqs_cis, q) + + # Perform manual "complex" multiplication + q_cos = q[..., 0] * freqs_cis[..., 0] - q[..., 1] * freqs_cis[..., 1] + q_sin = q[..., 0] * freqs_cis[..., 1] + q[..., 1] * freqs_cis[..., 0] + k_cos = k[..., 0] * freqs_cis[..., 0] - k[..., 1] * freqs_cis[..., 1] + k_sin = k[..., 0] * freqs_cis[..., 1] + k[..., 1] * freqs_cis[..., 0] + + # Combine the results back into the interleaved format expected by q and k + q_out = torch.stack((q_cos, q_sin), dim=-1).reshape(q.shape).flatten(3) + k_out = torch.stack((k_cos, k_sin), dim=-1).reshape(k.shape).flatten(3) + + return q_out, k_out + + +class RMSNorm2(nn.Module): + def __init__(self, ndim, eps=1e-5, bias=False): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + self.eps = eps + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, self.eps) + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + + hidden_dim = config.n_embd * 4 + + self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + + def forward(self, x): + return self.c_proj(nn.functional.gelu(self.w1(x))) + + +class LlamaAttention(CausalSelfAttention): + def forward(self, x, freqs_cis): + # batch size, sequence length, embedding dimensionality (n_embd) + ( + B, + T, + C, + ) = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + # (B, T, nh, hs) + k = k.view(B, T, self.n_head, C // self.n_head) + q = q.view(B, T, self.n_head, C // self.n_head) + q, k = apply_rotary_emb(q, k, freqs_cis) + # (B, nh, T, hs) + q, k = q.transpose(1, 2), k.transpose(1, 2) + + # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class LlamaBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = RMSNorm2(config.n_embd, eps=config.rmsnorm_eps) + self.attn = LlamaAttention(config) + self.ln_2 = RMSNorm2(config.n_embd, eps=config.rmsnorm_eps) + self.mlp = LlamaMLP(config) + + def forward(self, x, freqs_cis): + x = x + self.attn(self.ln_1(x), freqs_cis) + x_ = self.mlp(self.ln_2(x)) + x = x + x_ + return x + + +class Test(GPTBase): + def __init__(self, config): + super().__init__(config) + assert config.vocab_size is not None + assert config.sequence_length is not None + self.config = config + self.tokenizer = tiktoken.get_encoding("gpt2") + + # create the token and position embeddings + self.head_dim = config.n_embd // config.n_head + self.freqs_cis = precompute_freqs_cis(self.head_dim, config.sequence_length) + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([LlamaBlock(config) for _ in range(config.n_layer)]), + ln_f=RMSNorm2(config.n_embd, eps=config.rmsnorm_eps), + ) + ) + + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate + self.transformer.wte.weight = ( + self.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default) + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def forward(self, idx, targets=None, get_logits=False): + device = idx.device + b, t = idx.size() + assert ( + t <= self.config.sequence_length + ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" + # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device) + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + # pos_emb = self.transformer.wpe(pos) + + x = self.transformer.drop(tok_emb) # + pos_emb) + freqs_cis = self.freqs_cis.to(x.device)[pos] + + for block in self.transformer.h: + x = block(x, freqs_cis=freqs_cis) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head( + x[:, [-1], :] + ) # note: using list [-1] to preserve the time dim + loss = None + + logits = logits if get_logits else None + + return { + "logits": logits, + "loss": loss, + } diff --git a/src/models/utils.py b/src/models/utils.py index 6d60e10..05cd133 100755 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -1,23 +1,38 @@ import torch -from .llama import Llama, RMSNorm -from .base import GPTBase, LayerNorm +from .base import GPTBase, LayerNorm +from .llama import Llama, RMSNorm +from .test import RMSNorm2, Test BLACKLIST_WEIGHT_MODULES = ( torch.nn.LayerNorm, LayerNorm, RMSNorm, + RMSNorm2, torch.nn.Embedding, ) def get_model(args): - """ Return the right model """ - if args.model == 'base': + """Return the right model""" + if args.model == "base": model = GPTBase(args) + if args.use_pretrained != "none": + model.from_pretrained(args.use_pretrained) return model - elif args.model == 'llama2': + elif args.model == "llama": model = Llama(args) + if args.use_pretrained != "none": + raise NotImplementedError( + f"Loading of pretrained models not yet implemented for model '{args.model}'." + ) + return model + elif args.model == "test": + model = Test(args) + if args.use_pretrained != "none": + raise NotImplementedError( + f"Loading of pretrained models not yet implemented for model '{args.model}'." + ) return model else: raise KeyError(f"Unknown model '{args.model}'.") diff --git a/src/optim/adafactor.py b/src/optim/adafactor.py new file mode 100644 index 0000000..2381f92 --- /dev/null +++ b/src/optim/adafactor.py @@ -0,0 +1,205 @@ +""" +Here is an implementation of Adafactor. +Source: https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py +""" + +import math +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +LossClosure = Callable[[], float] +OptLossClosure = Optional[LossClosure] +OptFloat = Optional[float] +ParamGroup = Dict[str, Any] +State = Dict[str, Any] + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + + It has been proposed in: `Adafactor: Adaptive Learning Rates with + Sublinear Memory Cost`__. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: external learning rate (default: None) + eps2: regularization constans for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold: threshold of root mean square of + final gradient update (default: 1.0) + decay_rate: coefficient used to compute running averages of square + gradient (default: -0.8) + beta1: coefficient used for computing running averages of gradient + (default: None) + weight_decay: weight decay (L2 penalty) (default: 0) + scale_parameter: if true, learning rate is scaled by root mean square + of parameter (default: True) + relative_step: if true, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init: time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.Adafactor(model.parameters()) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ https://arxiv.org/abs/1804.04235 + + Note: + Reference code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py # noqa + """ + + def __init__( + self, + params, + lr=1e-3, + eps2=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict( + lr=lr, + eps2=eps2, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super(Adafactor, self).__init__(params, defaults) + + def _get_lr(self, param_group: ParamGroup, param_state: State) -> float: + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = ( + 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + ) + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps2"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + def _get_options( + self, param_group: ParamGroup, param_shape: Tuple[int, ...] + ) -> Tuple[bool, bool]: + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + def _rms(self, tensor: torch.Tensor) -> float: + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad( + self, + exp_avg_sq_row: torch.Tensor, + exp_avg_sq_col: torch.Tensor, + output: torch.Tensor, + ) -> None: + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + torch.mul(r_factor, c_factor, out=output) + + def step(self, closure: OptLossClosure = None) -> OptFloat: + r"""Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + grad, memory_format=torch.preserve_format + ) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as( + grad + ) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).type_as(grad) + else: + state["exp_avg_sq"] = torch.zeros_like( + grad, memory_format=torch.preserve_format + ) + + state["RMS"] = 0 + + state["step"] += 1 + state["RMS"] = self._rms(p.data) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps2"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + + # Approximation of exponential moving average of square + # of gradient + self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + torch.rsqrt(exp_avg_sq, out=update).mul_(grad) + + update.div_(max(1.0, self._rms(update) / group["clip_threshold"])) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) + update = exp_avg + + if group["weight_decay"] != 0: + p.data.add_(p.data, alpha=-group["weight_decay"] * lr) + + p.data.add_(-update) + + return loss diff --git a/src/optim/adammini.py b/src/optim/adammini.py new file mode 100644 index 0000000..c340b31 --- /dev/null +++ b/src/optim/adammini.py @@ -0,0 +1,452 @@ +""" +Here is an original implementation of Adam-mini. +Source: https://github.com/zyushun/Adam-mini +""" + +import math +from typing import Iterable, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._tensor import Replicate + + +class Adam_mini(torch.optim.Optimizer): + def __init__( + self, + device, + world_size, + named_parameters: Iterable[Tuple[str, nn.Parameter]], + lr: Union[float, torch.Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + model_sharding: bool = None, + dim: int = 2048, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + verbose=True, + ): + """ + This is the official implementation of Adam-mini (version 1.1.0). + + Paper: [Adam-mini: Use Fewer Learning Rates To Gain More](https://arxiv.org/abs/2406.16793). + + Github repo: https://github.com/zyushun/Adam-mini + + Arguments: + named_parameters ('Iterable[Tuple[str, nn.Parameter]]'): Iterable of named parameters to optimize or dictionaries defining parameter groups. Usually set to model.named_parameters() + + lr (`float`, *optional*, defaults to 0.001): The learning rate to use. + + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): Same as Adam's betas parameters (b1, b2). + + eps (`float`, *optional*, defaults to 1e-06): Same as Adam's epsilon for numerical stability. + + weight_decay (`float`, *optional*, defaults to 0.0): Decoupled weight decay to apply. + + model_sharding (`bool`, *optional*, defaults to None): Set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if otherwise. Due to the historical reason, this argument is deprecated since version 1.0.2. We will assume that model parallelism is always used. We will remove this argument in the future version. + + dim (`int`, *optional*, defaults to 2048): Dimension for hidden features. Can be left unspecified if training non-transformer models. + + n_heads (`int`, *optional*, defaults to 32): Number of attention heads. Can be left unspecified if training non-transformer models. + + n_kv_heads (`int`, *optional*, defaults to None): Number of heads for Key and Value. Or equivalently, number of query groups in Group Query Attention. Also known as "n_query_groups". If not specified, it will be equal to n_head. Can be left unspecified if training non-transformer models. + + verbose (`bool`, *optional*, defaults to True): Print all the logs if true. + Example: + + ```python + optimizer = Adam_mini( + named_parameters = model.named_parameters(), + lr = lr, + betas = (beta1,beta2), + eps = eps, + weight_decay = weight_decay, + dim = model_config.dim, + n_heads = model_config.n_heads, + n_kv_heads = model_config.n_kv_heads, + ) + ``` + + """ + self.named_parameters = named_parameters + self.dim = dim + self.n_heads = n_heads + if n_kv_heads is not None: + assert n_heads % n_kv_heads == 0, f"{n_heads} {n_kv_heads}" + self.n_kv_heads = n_kv_heads + else: + self.n_kv_heads = n_heads + + self.device = device + self.world_size = world_size # torch.cuda.device_count() + self.verbose = verbose + self.check_block_name = True + self.head_numel = self.dim * self.dim // self.n_heads + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not self.dim == int(self.dim): + raise ValueError("Invalid dim value: {}".format(self.dim)) + if not self.n_heads == int(self.n_heads): + raise ValueError("Invalid n_heads value: {}".format(self.n_heads)) + if not self.n_kv_heads == int(self.n_kv_heads): + raise ValueError("Invalid n_kv_heads value: {}".format(self.n_kv_heads)) + + if model_sharding is not None and verbose: + print( + "Warning by Adam-mini: model_sharding is deprecated since version 1.0.2. This argument is always set True. We will remove this argument in the future version." + ) + + # Embedding layer. Use one lr per token + self.embd_names = {"embed", "embd", "wte"} # move to mlp + # Output layers. Use one lr per token + self.output_names = {"lm_head.weight", "output.weight"} # move output to mlp + # Query and Keys. User one lr per head + self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"} + # Values. Use one lr per neuron + # it is okay to set self.wv_names to be empty and use a single lr for the whole v. But this will bring extra all_reduce operations + self.wv_names = {"v_proj.weight", "wv.weight"} + # attn_proj. Use one lr per neuron + self.attn_proj_names = {"o_proj.weight", "wo.weight", "attn.proj.weight"} + # MLPs. Use one lr per neuron + self.mlp_names = {"feed_forward", "linear", "mlp"} + # Blocks that use Adam. For old versions before v.1.1.0, this is for embedding layer and output layer. For the current version, this is empty + self.adam_block_names = {} + + optim_groups = [] + # count_embd = count_output = count_wqk = 0 + for param_name, param in named_parameters: + if not param.requires_grad: + continue + if verbose: + print( + "Adam-mini found the param block with name:", + param_name, + param.size(), + ) + state = {} + state["name"] = param_name + state["params"] = param + if "norm" in param_name or "ln_f" in param_name: + state["weight_decay"] = 0.0 + else: + state["weight_decay"] = weight_decay + + optim_groups.append(state) + + defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], eps=eps) + super().__init__(optim_groups, defaults) + + def count_block(self): + count_embd = 0 + count_output = 0 + count_wqk = 0 + count_wv = 0 + count_attn_proj = 0 + count_mlp = 0 + for param_name, param in self.named_parameters: + if not param.requires_grad: + continue + if any(embd_name in param_name for embd_name in self.embd_names): + count_embd += 1 + if any(output_name in param_name for output_name in self.output_names): + count_output += 1 + if any(wqk_name in param_name for wqk_name in self.wqk_names): + count_wqk += 1 + assert ( + self.dim * self.dim + ) % self.n_heads == 0, f"{self.dim} {self.n_heads}" + if any(wv_name in param_name for wv_name in self.wv_names): + count_wv += 1 + if any( + attn_proj_name in param_name for attn_proj_name in self.attn_proj_names + ): + count_attn_proj += 1 + if any(mlp_name in param_name for mlp_name in self.mlp_names): + count_mlp += 1 + if self.verbose: + print( + f"Adam-mini found {count_embd} embedding layers, {count_output} output layers; {count_wqk} Querys and Keys; {count_wv} Values; {count_attn_proj} attn_proj; {count_mlp} MLPs;" + ) + + if count_embd == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No embedding layer found. If you are training Transformers, please check the name of your embedding layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the keywords in the name of your embedding layer'). " + ) + if count_output == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No output layer found. If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.output_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.output_names.add('the keywords in the name of your output layer'). Please ignore this warning if you are using weight-tying." + ) + if count_wqk == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No Query or Key found. If you are training Transformers, please check the name of your Query and Key in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding two additional lines of code: optimizer.wqk_names.add('the keywords in the name of your Query' ); optimizer.wqk_names.add('the keywords in the name of your Key'). " + ) + + if count_wv == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No Value found. If you are training Transformers, please check the name of your Value in attention blocks and manually add them to 'self.wv_names' of Adam-mini. You can do this by adding an additional lines of code: optimizer.wv_names.add('the keywords in the name of your Value' ). " + ) + + if count_attn_proj == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No attn_proj found. If you are training Transformers, please check the name of your attn_proj in attention blocks and manually add them to 'self.attn_proj_names' of Adam-mini. You can do this by adding an additional lines of code: optimizer.attn_proj_names.add('the keywords in the name of your attn_proj' ). " + ) + + if count_mlp == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No MLP found. If you are training Transformers, please check the name of your MLP in attention blocks and manually add them to 'self.mlp_names' of Adam-mini. You can do this by adding an additional lines of code: optimizer.attn_proj_names.add('the keywords in the name of your MLP' ). " + ) + + if ( + count_output + + count_embd + + count_wqk + + count_wv + + count_attn_proj + + count_mlp + == 0 + ) and self.verbose: + print( + "=====>>> Warning by Adam-mini: you are using default PyTorch partition for Adam-mini. It can cause training instability on large-scale Transformers." + ) + + @torch.no_grad() + def step(self, closure=None): + if self.check_block_name: + self.count_block() + self.check_block_name = False + + loss = None + device = self.device + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + beta1 = group["beta1"] + beta2 = group["beta2"] + lr = group["lr"] + name = group["name"] + eps = group["eps"] + + for p in group["params"]: + state = self.state[p] + if any( + adam_block_name in name for adam_block_name in self.adam_block_names + ): # For v.1.1.0, we will not enter here + if p.grad is None: + continue + if len(state) == 0: + state["m"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["step"] = 0 + state["v"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + grad = p.grad + state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = lr / bias_correction_1 + p.addcdiv_(state["m"], h, value=-stepsize) + elif any( + wqk_name in name for wqk_name in self.wqk_names + ): # this is for query and key + if p.grad is None: + continue + head_numel = self.head_numel # group["head_numel"] + if len(state) == 0: + m = torch.zeros_like(p, memory_format=torch.preserve_format) + state["m"] = m.view(-1, head_numel) + state["head_per_gpu"] = state["m"].size( + 0 + ) # this is head per gpu + state["step"] = 0 + # NOTE: We must use `zeros_like` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # the following line is equivalent to: state["vmean"] = torch.zeros(state["head"]) + state["vmean"] = torch.zeros_like( + state["m"][0 : state["head_per_gpu"], 0:1], + memory_format=torch.preserve_format, + ) + + grad = p.grad # .to(torch.float32) + head_per_gpu = state["head_per_gpu"] + grad = grad.view(head_per_gpu, head_numel) + tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) + + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = ((1 / bias_correction_1) / h).view(head_per_gpu, 1) + update = (state["m"] * stepsize).view(p.size()) + update.mul_(lr) + p.add_(-update) + elif ( + any(embd_name in name for embd_name in self.embd_names) + or any(output_name in name for output_name in self.output_names) + or any(wv_name in name for wv_name in self.wv_names) + or any(mlp_name in name for mlp_name in self.mlp_names) + or any( + attn_proj_name in name + for attn_proj_name in self.attn_proj_names + ) + ): + if p.grad is None: + continue + # neuron_numel = group["neuron_numel"] # assume grad is a matrix by default, so do not need this + if len(state) == 0: + state["m"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) # assume grad is a matrix by default, no need to view + # state["m"] = torch.zeros_like(p, memory_format=torch.preserve_format).view(-1, neuron_numel) + state["step"] = 0 + state["neuron_per_gpu"] = state["m"].size( + 0 + ) # this is neuron per gpu + # NOTE: We must use `new_zeros` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # for standard tensor: state["vmean"] = torch.zeros(1, device=p.device) + # for DTensor: state["vmean"] = p.new_zeros(1) + # the following implementation unifies the above two lines + state["vmean"] = torch.zeros_like( + state["m"][0 : state["neuron_per_gpu"], 0:1], + memory_format=torch.preserve_format, + ) + + grad = p.grad # .to(torch.float32) + neuron_per_gpu = state["neuron_per_gpu"] + # grad = grad.view(neuron_per_gpu, neuron_numel) # assume grad is a matrix by default, so no need to reshape + tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) + + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = ((1 / bias_correction_1) / h).view(neuron_per_gpu, 1) + update = (state["m"] * stepsize).view(p.size()) + update.mul_(lr) + p.add_(-update) + + else: # other blocks. By default, this is for LayerNorms. Sometimes it is also fine to put Value here + if len(state) == 0: + block_numel = ( + torch.tensor(p.numel()).to(torch.float32).to(device) + ) + reduced = False + if self.world_size > 1: + tensor_list = [ + torch.zeros_like(block_numel) + for _ in range(self.world_size) + ] + + dist.all_gather(tensor_list, block_numel) + s = 0 + block_numel = 0 + for d in tensor_list: + if d > 0: + s = s + 1 + block_numel = block_numel + d + if s >= 2: + reduced = True + + state["m"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["step"] = 0 + state["reduced"] = reduced + # NOTE: We must use `new_zeros` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # For standard tensor: state["vmean"] = torch.zeros(1, device=p.device) + # For DTensor: state["vmean"] = p.new_zeros(1) + # the following implementation unifies the above two lines + state["vmean"] = torch.zeros_like( + torch.sum(p * p), memory_format=torch.preserve_format + ) + state["block_numel"] = block_numel.item() + if p.grad is None: + tmp_lr = torch.zeros_like(torch.sum(p * p)) + else: + grad = p.grad # .to(torch.float32) + tmp_lr = torch.sum(grad * grad) + + if state["reduced"]: + # Force communication over GPUs when GPUs are available + if tmp_lr.device.type == "cpu": + # Move the tensor to the current GPU device + tmp_lr_gpu = tmp_lr.to(torch.cuda.current_device()) + + if "device_mesh" in dir(tmp_lr): + # when tmp_lr is a DTensor in TorchTitan + lr_local = tmp_lr.to_local() + dist.all_reduce(lr_local, op=dist.ReduceOp.SUM) + tmp_lr.redistribute(placements=[Replicate()]) + else: + # when tmp_lr is a standard tensor + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + # Move the result back to the CPU tensor + tmp_lr.copy_(tmp_lr_gpu.cpu()) + else: + # Tensor is already on GPU, use NCCL backend + if "device_mesh" in dir(tmp_lr): + # when tmp_lr is a DTensor in TorchTitan + lr_local = tmp_lr.to_local() + dist.all_reduce(lr_local, op=dist.ReduceOp.SUM) + tmp_lr.redistribute(placements=[Replicate()]) + else: + # when tmp_lr is a standard tensor + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + if p.grad is None: + continue + tmp_lr = tmp_lr / state["block_numel"] + + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["step"] += 1 + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = (1 / bias_correction_1) / h + update = state["m"] * (stepsize.to(state["m"].device)) + update.mul_(lr) + p.add_(-update) + + return loss diff --git a/src/optim/ademamix.py b/src/optim/ademamix.py new file mode 100644 index 0000000..035697a --- /dev/null +++ b/src/optim/ademamix.py @@ -0,0 +1,187 @@ +""" +Here is an original implementation of AdEMAMix. +Source: https://github.com/apple/ml-ademamix +""" + +import math + +import torch + + +def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): + if step < warmup: + a = step / float(warmup) + return (1.0 - a) * alpha_start + a * alpha_end + return alpha_end + + +def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): + def f(beta, eps=1e-8): + return math.log(0.5) / math.log(beta + eps) - 1 + + def f_inv(t): + return math.pow(0.5, 1 / (t + 1)) + + if step < warmup: + a = step / float(warmup) + return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end)) + return beta_end + + +class AdEMAMix(torch.optim.Optimizer): + r"""Implements the AdEMAMix algorithm. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) + corresponding to beta_1, beta_2, beta_3 in AdEMAMix + alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) + beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) + alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay as in AdamW (default: 0) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999, 0.9999), + alpha=2.0, + beta3_warmup=None, + alpha_warmup=None, + eps=1e-8, + weight_decay=0, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + alpha=alpha, + beta3_warmup=beta3_warmup, + alpha_warmup=alpha_warmup, + weight_decay=weight_decay, + ) + super(AdEMAMix, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdEMAMix, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + lmbda = group["weight_decay"] + eps = group["eps"] + beta1, beta2, beta3_final = group["betas"] + beta3_warmup = group["beta3_warmup"] + alpha_final = group["alpha"] + alpha_warmup = group["alpha_warmup"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdEMAMix does not support sparse gradients.") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + if beta1 != 0.0: # save memory in case beta1 is 0.0 + state["exp_avg_fast"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + else: + state["exp_avg_fast"] = None + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg_fast, exp_avg_slow, exp_avg_sq = ( + state["exp_avg_fast"], + state["exp_avg_slow"], + state["exp_avg_sq"], + ) + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Compute the effective alpha and beta3 in case warmup is used + if alpha_warmup is not None: + alpha = linear_warmup_scheduler( + state["step"], + alpha_end=alpha_final, + alpha_start=0, + warmup=alpha_warmup, + ) + else: + alpha = alpha_final + + if beta3_warmup is not None: + beta3 = linear_hl_warmup_scheduler( + state["step"], + beta_end=beta3_final, + beta_start=beta1, + warmup=beta3_warmup, + ) + else: + beta3 = beta3_final + + # Decay the first and second moment running average coefficient + if beta1 != 0.0: + exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) + else: + exp_avg_fast = grad + exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + update = ( + exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow + ) / denom + + # decay + update.add_(p, alpha=lmbda) + + p.add_(-lr * update) + + return loss diff --git a/src/optim/ademamix2.py b/src/optim/ademamix2.py new file mode 100644 index 0000000..9cc5084 --- /dev/null +++ b/src/optim/ademamix2.py @@ -0,0 +1,188 @@ +""" +Here is an original implementation of AdEMAMix. +Source: https://github.com/apple/ml-ademamix +""" + +import math + +import torch + + +def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): + if step < warmup: + a = step / float(warmup) + return (1.0 - a) * alpha_start + a * alpha_end + return alpha_end + + +def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): + def f(beta, eps=1e-8): + return math.log(0.5) / math.log(beta + eps) - 1 + + def f_inv(t): + return math.pow(0.5, 1 / (t + 1)) + + if step < warmup: + a = step / float(warmup) + return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end)) + return beta_end + + +class AdEMAMix2(torch.optim.Optimizer): + r"""Implements the AdEMAMix algorithm. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) + corresponding to beta_1, beta_2, beta_3 in AdEMAMix + alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) + beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) + alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay as in AdamW (default: 0) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999, 0.9999), + alpha=2.0, + beta3_warmup=None, + alpha_warmup=None, + eps=1e-8, + weight_decay=0, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + alpha=alpha, + beta3_warmup=beta3_warmup, + alpha_warmup=alpha_warmup, + weight_decay=weight_decay, + ) + super(AdEMAMix2, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdEMAMix2, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + lmbda = group["weight_decay"] + eps = group["eps"] + beta1, beta2, beta3_final = group["betas"] + beta3_warmup = group["beta3_warmup"] + alpha_final = group["alpha"] + alpha_warmup = group["alpha_warmup"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdEMAMix does not support sparse gradients.") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + if beta1 != 0.0: # save memory in case beta1 is 0.0 + state["exp_avg_fast"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + else: + state["exp_avg_fast"] = None + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg_fast, exp_avg_slow, exp_avg_sq = ( + state["exp_avg_fast"], + state["exp_avg_slow"], + state["exp_avg_sq"], + ) + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Compute the effective alpha and beta3 in case warmup is used + if alpha_warmup is not None: + alpha = linear_warmup_scheduler( + state["step"], + alpha_end=alpha_final, + alpha_start=0, + warmup=alpha_warmup, + ) + else: + alpha = alpha_final + + if beta3_warmup is not None: + beta3 = linear_hl_warmup_scheduler( + state["step"], + beta_end=beta3_final, + beta_start=beta1, + warmup=beta3_warmup, + ) + else: + beta3 = beta3_final + + # Decay the first and second moment running average coefficient + if beta1 != 0.0: + exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) + else: + exp_avg_fast = grad + exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + update = ( + exp_avg_fast.div(bias_correction1) / denom + + alpha * exp_avg_slow / exp_avg_slow.norm() + ) + + # decay + update.add_(p, alpha=lmbda) + + p.add_(-lr * update) + + return loss diff --git a/src/optim/adopt.py b/src/optim/adopt.py new file mode 100644 index 0000000..aa6dc16 --- /dev/null +++ b/src/optim/adopt.py @@ -0,0 +1,72 @@ +""" +Here is an original implementation of ADOPT. +Source: https://github.com/iShohei220/adopt +""" + +import torch + + +def exists(val): + return val is not None + + +class ADOPT(torch.optim.Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(ADOPT, self).__init__(params, defaults) + self.eps = eps + + def __setstate__(self, state): + super(ADOPT, self).__setstate__(state) + + @torch.no_grad() + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + grad_scaler=None, + ): + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group["params"]): + if p.grad is None: + continue + grad = p.grad.data + grad.add_(p.data, alpha=group["weight_decay"]) + state = self.state[p] + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = grad.mul(grad) + continue + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if "step" in state: + state["step"] += 1 + else: + state["step"] = 1 + beta1, beta2 = group["betas"] + denom = torch.maximum(exp_avg_sq.sqrt(), torch.tensor(self.eps)) + if state["step"] == 1: + exp_avg = grad.div(denom) + else: + exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + + p.data.add_(state["exp_avg"], alpha=-group["lr"]) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + return loss diff --git a/src/optim/base.py b/src/optim/base.py index 241f508..19875d0 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -1,184 +1,294 @@ +import copy +import time from contextlib import nullcontext -from data.utils import get_dataloader +from pathlib import Path import torch -import torch.nn.functional as F +import yaml + import wandb -import time -import itertools -import copy -import random -import os -import numpy as np -from .utils import eval, get_batch, save_checkpoint - - -def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None): - device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' - type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( - device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) - best_val_loss, text_table = float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible - substep = itr * acc_steps - data["train"], train_sampler = get_dataloader( - data["train"], - sequence_length=sequence_length, - batch_size=batch_size, - seed=data_seed, - distributed_backend=distributed_backend, - ) - - data["val"], val_sampler = get_dataloader( - data["val"], - sequence_length=sequence_length, - batch_size=batch_size, - seed=data_seed, - ) - num_substeps_per_epoch = len(data["train"]) - train_epochs = substep//num_substeps_per_epoch - - if rng_state_dict is not None and rng_state_dict.get("train_sampler_state", None) is not None: - train_sampler.generator.set_state(rng_state_dict["train_sampler_state"]) - if hasattr(train_sampler, "set_epoch"): - train_sampler.set_epoch(train_epochs) +# from logger.logger import DynamicsLogger +from .utils import (eval, get_batch, load_checkpoint, load_worker_state, + save_checkpoint, save_worker_state) + + +def train( + model, + opt, + datareaders, + scheduler, + exp_dir, + distributed_backend, + cfg, +): + not_compiled_model = model + if cfg.compile: + print(f"Compiling model ...") + model = torch.compile(model) + + if "cuda" in cfg.device: + type_ctx = torch.amp.autocast( + device_type="cuda", + dtype={ + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[cfg.dtype], + ) else: - sampler_state_before_iter = train_sampler.generator.get_state() - data_train_iter = iter(data["train"]) + type_ctx = nullcontext() - - # for val data we don't care about epochs? just cycle through (no need to set_epoch to reshuffle) - data_val_iter = itertools.cycle(data["val"]) + if cfg.resume_from: + # This is a full resume including the model weights, optimizer, state + # dataloader state, random seed, etc. Not indended for fine tuning or + # other scenarios where some of these should change. + print(f"\nResuming Training From {cfg.resume_from}") + ckpt_dir = Path(cfg.resume_from) + curr_iter = load_checkpoint( + model, + opt, + scheduler, + ckpt_dir / "main.pt", + cfg.device, + ) + load_worker_state(ckpt_dir) + else: + curr_iter = 0 - stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} + # if distributed_backend.is_master_process() and cfg.log_dynamics: + # with open(cfg.dynamics_logger_cfg, "r") as f: + # dlcfg = yaml.safe_load(f) - - - if extra_args.compile: - print(f"Compiling model ...") - model = torch.compile(model) # requires pytorch 2.0+ + # # Hooks into optimizer + # dlogger = DynamicsLogger( + # model, opt, dlcfg, cfg.results_base_folder, wandb=cfg.wandb + # ) + # dlogger.iteration = curr_iter + substep = curr_iter * cfg.acc_steps + train_reader, val_reader = datareaders["train"], datareaders["val"] + train_reader.set_step(substep) + stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} + grad_norms = [] model.train() - t0 = time.time() - - if rng_state_dict is not None: - torch.set_rng_state(rng_state_dict["cpu_rng_state"]) - torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) - np.random.set_state(rng_state_dict["numpy_rng_state"]) - random.setstate(rng_state_dict["py_rng_state"]) - for _ in range(substep % num_substeps_per_epoch): - get_batch(data_train_iter, device=extra_args.device) - - - while itr < iterations: - - for microstep_idx in range(acc_steps): # gradient accumulation - x, y = get_batch(data_train_iter, device=extra_args.device) - + while curr_iter <= cfg.iterations: + # Save permanent checkpoint + if cfg.permanent_ckpt_interval > 0: + if curr_iter % cfg.permanent_ckpt_interval == 0: + ckpt_dir = exp_dir / "ckpts" / str(curr_iter) + if distributed_backend.is_master_process(): + save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) + save_worker_state(ckpt_dir) + + # Save temporary checkpoint for resuming training + if cfg.latest_ckpt_interval > 0: + if curr_iter % cfg.latest_ckpt_interval == 0 or curr_iter == cfg.iterations: + ckpt_dir = exp_dir / "ckpts" / "latest" + if distributed_backend.is_master_process(): + save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) + save_worker_state(ckpt_dir) + + ws = distributed_backend.get_world_size() + tokens = ws * substep * cfg.sequence_length * cfg.batch_size + epoch = tokens / train_reader.num_tokens + if ( + curr_iter % cfg.eval_interval == 0 + or curr_iter == cfg.iterations + or (curr_iter in cfg.full_eval_at) + ): + eval_and_log( + tokens, + curr_iter, + epoch, + model, + val_reader, + type_ctx, + distributed_backend, + cfg, + opt, + full_eval=(curr_iter in cfg.full_eval_at), + ) + + if curr_iter == cfg.iterations: + # Save checkpoints and evaluate at final iteration, but no need to train further + break + + # Train model + t_start = time.perf_counter_ns() + for microstep_idx in range(cfg.acc_steps): # gradient accumulation + x, y = get_batch(train_reader, device=cfg.device) with type_ctx: - with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps): + with distributed_backend.get_context_for_microstep_forward( + model=model, + microstep_idx=microstep_idx, + gradient_accumulation_steps=cfg.acc_steps, + ): outputs = model(x, targets=y) - loss = outputs['loss'] / acc_steps + loss = outputs["loss"] / cfg.acc_steps loss.backward() substep += 1 - if substep % len(data["train"]) == 0: - train_epochs += 1 - print(f"Train epoch {train_epochs} done (full pass over training data)") - if hasattr(train_sampler, "set_epoch"): - # set epoch for reshuffling between epochs - train_sampler.set_epoch(train_epochs) - sampler_state_before_iter = None - else: - sampler_state_before_iter = train_sampler.generator.get_state() - data_train_iter = iter(data["train"]) - - - if extra_args.grad_clip != 0.0: - torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) - opt.step() - scheduler.step() - opt.zero_grad(set_to_none=True) - itr += 1 - - if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above - if distributed_backend.is_master_process(): - t1 = time.time() - dt = t1 - t0 - epoch = substep//num_substeps_per_epoch - - model.eval() - train_loss = loss.detach().cpu().item() * acc_steps - current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr - - eval_steps = ( - 24 if itr < iterations else len(data["val"]) + + if cfg.grad_clip != 0.0: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + grad_norm = torch.nn.utils.clip_grad_norm_( + model.module.parameters(), cfg.grad_clip + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), cfg.grad_clip ) - val_acc, val_loss, val_perplexity = eval( - model, - data_val_iter, - extra_args.device, - max_num_batches=eval_steps, - ctx=type_ctx, + grad_norms.append(grad_norm) + + if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": + opt.train() + ( + opt.step() + if cfg.opt != "sophiag" + else opt.step(bs=cfg.sophia_bs * cfg.sequence_length) + ) + if cfg.scheduler != "none": + scheduler.step() + if cfg.opt == "sophiag": + opt.zero_grad(set_to_none=True) + if curr_iter % cfg.precondition_frequency == cfg.precondition_frequency - 1: + sample_again = model(x, targets=y, get_logits=True) + samp_dist = torch.distributions.Categorical( + logits=sample_again["logits"] + ) + y_sample = samp_dist.sample() + loss_sampled = torch.nn.functional.cross_entropy( + sample_again["logits"].view(-1, sample_again["logits"].size(-1)), + y_sample.view(-1), + ignore_index=-1, ) + (loss_sampled / cfg.acc_steps).backward() + opt.update_hessian() + opt.zero_grad(set_to_none=True) + model.zero_grad() + elif cfg.opt == "mars": + opt.zero_grad(set_to_none=True) + opt.update_last_grad() + else: + opt.zero_grad(set_to_none=True) + # opt.zero_grad(set_to_none=True) + dt = (time.perf_counter_ns() - t_start) / 1e9 + + curr_iter += 1 + + if ( + cfg.log_interval + and curr_iter % cfg.log_interval == 0 + and distributed_backend.is_master_process() # Only log on master rank + ): + train_loss = loss.detach().cpu().item() * cfg.acc_steps + + current_lrs = [param_group["lr"] for param_group in opt.param_groups] - print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}" - print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms" - if scheduler is not None: - print_string += f" [lr] {current_lr:.5f}" - print(print_string) + print( + f"Train: Iter={curr_iter} ({epoch:0.3f} epochs) " + f"train_loss={train_loss:.3f} iter_dt={dt:.2e}s " + f"lr={current_lrs[0]:.2e}" + ) - if extra_args.wandb: - logs = { - "iter": itr, + if cfg.wandb: + wandb.log( + { + "tokens": tokens, + "iter": curr_iter, "train/loss": train_loss, - "val/loss": val_loss, - "val/perplexity": val_perplexity, - "val/acc": val_acc, - "lr": current_lr, + "train/perplexity": 2.71828**train_loss, + "lr": current_lrs[0], + "iter_dt": dt, + "max_grad_norm": max(grad_norms).item() if grad_norms else 0, + "mean_grad_norm": ( + torch.tensor(grad_norms).mean().item() if grad_norms else 0 + ), } + ) - if itr == iterations: - logs["val/final-ppl"] = val_perplexity - logs["val/final-acc"] = val_acc - logs["val/final-loss"] = val_loss - - wandb.log(logs) - - if extra_args.eval_seq_prefix != 'none' and (itr % (eval_freq * 5) == 0 or itr == iterations): - if text_table is None: - text_table = wandb.Table(columns=["itr", "val-pp", "text"]) - - out_str = distributed_backend.get_raw_model(model).generate_from_string( - extra_args.eval_seq_prefix, max_new_tokens=40, temperature=0.9, top_k=None) - text_table.add_data(itr, val_perplexity, out_str) - # why a copy? see github.com/wandb/wandb/issues/2981 - wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) - - model.train() - t0 = time.time() - if distributed_backend.is_master_process(): - if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: - print(f"saving checkpoint to {os.path.dirname(ckpt_path)}/ckpt_{itr}.pt") - save_checkpoint(distributed_backend=distributed_backend, - model=model, - opt=opt, - scheduler=scheduler, - itr=itr, - cpu_rng_state=torch.get_rng_state(), - gpu_rng_state=torch.cuda.get_rng_state(), - numpy_rng_state=np.random.get_state(), - py_rng_state=random.getstate(), - train_sampler_state=sampler_state_before_iter, - ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt")) - - if distributed_backend.is_master_process(): - print(f"saving checkpoint to {ckpt_path}") - save_checkpoint(distributed_backend=distributed_backend, - model=model, - opt=opt, - scheduler=scheduler, - itr=itr, - ckpt_path=ckpt_path) + grad_norms = [] return stats + + +def eval_and_log( + tokens, + curr_iter, + epoch, + model, + val_reader, + type_ctx, + distributed_backend, + cfg, + opt, + full_eval=False, +): + if not distributed_backend.is_master_process(): + # Only evaluate and log on master rank + return + + model.eval() + if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": + opt.eval() + + if curr_iter == cfg.iterations or full_eval: + max_num_batches = val_reader.num_batches() + else: + max_num_batches = cfg.eval_batches + + # to make sure we start from the beginning of the validation set, + # i.e. repeat the same batches + val_reader.set_step(0) + val_acc, val_loss, val_perplexity = eval( + model, + val_reader, + cfg.device, + max_num_batches=max_num_batches, + ctx=type_ctx, + cfg=cfg, + ) + + print( + f">Eval: Iter={curr_iter} ({epoch:0.3f} epochs) " + f"val_loss={val_loss:.3f} " + f"val_pp={val_perplexity:.3f} " + f"val_acc={val_acc:3f}" + ) + + if cfg.wandb: + if curr_iter == cfg.iterations or full_eval: + logs = { + "tokens": tokens, + "iter": curr_iter, + "final-val/loss": val_loss, + "final-val/perplexity": val_perplexity, + "final-val/acc": val_acc, + } + else: + logs = { + "tokens": tokens, + "iter": curr_iter, + "val/loss": val_loss, + "val/perplexity": val_perplexity, + "val/acc": val_acc, + } + + wandb.log(logs) + if cfg.eval_seq_prefix != "none" and ( + curr_iter % (cfg.eval_interval * 5) == 0 or curr_iter == cfg.iterations + ): + text_table = wandb.Table(columns=["itr", "val-pp", "text"]) + + out_str = distributed_backend.get_raw_model(model).generate_from_string( + cfg.eval_seq_prefix, + max_new_tokens=40, + temperature=0.9, + top_k=None, + ) + text_table.add_data(curr_iter, val_perplexity, out_str) + # why a copy? see github.com/wandb/wandb/issues/2981 + wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) + model.train() diff --git a/src/optim/clipped.py b/src/optim/clipped.py new file mode 100644 index 0000000..0f0fc6f --- /dev/null +++ b/src/optim/clipped.py @@ -0,0 +1,414 @@ +""" +Here is an original implementation of Clip-Adam and Clip-Adagrad. +Source: https://github.com/yaroslavkliukin/Clipped-AdaGrad-and-Adam +""" + +import math + +import torch + + +class AdamClip(torch.optim.Optimizer): + """ + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. + clipping (str): "no", "local", "elementwise". Default: "no" + max_grad_norm (float): value to which we clip the gradient. Default: 1.0 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.0, + correct_bias=True, + clipping="no", + max_grad_norm=1.0, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]) + ) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + clipping=clipping, + max_grad_norm=max_grad_norm, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + grad = p.grad.data + + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + p.data.addcdiv_(tensor1=exp_avg, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss + + +class AdamClipDelayedEta(torch.optim.Optimizer): + """ + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default False. + clipping (str): "no", "local", "elementwise". Default: "no" + max_grad_norm (float): value to which we clip the gradient. Default: 1.0 + exp_avg_sq_value (float): value used to initialise the second moment in the first gradient update step. Default: 0.00001 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.0, + correct_bias=False, + clipping="no", + max_grad_norm=1.0, + exp_avg_sq_value=0.00001, + etta=1.0, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]) + ) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + clipping=clipping, + max_grad_norm=max_grad_norm, + exp_avg_sq_value=exp_avg_sq_value, + etta=etta, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + if p.grad.data.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = ( + torch.ones_like(p.data) * group["exp_avg_sq_value"] + ) + # Gradient from previous step + state["prev_grad"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq, prev_grad = ( + state["exp_avg"], + state["exp_avg_sq"], + state["prev_grad"], + ) + beta1, beta2 = group["betas"] + + state["step"] += 1 + + exp_avg.mul_(beta1).add_(p.grad.data, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + prev_grad, prev_grad, value=(1.0 - beta2) * group["etta"] + ) + state["prev_grad"] = p.grad.data + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + p.data.addcdiv_(tensor1=exp_avg, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss + + +class AdagradClip(torch.optim.Optimizer): + """Implements Adagrad algorithm with clipping. + Parameters: + lr (float): learning rate. Default 1e-3. + eps (float): Adagrad epsilon. Default: 1e-10 + weight_decay (float): Weight decay. Default: 0.0 + clipping (str): "no", "global", "local", "elementwise". Default: "local" + max_grad_norm (bool): value to which we clip the gradient. Default: 1.0 + """ + + def __init__( + self, + params, + lr=1e-3, + eps=1e-10, + weight_decay=0.0, + clipping="local", + max_grad_norm=1.0, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + eps=eps, + weight_decay=weight_decay, + clipping=clipping, + max_grad_norm=max_grad_norm, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + grad = p.grad.data + + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + + exp_avg_sq.addcmul_(grad, grad, value=1.0) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + p.data.addcdiv_(tensor1=grad, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss + + +class AdaGradClipDelayedEta(torch.optim.Optimizer): + """Implements Adagrad algorithm with clipping, delay and reweighing. + Parameters: + lr (float): learning rate. Default 1e-3. + eps (float): Adagrad epsilon. Default: 1e-10 + weight_decay (float): Weight decay. Default: 0.0 + clipping (str): "no", "global", "local", "elementwise". Default: "local" + max_grad_norm (bool): value to which we clip the gradient. Default: 1.0 + etta (float): reweighing parameter. Default: 1.0 + exp_avg_sq_value (float): initial value to imitate first gradient + """ + + def __init__( + self, + params, + lr=1e-3, + eps=1e-10, + weight_decay=0.0, + clipping="local", + max_grad_norm=1.0, + etta=1.0, + exp_avg_sq_value=0.0001, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + eps=eps, + weight_decay=weight_decay, + clipping=clipping, + max_grad_norm=max_grad_norm, + etta=etta, + exp_avg_sq_value=exp_avg_sq_value, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + if group["clipping"] == "global": + torch.nn.utils.clip_grad_norm_(group["params"], group["max_grad_norm"]) + + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + if p.grad.data.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg_sq"] = ( + torch.ones_like(p.data) * group["exp_avg_sq_value"] + ) + state["prev_grad"] = torch.zeros_like(p.data) + + exp_avg_sq, prev_grad = state["exp_avg_sq"], state["prev_grad"] + state["step"] += 1 + + exp_avg_sq.addcmul_(prev_grad, prev_grad, value=group["etta"]) + state["prev_grad"] = p.grad.data + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + p.data.addcdiv_(tensor1=prev_grad, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss diff --git a/src/optim/lamb.py b/src/optim/lamb.py new file mode 100644 index 0000000..baa63bf --- /dev/null +++ b/src/optim/lamb.py @@ -0,0 +1,126 @@ +""" +Here is an official implementation of LAMB. +Source: https://github.com/cybertronai/pytorch-lamb +""" + +import math + +import torch + + +class Lamb(torch.optim.Optimizer): + r"""Implements Lamb algorithm. + + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0, + adam=False, + bias_correction=False, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.adam = adam + self.bias_correction = bias_correction + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + if self.bias_correction: + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) + else: + step_size = group["lr"] + + weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state["weight_norm"] = weight_norm + state["adam_norm"] = adam_norm + state["trust_ratio"] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss diff --git a/src/optim/lion.py b/src/optim/lion.py new file mode 100644 index 0000000..2c0c59a --- /dev/null +++ b/src/optim/lion.py @@ -0,0 +1,75 @@ +""" +Here is an original implementation of Lion. +Source: https://github.com/google/automl/tree/master/lion +""" + +import torch + + +class Lion(torch.optim.Optimizer): + r"""Implements Lion algorithm.""" + + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): + """Initialize the hyperparameters. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-4) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float, optional): weight decay coefficient (default: 0) + """ + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + + exp_avg = state["exp_avg"] + beta1, beta2 = group["betas"] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + + p.add_(update.sign_(), alpha=-group["lr"]) + + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss diff --git a/src/optim/mars.py b/src/optim/mars.py new file mode 100644 index 0000000..ad34576 --- /dev/null +++ b/src/optim/mars.py @@ -0,0 +1,293 @@ +""" +Here is an original implementation of MARS. +Source: https://github.com/AGI-Arena/MARS +""" + +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +import math + +import torch + +from .muon import zeropower_via_newtonschulz5 + + +def exists(val): + return val is not None + + +def update_fn( + p, + grad, + exp_avg, + exp_avg_sq, + lr, + wd, + beta1, + beta2, + last_grad, + eps, + amsgrad, + max_exp_avg_sq, + step, + gamma, + mars_type, + is_grad_2d, + optimize_1d, + lr_1d_factor, + betas_1d, + weight_decay_1d, +): + # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para + if optimize_1d or is_grad_2d: + c_t = (grad - last_grad).mul(gamma * (beta1 / (1.0 - beta1))).add(grad) + c_t_norm = torch.norm(c_t) + if c_t_norm > 1.0: + c_t = c_t / c_t_norm + exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1) + if (mars_type == "mars-adamw") or ( + mars_type == "mars-shampoo" and not is_grad_2d + ): + exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2) + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + if amsgrad: + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + denom = ( + max_exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + else: + denom = ( + exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom)) + elif mars_type == "mars-lion": + real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign()) + elif mars_type == "mars-shampoo" and is_grad_2d: + factor = max(1, grad.size(0) / grad.size(1)) ** 0.5 + real_update_tmp = ( + zeropower_via_newtonschulz5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps) + .mul(factor) + .add(wd, p.data) + .mul(-lr) + ) + p.data.add_(real_update_tmp) + else: + beta1_1d, beta2_1d = betas_1d + exp_avg.mul_(beta1_1d).add_(grad, alpha=1 - beta1_1d) + exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1 - beta2_1d) + bias_correction1 = 1.0 - beta1_1d**step + bias_correction2 = 1.0 - beta2_1d**step + if amsgrad: + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + denom = ( + max_exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + else: + denom = ( + exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + real_update_tmp = ( + -lr + * lr_1d_factor + * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom)) + ) + p.data.add_(real_update_tmp) + return exp_avg, exp_avg_sq + + +class MARS(torch.optim.Optimizer): + def __init__( + self, + params, + lr=3e-3, + betas=(0.95, 0.99), + eps=1e-8, + weight_decay=0.0, + amsgrad=False, + gamma=0.025, + is_approx=True, + mars_type="mars-adamw", + optimize_1d=False, + lr_1d=3e-3, + betas_1d=(0.9, 0.95), + weight_decay_1d=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + assert mars_type in [ + "mars-adamw", + "mars-lion", + "mars-shampoo", + ], "MARS type not supported" + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + mars_type=mars_type, + gamma=gamma, + optimize_1d=optimize_1d, + weight_decay_1d=weight_decay_1d, + ) + super(MARS, self).__init__(params, defaults) + self.eps = eps + self.update_fn = update_fn + self.lr = lr + self.weight_decay = weight_decay + self.amsgrad = amsgrad + self.step_num = 0 + self.is_approx = is_approx + self.gamma = gamma + self.mars_type = mars_type + self.optimize_1d = optimize_1d + self.lr_1d_factor = lr_1d / lr + self.weight_decay_1d = weight_decay_1d + self.betas_1d = betas_1d + + @torch.no_grad() + def update_last_grad(self): + if not self.is_approx: + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + if "last_grad" not in state: + state["last_grad"] = torch.zeros_like(p) + state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) + + @torch.no_grad() + def update_previous_grad(self): + if not self.is_approx: + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + print(p, "grad is none") + continue + state = self.state[p] + if "previous_grad" not in state: + state["previous_grad"] = torch.zeros_like(p) + state["previous_grad"].zero_().add_(p.grad, alpha=1.0) + + def __setstate__(self, state): + super(MARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @torch.no_grad() + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + grad_scaler=None, + ): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + "FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." + ) + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + gamma = self.gamma + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group["params"]): + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group["amsgrad"] + + state = self.state[p] + # ('----- starting a parameter state', state.keys(), 'Length of state', len(state)) + # State initialization + if len(state) <= 1: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Last Gradient + state["last_grad"] = torch.zeros_like(p) + # state['previous_grad'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p.data) + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + last_grad = state["last_grad"] + lr, wd, beta1, beta2 = ( + group["lr"], + group["weight_decay"], + *group["betas"], + ) + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + else: + max_exp_avg_sq = 0 + + if "step" in state: + state["step"] += 1 + else: + state["step"] = 1 + step = state["step"] + is_grad_2d = len(grad.shape) == 2 + exp_avg, exp_avg_sq = self.update_fn( + p, + grad, + exp_avg, + exp_avg_sq, + lr, + wd, + beta1, + beta2, + last_grad, + self.eps, + amsgrad, + max_exp_avg_sq, + step, + gamma, + mars_type=self.mars_type, + is_grad_2d=is_grad_2d, + optimize_1d=self.optimize_1d, + lr_1d_factor=self.lr_1d_factor, + betas_1d=self.betas_1d, + weight_decay_1d=( + self.weight_decay if self.optimize_1d else self.weight_decay_1d + ), + ) + if self.is_approx: + state["last_grad"] = grad + self.step_num = step + + return loss diff --git a/src/optim/muon.py b/src/optim/muon.py new file mode 100644 index 0000000..f0d6e49 --- /dev/null +++ b/src/optim/muon.py @@ -0,0 +1,345 @@ +""" +Here is an original implementation of Muon. +Source: https://github.com/KellerJordan/modded-nanogpt +""" + +import os + +import torch +import torch.distributed as dist + +from .schedule import cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(0) > G.size(1): + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + muon_params, + lr=0.02, + momentum=0.95, + nesterov=True, + ns_steps=6, + adamw_params=None, + adamw_lr=3e-4, + adamw_betas=(0.95, 0.95), + adamw_eps=1e-8, + adamw_wd=0, + ): + + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_lr=adamw_lr, + adamw_lr_ratio=adamw_lr / lr, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + adamw_wd=adamw_wd, + ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) + super().__init__(params, defaults) + + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + if p.ndim >= 2 and p.size(0) < 10000: + self.state[p]["use_muon"] = True + # self.state[p]["use_muon"] = True + else: + self.state[p]["use_muon"] = False + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + if "WORLD_SIZE" in os.environ: + self.world_size = int(os.environ["WORLD_SIZE"]) + self.rank = int(os.environ["RANK"]) + else: + self.world_size = 1 + self.rank = 0 + + def step(self): + + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + momentum = group["momentum"] + + # generate weight updates in distributed fashion + total_params = sum(p.numel() for p in params) + updates_flat = torch.zeros( + total_params, device="cuda", dtype=torch.bfloat16 + ) + curr_idx = 0 + for i, p in enumerate(params): + # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs + if i % self.world_size == self.rank: + g = p.grad + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() + curr_idx += p.numel() + + # sync updates across devices. we are not memory-constrained so can do this simple deserialization + if self.world_size > 1: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + # deserialize and apply updates + curr_idx = 0 + for p in params: + g = ( + updates_flat[curr_idx : curr_idx + p.numel()] + .view_as(p.data) + .type_as(p.data) + ) + p.data.add_(g, alpha=-lr) + curr_idx += p.numel() + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = ( + group["adamw_lr_ratio"] * group["lr"] + ) # in order for lr schedule to work + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["adamw_wd"] + + for p in params: + g = p.grad + assert g is not None + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + +def separate_params(param_groups): + param_groups_2d = [] + param_groups_non2d = [] + total_param_2d_count = 0 + total_param_non2d_count = 0 + + # Check if param_groups is a list of dicts or list of params + if ( + isinstance(param_groups, list) and isinstance(param_groups[0], dict) + ) or isinstance(param_groups, dict): + if isinstance(param_groups, dict): + param_groups = [param_groups] + # param_groups is a list of dicts + for group in param_groups: + ( + params_2d, + params_non2d, + param_2d_count, + param_non2d_count, + ) = separate_params(group["params"]) + param_group_2d = {"params": params_2d} + param_group_non2d = {"params": params_non2d} + # Copy the group dict and replace the 'params' key with the separated params + for k in group.keys(): + if k != "params": + param_group_2d[k] = group[k] + param_group_non2d[k] = group[k] + + param_groups_2d.append(param_group_2d) + param_groups_non2d.append(param_group_non2d) + total_param_2d_count += param_2d_count + total_param_non2d_count += param_non2d_count + + return ( + param_groups_2d, + param_groups_non2d, + total_param_2d_count, + total_param_non2d_count, + ) + + elif isinstance(param_groups, list) and isinstance(param_groups[0], torch.Tensor): + params_2d = [] + params_non2d = [] + param_group = param_groups + # param_group is a list of param tensors + for param in param_group: + if param.ndim == 2: + params_2d.append(param) + else: + params_non2d.append(param) + return params_2d, params_non2d, len(params_2d), len(params_non2d) + else: + breakpoint() + + +class CombinedScheduler: + """ + CombinedScheduler implements a scheduler for the Muon optimizer: it leverages both Muon and AdamW learning rates, and applies the same sort of scheduler for both of them. + + Arguments: + optimizer: Muon optimizer. + cfg: arguments used for schedulers. + muon_lr_key: defaults["lr"] is responsible for the Muon learning rate. + adamw_lr_key: defaults["adamw_r"] is responsible for the AdamW learning rate. + """ + + def __init__(self, optimizer, cfg, muon_lr_key="lr", adamw_lr_key="adamw_lr"): + self.schedulers = [] + scheduler_map = { + "cos": torch.optim.lr_scheduler.OneCycleLR, + "linear": torch.optim.lr_scheduler.OneCycleLR, + "cos_inf": lambda opt, lr: torch.optim.lr_scheduler.LambdaLR( + opt, + cos_inf_schedule( + n_iterations=cfg.iterations, + n_warmup=cfg.warmup_steps, + n_inf=cfg.cos_inf_steps, + div_factor=1e2, + final_div_factor=0.1, + ), + ), + "wsd": lambda opt, lr: torch.optim.lr_scheduler.LambdaLR( + opt, + wsd_schedule( + n_iterations=cfg.iterations, + n_warmup=cfg.warmup_steps, + fract_decay=cfg.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=cfg.wsd_final_lr_scale, + decay_type=cfg.decay_type, + ), + ), + "cos_wsd": lambda opt, lr: torch.optim.lr_scheduler.LambdaLR( + opt, + cosine_wsd_decay_schedule( + n_iterations=cfg.iterations, + n_warmup=cfg.warmup_steps, + anneal_end_factor=0.15, + fract_decay=cfg.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=0.1, + decay_type=cfg.decay_type, + ), + ), + } + + for group in optimizer.param_groups: + lr_key = muon_lr_key if muon_lr_key in group else adamw_lr_key + if lr_key in group: + scheduler_cls = scheduler_map.get(cfg.scheduler, None) + if scheduler_cls: + if cfg.scheduler in ["cos", "linear"]: + scheduler = scheduler_cls( + optimizer, + max_lr=[group.get(lr_key, getattr(cfg, lr_key.lower()))], + total_steps=cfg.iterations, + pct_start=cfg.warmup_steps / cfg.iterations, + anneal_strategy=cfg.scheduler, + cycle_momentum=False, + div_factor=1e2, + final_div_factor=1, + ) + else: + scheduler = scheduler_cls( + optimizer, group.get(lr_key, getattr(cfg, lr_key.lower())) + ) + self.schedulers.append(scheduler) + + def step(self): + for scheduler in self.schedulers: + scheduler.step() + + def state_dict(self): + state_dict = {} + for i, scheduler in enumerate(self.schedulers): + state_dict[f"scheduler_{i}"] = scheduler.state_dict() + return state_dict + + def load_state_dict(self, state_dict): + for i, scheduler in enumerate(self.schedulers): + scheduler.load_state_dict(state_dict[f"scheduler_{i}"]) diff --git a/src/optim/normalized.py b/src/optim/normalized.py new file mode 100644 index 0000000..7a93c80 --- /dev/null +++ b/src/optim/normalized.py @@ -0,0 +1,116 @@ +""" +Version of different optimizers with normalized update +""" + +from typing import Dict + +import torch + + +class NormalizedSGD(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-3, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + sign_update=False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + sign_update=sign_update, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def _init_state(self, example, state=None): + assert isinstance(example, torch.Tensor) + assert isinstance(state, Dict) or state is None + if state is None: + state = {} + state["step"] = 0 + state["momentum_buffer"] = torch.clone(example).detach() + return state + + @torch.no_grad() + def _compute_update( + self, grad, state, lr, momentum, nesterov, dampening, sign_update, **kwargs + ): + if momentum != 0: + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + if sign_update: + grad = grad.sign() + + return grad / (grad.norm() + 1e-8) * (-lr) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if group["weight_decay"] != 0: + p.mul_(1 - group["lr"] * group["weight_decay"]) + + if len(state) == 0: + self._init_state(example=p, state=state) + if not group["momentum"]: + state.pop("momentum_buffer", None) + + state["step"] += 1 + + update = self._compute_update( + grad, + state, + group["lr"], + group["momentum"], + group["nesterov"], + group["dampening"], + group["sign_update"], + ) + + p.add_(update) + + return loss diff --git a/src/optim/prodigy.py b/src/optim/prodigy.py new file mode 100644 index 0000000..aa7f5b9 --- /dev/null +++ b/src/optim/prodigy.py @@ -0,0 +1,274 @@ +""" +Here is an original implementation of Prodigy. +Source: https://github.com/konstmish/prodigy +""" + +import math + +import torch +import torch.distributed as dist + + +class Prodigy(torch.optim.Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__( + self, + params, + lr=1.0, + betas=(0.9, 0.999), + beta3=None, + eps=1e-8, + weight_decay=0, + decouple=True, + use_bias_correction=False, + safeguard_warmup=False, + d0=1e-6, + d_coef=1.0, + growth_rate=float("inf"), + fsdp_in_use=False, + ): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict( + lr=lr, + betas=betas, + beta3=beta3, + eps=eps, + weight_decay=weight_decay, + d=d0, + d0=d0, + d_max=d0, + d_numerator=0.0, + d_coef=d_coef, + k=0, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use, + ) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group["use_bias_correction"] + beta1, beta2 = group["betas"] + beta3 = group["beta3"] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group["k"] + + d = group["d"] + d_max = group["d_max"] + d_coef = group["d_coef"] + lr = max(group["lr"] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group["growth_rate"] + decouple = group["decouple"] + fsdp_in_use = group["fsdp_in_use"] + + d_numerator = group["d_numerator"] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + group_lr = group["lr"] + d0 = group["d0"] + safeguard_warmup = group["safeguard_warmup"] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0" + ) + + for p in group["params"]: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if "step" not in state: + state["step"] = 0 + state["s"] = torch.zeros_like(p.data).detach() + state["p0"] = p.detach().clone() + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + s = state["s"] + p0 = state["p0"] + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += ( + (d / d0) + * dlr + * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + ) + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1 - beta2) + ) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + ###### + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group["d0"]: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group["d_numerator"] = global_d_numerator + group["d_denom"] = global_d_denom + group["d"] = d + group["d_max"] = d_max + group["d_hat"] = d_hat + + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-dlr) + + group["k"] = k + 1 + + return loss diff --git a/src/optim/schedule.py b/src/optim/schedule.py new file mode 100644 index 0000000..a6a6730 --- /dev/null +++ b/src/optim/schedule.py @@ -0,0 +1,342 @@ +import math + +import numpy as np + + +def cos_inf_schedule(n_iterations, n_warmup, div_factor, final_div_factor, n_inf): + """Cosine annealing with warmup and _constant_ final_lr after cycle ended. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + div_factor: initial division factor for warmup + final_div_factor: final division factor for final lr + n_inf: number of iterations for the final lr (constant lr after cycle ended) + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + max_lr = 1.0 + base_lr = max_lr / div_factor + final_lr = base_lr / final_div_factor + + n_anneal_steps = n_iterations - n_inf + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / div_factor + elif step < n_anneal_steps: + t = (step - n_warmup) / (n_anneal_steps - n_warmup) + lr = final_lr + 0.5 * (max_lr - final_lr) * (1 + np.cos(np.pi * t)) + return lr + else: + return final_lr + + return schedule + + +def wsd_schedule( + n_iterations, + final_lr_factor=0.0, + n_warmup=1000, + init_div_factor=100, + fract_decay=0.1, + decay_type="linear", +): + """Warmup, hold, and decay schedule. + Args: + n_iterations: total number of iterations + final_lr_factor: factor by which to reduce max_lr at the end + warmup_fract: fraction of iterations used for warmup + init_div_factor: initial division factor for warmup + fract_decay: fraction of iterations used for decay + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + n_anneal_steps = int(fract_decay * n_iterations) + n_hold = n_iterations - n_anneal_steps + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor + elif step < n_hold: + return 1.0 + elif step < n_iterations: + if decay_type == "linear": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - (step - n_hold) / n_anneal_steps + ) + elif decay_type == "exp": + return final_lr_factor ** ((step - n_hold) / n_anneal_steps) + elif decay_type == "cosine": + return ( + final_lr_factor + + (1 - final_lr_factor) + * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) + * 0.5 + ) + elif decay_type == "miror_cosine": + cosine_value = ( + final_lr_factor + + (1 - final_lr_factor) + * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) + * 0.5 + ) + linear_value = final_lr_factor + (1 - final_lr_factor) * ( + 1 - (step - n_hold) / n_anneal_steps + ) + return linear_value * 2 - cosine_value + elif decay_type == "square": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - ((step - n_hold) / n_anneal_steps) ** 2 + ) + + elif decay_type == "sqrt": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - math.sqrt((step - n_hold) / n_anneal_steps) + ) + + else: + raise ValueError( + f"decay type {decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + + else: + return final_lr_factor + + return schedule + + +def cosine_wsd_decay_schedule( + n_iterations, + n_warmup=1000, + anneal_end_factor=0.15, + final_lr_factor=0.0, + init_div_factor=1e-2, + fract_decay=0.1, + decay_type="linear", +): + """Warmup, cosine, and wsd-like decay schedule. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + anneal_end_factor: factor at which cosine annealing ends + final_lr_factor: factor by which to reduce max_lr at the end + init_div_factor: initial division factor for warmup + fract_decay: fraction of iterations used for decay + decay_type: type of decay after cosine phase + ['linear', 'exp', 'cosine', 'mirror_cosine', 'square', 'sqrt'] + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + valid_decay_types = ["linear", "exp", "cosine", "mirror_cosine", "square", "sqrt"] + if decay_type not in valid_decay_types: + raise ValueError(f"decay_type {decay_type} is not in {valid_decay_types}") + + max_lr = 1.0 + base_lr = max_lr / init_div_factor + # final_lr = base_lr / final_lr_factor + n_decay_steps = int(fract_decay * n_iterations) + n_hold = n_iterations - n_decay_steps + cosine_start = n_warmup + cosine_end = n_warmup + n_hold + + def schedule(step): + if step < n_warmup: + # Warmup phase + return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor + + elif step < cosine_end: + # Cosine regime + t = (step - cosine_start) / (cosine_end - cosine_start) + return anneal_end_factor + (max_lr - anneal_end_factor) * 0.5 * ( + 1 + math.cos(math.pi * t) + ) + + elif step < n_iterations: + # Decay regime + progress = (step - cosine_end) / n_decay_steps + + if decay_type == "linear": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + 1 - progress + ) + + elif decay_type == "exp": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + final_lr_factor ** (progress) + ) + + elif decay_type == "cosine": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + (1 + math.cos(math.pi * progress)) * 0.5 + ) + + elif decay_type == "mirror_cosine": + cosine_value = final_lr_factor + ( + anneal_end_factor - final_lr_factor + ) * ((1 + math.cos(math.pi * progress)) * 0.5) + linear_value = final_lr_factor + ( + anneal_end_factor - final_lr_factor + ) * (1 - progress) + return linear_value * 2 - cosine_value + + elif decay_type == "square": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + 1 - progress**2 + ) + + elif decay_type == "sqrt": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + 1 - math.sqrt(progress) + ) + + else: + return final_lr_factor + + return schedule + + +def dd_schedule( + n_iterations, + n_warmup, + fract_fisrt_decay, + max_lr, + first_final_lr_factor=1e-2, + second_final_lr_factor=0.0, + div_factor=1e2, + first_decay_type="cosine", + second_decay_type="linear", +): + """Warmup, cosine annealing, and linear decay schedule. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + fract_fisrt_decay: fraction of iterations for the first decay phase + max_lr: the mamimum value of learning rate during the training + first_final_lr_factor: factor by which to reduce max_lr at the end of the first decay phase + second_final_lr_factor: factor by which to reduce first_final_lr_factor at the end of the second decay phase + div_factor: initial division factor for warmup + first_decay_type: which decay approach to use during the fisrt decay phase + second_decay_type: which decay approach to use during the second decay phase + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + if fract_fisrt_decay > 1.0: + raise ValueError( + "Invalid fract_fisrt_decay value: {}".format(fract_fisrt_decay) + ) + n_fisrt_decay = int(fract_fisrt_decay * n_iterations) + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / div_factor + elif step < n_warmup + n_fisrt_decay: + if first_decay_type == "cosine": + return first_final_lr_factor + 0.5 * ( + max_lr - first_final_lr_factor + ) * (1 + math.cos(math.pi * (step - n_warmup) / n_fisrt_decay)) + elif first_decay_type == "linear": + return first_final_lr_factor + (max_lr - first_final_lr_factor) * ( + 1 - (step - n_warmup) / n_fisrt_decay + ) + elif first_decay_type == "exp": + return first_final_lr_factor ** ((step - n_warmup) / n_fisrt_decay) + elif first_decay_type == "mirror_cosine": + cosine_value = ( + first_final_lr_factor + + (max_lr - first_final_lr_factor) + * (1 + math.cos(math.pi * (step - n_warmup) / n_fisrt_decay)) + * 0.5 + ) + linear_value = first_final_lr_factor + ( + max_lr - first_final_lr_factor + ) * (1 - (step - n_warmup) / n_fisrt_decay) + return linear_value * 2 - cosine_value + elif first_decay_type == "square": + return first_final_lr_factor + (max_lr - first_final_lr_factor) * ( + 1 - ((step - n_warmup) / n_fisrt_decay) ** 2 + ) + elif first_decay_type == "sqrt": + return first_final_lr_factor + (max_lr - first_final_lr_factor) * ( + 1 - math.sqrt((step - n_warmup) / n_fisrt_decay) + ) + else: + raise ValueError( + f"decay type {first_decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + elif step < n_iterations: + if second_decay_type == "linear": + return second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - (step - n_warmup - n_fisrt_decay) / (n_iterations - n_fisrt_decay) + ) + elif second_decay_type == "cosine": + return second_final_lr_factor + 0.5 * ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + + math.cos( + math.pi + * (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ) + elif second_decay_type == "exp": + return first_final_lr_factor ** ( + (step - n_warmup - n_fisrt_decay) / (n_iterations - n_fisrt_decay) + ) + elif second_decay_type == "mirror_cosine": + cosine_value = ( + second_final_lr_factor + + (first_final_lr_factor - second_final_lr_factor) + * ( + 1 + + math.cos( + math.pi + * (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ) + * 0.5 + ) + linear_value = second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - (step - n_warmup - n_fisrt_decay) / (n_iterations - n_fisrt_decay) + ) + return linear_value * 2 - cosine_value + elif second_decay_type == "square": + return second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - ( + (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ** 2 + ) + elif second_decay_type == "sqrt": + return second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - math.sqrt( + (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ) + else: + raise ValueError( + f"decay type {second_decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + else: + return second_final_lr_factor + + return schedule diff --git a/src/optim/schedulefree.py b/src/optim/schedulefree.py new file mode 100644 index 0000000..3c5f4f2 --- /dev/null +++ b/src/optim/schedulefree.py @@ -0,0 +1,410 @@ +""" +Here is an original implementation of Schedule-Free AdamW and Schedule-Free SGD. +Source: https://github.com/facebookresearch/schedule_free +""" + +import math +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union + +import torch +from typing_extensions import TypeAlias + +try: + from torch.optim.optimizer import ParamsT +except ImportError: + ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] + + +class AdamWScheduleFree(torch.optim.Optimizer): + r""" + Schedule-Free AdamW + As the name suggests, no scheduler is needed with this optimizer. + To add warmup, rather than using a learning rate schedule you can just + set the warmup_steps parameter. + + This optimizer requires that .train() and .eval() be called before the + beginning of training and evaluation respectively. The optimizer should + also be placed in eval mode when saving checkpoints. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): + Learning rate parameter (default 0.0025) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)). + eps (float): + Term added to the denominator outside of the root operation to + improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + warmup_steps (int): Enables a linear learning rate warmup (default 0). + r (float): Use polynomial weighting in the average + with power r (default 0). + weight_lr_power (float): During warmup, the weights in the average will + be equal to lr raised to this power. Set to 0 for no weighting + (default 2.0). + foreach (bool): Use a foreach-backed implementation of the optimizer. + Should be significantly faster, but will have higher peak memory + usage (default True if supported in your PyTorch version). + """ + + def __init__( + self, + params: ParamsT, + lr: Union[float, torch.Tensor] = 0.0025, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + warmup_steps: int = 0, + r: float = 0.0, + weight_lr_power: float = 2.0, + foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), + ): + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + r=r, + k=0, + warmup_steps=warmup_steps, + train_mode=False, + weight_sum=0.0, + lr_max=-1.0, + scheduled_lr=0.0, + weight_lr_power=weight_lr_power, + weight_decay=weight_decay, + foreach=foreach, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def eval(self): + for group in self.param_groups: + train_mode = group["train_mode"] + beta1, _ = group["betas"] + if train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p to x + p.lerp_(end=state["z"].to(p.device), weight=1 - 1 / beta1) + group["train_mode"] = False + + @torch.no_grad() + def train(self): + for group in self.param_groups: + train_mode = group["train_mode"] + beta1, _ = group["betas"] + if not train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p to y + p.lerp_(end=state["z"].to(p.device), weight=1 - beta1) + group["train_mode"] = True + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if not self.param_groups[0]["train_mode"]: + raise Exception( + "Optimizer was not in train mode when step is called. " + "Please insert .train() and .eval() calls on the " + "optimizer. See documentation for details." + ) + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + eps = group["eps"] + beta1, beta2 = group["betas"] + decay = group["weight_decay"] + k = group["k"] + r = group["r"] + warmup_steps = group["warmup_steps"] + weight_lr_power = group["weight_lr_power"] + + if k < warmup_steps: + sched = (k + 1) / warmup_steps + else: + sched = 1.0 + + bias_correction2 = 1 - beta2 ** (k + 1) + lr = group["lr"] * sched + group["scheduled_lr"] = lr # For logging purposes + + lr_max = group["lr_max"] = max(lr, group["lr_max"]) + + weight = ((k + 1) ** r) * (lr_max**weight_lr_power) + weight_sum = group["weight_sum"] = group["weight_sum"] + weight + + try: + ckp1 = weight / weight_sum + except ZeroDivisionError: + ckp1 = 0 + + active_p = [p for p in group["params"] if p.grad is not None] + + for p in active_p: + if "z" not in self.state[p]: + self.state[p]["z"] = torch.clone( + p, memory_format=torch.preserve_format + ) + self.state[p]["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if group["foreach"] and len(active_p) > 0: + y, grad, exp_avg_sq, z = zip( + *[ + (p, p.grad, self.state[p]["exp_avg_sq"], self.state[p]["z"]) + for p in active_p + ] + ) + + # Decay the first and second moment running average coefficient + torch._foreach_mul_(exp_avg_sq, beta2) + torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2) + denom = torch._foreach_div(exp_avg_sq, bias_correction2) + torch._foreach_sqrt_(denom) + torch._foreach_add_(denom, eps) + + # Normalize grad in-place for memory efficiency + torch._foreach_div_(grad, denom) + + # Weight decay calculated at y + if decay != 0: + torch._foreach_add_(grad, y, alpha=decay) + + # These operations update y in-place, + # without computing x explicitly. + torch._foreach_lerp_(y, z, weight=ckp1) + torch._foreach_add_(y, grad, alpha=lr * (beta1 * (1 - ckp1) - 1)) + + # z step + torch._foreach_sub_(z, grad, alpha=lr) + else: + for p in active_p: + y = p # Notation to match theory + grad = p.grad + + state = self.state[p] + + z = state["z"] + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps) + + # Reuse grad buffer for memory efficiency + grad_normalized = grad.div_(denom) + + # Weight decay calculated at y + if decay != 0: + grad_normalized.add_(y, alpha=decay) + + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.add_(grad_normalized, alpha=lr * (beta1 * (1 - ckp1) - 1)) + + # z step + z.sub_(grad_normalized, alpha=lr) + + group["k"] = k + 1 + return loss + + +class SGDScheduleFree(torch.optim.Optimizer): + r""" + Schedule-Free SGD + As the name suggests, no scheduler is needed with this optimizer. + To add warmup, rather than using a learning rate schedule you can just + set the warmup_steps parameter. + + This optimizer requires that .train() and .eval() be called before the + beginning of training and evaluation respectively. The optimizer should + also be placed in eval mode when saving checkpoints. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): + Learning rate parameter (default 1.0) + momentum (float): momentum factor, must be between 0 and 1 exclusive + (default: 0.9) + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + warmup_steps (int): Enables a linear learning rate warmup (default 0). + r (float): Use polynomial weighting in the average + with power r (default 0). + weight_lr_power (float): During warmup, the weights in the average will + be equal to lr raised to this power. Set to 0 for no weighting + (default 2.0). + foreach (bool): Use a foreach-backed implementation of the optimizer. + Should be significantly faster, but will have higher peak memory + usage (default True if supported in your PyTorch version). + """ + + def __init__( + self, + params: ParamsT, + lr: Union[float, torch.Tensor] = 1.0, + momentum: float = 0.9, + weight_decay: float = 0, + warmup_steps: int = 0, + r: float = 0.0, + weight_lr_power: float = 2, + foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if momentum <= 0 or momentum >= 1: + raise ValueError( + "Momentum must be between 0 and 1 exclusive: {}".format(momentum) + ) + + defaults = dict( + lr=lr, + momentum=momentum, + r=r, + k=0, + warmup_steps=warmup_steps, + train_mode=False, + weight_sum=0.0, + lr_max=-1.0, + scheduled_lr=0.0, + weight_lr_power=weight_lr_power, + weight_decay=weight_decay, + foreach=foreach, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def eval(self): + for group in self.param_groups: + train_mode = group["train_mode"] + momentum = group["momentum"] + if train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p to x + p.lerp_(end=state["z"].to(p.device), weight=1 - 1 / momentum) + group["train_mode"] = False + + @torch.no_grad() + def train(self): + for group in self.param_groups: + train_mode = group["train_mode"] + momentum = group["momentum"] + if not train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p to y + p.lerp_(end=state["z"].to(p.device), weight=1 - momentum) + group["train_mode"] = True + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if not self.param_groups[0]["train_mode"]: + raise Exception( + "Optimizer was not in train mode when step is called. " + "Please insert .train() and .eval() calls on the " + "optimizer. See documentation for details." + ) + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + lr = group["lr"] + weight_decay = group["weight_decay"] + k = group["k"] + warmup_steps = group["warmup_steps"] + + if k < warmup_steps: + sched = (k + 1) / warmup_steps + else: + sched = 1.0 + lr = group["lr"] * sched + group["scheduled_lr"] = lr # For logging purposes + + weight_lr_power = group["weight_lr_power"] + + r = group["r"] + lr_max = group["lr_max"] = max(lr, group["lr_max"]) + + weight = ((k + 1) ** r) * (lr_max**weight_lr_power) + weight_sum = group["weight_sum"] = group["weight_sum"] + weight + + try: + ckp1 = weight / weight_sum + except ZeroDivisionError: + ckp1 = 0 + + active_p = [p for p in group["params"] if p.grad is not None] + + for p in active_p: + if "z" not in self.state[p]: + self.state[p]["z"] = torch.clone( + p, memory_format=torch.preserve_format + ) + + if group["foreach"] and len(active_p) > 0: + y, grad, z = zip(*[(p, p.grad, self.state[p]["z"]) for p in active_p]) + + # Apply weight decay + if weight_decay != 0: + torch._foreach_add_(grad, y, alpha=weight_decay) + + # These operations update y in-place, + # without computing x explicitly. + torch._foreach_lerp_(y, z, weight=ckp1) + torch._foreach_add_(y, grad, alpha=lr * (momentum * (1 - ckp1) - 1)) + + # SGD step + torch._foreach_sub_(z, grad, alpha=lr) + else: + for p in active_p: + y = p # Notation to match theory + grad = p.grad + z = self.state[p]["z"] + + # Apply weight decay + if weight_decay != 0: + grad.add_(y, alpha=weight_decay) + + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.add_(grad, alpha=lr * (momentum * (1 - ckp1) - 1)) + + # SGD step + z.sub_(grad, alpha=lr) + + group["k"] = k + 1 + return loss diff --git a/src/optim/sgdf.py b/src/optim/sgdf.py new file mode 100644 index 0000000..d3ee42a --- /dev/null +++ b/src/optim/sgdf.py @@ -0,0 +1,83 @@ +""" +Here is an original implementation of SGDF. +Source: https://arxiv.org/abs/2311.02818 +""" + +import torch + + +class SGDF(torch.optim.Optimizer): + def __init__(self, params, lr=1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0): + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if betas[0] < 0.0 or betas[1] < 0.0: + raise ValueError("Invalid beta value: {}".format(betas)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(SGDF, self).__init__(params, defaults) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + eps = group["eps"] + beta1, beta2 = group["betas"] + weight_decay = group["weight_decay"] + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad.data + if weight_decay != 0.0: + grad.add_(p.data, alpha=weight_decay) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(grad) + state["exp_var"] = torch.zeros_like(grad) + + exp_avg = state["exp_avg"] + exp_var = state["exp_var"] + + # Compute gradient 1st and 2nd + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + grad_residual = grad - exp_avg + + exp_var.mul_(beta2).addcmul_( + grad_residual, grad_residual, value=1 - beta2 + ) + + state["step"] += 1 + + # Bias correction + bias_correction1 = 1 - beta1 ** state["step"] + # bias_correction2 = 1 - beta2 ** state['step'] + bias_correction2 = ( + (1 + beta1) + * (1 - beta2 ** state["step"]) + / ((1 - beta1) * (1 - beta1 ** (2 * state["step"]))) + ) + + exp_avg_corr = exp_avg / bias_correction1 + exp_var_corr = exp_var / bias_correction2 + + # Wiener gain + K = exp_var_corr / ( + exp_var_corr + (grad - exp_avg_corr).pow(2).add_(eps) + ) + + grad_hat_residual = grad - exp_avg_corr + grad_hat = exp_avg_corr + K * grad_hat_residual + + p.data.add_(grad_hat, alpha=-lr) + + return loss diff --git a/src/optim/shampoo.py b/src/optim/shampoo.py new file mode 100644 index 0000000..5a7e82a --- /dev/null +++ b/src/optim/shampoo.py @@ -0,0 +1,655 @@ +import math +from typing import Callable, List, Optional, Tuple + +import torch +import torch.distributed as dist + + +class DistributedShampoo(torch.optim.Optimizer): + """ + Args: + params (iterable): Iterable of parameters to optimize. + lr (float, optional): Learning rate (default: 1e-3). + betas (Tuple[float, float, float], optional): Coefficients used for computing + running averages of gradient, squared gradient, and slow EMA (default: (0.9, 0.999, 0.9999)). + eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-8). + weight_decay (float, optional): Weight decay (L2 penalty) (default: 0). + shampoo_decay (float, optional): Decay rate for Shampoo preconditioners (default: 0.9). + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + shampoo_decay: float = 0.9, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if len(betas) != 2: + raise ValueError(f"Invalid betas length: {len(betas)}, expected 2.") + if not all(0.0 <= beta < 1.0 for beta in betas): + raise ValueError(f"Invalid betas: {betas}. Each beta must be in [0, 1).") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= shampoo_decay < 1.0: + raise ValueError( + f"Invalid shampoo_decay value: {shampoo_decay}. Must be in [0, 1)." + ) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + shampoo_decay=shampoo_decay, + ) + super(DistributedShampoo, self).__init__(params, defaults) + + def __setstate__(self, state): + super(DistributedShampoo, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """ + Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + Optional[float]: The loss if closure is provided, else None. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Iterate over parameter groups + for group in self.param_groups: + # print(group) + params_with_grad = [] + grads = [] + beta1, beta2 = self.defaults["betas"] + exp_avgs = [] + exp_avg_sqs = [] + preconditioners1 = [] + preconditioners2 = [] + state_steps = [] + + # Collect parameters and their states + for p in group["params"]: + if p.grad is None: + continue + if p.grad.is_sparse: + raise RuntimeError( + "DistributedShampoo does not support sparse gradients" + ) + if not p.requires_grad: + continue + + params_with_grad.append(p) + grad = p.grad + grads.append(grad) + + state = self.state[p] + if not state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Initialize Shampoo preconditioners as identity matrices or scalars + if p.dim() >= 2: + state["preconditioner1"] = torch.eye( + p.size(0), device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.eye( + p.size(1), device=p.device, dtype=p.dtype + ) + else: + state["preconditioner1"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + preconditioners1.append(state["preconditioner1"]) + preconditioners2.append(state["preconditioner2"]) + state_steps.append(state["step"]) + state["step"] += 1 + + if not params_with_grad: + continue # Skip if no parameters to update in this group + + # betas = group['betas'] + # beta1, beta2 = betas + beta1, beta2 = self.defaults["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + # eps = group['eps'] + eps = self.defaults["eps"] + # shampoo_decay = group['shampoo_decay'] + shampoo_decay = self.defaults["shampoo_decay"] + + # Update Shampoo preconditioners in a distributed manner + self._update_preconditioners_distributed( + preconditioners1, preconditioners2, grads, group, shampoo_decay, eps + ) + + # Update parameters using Shampoo preconditioning + self._update_distributed_shampoo( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + preconditioners1, + preconditioners2, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + ) + + return loss + + def _update_preconditioners_distributed( + self, + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + grads: List[torch.Tensor], + group: dict, + shampoo_decay: float, + eps: float, + ): + """ + Updates Shampoo preconditioners and synchronizes them across distributed workers. + + Args: + preconditioners1 (List[torch.Tensor]): List of first preconditioners for each parameter. + preconditioners2 (List[torch.Tensor]): List of second preconditioners for each parameter. + grads (List[torch.Tensor]): List of gradients for each parameter. + group (dict): Parameter group options. + shampoo_decay (float): Decay rate for Shampoo preconditioners. + eps (float): Small epsilon for numerical stability. + """ + for pc1, pc2, grad in zip(preconditioners1, preconditioners2, grads): + if grad.dim() >= 2: + A = grad @ grad.t() # [in_features, in_features] + B = grad.t() @ grad # [out_features, out_features] + else: + A = (grad**2).sum() + B = A.clone() # For 1D gradients, B is same as A + + # Update preconditioners with exponential moving average + pc1.mul_(shampoo_decay).add_(A, alpha=1 - shampoo_decay) + pc2.mul_(shampoo_decay).add_(B, alpha=1 - shampoo_decay) + + # Synchronize preconditioners across workers + if dist.is_initialized(): + dist.all_reduce(pc1, op=dist.ReduceOp.SUM) + dist.all_reduce(pc2, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + pc1.div_(world_size) + pc2.div_(world_size) + + def _update_distributed_shampoo( + self, + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + steps: List[int], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + ): + """ + Performs an update with Shampoo preconditioning. + + Args: + params (List[torch.Tensor]): List of parameters to update. + grads (List[torch.Tensor]): List of gradients for each parameter. + exp_avgs (List[torch.Tensor]): List of first moment estimates. + exp_avg_sqs (List[torch.Tensor]): List of second moment estimates. + preconditioners1 (List[torch.Tensor]): List of first preconditioners. + preconditioners2 (List[torch.Tensor]): List of second preconditioners. + steps (List[int]): List of step counts for each parameter. + beta1 (float): Coefficient for first moment. + beta2 (float): Coefficient for second moment. + lr (float): Learning rate. + weight_decay (float): Weight decay coefficient. + eps (float): Small epsilon for numerical stability. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + pc1 = preconditioners1[i] + pc2 = preconditioners2[i] + step = steps[i] + + # Bias corrections + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # Update biased first moment estimate + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # Update biased second raw moment estimate + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Compute bias-corrected second moment + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + # Compute step size NOTE: think about another step_size (mb like in the AdEMAMix or SOAP) + step_size = lr / (bias_correction1 if bias_correction1 > 0 else 0.01) + + # Apply weight decay + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + # Compute Shampoo preconditioned gradient + if grad.dim() >= 2: + # Safe inversion with added epsilon to diagonal + if not torch.isfinite(pc1).all() or not torch.isfinite(pc2).all(): + raise RuntimeError("Matrix contains NaN or Inf values.") + # inv_pc1 = torch.inverse( + # pc1 + # + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps + # ).sqrt() + # inv_pc2 = torch.inverse( + # pc2 + # + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps + # ).sqrt() + inv_pc1 = torch.linalg.inv( + pc1 + + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps + ).sqrt() + inv_pc2 = torch.linalg.inv( + pc2 + + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps + ).sqrt() + # inv_pc1 = torch.linalg.pinv(pc1 + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps).sqrt() + # inv_pc2 = torch.linalg.pinv(pc2 + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps).sqrt() + + # Precondition the gradient + preconditioned_grad = inv_pc1 @ grad @ inv_pc2 + else: + # For 1D gradients, use scalar preconditioning + preconditioned_grad = grad / (pc1.sqrt() + eps) + + combined_grad = (exp_avg + preconditioned_grad) / 2 # Weighted average + + # Update parameters + param.addcdiv_(combined_grad, denom, value=-step_size) + + # Optional: Gradient Clipping (Uncomment if needed) + # torch.nn.utils.clip_grad_norm_(param, max_norm=1.0) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(lr={self.defaults['lr']}, " + f"betas={self.defaults['betas']}, eps={self.defaults['eps']}, " + f"weight_decay={self.defaults['weight_decay']}" + ) + + +class AdEMAMixDistributedShampoo(torch.optim.Optimizer): + """ + AdEMAMix optimizer with Distributed Shampoo preconditioning. + + Combines the AdEMAMix optimizer with Shampoo’s second-order preconditioning. + Supports distributed training via torch.distributed. + + Args: + params (iterable): Iterable of parameters to optimize. + lr (float, optional): Learning rate (default: 1e-3). + betas (Tuple[float, float, float], optional): Coefficients used for computing + running averages of gradient, squared gradient, and slow EMA (default: (0.9, 0.999, 0.9999)). + eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-8). + weight_decay (float, optional): Weight decay (L2 penalty) (default: 0). + alpha (float, optional): Alpha parameter for AdEMAMix (default: 5.0). + T_alpha_beta3 (Optional[int], optional): Time constant for alpha and beta3 scheduling (default: None). + shampoo_decay (float, optional): Decay rate for Shampoo preconditioners (default: 0.9). + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + eps: float = 1e-8, + weight_decay: float = 0, + alpha: float = 5.0, + T_alpha_beta3: Optional[int] = None, + shampoo_decay: float = 0.9, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if len(betas) != 3: + raise ValueError(f"Invalid betas length: {len(betas)}, expected 3.") + if not all(0.0 <= beta < 1.0 for beta in betas): + raise ValueError(f"Invalid betas: {betas}. Each beta must be in [0, 1).") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= shampoo_decay < 1.0: + raise ValueError( + f"Invalid shampoo_decay value: {shampoo_decay}. Must be in [0, 1)." + ) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + alpha=alpha, + T_alpha_beta3=T_alpha_beta3, + shampoo_decay=shampoo_decay, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """ + Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + Optional[float]: The loss if closure is provided, else None. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Iterate over parameter groups + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_slows = [] + preconditioners1 = [] + preconditioners2 = [] + state_steps = [] + + # Collect parameters and their states + for p in group["params"]: + if p.grad is None: + continue + if p.grad.is_sparse: + raise RuntimeError( + "AdEMAMixDistributedShampoo does not support sparse gradients" + ) + if not p.requires_grad: + continue + + params_with_grad.append(p) + grad = p.grad + grads.append(grad) + + state = self.state[p] + if not state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Initialize Shampoo preconditioners as identity matrices or scalars + if p.dim() >= 2: + state["preconditioner1"] = torch.eye( + p.size(0), device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.eye( + p.size(1), device=p.device, dtype=p.dtype + ) + else: + state["preconditioner1"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + exp_avg_slows.append(state["exp_avg_slow"]) + preconditioners1.append(state["preconditioner1"]) + preconditioners2.append(state["preconditioner2"]) + state_steps.append(state["step"]) + state["step"] += 1 + + if not params_with_grad: + continue # Skip if no parameters to update in this group + + betas = group["betas"] + beta1, beta2, beta3 = betas + alpha = group["alpha"] + T_alpha_beta3 = group["T_alpha_beta3"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + shampoo_decay = group["shampoo_decay"] + + # Update Shampoo preconditioners in a distributed manner + self._update_preconditioners_distributed( + preconditioners1, preconditioners2, grads, group, shampoo_decay, eps + ) + + # Update parameters using AdEMAMix with Shampoo preconditioning + self._update_adamemix_distributed_shampoo( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + exp_avg_slows, + preconditioners1, + preconditioners2, + state_steps, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + T_alpha_beta3=T_alpha_beta3, + lr=lr, + weight_decay=weight_decay, + eps=eps, + ) + + return loss + + def _update_preconditioners_distributed( + self, + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + grads: List[torch.Tensor], + group: dict, + shampoo_decay: float, + eps: float, + ): + """ + Updates Shampoo preconditioners and synchronizes them across distributed workers. + + Args: + preconditioners1 (List[torch.Tensor]): List of first preconditioners for each parameter. + preconditioners2 (List[torch.Tensor]): List of second preconditioners for each parameter. + grads (List[torch.Tensor]): List of gradients for each parameter. + group (dict): Parameter group options. + shampoo_decay (float): Decay rate for Shampoo preconditioners. + eps (float): Small epsilon for numerical stability. + """ + for pc1, pc2, grad in zip(preconditioners1, preconditioners2, grads): + if grad.dim() >= 2: + A = grad @ grad.t() # [in_features, in_features] + B = grad.t() @ grad # [out_features, out_features] + else: + A = (grad**2).sum() + B = A.clone() # For 1D gradients, B is same as A + + # Update preconditioners with exponential moving average + pc1.mul_(shampoo_decay).add_(A, alpha=1 - shampoo_decay) + pc2.mul_(shampoo_decay).add_(B, alpha=1 - shampoo_decay) + + # Synchronize preconditioners across workers + if dist.is_initialized(): + dist.all_reduce(pc1, op=dist.ReduceOp.SUM) + dist.all_reduce(pc2, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + pc1.div_(world_size) + pc2.div_(world_size) + + def _update_adamemix_distributed_shampoo( + self, + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + exp_avg_slows: List[torch.Tensor], + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + steps: List[int], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + T_alpha_beta3: Optional[int], + lr: float, + weight_decay: float, + eps: float, + ): + """ + Performs the AdEMAMix update with Shampoo preconditioning. + + Args: + params (List[torch.Tensor]): List of parameters to update. + grads (List[torch.Tensor]): List of gradients for each parameter. + exp_avgs (List[torch.Tensor]): List of first moment estimates. + exp_avg_sqs (List[torch.Tensor]): List of second moment estimates. + exp_avg_slows (List[torch.Tensor]): List of slow EMA estimates. + preconditioners1 (List[torch.Tensor]): List of first preconditioners. + preconditioners2 (List[torch.Tensor]): List of second preconditioners. + steps (List[int]): List of step counts for each parameter. + beta1 (float): Coefficient for first moment. + beta2 (float): Coefficient for second moment. + beta3 (float): Coefficient for slow EMA. + alpha (float): Alpha parameter for AdEMAMix. + T_alpha_beta3 (Optional[int]): Time constant for scheduling. + lr (float): Learning rate. + weight_decay (float): Weight decay coefficient. + eps (float): Small epsilon for numerical stability. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_slow = exp_avg_slows[i] + pc1 = preconditioners1[i] + pc2 = preconditioners2[i] + step = steps[i] + + # Bias corrections + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # Schedule alpha_t and beta3_t + if T_alpha_beta3 is not None and T_alpha_beta3 > 0: + alpha_t = min(step * alpha / T_alpha_beta3, alpha) + # Avoid division by zero + if T_alpha_beta3 != step: + log_beta1 = math.log(beta1) + log_beta3 = math.log(beta3) + denominator = (1 - step / T_alpha_beta3) * log_beta3 + ( + step / T_alpha_beta3 + ) * log_beta1 + if denominator != 0: + beta3_t = min( + math.exp((log_beta1 * log_beta3) / denominator), beta3 + ) + else: + beta3_t = beta3 + else: + beta3_t = beta3 + else: + alpha_t = alpha + beta3_t = beta3 + + # Update biased first moment estimate + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # Update biased second raw moment estimate + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + # Update slow EMA + exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1 - beta3_t) + + # Compute bias-corrected second moment + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + # Compute step size + step_size = lr / (bias_correction1 if bias_correction1 > 0 else 0.01) + + # Apply weight decay + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + # Compute Shampoo preconditioned gradient + if grad.dim() >= 2: + # Safe inversion with added epsilon to diagonal + inv_pc1 = torch.inverse( + pc1 + + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps + ).sqrt() + inv_pc2 = torch.inverse( + pc2 + + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps + ).sqrt() + + # Precondition the gradient + preconditioned_grad = inv_pc1 @ grad @ inv_pc2 + else: + # For 1D gradients, use scalar preconditioning + preconditioned_grad = grad / (pc1.sqrt() + eps) + + # Combine AdEMAMix update with Shampoo preconditioning + combined_grad = ( + exp_avg + alpha_t * exp_avg_slow + preconditioned_grad + ) / 3 # Weighted average + + # Update parameters + param.addcdiv_(combined_grad, denom, value=-step_size) + + # Optional: Gradient Clipping (Uncomment if needed) + # torch.nn.utils.clip_grad_norm_(param, max_norm=1.0) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(lr={self.defaults['lr']}, " + f"betas={self.defaults['betas']}, eps={self.defaults['eps']}, " + f"weight_decay={self.defaults['weight_decay']}, alpha={self.defaults['alpha']}, " + f"T_alpha_beta3={self.defaults['T_alpha_beta3']})" + ) diff --git a/src/optim/sign.py b/src/optim/sign.py new file mode 100644 index 0000000..d8e9e37 --- /dev/null +++ b/src/optim/sign.py @@ -0,0 +1,112 @@ +from typing import Dict + +import torch + + +class Signum(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-3, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + sign_update=True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + sign_update=sign_update, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def _init_state(self, example, state=None): + assert isinstance(example, torch.Tensor) + assert isinstance(state, Dict) or state is None + if state is None: + state = {} + state["step"] = 0 + state["momentum_buffer"] = torch.clone(example).detach() + return state + + @torch.no_grad() + def _compute_update( + self, grad, state, lr, momentum, nesterov, dampening, sign_update, **kwargs + ): + if momentum != 0: # Signum check + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + if sign_update: + grad = grad.sign() + + return grad * (-lr) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if group["weight_decay"] != 0: + p.mul_(1 - group["lr"] * group["weight_decay"]) + + if len(state) == 0: + self._init_state(example=p, state=state) + if not group["momentum"]: + state.pop("momentum_buffer", None) + + state["step"] += 1 + + update = self._compute_update( + grad, + state, + group["lr"], + group["momentum"], + group["nesterov"], + group["dampening"], + group["sign_update"], + ) + + p.add_(update) + + return loss diff --git a/src/optim/soap.py b/src/optim/soap.py new file mode 100644 index 0000000..70a60cf --- /dev/null +++ b/src/optim/soap.py @@ -0,0 +1,481 @@ +""" +Here is an original implementation of SOAP. +Source: https://github.com/nikhilvyas/SOAP +""" + +from itertools import chain + +import torch +import torch.nn as nn + +# Parts of the code are modifications of Pytorch's AdamW optimizer +# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py + + +class SOAP(torch.optim.Optimizer): + """ + Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.003): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): + Adam's betas parameters (b1, b2). + shampoo_beta (`float`, *optional*, defaults to -1): + If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. + eps (`float`, *optional*, defaults to 1e-08): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. + precondition_frequency (`int`, *optional*, defaults to 10): + How often to update the preconditioner. + max_precond_dim (`int`, *optional*, defaults to 10000): + Maximum dimension of the preconditioner. + Set to 10000, so that we exclude most common vocab sizes while including layers. + merge_dims (`bool`, *optional*, defaults to `False`): + Whether or not to merge dimensions of the preconditioner. + precondition_1d (`bool`, *optional*, defaults to `False`): + Whether or not to precondition 1D gradients. + normalize_grads (`bool`, *optional*, defaults to `False`): + Whether or not to normalize gradients per layer. + Helps at large precondition_frequency (~100 in our experiments), + but hurts performance at small precondition_frequency (~10 in our experiments). + data_format (`str`, *optional*, defaults to `channels_first`): + Data format of the input for convolutional layers. + Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias correction in Adam. + + Example of usage: + optim = SOAP(lr = 3e-3, betas=(.95, .95), weight_decay=.01, precondition_frequency=10) + """ + + def __init__( + self, + params, + lr: float = 3e-3, + betas=(0.95, 0.95), + shampoo_beta: float = -1, + eps: float = 1e-8, + weight_decay: float = 0.01, + precondition_frequency: int = 10, + max_precond_dim: int = 10000, # + merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim. + precondition_1d: bool = False, + normalize_grads: bool = False, + data_format: str = "channels_first", + correct_bias: bool = True, + ): + defaults = { + "lr": lr, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + "precondition_frequency": precondition_frequency, + "max_precond_dim": max_precond_dim, + "merge_dims": merge_dims, + "precondition_1d": precondition_1d, + "normalize_grads": normalize_grads, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + self._data_format = data_format + + def merge_dims(self, grad, max_precond_dim): + """ + Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim. + """ + assert self._data_format in ["channels_first", "channels_last"] + if self._data_format == "channels_last" and grad.dim() == 4: + grad = grad.permute(0, 3, 1, 2) + shape = grad.shape + new_shape = [] + + curr_shape = 1 + for sh in shape: + temp_shape = curr_shape * sh + if temp_shape > max_precond_dim: + if curr_shape > 1: + new_shape.append(curr_shape) + curr_shape = sh + else: + new_shape.append(sh) + curr_shape = 1 + else: + curr_shape = temp_shape + + if curr_shape > 1 or len(new_shape) == 0: + new_shape.append(curr_shape) + + new_grad = grad.reshape(new_shape) + return new_grad + + @torch.no_grad() + def step(self): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + if "Q" not in state: + self.init_preconditioner( + grad, + state, + precondition_frequency=group["precondition_frequency"], + precondition_1d=group["precondition_1d"], + shampoo_beta=( + group["shampoo_beta"] + if group["shampoo_beta"] >= 0 + else group["betas"][1] + ), + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + ) + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + continue # first step is skipped so that we never use the current gradients in the projection. + + # Projecting gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + grad_projected = self.project( + grad, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).add_( + grad_projected.square(), alpha=(1.0 - beta2) + ) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + exp_avg_projected = self.project( + exp_avg, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + step_size = group["lr"] + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** (state["step"]) + bias_correction2 = 1.0 - beta2 ** (state["step"]) + step_size = step_size * (bias_correction2**0.5) / bias_correction1 + + # Projecting back the preconditioned (by Adam) exponential moving average of gradients + # to the original space + norm_grad = self.project_back( + exp_avg_projected / denom, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + if group["normalize_grads"]: + norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5) + + p.add_(norm_grad, alpha=-step_size) + + # From AdamW code: Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + # Update is done after the gradient step to avoid using current gradients in the projection. + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + + return loss + + def init_preconditioner( + self, + grad, + state, + precondition_frequency=10, + shampoo_beta=0.95, + max_precond_dim=10000, + precondition_1d=False, + merge_dims=False, + ): + """ + Initializes the preconditioner matrices (L and R in the paper). + """ + state["GG"] = ( + [] + ) # Will hold all the preconditioner matrices (L and R in the paper). + if grad.dim() == 1: + if not precondition_1d or grad.shape[0] > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append( + torch.zeros(grad.shape[0], grad.shape[0], device=grad.device) + ) + else: + if merge_dims: + grad = self.merge_dims(grad, max_precond_dim) + + for sh in grad.shape: + if sh > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append(torch.zeros(sh, sh, device=grad.device)) + + state["Q"] = None # Will hold all the eigenbases of the preconditioner. + state["precondition_frequency"] = precondition_frequency + state["shampoo_beta"] = shampoo_beta + + def project(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient to the eigenbases of the preconditioner. + """ + original_shape = grad.shape + if merge_dims: + if grad.dim() == 4 and self._data_format == "channels_last": + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [0]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def update_preconditioner( + self, + grad, + state, + max_precond_dim=10000, + merge_dims=False, + precondition_1d=False, + ): + """ + Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). + """ + if grad.dim() == 1: + if precondition_1d and grad.shape[0] <= max_precond_dim: + state["GG"][0].lerp_( + grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"] + ) + else: + if merge_dims: + new_grad = self.merge_dims(grad, max_precond_dim) + for idx, sh in enumerate(new_grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + new_grad, + new_grad, + dims=[ + [ + *chain( + range(idx), range(idx + 1, len(new_grad.shape)) + ) + ] + ] + * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + else: + for idx, sh in enumerate(grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + grad, + grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] + * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + + if state["Q"] is None: + state["Q"] = self.get_orthogonal_matrix(state["GG"]) + if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0: + state["Q"] = self.get_orthogonal_matrix_QR( + state, max_precond_dim, merge_dims + ) + + def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient back to the original space. + """ + original_shape = grad.shape + if merge_dims: + if self._data_format == "channels_last" and grad.dim() == 4: + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [1]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def get_orthogonal_matrix(self, mat): + """ + Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. + """ + matrix = [] + for m in mat: + if len(m) == 0: + matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + else: + float_data = True + matrix.append(m.data) + + final = [] + for m in matrix: + if len(m) == 0: + final.append([]) + continue + try: + _, Q = torch.linalg.eigh( + m + 1e-30 * torch.eye(m.shape[0], device=m.device) + ) + except: + _, Q = torch.linalg.eigh( + m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device) + ) + Q = Q.to(m.dtype) + Q = torch.flip(Q, [1]) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + return final + + def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False): + """ + Computes the eigenbases of the preconditioner using one round of power iteration + followed by torch.linalg.qr decomposition. + """ + precond_list = state["GG"] + orth_list = state["Q"] + + matrix = [] + orth_matrix = [] + for m, o in zip(precond_list, orth_list): + if len(m) == 0: + matrix.append([]) + orth_matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + else: + float_data = True + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + + orig_shape = state["exp_avg_sq"].shape + if self._data_format == "channels_last" and len(orig_shape) == 4: + permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape + if merge_dims: + exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim) + else: + exp_avg_sq = state["exp_avg_sq"] + + final = [] + for ind, (m, o) in enumerate(zip(matrix, orth_matrix)): + if len(m) == 0: + final.append([]) + continue + est_eig = torch.diag(o.T @ m @ o) + sort_idx = torch.argsort(est_eig, descending=True) + exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) + o = o[:, sort_idx] + power_iter = m @ o + Q, _ = torch.linalg.qr(power_iter) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + + if merge_dims: + if self._data_format == "channels_last" and len(orig_shape) == 4: + exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + exp_avg_sq = exp_avg_sq.reshape(orig_shape) + + state["exp_avg_sq"] = exp_avg_sq + return final diff --git a/src/optim/sophia.py b/src/optim/sophia.py new file mode 100644 index 0000000..e0cdc4d --- /dev/null +++ b/src/optim/sophia.py @@ -0,0 +1,246 @@ +""" +Here is an original implementation of SophiaG. +Source: https://github.com/Liuhong99/Sophia +""" + +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class SophiaG(Optimizer): + def __init__( + self, + params, + lr=1e-4, + betas=(0.965, 0.99), + rho=0.04, + weight_decay=1e-1, + *, + maximize: bool = False, + capturable: bool = False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict( + lr=lr, + betas=betas, + rho=rho, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + ) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("capturable", False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]["step"] + ) + if not step_is_tensor: + for s in state_values: + s["step"] = torch.tensor(float(s["step"])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group["betas"] + for p in group["params"]: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state["step"] = ( + torch.zeros((1,), dtype=torch.float, device=p.device) + if self.defaults["capturable"] + else torch.tensor(0.0) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if "hessian" not in state.keys(): + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group["betas"] + + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError("Hero does not support sparse gradients") + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((1,), dtype=torch.float, device=p.device) + if self.defaults["capturable"] + else torch.tensor(0.0) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if "hessian" not in state.keys(): + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + state_steps.append(state["step"]) + hessian.append(state["hessian"]) + + if self.defaults["capturable"]: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag( + params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group["rho"], + lr=group["lr"], + weight_decay=group["weight_decay"], + maximize=group["maximize"], + capturable=group["capturable"], + ) + + return loss + + +def sophiag( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool +): + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + func = _single_tensor_sophiag + + func( + params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + ) + + +def _single_tensor_sophiag( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool +): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = -lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) diff --git a/src/optim/utils.py b/src/optim/utils.py index 68162f6..6db4c39 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -1,11 +1,15 @@ +import math +import random +from contextlib import nullcontext +from pathlib import Path + import numpy as np import torch -import torch.nn.functional as F -from contextlib import nullcontext, contextmanager, ExitStack +import torch.distributed as dist -def get_batch(dataloader, device="cpu"): - x, y = next(dataloader) +def get_batch(datareader, device="cpu"): + x, y = datareader.sample_batch() if "cuda" in torch.device(device).type: # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) x = x.pin_memory().to(device, non_blocking=True) @@ -17,33 +21,164 @@ def get_batch(dataloader, device="cpu"): @torch.no_grad() -def eval(model, data_val_iter, device='cpu', max_num_batches=24, ctx=nullcontext()): +def eval( + model, + reader, + device="cpu", + max_num_batches=24, + ctx=nullcontext(), + cfg=None, +): assert model.training == False loss_list_val, acc_list = [], [] - for _ in range(max_num_batches): - x, y = get_batch(data_val_iter, device=device) + for idx in range(max_num_batches): + x, y = get_batch(reader, device=device) with ctx: outputs = model(x, targets=y, get_logits=True) - val_loss = outputs['loss'] + val_loss = outputs["loss"] + loss_list_val.append(val_loss) - acc_list.append((outputs['logits'].argmax(-1) == y).float().mean()) + acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) val_acc = torch.stack(acc_list).mean().item() val_loss = torch.stack(loss_list_val).mean().item() - val_perplexity = 2.71828 ** val_loss + val_perplexity = 2.71828**val_loss return val_acc, val_loss, val_perplexity -def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args): +@torch.no_grad() +def eval_sweep_dropk( + model, + data_tensor, + sequence_length, + batch_size, + n_heads, + device="cpu", + max_num_batches=24, + ctx=nullcontext(), +): + assert model.training == False + + x_axis, y_axis_pp, y_axis_acc, y_axis_loss = ( + torch.linspace(0.0, 0.95, 15), + [], + [], + [], + ) + loss_list_val, acc_list = [], [] + + for frac in x_axis: + drop_k = int(sequence_length * frac * n_heads) + for _ in range(max_num_batches): + x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) + with ctx: + outputs = model( + x, targets=y, alpha_th=None, drop_k=drop_k, get_logits=True + ) + loss_list_val.append(outputs["ce_loss"]) + acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) + + y_axis_acc.append(torch.stack(acc_list).mean().item()) + y_axis_loss.append(np.mean(loss_list_val)) + y_axis_pp.append(2.71828 ** y_axis_loss[-1]) + + return x_axis, y_axis_acc, y_axis_pp, y_axis_loss + + +@torch.no_grad() +def eval_sweep_alphath( + model, + data_tensor, + sequence_length, + batch_size, + device="cpu", + max_num_batches=24, + ctx=nullcontext(), +): + assert model.training == False + + alpha_ths, y_axis_pp, y_axis_acc, y_axis_loss = ( + [0, 1e-4, 1e-3, 1e-2, 1e-1, 2e-1, 3e-1, 4e-1, 5e-1], + [], + [], + [], + ) + loss_list_val, acc_list, x_axis = [], [], [] + + for alpha_th in alpha_ths: + frac_heads_pruned_list = [] + for _ in range(max_num_batches): + x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) + with ctx: + outputs = model( + x, targets=y, alpha_th=alpha_th, drop_k=None, get_logits=True + ) + nph, nh = ( + outputs["num_head_pruned_per_layer"], + outputs["num_heads_per_layer"], + ) + frac_heads_pruned = np.sum(nph) / np.sum( + nh + ) # fractions of heads removed given alpha_th + frac_heads_pruned_list.append(frac_heads_pruned) + loss_list_val.append(outputs["ce_loss"]) + acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) + + x_axis.append(np.mean(frac_heads_pruned_list)) + y_axis_acc.append(torch.stack(acc_list).mean().item()) + y_axis_loss.append(np.mean(loss_list_val)) + y_axis_pp.append(2.71828 ** y_axis_loss[-1]) + + return x_axis, y_axis_acc, y_axis_pp, y_axis_loss + + +def save_checkpoint(model, opt, scheduler, itr, ckpt_dir: Path): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer": opt.state_dict(), + "scheduler": scheduler.state_dict() if scheduler is not None else None, + "itr": itr, + } + ckpt_dir.mkdir(exist_ok=True, parents=True) + torch.save(checkpoint, ckpt_dir / "main.pt") + + +def load_checkpoint(model, opt, scheduler, ckpt_path, device): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + ckpt = torch.load(ckpt_path, map_location=device) + model.load_state_dict(ckpt["model"]) + opt.load_state_dict(ckpt["optimizer"]) + if scheduler is not None: + scheduler.load_state_dict(ckpt["scheduler"]) + itr = ckpt["itr"] + return itr + + +def save_worker_state(ckpt_dir: Path): + # Dataloader, rng states + worker_state = { + "rng_torch_cpu": torch.random.get_rng_state(), + "rng_torch_gpu": torch.cuda.get_rng_state(), + "rng_np": np.random.get_state(), + "rng_python": random.getstate(), + } + rank = 0 if not dist.is_initialized() else dist.get_rank() + ckpt_dir.mkdir(exist_ok=True, parents=True) + torch.save(worker_state, ckpt_dir / f"worker_{rank}.pt") - checkpoint = dict({ - 'model': distributed_backend.get_raw_model(model).state_dict(), - 'optimizer': opt.state_dict(), - 'scheduler': scheduler.state_dict(), - 'itr': itr, - }, **extra_args) - torch.save(checkpoint, ckpt_path) +def load_worker_state(ckpt_dir: Path): + rank = 0 if not dist.is_initialized() else dist.get_rank() + worker_state = torch.load(ckpt_dir / f"worker_{rank}.pt") + torch.random.set_rng_state(worker_state["rng_torch_cpu"]) + torch.cuda.set_rng_state(worker_state["rng_torch_gpu"]) + np.random.set_state(worker_state["rng_np"]) + random.setstate(worker_state["rng_python"])