From ea9f4d06be67dfd61d55c0ac9ef82ce54f5e9e3a Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 15 Oct 2024 13:50:32 +0200 Subject: [PATCH 01/58] soap and muon has been added. also black + isort everything --- src/config/base.py | 143 +++++++---- src/data/arxiv.py | 55 +++-- src/data/openwebtext2.py | 59 +++-- src/data/shakespeare.py | 14 +- src/data/slimpajama.py | 6 +- src/data/utils.py | 25 +- src/data/wikitext.py | 42 ++-- src/distributed/__init__.py | 4 +- src/distributed/backend.py | 5 +- src/distributed/ddp.py | 40 +-- src/main.py | 166 ++++++++----- src/models/base.py | 140 +++++++---- src/models/llama.py | 2 +- src/models/utils.py | 10 +- src/optim/base.py | 166 ++++++++----- src/optim/muon.py | 137 +++++++++++ src/optim/soap.py | 479 ++++++++++++++++++++++++++++++++++++ src/optim/utils.py | 32 ++- 18 files changed, 1197 insertions(+), 328 deletions(-) create mode 100644 src/optim/muon.py create mode 100644 src/optim/soap.py diff --git a/src/config/base.py b/src/config/base.py index e48c277..78309b6 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -1,72 +1,118 @@ +import distributed import torch -import distributed def none_or_str(value): - if value == 'None': + if value == "None": return None return value + def parse_args(base_parser, args, namespace): parser = base_parser # General training params - parser.add_argument('--batch_size', default=32, type=int) - parser.add_argument('--acc_steps', default=4, type=int) - parser.add_argument('--seed', default=0, type=int) - parser.add_argument('--data_seed', default=1337, type=int) - parser.add_argument('--device', default='cuda:0', type=str) - parser.add_argument('--iterations', default=25000, type=int) - parser.add_argument('--lr', default=1e-3, type=float) - parser.add_argument('--warmup_percent', default=0.05, type=float) - parser.add_argument('--weight_decay', default=0.1, type=float) - parser.add_argument('--beta1', default=0.9, type=float) - parser.add_argument('--beta2', default=0.95, type=float) - parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) - parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd']) - parser.add_argument('--eval_freq', default=200, type=int) # in iterations - parser.add_argument('--results_base_folder', default="./exps", type=str) - parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT + parser.add_argument("--batch_size", default=32, type=int) + parser.add_argument("--acc_steps", default=4, type=int) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--data_seed", default=1337, type=int) + parser.add_argument("--device", default="cuda:0", type=str) + parser.add_argument("--iterations", default=25000, type=int) + parser.add_argument("--lr", default=1e-3, type=float) + parser.add_argument("--warmup_percent", default=0.05, type=float) + parser.add_argument("--weight_decay", default=0.1, type=float) + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.95, type=float) + parser.add_argument("--scheduler", default="cos", choices=["linear", "cos", "none"]) + parser.add_argument("--opt", default="adamw", choices=["adamw", "sgd"]) + parser.add_argument("--eval_freq", default=200, type=int) # in iterations + parser.add_argument("--results_base_folder", default="./exps", type=str) + parser.add_argument( + "--grad_clip", default=0.0, type=float + ) # default value is 1.0 in NanoGPT # Dataset params - parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) - parser.add_argument('--vocab_size', default=50304, type=int) - parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, mostly useless except for openwebtext2 + parser.add_argument( + "--dataset", + default="slimpajama", + choices=[ + "slimpajama", + "wikitext", + "shakespeare-char", + "arxiv", + "arxiv2000", + "arxiv+wiki", + "openwebtext2", + ], + ) + parser.add_argument("--vocab_size", default=50304, type=int) + parser.add_argument( + "--data_in_ram", action="store_true" + ) # force the data to RAM, mostly useless except for openwebtext2 # Model params - parser.add_argument('--model', default='base', choices=['base', 'llama2']) - parser.add_argument('--use_pretrained', default="auto", type=none_or_str) # 'none', 'gpt-2' or a path to the pretraind model - parser.add_argument('--dropout', default=0.0, type=float) - parser.add_argument('--n_head', default=12, type=int) - parser.add_argument('--n_layer', default=12, type=int) # depths in att + ff blocks - parser.add_argument('--n_embd', default=768, type=int) # embedding size / hidden size ... - parser.add_argument('--sequence_length', default=512, type=int) - parser.add_argument('--dtype', default=torch.bfloat16, type=torch.dtype) - parser.add_argument('--bias', default=False, type=bool) - parser.add_argument('--compile', action='store_true') # if true then model is compiled + parser.add_argument("--model", default="base", choices=["base", "llama2"]) + parser.add_argument( + "--use_pretrained", default="auto", type=none_or_str + ) # 'none', 'gpt-2' or a path to the pretraind model + parser.add_argument("--dropout", default=0.0, type=float) + parser.add_argument("--n_head", default=12, type=int) + parser.add_argument("--n_layer", default=12, type=int) # depths in att + ff blocks + parser.add_argument( + "--n_embd", default=768, type=int + ) # embedding size / hidden size ... + parser.add_argument("--sequence_length", default=512, type=int) + parser.add_argument("--dtype", default=torch.bfloat16, type=torch.dtype) + parser.add_argument("--bias", default=False, type=bool) + parser.add_argument( + "--compile", action="store_true" + ) # if true then model is compiled parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) parser.add_argument( "--multiple_of", # make SwiGLU hidden layer size multiple of large power of 2 default=256, type=int, ) - parser.add_argument('--run_prefix', default=None, type=str, required=False) # is added before the autogenerated experiment name - parser.add_argument('--exp_name', default=None, type=str, required=False) + parser.add_argument( + "--run_prefix", default=None, type=str, required=False + ) # is added before the autogenerated experiment name + parser.add_argument("--exp_name", default=None, type=str, required=False) # logging params (WandB) - parser.add_argument('--wandb', action='store_true') # whether to use wandb or not - parser.add_argument('--wandb_project', default="my-project", type=str) - parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name - parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences + parser.add_argument("--wandb", action="store_true") # whether to use wandb or not + parser.add_argument("--wandb_project", default="my-project", type=str) + parser.add_argument( + "--wandb_run_prefix", default="none", type=str + ) # is added before the autogenerated experiment name + parser.add_argument( + "--eval_seq_prefix", default="Once upon a time", type=str + ) # prefix used to generate sequences # Distributed args - parser.add_argument('--distributed_backend', default=None, type=str, required=False, - choices=distributed.registered_backends()) # distributed backend type - parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) + parser.add_argument( + "--distributed_backend", + default=None, + type=str, + required=False, + choices=distributed.registered_backends(), + ) # distributed backend type + parser.add_argument( + "--save_checkpoint_freq", default=None, type=int, required=False + ) args = parser.parse_args(args, namespace) if args.exp_name is None: - special_name_handle_fields = {"model", "lr", "batch_size", - "acc_steps", "seed", "exp_name", - "wandb", "wandb_project", "eval_seq_prefix", - "run_prefix", "distributed_backend", "config_format", - "sequence_length"} + special_name_handle_fields = { + "model", + "lr", + "batch_size", + "acc_steps", + "seed", + "exp_name", + "wandb", + "wandb_project", + "eval_seq_prefix", + "run_prefix", + "distributed_backend", + "config_format", + "sequence_length", + } overriden_values = [] for key in vars(args): if key in special_name_handle_fields: @@ -76,7 +122,12 @@ def parse_args(base_parser, args, namespace): chunk_len = 10 overriden_values_str_parts = [] for chunk_id in range(0, len(overriden_values), chunk_len): - overriden_values_str = "_".join(["{}={}".format(key, value) for key, value in overriden_values[chunk_id:chunk_id+chunk_len]]) + overriden_values_str = "_".join( + [ + "{}={}".format(key, value) + for key, value in overriden_values[chunk_id : chunk_id + chunk_len] + ] + ) overriden_values_str_parts.append(overriden_values_str) overriden_values_str = "/".join(overriden_values_str_parts) exp_name = "" diff --git a/src/data/arxiv.py b/src/data/arxiv.py index bd146f1..e9de234 100644 --- a/src/data/arxiv.py +++ b/src/data/arxiv.py @@ -1,35 +1,35 @@ +import logging import os import tarfile -import logging -from pathlib import Path -from typing import Optional from multiprocessing import Pool +from pathlib import Path +from subprocess import PIPE, Popen, TimeoutExpired from tempfile import NamedTemporaryFile -from subprocess import Popen, TimeoutExpired, PIPE -from typing import Tuple, List +from typing import List, Optional, Tuple import numpy as np import requests -from tqdm.auto import tqdm import tiktoken +from tqdm.auto import tqdm def convert_to_markdown(args: Tuple[Path, Path]): texfile, mdroot = args - mdfile = mdroot/f"{texfile.name}.md" - with Popen(["pandoc", "--wrap=none", "--from", "latex", texfile, - "--output", mdfile], stderr=PIPE) as proc: + mdfile = mdroot / f"{texfile.name}.md" + with Popen( + ["pandoc", "--wrap=none", "--from", "latex", texfile, "--output", mdfile], + stderr=PIPE, + ) as proc: try: proc.communicate(timeout=1) except TimeoutExpired: proc.kill() - def fetch_arxiv(root: Path, year: int): # download latex url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz" - texroot = root/"tex" + texroot = root / "tex" print("Downloading Arxiv year", year) req = requests.get(url, timeout=60) with NamedTemporaryFile(suffix=".tar.gz") as f: @@ -40,13 +40,16 @@ def fetch_arxiv(root: Path, year: int): tar.extractall(texroot) # convert to markdown - mdroot = root/"md"/str(year) + mdroot = root / "md" / str(year) mdroot.mkdir(parents=True) - files = list((texroot/str(year)).iterdir()) + files = list((texroot / str(year)).iterdir()) with Pool(os.cpu_count()) as p: args = [(texfile, mdroot) for texfile in files] - for _ in tqdm(p.imap_unordered(convert_to_markdown, args), - desc="Converting to markdown", total=len(files)): + for _ in tqdm( + p.imap_unordered(convert_to_markdown, args), + desc="Converting to markdown", + total=len(files), + ): pass @@ -55,7 +58,7 @@ def tokenize_arxiv(root: Path, year: int): tokens = [] tokens_val = [] tokens_test = [] - mds = root/"md"/str(year) + mds = root / "md" / str(year) # tokenize desc = f"Tokenizing {year}" @@ -70,12 +73,10 @@ def tokenize_arxiv(root: Path, year: int): tokens_test += tokenizer.encode(text) # save to dir - tpath = root/str(year) + tpath = root / str(year) tpath.mkdir(parents=True) - for x, name in zip([tokens, tokens_val, tokens_test], - ["train", "val", "test"]): - mem = np.memmap(tpath/f"{name}.npy", dtype=np.uint16, mode="w+", - shape=len(x)) + for x, name in zip([tokens, tokens_val, tokens_test], ["train", "val", "test"]): + mem = np.memmap(tpath / f"{name}.npy", dtype=np.uint16, mode="w+", shape=len(x)) for i, v in enumerate(x): mem[i] = v @@ -85,31 +86,31 @@ def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): if years is None: years = all_years assert set(years) <= set(all_years) - root = cachedir/"arxiv" + root = cachedir / "arxiv" root.mkdir(exist_ok=True, parents=True) # download all years requested that are not present for year in years: - if not (root/"md"/str(year)).exists(): + if not (root / "md" / str(year)).exists(): fetch_arxiv(root, year) # tokenize all years not previously tokenized for year in years: - if not (root/str(year)).exists(): + if not (root / str(year)).exists(): tokenize_arxiv(root, year) # load meta ret = {} for split in ["train", "val"]: - paths = [root/str(year)/f"{split}.npy" for year in years] + paths = [root / str(year) / f"{split}.npy" for year in years] x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths] ret[split] = np.concatenate(x) return ret def get_arxiv_2000(): - return load_arxiv(Path(os.path.dirname(__file__))/"datasets", [2000]) + return load_arxiv(Path(os.path.dirname(__file__)) / "datasets", [2000]) def get_arxiv_full(): - return load_arxiv(Path(os.path.dirname(__file__))/"datasets") + return load_arxiv(Path(os.path.dirname(__file__)) / "datasets") diff --git a/src/data/openwebtext2.py b/src/data/openwebtext2.py index eef9d50..65eea73 100644 --- a/src/data/openwebtext2.py +++ b/src/data/openwebtext2.py @@ -1,58 +1,69 @@ import os -from tqdm import tqdm + import numpy as np import tiktoken -from datasets import load_dataset - +from datasets import load_dataset +from tqdm import tqdm OWT2_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/openwebtext2/") tknzr = tiktoken.get_encoding("gpt2") + def get_openwebtext2_data(num_proc=40): - """ https://openwebtext2.readthedocs.io/en/latest/ - """ - if not os.path.exists(os.path.join(OWT2_DATA_PATH, 'train.bin')): + """https://openwebtext2.readthedocs.io/en/latest/""" + if not os.path.exists(os.path.join(OWT2_DATA_PATH, "train.bin")): os.makedirs(OWT2_DATA_PATH, exist_ok=True) dataset = load_dataset("the_pile_openwebtext2") - split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) - split_dataset['val'] = split_dataset.pop('test') - + split_dataset = dataset["train"].train_test_split( + test_size=0.0005, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + def process(example): - ids = tknzr.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens - ids.append(tknzr.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe # note: I think eot should be prepended not appended... hmm. it's called "eot" though... - out = {'ids': ids, 'len': len(ids)} + out = {"ids": ids, "len": len(ids)} return out # tokenize the dataset tokenized = split_dataset.map( process, - remove_columns=['text'], + remove_columns=["text"], desc="tokenizing the splits", num_proc=num_proc, ) # concatenate all the ids in each dataset into one large file we can use for training for split, dset in tokenized.items(): - arr_len = np.sum(dset['len']) - filename = os.path.join(OWT2_DATA_PATH, f'{split}.bin') - dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) - arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) + arr_len = np.sum(dset["len"]) + filename = os.path.join(OWT2_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) total_batches = 1024 idx = 0 - for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): # Batch together samples for faster write - batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') - arr_batch = np.concatenate(batch['ids']) + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) # Write into mmap arr[idx : idx + len(arr_batch)] = arr_batch idx += len(arr_batch) arr.flush() - train_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') - val_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') - - return {'train': train_data, 'val': val_data} + train_data = np.memmap( + os.path.join(OWT2_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" + ) + val_data = np.memmap( + os.path.join(OWT2_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" + ) + return {"train": train_data, "val": val_data} diff --git a/src/data/shakespeare.py b/src/data/shakespeare.py index ab6e022..21c607e 100644 --- a/src/data/shakespeare.py +++ b/src/data/shakespeare.py @@ -4,8 +4,9 @@ import numpy as np import requests - -_char_decode = dict(enumerate(sorted(set(ascii_letters + digits + punctuation + " \n")))) +_char_decode = dict( + enumerate(sorted(set(ascii_letters + digits + punctuation + " \n"))) +) _char_encode = {char: i for i, char in _char_decode.items()} @@ -15,6 +16,7 @@ def char_tknzr(txt: str): DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets", "shakespeare") + def get_shakespeare_data(): """Inspired from https://github.com/karpathy/nanoGPT/""" raw_path = os.path.join(DATA_PATH, "raw.txt") @@ -36,7 +38,7 @@ def get_shakespeare_data(): # load text with open(raw_path, encoding="utf8") as f: text = "".join(f.readlines()) - i = int(0.8*len(text)) + i = int(0.8 * len(text)) # encode text x = np.array(char_tknzr(text[:i]), dtype=np.uint16) x_test = np.array(char_tknzr(text[i:]), dtype=np.uint16) @@ -47,5 +49,7 @@ def get_shakespeare_data(): mem[:] = x_test # at this point we know that the binfile was properly created so we load it - return {"train": np.memmap(train_path, dtype=np.uint16, mode="r"), - "val": np.memmap(test_path, dtype=np.uint16, mode="r")} + return { + "train": np.memmap(train_path, dtype=np.uint16, mode="r"), + "val": np.memmap(test_path, dtype=np.uint16, mode="r"), + } diff --git a/src/data/slimpajama.py b/src/data/slimpajama.py index c3960d7..5bd0600 100644 --- a/src/data/slimpajama.py +++ b/src/data/slimpajama.py @@ -1,9 +1,9 @@ -from tqdm import tqdm +import os + import numpy as np import tiktoken from datasets import load_dataset -import os - +from tqdm import tqdm SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/") SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") diff --git a/src/data/utils.py b/src/data/utils.py index b5c7bb8..ec5604b 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -1,19 +1,21 @@ -import numpy as np from typing import Dict + +import numpy as np import torch -from .shakespeare import get_shakespeare_data -from .wikitext import get_wikitext_data from .arxiv import get_arxiv_2000, get_arxiv_full from .openwebtext2 import get_openwebtext2_data +from .shakespeare import get_shakespeare_data from .slimpajama import get_slimpajama_data +from .wikitext import get_wikitext_data def get_dataset(args) -> Dict[str, np.ndarray]: - """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is - contained in its own python file. The expected format at the moment is a dictionary of np.memmap - containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ - if args.dataset == 'wikitext': + """Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is + contained in its own python file. The expected format at the moment is a dictionary of np.memmap + containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. + """ + if args.dataset == "wikitext": return get_wikitext_data() if args.dataset == "shakespeare-char": return get_shakespeare_data() @@ -24,16 +26,17 @@ def get_dataset(args) -> Dict[str, np.ndarray]: if args.dataset == "arxiv+wiki": arxiv_data = get_arxiv_full() wiki_data = get_wikitext_data() - train_data = np.concatenate((arxiv_data['train'], wiki_data['train'])) - val_data = np.concatenate((arxiv_data['val'], wiki_data['val'])) - return {'train': train_data, 'val': val_data} - if args.dataset == 'openwebtext2': + train_data = np.concatenate((arxiv_data["train"], wiki_data["train"])) + val_data = np.concatenate((arxiv_data["val"], wiki_data["val"])) + return {"train": train_data, "val": val_data} + if args.dataset == "openwebtext2": return get_openwebtext2_data() if args.dataset == "slimpajama": return get_slimpajama_data() else: raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") + class Dataset(torch.utils.data.Dataset): def __init__(self, data, sequence_length): super().__init__() diff --git a/src/data/wikitext.py b/src/data/wikitext.py index 646f636..3bc4b03 100755 --- a/src/data/wikitext.py +++ b/src/data/wikitext.py @@ -1,42 +1,54 @@ import os -import zipfile import urllib +import zipfile + import numpy as np import tiktoken - WIKITEXT_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/wikitext/") def get_wikitext_data(): - """ Inspired from https://github.com/tysam-code/hlb-gpt """ + """Inspired from https://github.com/tysam-code/hlb-gpt""" if not os.path.exists(WIKITEXT_DATA_PATH): os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True) print("downloading data and tokenizing (1-2 min)") - raw_data_source = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip' - urllib.request.urlretrieve(raw_data_source, os.path.join(WIKITEXT_DATA_PATH,'data.zip')) - - with zipfile.ZipFile(os.path.join(WIKITEXT_DATA_PATH, "data.zip"), 'r') as zip_ref: + raw_data_source = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip" + urllib.request.urlretrieve( + raw_data_source, os.path.join(WIKITEXT_DATA_PATH, "data.zip") + ) + + with zipfile.ZipFile( + os.path.join(WIKITEXT_DATA_PATH, "data.zip"), "r" + ) as zip_ref: zip_ref.extractall(WIKITEXT_DATA_PATH) - with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), 'r') as data_file: + with open( + os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), "r" + ) as data_file: raw_train_data = data_file.read() - with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), 'r') as data_file: + with open( + os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), "r" + ) as data_file: raw_eval_data = data_file.read() tokenizer = tiktoken.get_encoding("gpt2") raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data) raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data) - train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) + train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16) - train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'train.bin')) - eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'val.bin')) + train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "train.bin")) + eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "val.bin")) print("completed the tokenization process!") - train_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') - val_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') + train_data = np.memmap( + os.path.join(WIKITEXT_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" + ) + val_data = np.memmap( + os.path.join(WIKITEXT_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" + ) - return {'train': train_data, 'val': val_data} + return {"train": train_data, "val": val_data} diff --git a/src/distributed/__init__.py b/src/distributed/__init__.py index 160eebd..74a4ae7 100644 --- a/src/distributed/__init__.py +++ b/src/distributed/__init__.py @@ -1,6 +1,4 @@ - -from . import ddp -from . import single +from . import ddp, single BACKEND_TYPE_TO_MODULE_MAP = { "nccl": ddp.DataParallelDistributedBackend, diff --git a/src/distributed/backend.py b/src/distributed/backend.py index 3df0fda..bcca248 100644 --- a/src/distributed/backend.py +++ b/src/distributed/backend.py @@ -1,4 +1,3 @@ - from typing import List @@ -10,7 +9,9 @@ def __init__(self, args): def transform_model(self, model): raise NotImplementedError - def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): + def get_context_for_microstep_forward( + self, model, microstep_idx, gradient_accumulation_steps + ): raise NotImplementedError def is_master_process(self) -> bool: diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index d4470b6..951dbf3 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -1,9 +1,14 @@ -import os import math +import os from contextlib import contextmanager +from torch.distributed import ( + barrier, + destroy_process_group, + get_world_size, + init_process_group, +) from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed import init_process_group, destroy_process_group, get_world_size, barrier from .backend import DistributedBackend @@ -11,27 +16,31 @@ class DataParallelDistributedBackend(DistributedBackend): def __init__(self, args): - self.rank = int(os.environ.get('RANK', -1)) + self.rank = int(os.environ.get("RANK", -1)) assert self.rank != -1, "DDP backend can not be used without rank" assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices" init_process_group(backend=args.distributed_backend) - self.local_rank = int(os.environ['LOCAL_RANK']) + self.local_rank = int(os.environ["LOCAL_RANK"]) def get_adjusted_args_for_process(self, args): effective_batch_size = args.batch_size * args.acc_steps world_size = self.get_world_size() if args.acc_steps % world_size != 0: - raise ValueError(f"Number of accumulation steps " - "{args.acc_steps} is not divisible " - "by the world size {world_size}.") + raise ValueError( + f"Number of accumulation steps " + "{args.acc_steps} is not divisible " + "by the world size {world_size}." + ) if effective_batch_size % world_size != 0: - raise ValueError(f"Effective batch size " - "{effective_batch_size} is not divisible " - "by the world size {world_size}.") + raise ValueError( + f"Effective batch size " + "{effective_batch_size} is not divisible " + "by the world size {world_size}." + ) acc_steps_div = math.gcd(args.acc_steps, world_size) args.acc_steps = args.acc_steps // acc_steps_div args.batch_size = args.batch_size // (world_size // acc_steps_div) - args.device = f'cuda:{self.local_rank}' + args.device = f"cuda:{self.local_rank}" args.seed = args.seed + self.local_rank return args @@ -39,9 +48,12 @@ def transform_model(self, model): return DDP(model, device_ids=[self.local_rank]) @contextmanager - def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): + def get_context_for_microstep_forward( + self, model, microstep_idx, gradient_accumulation_steps + ): model.require_backward_grad_sync = ( - microstep_idx == gradient_accumulation_steps - 1) + microstep_idx == gradient_accumulation_steps - 1 + ) yield def is_master_process(self) -> bool: @@ -51,7 +63,7 @@ def get_raw_model(self, model): return model.module def translate_model_parameter_name_for_node(self, parameter_name): - return [f'module.{parameter_name}'] + return [f"module.{parameter_name}"] def get_world_size(self): return get_world_size() diff --git a/src/main.py b/src/main.py index 92ed664..ab0ac56 100755 --- a/src/main.py +++ b/src/main.py @@ -1,33 +1,40 @@ +import argparse +import copy +import inspect +import json import os +import random import sys + import numpy as np import torch -import inspect -import json -import copy -import argparse -import random import wandb import config -from models.utils import get_model +import distributed from data.utils import get_dataset +from models.utils import get_model from optim.base import train_base -import distributed def get_args(): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument('--config_format', default='base', choices=config.registered_formats()) + parser.add_argument( + "--config_format", default="base", choices=config.registered_formats() + ) args, rem_args = parser.parse_known_args() - return config.parse_args_with_format(format=args.config_format, base_parser=parser, args=rem_args, namespace=args) + return config.parse_args_with_format( + format=args.config_format, base_parser=parser, args=rem_args, namespace=args + ) -def main(args): +def main(args): - torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training + torch.backends.cuda.matmul.allow_tf32 = ( + True # allows us to make sure we're able to use tensorfloat32 during training + ) torch.backends.cudnn.allow_tf32 = True distributed_backend = distributed.make_backend_from_args(args) @@ -41,45 +48,67 @@ def main(args): torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) - + print(f"Loading dataset '{args.dataset}'") - - data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} + + data = get_dataset( + args + ) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} if args.data_in_ram: - data = {'train': np.array(data['train']), 'val': np.array(data['val'])} - + data = {"train": np.array(data["train"]), "val": np.array(data["val"])} + print(f"Num training tokens: {len(data['train'])}") print(f"Num validation tokens: {len(data['val'])}") - - model = get_model(args).to(args.device) # todo: take care of initializing the model if args.use_pretrained != 'none' + + model = get_model(args).to( + args.device + ) # todo: take care of initializing the model if args.use_pretrained != 'none' model = distributed_backend.transform_model(model) - + group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs() param_name_mapping = {p_name: p for p_name, p in model.named_parameters()} optimized_params_cnt = 0 for g in group_specs: params = [] for p_name in g["params"]: - translated_p_names = distributed_backend.translate_model_parameter_name_for_node(p_name) + translated_p_names = ( + distributed_backend.translate_model_parameter_name_for_node(p_name) + ) params += [param_name_mapping[p_name] for p_name in translated_p_names] g["params"] = params optimized_params_cnt += sum([p.numel() for p in g["params"]]) - print("number of optimized parameters: %.2fM" % (optimized_params_cnt/1e6,)) - if args.opt == 'adamw': - use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) + print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,)) + if args.opt == "adamw": + use_fused = (device_type == "cuda") and ( + "fused" in inspect.signature(torch.optim.AdamW).parameters + ) print(f"using fused AdamW: {use_fused}") extra_args = dict(fused=True) if use_fused else dict() - opt = torch.optim.AdamW(group_specs, lr=args.lr, betas=(args.beta1, args.beta2), - weight_decay=args.weight_decay, **extra_args) + opt = torch.optim.AdamW( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + **extra_args, + ) else: - opt = torch.optim.SGD(group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) - - if args.scheduler != 'none': - if args.scheduler in ['cos', 'linear']: - scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=args.lr, total_steps=args.iterations, - pct_start=args.warmup_percent, anneal_strategy=args.scheduler, - cycle_momentum=False, div_factor=1e2, final_div_factor=.1) + opt = torch.optim.SGD( + group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay + ) + + if args.scheduler != "none": + if args.scheduler in ["cos", "linear"]: + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=opt, + max_lr=args.lr, + total_steps=args.iterations, + pct_start=args.warmup_percent, + anneal_strategy=args.scheduler, + cycle_momentum=False, + div_factor=1e2, + final_div_factor=0.1, + ) else: raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") else: @@ -89,67 +118,92 @@ def main(args): exp_name = args.exp_name if distributed_backend.is_master_process() and args.wandb: params_copy = copy.deepcopy(vars(args)) - del params_copy['device'] + del params_copy["device"] wandb.init(project=args.wandb_project, name=exp_name, config=params_copy) - - ckpt_path = os.path.join(args.results_base_folder, args.dataset, args.model, exp_name) + + ckpt_path = os.path.join( + args.results_base_folder, args.dataset, args.model, exp_name + ) if not os.path.exists(ckpt_path): if distributed_backend.is_master_process(): os.makedirs(ckpt_path) distributed_backend.sync() - elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed + elif os.path.isfile( + os.path.join(ckpt_path, "summary.json") + ): # the experiment was already completed print(f"Already found experiment '{ckpt_path}'.\nSkipping.") sys.exit(0) itr = 0 rng_state_dict = None if args.use_pretrained == "auto": - checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] + checkpoints = [file for file in os.listdir(ckpt_path) if "ckpt_" in file] if checkpoints: args.use_pretrained = sorted(checkpoints)[-1] else: args.use_pretrained = None - + if args.use_pretrained is not None: last_ckpt_path = args.use_pretrained print(f"Resuming from {last_ckpt_path}") checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path)) - model_state_dict = {distributed_backend.translate_model_parameter_name_for_node(k.replace("_orig_mod.", ""))[0]:v for k,v in checkpoint['model'].items()} + model_state_dict = { + distributed_backend.translate_model_parameter_name_for_node( + k.replace("_orig_mod.", "") + )[0]: v + for k, v in checkpoint["model"].items() + } # FIXME checkpoints from compiled model have _orig_mod keyword - optimizer_state_dict = checkpoint['optimizer'] + optimizer_state_dict = checkpoint["optimizer"] rng_state_dict = { - module: checkpoint[module] for module in [ - "cpu_rng_state", - "gpu_rng_state", - "numpy_rng_state", + module: checkpoint[module] + for module in [ + "cpu_rng_state", + "gpu_rng_state", + "numpy_rng_state", "py_rng_state", - "train_sampler_state" + "train_sampler_state", ] } - model.load_state_dict(model_state_dict) + model.load_state_dict(model_state_dict) opt.load_state_dict(optimizer_state_dict) - itr = checkpoint['itr'] + itr = checkpoint["itr"] if scheduler is not None: - scheduler_state_dict = checkpoint['scheduler'] + scheduler_state_dict = checkpoint["scheduler"] scheduler.load_state_dict(scheduler_state_dict) - if args.model in ['base', 'llama2']: # all train functions have the same interface + if args.model in ["base", "llama2"]: # all train functions have the same interface train = train_base else: - raise NotImplementedError(f"No training method implemented for model type '{args.model}'.") + raise NotImplementedError( + f"No training method implemented for model type '{args.model}'." + ) print(f"\nTraining model={args.model} \n{vars(args)}\n") - stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, - eval_freq=args.eval_freq, - distributed_backend=distributed_backend, - ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args) - + stats = train( + model, + opt, + data, + args.data_seed, + scheduler, + args.iterations, + args.acc_steps, + args.batch_size, + args.sequence_length, + eval_freq=args.eval_freq, + distributed_backend=distributed_backend, + ckpt_path=f"{ckpt_path}/ckpt.pt", + itr=itr, + rng_state_dict=rng_state_dict, + extra_args=args, + ) + args.device = None args.dtype = None - stats['args'] = vars(args) + stats["args"] = vars(args) if distributed_backend.is_master_process(): with open(f"{ckpt_path}/summary.json", "w") as fs: json.dump(stats, fs) diff --git a/src/models/base.py b/src/models/base.py index a844592..dc2258a 100755 --- a/src/models/base.py +++ b/src/models/base.py @@ -7,8 +7,8 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py """ -import math import inspect +import math import tiktoken import torch @@ -17,7 +17,7 @@ class LayerNorm(nn.Module): - """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" def __init__(self, ndim, bias): super().__init__() @@ -44,34 +44,52 @@ def __init__(self, config): self.n_embd = config.n_embd self.dropout = config.dropout # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 - self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") if not self.flash: - print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + print( + "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" + ) # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.sequence_length, config.sequence_length)) - .view(1, 1, config.sequence_length, config.sequence_length)) + self.register_buffer( + "bias", + torch.tril( + torch.ones(config.sequence_length, config.sequence_length) + ).view(1, 1, config.sequence_length, config.sequence_length), + ) def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + B, T, C = ( + x.size() + ) # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True + ) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -82,8 +100,8 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) self.activation = nn.GELU() @@ -108,7 +126,7 @@ def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x - + class GPTBase(nn.Module): @@ -119,30 +137,36 @@ def __init__(self, config): self.config = config self.tokenizer = tiktoken.get_encoding("gpt2") - self.transformer = nn.ModuleDict(dict( - wte = nn.Embedding(config.vocab_size, config.n_embd), - wpe = nn.Embedding(config.sequence_length, config.n_embd), - drop = nn.Dropout(config.dropout), - h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = LayerNorm(config.n_embd, bias=config.bias), - )) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.sequence_length, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: # "UserWarning: functional_call was passed multiple values for tied weights. # This behavior is deprecated and will be an error in future versions" # not 100% sure what this is, so far seems to be harmless. TODO investigate - self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + self.transformer.wte.weight = ( + self.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): - if pn.endswith('c_proj.weight'): - torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) # report number of parameters - print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) def get_num_params(self, non_embedding=True): """ @@ -167,12 +191,18 @@ def _init_weights(self, module): def forward(self, idx, targets=None, get_logits=False): device = idx.device b, t = idx.size() - assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + assert ( + t <= self.config.sequence_length + ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( + 0 + ) # shape (1, t) # forward the GPT model itself - tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -181,13 +211,17 @@ def forward(self, idx, targets=None, get_logits=False): if targets is not None: # if we are given some desired targets also calculate the loss logits = self.lm_head(x) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) else: # inference-time mini-optimization: only forward the lm_head on the very last position - logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + logits = self.lm_head( + x[:, [-1], :] + ) # note: using list [-1] to preserve the time dim loss = None logits = logits if get_logits else None - return {'logits': logits, 'loss': loss} + return {"logits": logits, "loss": loss} def crop_sequence_length(self, sequence_length): # model surgery to decrease the block size if necessary @@ -195,9 +229,11 @@ def crop_sequence_length(self, sequence_length): # but want to use a smaller block size for some smaller, simpler model assert sequence_length <= self.config.sequence_length self.config.sequence_length = sequence_length - self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:sequence_length]) + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:sequence_length] + ) for block in self.transformer.h: - block.attn.bias = block.attn.bias[:,:,:sequence_length,:sequence_length] + block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] @classmethod def from_pretrained(cls, model_type, override_args=None): @@ -262,7 +298,6 @@ def get_parameter_group_specs(self): {"params": sorted(list(no_decay)), "weight_decay": 0.0}, ] - @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ @@ -272,15 +307,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at sequence_length - idx_cond = idx if idx.size(1) <= self.config.sequence_length else idx[:, -self.config.sequence_length:] + idx_cond = ( + idx + if idx.size(1) <= self.config.sequence_length + else idx[:, -self.config.sequence_length :] + ) # forward the model to get the logits for the index in the sequence - logits = self(idx_cond, get_logits=True)['logits'] + logits = self(idx_cond, get_logits=True)["logits"] # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') + logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution @@ -289,9 +328,20 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx - + @torch.no_grad() def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None): - idx = torch.tensor(self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})).view(1,-1).to(self.lm_head.weight.device) - out_idx = self.generate(idx, max_new_tokens, temperature, top_k).view(-1).to('cpu').numpy() + idx = ( + torch.tensor( + self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"}) + ) + .view(1, -1) + .to(self.lm_head.weight.device) + ) + out_idx = ( + self.generate(idx, max_new_tokens, temperature, top_k) + .view(-1) + .to("cpu") + .numpy() + ) return self.tokenizer.decode(out_idx) diff --git a/src/models/llama.py b/src/models/llama.py index 1604ca7..34996ba 100644 --- a/src/models/llama.py +++ b/src/models/llama.py @@ -20,6 +20,7 @@ import torch import torch.nn as nn from torch.nn import functional as F + from models.base import CausalSelfAttention, GPTBase @@ -185,7 +186,6 @@ def __init__(self, config): p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) ) - def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. diff --git a/src/models/utils.py b/src/models/utils.py index 6d60e10..2d68c05 100755 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -1,7 +1,7 @@ import torch -from .llama import Llama, RMSNorm -from .base import GPTBase, LayerNorm +from .base import GPTBase, LayerNorm +from .llama import Llama, RMSNorm BLACKLIST_WEIGHT_MODULES = ( torch.nn.LayerNorm, @@ -12,11 +12,11 @@ def get_model(args): - """ Return the right model """ - if args.model == 'base': + """Return the right model""" + if args.model == "base": model = GPTBase(args) return model - elif args.model == 'llama2': + elif args.model == "llama2": model = Llama(args) return model else: diff --git a/src/optim/base.py b/src/optim/base.py index 241f508..486f94b 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -1,23 +1,46 @@ +import copy +import itertools +import os +import random +import time from contextlib import nullcontext -from data.utils import get_dataloader +import numpy as np import torch import torch.nn.functional as F import wandb -import time -import itertools -import copy -import random -import os -import numpy as np +from data.utils import get_dataloader + from .utils import eval, get_batch, save_checkpoint -def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None): - device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' - type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( - device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) - best_val_loss, text_table = float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible +def train_base( + model, + opt, + data, + data_seed, + scheduler, + iterations, + acc_steps, + batch_size, + sequence_length, + eval_freq, + ckpt_path, + distributed_backend, + extra_args, + itr=0, + rng_state_dict=None, +): + device_type = "cuda" if "cuda" in str(extra_args.device) else "cpu" + type_ctx = ( + nullcontext() + if device_type == "cpu" + else torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) + ) # extra_args.dtype) + best_val_loss, text_table = ( + float("inf"), + None, + ) # best_val_loss not used atm, early stopping not recommended but possible substep = itr * acc_steps data["train"], train_sampler = get_dataloader( data["train"], @@ -26,7 +49,7 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba seed=data_seed, distributed_backend=distributed_backend, ) - + data["val"], val_sampler = get_dataloader( data["val"], sequence_length=sequence_length, @@ -35,9 +58,12 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba ) num_substeps_per_epoch = len(data["train"]) - train_epochs = substep//num_substeps_per_epoch - - if rng_state_dict is not None and rng_state_dict.get("train_sampler_state", None) is not None: + train_epochs = substep // num_substeps_per_epoch + + if ( + rng_state_dict is not None + and rng_state_dict.get("train_sampler_state", None) is not None + ): train_sampler.generator.set_state(rng_state_dict["train_sampler_state"]) if hasattr(train_sampler, "set_epoch"): train_sampler.set_epoch(train_epochs) @@ -45,23 +71,20 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba sampler_state_before_iter = train_sampler.generator.get_state() data_train_iter = iter(data["train"]) - # for val data we don't care about epochs? just cycle through (no need to set_epoch to reshuffle) data_val_iter = itertools.cycle(data["val"]) stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} - - if extra_args.compile: print(f"Compiling model ...") - model = torch.compile(model) # requires pytorch 2.0+ + model = torch.compile(model) # requires pytorch 2.0+ model.train() t0 = time.time() - - if rng_state_dict is not None: + + if rng_state_dict is not None: torch.set_rng_state(rng_state_dict["cpu_rng_state"]) torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) np.random.set_state(rng_state_dict["numpy_rng_state"]) @@ -69,17 +92,20 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba for _ in range(substep % num_substeps_per_epoch): get_batch(data_train_iter, device=extra_args.device) - while itr < iterations: - + for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) - + with type_ctx: - with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps): + with distributed_backend.get_context_for_microstep_forward( + model=model, + microstep_idx=microstep_idx, + gradient_accumulation_steps=acc_steps, + ): outputs = model(x, targets=y) - loss = outputs['loss'] / acc_steps + loss = outputs["loss"] / acc_steps loss.backward() substep += 1 if substep % len(data["train"]) == 0: @@ -93,7 +119,6 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba sampler_state_before_iter = train_sampler.generator.get_state() data_train_iter = iter(data["train"]) - if extra_args.grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) opt.step() @@ -101,19 +126,23 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba opt.zero_grad(set_to_none=True) itr += 1 - if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above + if ( + itr % eval_freq == 0 or itr == iterations + ): # from here it's only evaluation code, all the training is above if distributed_backend.is_master_process(): t1 = time.time() dt = t1 - t0 - epoch = substep//num_substeps_per_epoch + epoch = substep // num_substeps_per_epoch model.eval() train_loss = loss.detach().cpu().item() * acc_steps - current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr - - eval_steps = ( - 24 if itr < iterations else len(data["val"]) + current_lr = ( + scheduler.get_last_lr()[0] + if scheduler is not None + else extra_args.lr ) + + eval_steps = 24 if itr < iterations else len(data["val"]) val_acc, val_loss, val_perplexity = eval( model, data_val_iter, @@ -145,40 +174,61 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba wandb.log(logs) - if extra_args.eval_seq_prefix != 'none' and (itr % (eval_freq * 5) == 0 or itr == iterations): + if extra_args.eval_seq_prefix != "none" and ( + itr % (eval_freq * 5) == 0 or itr == iterations + ): if text_table is None: text_table = wandb.Table(columns=["itr", "val-pp", "text"]) - out_str = distributed_backend.get_raw_model(model).generate_from_string( - extra_args.eval_seq_prefix, max_new_tokens=40, temperature=0.9, top_k=None) + out_str = distributed_backend.get_raw_model( + model + ).generate_from_string( + extra_args.eval_seq_prefix, + max_new_tokens=40, + temperature=0.9, + top_k=None, + ) text_table.add_data(itr, val_perplexity, out_str) # why a copy? see github.com/wandb/wandb/issues/2981 - wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) + wandb.log( + {f"generated-text-{wandb.run.name}": copy.copy(text_table)} + ) model.train() t0 = time.time() if distributed_backend.is_master_process(): - if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: - print(f"saving checkpoint to {os.path.dirname(ckpt_path)}/ckpt_{itr}.pt") - save_checkpoint(distributed_backend=distributed_backend, - model=model, - opt=opt, - scheduler=scheduler, - itr=itr, - cpu_rng_state=torch.get_rng_state(), - gpu_rng_state=torch.cuda.get_rng_state(), - numpy_rng_state=np.random.get_state(), - py_rng_state=random.getstate(), - train_sampler_state=sampler_state_before_iter, - ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt")) - + if ( + extra_args.save_checkpoint_freq is not None + and itr % extra_args.save_checkpoint_freq == 0 + ): + print( + f"saving checkpoint to {os.path.dirname(ckpt_path)}/ckpt_{itr}.pt" + ) + save_checkpoint( + distributed_backend=distributed_backend, + model=model, + opt=opt, + scheduler=scheduler, + itr=itr, + cpu_rng_state=torch.get_rng_state(), + gpu_rng_state=torch.cuda.get_rng_state(), + numpy_rng_state=np.random.get_state(), + py_rng_state=random.getstate(), + train_sampler_state=sampler_state_before_iter, + ckpt_path=os.path.join( + os.path.dirname(ckpt_path), f"ckpt_{itr}.pt" + ), + ) + if distributed_backend.is_master_process(): print(f"saving checkpoint to {ckpt_path}") - save_checkpoint(distributed_backend=distributed_backend, - model=model, - opt=opt, - scheduler=scheduler, - itr=itr, - ckpt_path=ckpt_path) + save_checkpoint( + distributed_backend=distributed_backend, + model=model, + opt=opt, + scheduler=scheduler, + itr=itr, + ckpt_path=ckpt_path, + ) return stats diff --git a/src/optim/muon.py b/src/optim/muon.py new file mode 100644 index 0000000..bc967f9 --- /dev/null +++ b/src/optim/muon.py @@ -0,0 +1,137 @@ +"""Here is an original implementation of SOAP. Source: https://github.com/KellerJordan/modded-nanogpt""" + +import torch +import torch.distributed as dist + + +def zeropower_via_svd(G, steps=None): + U, S, V = G.svd() + return U @ V.T + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = A @ X + X = a * X + b * B + c * A @ B + if G.size(0) > G.size(1): + X = X.T + return X + + +zeropower_backends = dict( + svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5 +) + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - This optimizer assumes that all parameters passed in are 2D. + - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D + parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + - We believe it is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). + + Arguments: + lr: The learning rate used by the internal SGD. + momentum: The momentum used by the internal SGD. + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') + backend_steps: The number of iteration steps to use in the backend, if it is iterative. + """ + + def __init__( + self, + params, + lr=3e-4, + momentum=0.95, + nesterov=True, + backend="newtonschulz5", + backend_steps=5, + rank=0, + world_size=1, + ): + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + backend=backend, + backend_steps=backend_steps, + ) + super().__init__(params, defaults) + self.rank = rank + self.world_size = world_size + + def step(self): + + for group in self.param_groups: + + lr = group["lr"] + momentum = group["momentum"] + zeropower_backend = zeropower_backends[group["backend"]] + + # generate weight updates in distributed fashion + total_params = sum(p.numel() for p in group["params"]) + updates_flat = torch.zeros( + total_params, device="cuda", dtype=torch.bfloat16 + ) + curr_idx = 0 + for i, p in enumerate(group["params"]): + # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs + if i % self.world_size == self.rank: + g = p.grad + if g is None: + continue + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + g = zeropower_backend(g, steps=group["backend_steps"]) + g *= ( + max(g.size(0), g.size(1)) ** 0.5 + ) # scale to have update.square().mean() == 1 + updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() + curr_idx += p.numel() + + # sync updates across devices. we are not memory-constrained so can do this simple deserialization + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + # deserialize and apply updates + curr_idx = 0 + for p in group["params"]: + g = ( + updates_flat[curr_idx : curr_idx + p.numel()] + .view_as(p.data) + .type_as(p.data) + ) + p.data.add_(g, alpha=-lr) + curr_idx += p.numel() diff --git a/src/optim/soap.py b/src/optim/soap.py new file mode 100644 index 0000000..0a75004 --- /dev/null +++ b/src/optim/soap.py @@ -0,0 +1,479 @@ +"""Here is an original implementation of SOAP. Source: https://github.com/nikhilvyas/SOAP""" + +from itertools import chain + +import torch +import torch.nn as nn +import torch.optim as optim + +# Parts of the code are modifications of Pytorch's AdamW optimizer +# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py + + +class SOAP(optim.Optimizer): + """ + Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.003): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): + Adam's betas parameters (b1, b2). + shampoo_beta (`float`, *optional*, defaults to -1): + If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. + eps (`float`, *optional*, defaults to 1e-08): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. + precondition_frequency (`int`, *optional*, defaults to 10): + How often to update the preconditioner. + max_precond_dim (`int`, *optional*, defaults to 10000): + Maximum dimension of the preconditioner. + Set to 10000, so that we exclude most common vocab sizes while including layers. + merge_dims (`bool`, *optional*, defaults to `False`): + Whether or not to merge dimensions of the preconditioner. + precondition_1d (`bool`, *optional*, defaults to `False`): + Whether or not to precondition 1D gradients. + normalize_grads (`bool`, *optional*, defaults to `False`): + Whether or not to normalize gradients per layer. + Helps at large precondition_frequency (~100 in our experiments), + but hurts performance at small precondition_frequency (~10 in our experiments). + data_format (`str`, *optional*, defaults to `channels_first`): + Data format of the input for convolutional layers. + Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias correction in Adam. + + Example of usage: + optim = SOAP(lr = 3e-3, betas=(.95, .95), weight_decay=.01, precondition_frequency=10) + """ + + def __init__( + self, + params, + lr: float = 3e-3, + betas=(0.95, 0.95), + shampoo_beta: float = -1, + eps: float = 1e-8, + weight_decay: float = 0.01, + precondition_frequency: int = 10, + max_precond_dim: int = 10000, # + merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim. + precondition_1d: bool = False, + normalize_grads: bool = False, + data_format: str = "channels_first", + correct_bias: bool = True, + ): + defaults = { + "lr": lr, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + "precondition_frequency": precondition_frequency, + "max_precond_dim": max_precond_dim, + "merge_dims": merge_dims, + "precondition_1d": precondition_1d, + "normalize_grads": normalize_grads, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + self._data_format = data_format + + def merge_dims(self, grad, max_precond_dim): + """ + Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim. + """ + assert self._data_format in ["channels_first", "channels_last"] + if self._data_format == "channels_last" and grad.dim() == 4: + grad = grad.permute(0, 3, 1, 2) + shape = grad.shape + new_shape = [] + + curr_shape = 1 + for sh in shape: + temp_shape = curr_shape * sh + if temp_shape > max_precond_dim: + if curr_shape > 1: + new_shape.append(curr_shape) + curr_shape = sh + else: + new_shape.append(sh) + curr_shape = 1 + else: + curr_shape = temp_shape + + if curr_shape > 1 or len(new_shape) == 0: + new_shape.append(curr_shape) + + new_grad = grad.reshape(new_shape) + return new_grad + + @torch.no_grad() + def step(self): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + if "Q" not in state: + self.init_preconditioner( + grad, + state, + precondition_frequency=group["precondition_frequency"], + precondition_1d=group["precondition_1d"], + shampoo_beta=( + group["shampoo_beta"] + if group["shampoo_beta"] >= 0 + else group["betas"][1] + ), + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + ) + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + continue # first step is skipped so that we never use the current gradients in the projection. + + # Projecting gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + grad_projected = self.project( + grad, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).add_( + grad_projected.square(), alpha=(1.0 - beta2) + ) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + exp_avg_projected = self.project( + exp_avg, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + step_size = group["lr"] + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** (state["step"]) + bias_correction2 = 1.0 - beta2 ** (state["step"]) + step_size = step_size * (bias_correction2**0.5) / bias_correction1 + + # Projecting back the preconditioned (by Adam) exponential moving average of gradients + # to the original space + norm_grad = self.project_back( + exp_avg_projected / denom, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + if group["normalize_grads"]: + norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5) + + p.add_(norm_grad, alpha=-step_size) + + # From AdamW code: Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + # Update is done after the gradient step to avoid using current gradients in the projection. + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + + return loss + + def init_preconditioner( + self, + grad, + state, + precondition_frequency=10, + shampoo_beta=0.95, + max_precond_dim=10000, + precondition_1d=False, + merge_dims=False, + ): + """ + Initializes the preconditioner matrices (L and R in the paper). + """ + state["GG"] = ( + [] + ) # Will hold all the preconditioner matrices (L and R in the paper). + if grad.dim() == 1: + if not precondition_1d or grad.shape[0] > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append( + torch.zeros(grad.shape[0], grad.shape[0], device=grad.device) + ) + else: + if merge_dims: + grad = self.merge_dims(grad, max_precond_dim) + + for sh in grad.shape: + if sh > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append(torch.zeros(sh, sh, device=grad.device)) + + state["Q"] = None # Will hold all the eigenbases of the preconditioner. + state["precondition_frequency"] = precondition_frequency + state["shampoo_beta"] = shampoo_beta + + def project(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient to the eigenbases of the preconditioner. + """ + original_shape = grad.shape + if merge_dims: + if grad.dim() == 4 and self._data_format == "channels_last": + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [0]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def update_preconditioner( + self, + grad, + state, + max_precond_dim=10000, + merge_dims=False, + precondition_1d=False, + ): + """ + Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). + """ + if grad.dim() == 1: + if precondition_1d and grad.shape[0] <= max_precond_dim: + state["GG"][0].lerp_( + grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"] + ) + else: + if merge_dims: + new_grad = self.merge_dims(grad, max_precond_dim) + for idx, sh in enumerate(new_grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + new_grad, + new_grad, + dims=[ + [ + *chain( + range(idx), range(idx + 1, len(new_grad.shape)) + ) + ] + ] + * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + else: + for idx, sh in enumerate(grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + grad, + grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] + * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + + if state["Q"] is None: + state["Q"] = self.get_orthogonal_matrix(state["GG"]) + if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0: + state["Q"] = self.get_orthogonal_matrix_QR( + state, max_precond_dim, merge_dims + ) + + def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient back to the original space. + """ + original_shape = grad.shape + if merge_dims: + if self._data_format == "channels_last" and grad.dim() == 4: + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [1]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def get_orthogonal_matrix(self, mat): + """ + Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. + """ + matrix = [] + for m in mat: + if len(m) == 0: + matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + else: + float_data = True + matrix.append(m.data) + + final = [] + for m in matrix: + if len(m) == 0: + final.append([]) + continue + try: + _, Q = torch.linalg.eigh( + m + 1e-30 * torch.eye(m.shape[0], device=m.device) + ) + except: + _, Q = torch.linalg.eigh( + m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device) + ) + Q = Q.to(m.dtype) + Q = torch.flip(Q, [1]) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + return final + + def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False): + """ + Computes the eigenbases of the preconditioner using one round of power iteration + followed by torch.linalg.qr decomposition. + """ + precond_list = state["GG"] + orth_list = state["Q"] + + matrix = [] + orth_matrix = [] + for m, o in zip(precond_list, orth_list): + if len(m) == 0: + matrix.append([]) + orth_matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + else: + float_data = True + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + + orig_shape = state["exp_avg_sq"].shape + if self._data_format == "channels_last" and len(orig_shape) == 4: + permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape + if merge_dims: + exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim) + else: + exp_avg_sq = state["exp_avg_sq"] + + final = [] + for ind, (m, o) in enumerate(zip(matrix, orth_matrix)): + if len(m) == 0: + final.append([]) + continue + est_eig = torch.diag(o.T @ m @ o) + sort_idx = torch.argsort(est_eig, descending=True) + exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) + o = o[:, sort_idx] + power_iter = m @ o + Q, _ = torch.linalg.qr(power_iter) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + + if merge_dims: + if self._data_format == "channels_last" and len(orig_shape) == 4: + exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + exp_avg_sq = exp_avg_sq.reshape(orig_shape) + + state["exp_avg_sq"] = exp_avg_sq + return final diff --git a/src/optim/utils.py b/src/optim/utils.py index 68162f6..92f690a 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -1,7 +1,8 @@ +from contextlib import ExitStack, contextmanager, nullcontext + import numpy as np import torch import torch.nn.functional as F -from contextlib import nullcontext, contextmanager, ExitStack def get_batch(dataloader, device="cpu"): @@ -17,33 +18,38 @@ def get_batch(dataloader, device="cpu"): @torch.no_grad() -def eval(model, data_val_iter, device='cpu', max_num_batches=24, ctx=nullcontext()): +def eval(model, data_val_iter, device="cpu", max_num_batches=24, ctx=nullcontext()): assert model.training == False loss_list_val, acc_list = [], [] - for _ in range(max_num_batches): + for _ in range(max_num_batches): x, y = get_batch(data_val_iter, device=device) with ctx: outputs = model(x, targets=y, get_logits=True) - val_loss = outputs['loss'] + val_loss = outputs["loss"] loss_list_val.append(val_loss) - acc_list.append((outputs['logits'].argmax(-1) == y).float().mean()) + acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) val_acc = torch.stack(acc_list).mean().item() val_loss = torch.stack(loss_list_val).mean().item() - val_perplexity = 2.71828 ** val_loss + val_perplexity = 2.71828**val_loss return val_acc, val_loss, val_perplexity -def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args): +def save_checkpoint( + distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args +): - checkpoint = dict({ - 'model': distributed_backend.get_raw_model(model).state_dict(), - 'optimizer': opt.state_dict(), - 'scheduler': scheduler.state_dict(), - 'itr': itr, - }, **extra_args) + checkpoint = dict( + { + "model": distributed_backend.get_raw_model(model).state_dict(), + "optimizer": opt.state_dict(), + "scheduler": scheduler.state_dict(), + "itr": itr, + }, + **extra_args + ) torch.save(checkpoint, ckpt_path) From d5153c4254d80c7f01dd1f9f9c1bdf17d2ec730d Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 15 Oct 2024 13:53:14 +0200 Subject: [PATCH 02/58] -fix annotations --- src/optim/muon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optim/muon.py b/src/optim/muon.py index bc967f9..fe425af 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -1,4 +1,4 @@ -"""Here is an original implementation of SOAP. Source: https://github.com/KellerJordan/modded-nanogpt""" +"""Here is an original implementation of MUON. Source: https://github.com/KellerJordan/modded-nanogpt""" import torch import torch.distributed as dist From 5dcc69af006994452721e5f0d7888ddd7444a454 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 15 Oct 2024 22:37:24 +0300 Subject: [PATCH 03/58] soap is ready, muon needs to be done --- README.md | 14 +++- src/config/base.py | 16 +++- src/main.py | 33 +++++++- src/optim/muon.py | 186 +++++++++++++++++++++++++++++---------------- src/optim/soap.py | 3 +- 5 files changed, 181 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index bb6a749..6716b71 100755 --- a/README.md +++ b/README.md @@ -44,10 +44,22 @@ parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT +parser.add_argument("--momentum", default=0.9, type=float) +parser.add_argument('--shampoo_beta', default=-1.0, type=float) +parser.add_argument("--precondition_frequency", default=10, type=int) +parser.add_argument("--max_precond_dim", default=10000, type=int) +parser.add_argument("--merge_dims", default=False, type=bool) # merge dimensions till the product of the dimensions is less than or equal to max_precond_dim +parser.add_argument("--precondition_1d", default=False, type=bool) +parser.add_argument("--normalize_grads", default=False, type=bool) +parser.add_argument("--soap_data_format", default="channels_first", type=str) +parser.add_argument("--correct_bias", default=True, type=bool) +parser.add_argument("--nesterov", default=True, type=bool) # whether to use Nesterov-style momentum +parser.add_argument("--muon_backend", default="newtonschulz5", type=str) # the chosen backend for the orthogonalization step +parser.add_argument("--muon_backend_steps", default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) parser.add_argument('--vocab_size', default=50304, type=int) diff --git a/src/config/base.py b/src/config/base.py index 78309b6..f710c12 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -23,12 +23,26 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.95, type=float) parser.add_argument("--scheduler", default="cos", choices=["linear", "cos", "none"]) - parser.add_argument("--opt", default="adamw", choices=["adamw", "sgd"]) + parser.add_argument( + "--opt", default="adamw", choices=["adamw", "sgd", "muon", "soap"] + ) parser.add_argument("--eval_freq", default=200, type=int) # in iterations parser.add_argument("--results_base_folder", default="./exps", type=str) parser.add_argument( "--grad_clip", default=0.0, type=float ) # default value is 1.0 in NanoGPT + parser.add_argument("--momentum", default=0.9, type=float) + parser.add_argument("--shampoo_beta", default=-1.0, type=float) + parser.add_argument("--precondition_frequency", default=10, type=int) + parser.add_argument("--max_precond_dim", default=10000, type=int) + parser.add_argument("--merge_dims", default=False, type=bool) + parser.add_argument("--precondition_1d", default=False, type=bool) + parser.add_argument("--normalize_grads", default=False, type=bool) + parser.add_argument("--soap_data_format", default="channels_first", type=str) + parser.add_argument("--correct_bias", default=True, type=bool) + parser.add_argument("--nesterov", default=True, type=bool) + parser.add_argument("--muon_backend", default="newtonschulz5", type=str) + parser.add_argument("--muon_backend_steps", default=5, type=int) # Dataset params parser.add_argument( "--dataset", diff --git a/src/main.py b/src/main.py index ab0ac56..c7bc89b 100755 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,8 @@ from data.utils import get_dataset from models.utils import get_model from optim.base import train_base +from optim.muon import Muon, zeropower_backends +from optim.soap import SOAP def get_args(): @@ -92,9 +94,38 @@ def main(args): weight_decay=args.weight_decay, **extra_args, ) + elif args.opt == "soap": + opt = SOAP( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + shampoo_beta=args.shampoo_beta, + weight_decay=args.weight_decay, + precondition_frequency=args.precondition_frequency, + max_precond_dim=args.max_precond_dim, + merge_dims=args.merge_dims, + precondition_1d=args.precondition_1d, + normalize_grads=args.normalize_grads, + data_format=args.soap_data_format, + correct_bias=args.correct_bias, + ) + elif args.opt == "muon": + opt = Muon( + group_specs, + lr=args.lr, + momentum=args.momentum, + nesterov=args.nesterov, + backend=args.muon_backend, + backend_steps=args.muon_backend_steps, + # rank=args.rank, + # world_size=args.world_size, + ) # i have left rank and world_size inside Muon else: opt = torch.optim.SGD( - group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay + group_specs, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, ) if args.scheduler != "none": diff --git a/src/optim/muon.py b/src/optim/muon.py index fe425af..ab1f56e 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -9,7 +9,7 @@ def zeropower_via_svd(G, steps=None): return U @ V.T -@torch.compile +# @torch.compile def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -40,6 +40,102 @@ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): ) +# class Muon(torch.optim.Optimizer): +# """ +# Muon - MomentUm Orthogonalized by Newton-schulz + +# Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- +# processing step, in which each 2D parameter's update is replaced with the nearest orthogonal +# matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has +# the advantage that it can be stably run in bfloat16 on the GPU. + +# Some warnings: +# - This optimizer assumes that all parameters passed in are 2D. +# - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D +# parameters; those should all be optimized by a standard method (e.g., AdamW). +# - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. +# - We believe it is unlikely to work well for training with small batch size. +# - We believe it may not work well for finetuning pretrained models, but we haven't tested this. +# - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). + +# Arguments: +# lr: The learning rate used by the internal SGD. +# momentum: The momentum used by the internal SGD. +# nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) +# backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') +# backend_steps: The number of iteration steps to use in the backend, if it is iterative. +# """ + +# def __init__( +# self, +# params, +# lr=3e-4, +# momentum=0.95, +# nesterov=True, +# backend="newtonschulz5", +# backend_steps=5, +# rank=0, +# world_size=1, +# ): +# defaults = dict( +# lr=lr, +# momentum=momentum, +# nesterov=nesterov, +# backend=backend, +# backend_steps=backend_steps, +# ) +# super().__init__(params, defaults) +# self.rank = rank +# self.world_size = world_size + +# def step(self): + +# for group in self.param_groups: + +# lr = group["lr"] +# momentum = group["momentum"] +# zeropower_backend = zeropower_backends[group["backend"]] + +# # generate weight updates in distributed fashion +# total_params = sum(p.numel() for p in group["params"]) +# updates_flat = torch.zeros( +# total_params, device="cuda", dtype=torch.bfloat16 +# ) +# curr_idx = 0 +# for i, p in enumerate(group["params"]): +# # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs +# if i % self.world_size == self.rank: +# g = p.grad +# if g is None: +# continue +# state = self.state[p] +# if "momentum_buffer" not in state: +# state["momentum_buffer"] = torch.zeros_like(g) +# buf = state["momentum_buffer"] +# buf.mul_(momentum).add_(g) +# if group["nesterov"]: +# g = g.add(buf, alpha=momentum) +# g = zeropower_backend(g, steps=group["backend_steps"]) +# g *= ( +# max(g.size(0), g.size(1)) ** 0.5 +# ) # scale to have update.square().mean() == 1 +# updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() +# curr_idx += p.numel() + +# # sync updates across devices. we are not memory-constrained so can do this simple deserialization +# dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + +# # deserialize and apply updates +# curr_idx = 0 +# for p in group["params"]: +# g = ( +# updates_flat[curr_idx : curr_idx + p.numel()] +# .view_as(p.data) +# .type_as(p.data) +# ) +# p.data.add_(g, alpha=-lr) +# curr_idx += p.numel() + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -65,73 +161,31 @@ class Muon(torch.optim.Optimizer): backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') backend_steps: The number of iteration steps to use in the backend, if it is iterative. """ - - def __init__( - self, - params, - lr=3e-4, - momentum=0.95, - nesterov=True, - backend="newtonschulz5", - backend_steps=5, - rank=0, - world_size=1, - ): - defaults = dict( - lr=lr, - momentum=momentum, - nesterov=nesterov, - backend=backend, - backend_steps=backend_steps, - ) + def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, + backend='newtonschulz5', backend_steps=5): + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) super().__init__(params, defaults) - self.rank = rank - self.world_size = world_size def step(self): - + loss = None for group in self.param_groups: - - lr = group["lr"] - momentum = group["momentum"] - zeropower_backend = zeropower_backends[group["backend"]] - - # generate weight updates in distributed fashion - total_params = sum(p.numel() for p in group["params"]) - updates_flat = torch.zeros( - total_params, device="cuda", dtype=torch.bfloat16 - ) - curr_idx = 0 - for i, p in enumerate(group["params"]): - # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs - if i % self.world_size == self.rank: - g = p.grad - if g is None: - continue - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - g = zeropower_backend(g, steps=group["backend_steps"]) - g *= ( - max(g.size(0), g.size(1)) ** 0.5 - ) # scale to have update.square().mean() == 1 - updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() - curr_idx += p.numel() - - # sync updates across devices. we are not memory-constrained so can do this simple deserialization - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - - # deserialize and apply updates - curr_idx = 0 - for p in group["params"]: - g = ( - updates_flat[curr_idx : curr_idx + p.numel()] - .view_as(p.data) - .type_as(p.data) - ) + lr = group['lr'] + momentum = group['momentum'] + zeropower_backend = zeropower_backends[group['backend']] + + for p in group['params']: + g = p.grad + if g is None: + continue + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(g) + if group['nesterov']: + g = g.add(buf, alpha=momentum) + g = zeropower_backend(g, steps=group['backend_steps']) + g *= max(g.size(0), g.size(1)) ** 0.5 # scale to have update.square().mean() == 1 p.data.add_(g, alpha=-lr) - curr_idx += p.numel() + + return loss diff --git a/src/optim/soap.py b/src/optim/soap.py index 0a75004..3762096 100644 --- a/src/optim/soap.py +++ b/src/optim/soap.py @@ -4,13 +4,12 @@ import torch import torch.nn as nn -import torch.optim as optim # Parts of the code are modifications of Pytorch's AdamW optimizer # Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py -class SOAP(optim.Optimizer): +class SOAP(torch.optim.Optimizer): """ Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). From 49b802dfb44a56080b1366cc1829cd2cfe52f57f Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 16 Oct 2024 19:48:39 +0300 Subject: [PATCH 04/58] AdEMAMix and Lion is here, Muon still TODO --- README.md | 6 +- src/config/base.py | 10 ++- src/main.py | 19 +++++ src/optim/ademamix.py | 186 ++++++++++++++++++++++++++++++++++++++++++ src/optim/lion.py | 72 ++++++++++++++++ src/optim/muon.py | 43 +++++++--- 6 files changed, 321 insertions(+), 15 deletions(-) create mode 100644 src/optim/ademamix.py create mode 100644 src/optim/lion.py diff --git a/README.md b/README.md index 6716b71..2a5b334 100755 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -60,6 +60,10 @@ parser.add_argument("--correct_bias", default=True, type=bool) parser.add_argument("--nesterov", default=True, type=bool) # whether to use Nesterov-style momentum parser.add_argument("--muon_backend", default="newtonschulz5", type=str) # the chosen backend for the orthogonalization step parser.add_argument("--muon_backend_steps", default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative +parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta2 in AdEMAMix +parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix +parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) +parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) parser.add_argument('--vocab_size', default=50304, type=int) diff --git a/src/config/base.py b/src/config/base.py index f710c12..6c11143 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -1,3 +1,5 @@ +from typing import Optional + import distributed import torch @@ -24,7 +26,9 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--beta2", default=0.95, type=float) parser.add_argument("--scheduler", default="cos", choices=["linear", "cos", "none"]) parser.add_argument( - "--opt", default="adamw", choices=["adamw", "sgd", "muon", "soap"] + "--opt", + default="adamw", + choices=["adamw", "sgd", "muon", "soap", "ademamix", "lion"], ) parser.add_argument("--eval_freq", default=200, type=int) # in iterations parser.add_argument("--results_base_folder", default="./exps", type=str) @@ -43,6 +47,10 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--nesterov", default=True, type=bool) parser.add_argument("--muon_backend", default="newtonschulz5", type=str) parser.add_argument("--muon_backend_steps", default=5, type=int) + parser.add_argument("--adema_beta3", default=0.9, type=float) + parser.add_argument("--adema_alpha", default=2.0, type=float) + parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) + parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) # Dataset params parser.add_argument( "--dataset", diff --git a/src/main.py b/src/main.py index c7bc89b..12ca6e6 100755 --- a/src/main.py +++ b/src/main.py @@ -14,7 +14,9 @@ import distributed from data.utils import get_dataset from models.utils import get_model +from optim.ademamix import AdEMAMix from optim.base import train_base +from optim.lion import Lion from optim.muon import Muon, zeropower_backends from optim.soap import SOAP @@ -120,6 +122,23 @@ def main(args): # rank=args.rank, # world_size=args.world_size, ) # i have left rank and world_size inside Muon + elif args.opt == "ademamix": + opt = AdEMAMix( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2, args.adema_beta3), + alpha=args.adema_alpha, + beta3_warmup=args.adema_beta3_warmup, + alpha_warmup=args.adema_alpha_warmup, + weight_decay=args.weight_decay, + ) + elif args.opt == "lion": + opt = Lion( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/ademamix.py b/src/optim/ademamix.py new file mode 100644 index 0000000..e339d3a --- /dev/null +++ b/src/optim/ademamix.py @@ -0,0 +1,186 @@ +"""Here is an original implementation of AdEMAMix. Source: https://github.com/apple/ml-ademamix""" + +import math + +import torch + + +def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): + if step < warmup: + a = step / float(warmup) + return (1.0 - a) * alpha_start + a * alpha_end + return alpha_end + + +def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): + + def f(beta, eps=1e-8): + return math.log(0.5) / math.log(beta + eps) - 1 + + def f_inv(t): + return math.pow(0.5, 1 / (t + 1)) + + if step < warmup: + a = step / float(warmup) + return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end)) + return beta_end + + +class AdEMAMix(torch.optim.Optimizer): + r"""Implements the AdEMAMix algorithm. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) + corresponding to beta_1, beta_2, beta_3 in AdEMAMix + alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) + beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) + alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay as in AdamW (default: 0) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999, 0.9999), + alpha=2.0, + beta3_warmup=None, + alpha_warmup=None, + eps=1e-8, + weight_decay=0, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + alpha=alpha, + beta3_warmup=beta3_warmup, + alpha_warmup=alpha_warmup, + weight_decay=weight_decay, + ) + super(AdEMAMix, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdEMAMix, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + lr = group["lr"] + lmbda = group["weight_decay"] + eps = group["eps"] + beta1, beta2, beta3_final = group["betas"] + beta3_warmup = group["beta3_warmup"] + alpha_final = group["alpha"] + alpha_warmup = group["alpha_warmup"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdEMAMix does not support sparse gradients.") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + if beta1 != 0.0: # save memory in case beta1 is 0.0 + state["exp_avg_fast"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + else: + state["exp_avg_fast"] = None + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg_fast, exp_avg_slow, exp_avg_sq = ( + state["exp_avg_fast"], + state["exp_avg_slow"], + state["exp_avg_sq"], + ) + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Compute the effective alpha and beta3 in case warmup is used + if alpha_warmup is not None: + alpha = linear_warmup_scheduler( + state["step"], + alpha_end=alpha_final, + alpha_start=0, + warmup=alpha_warmup, + ) + else: + alpha = alpha_final + + if beta3_warmup is not None: + beta3 = linear_hl_warmup_scheduler( + state["step"], + beta_end=beta3_final, + beta_start=beta1, + warmup=beta3_warmup, + ) + else: + beta3 = beta3_final + + # Decay the first and second moment running average coefficient + if beta1 != 0.0: + exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) + else: + exp_avg_fast = grad + exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + update = ( + exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow + ) / denom + + # decay + update.add_(p, alpha=lmbda) + + p.add_(-lr * update) + + return loss diff --git a/src/optim/lion.py b/src/optim/lion.py new file mode 100644 index 0000000..cfcc26d --- /dev/null +++ b/src/optim/lion.py @@ -0,0 +1,72 @@ +"""Here is an original implementation of Lion. Source: https://github.com/google/automl/tree/master/lion""" + +import torch + + +class Lion(torch.optim.Optimizer): + r"""Implements Lion algorithm.""" + + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): + """Initialize the hyperparameters. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-4) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + weight_decay (float, optional): weight decay coefficient (default: 0) + """ + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) + + exp_avg = state["exp_avg"] + beta1, beta2 = group["betas"] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + + p.add_(update.sign_(), alpha=-group["lr"]) + + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss diff --git a/src/optim/muon.py b/src/optim/muon.py index ab1f56e..7a3aa05 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -136,6 +136,7 @@ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): # p.data.add_(g, alpha=-lr) # curr_idx += p.numel() + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -161,31 +162,47 @@ class Muon(torch.optim.Optimizer): backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') backend_steps: The number of iteration steps to use in the backend, if it is iterative. """ - def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, - backend='newtonschulz5', backend_steps=5): - defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) + + def __init__( + self, + params, + lr=3e-4, + momentum=0.95, + nesterov=True, + backend="newtonschulz5", + backend_steps=5, + ): + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + backend=backend, + backend_steps=backend_steps, + ) super().__init__(params, defaults) def step(self): loss = None for group in self.param_groups: - lr = group['lr'] - momentum = group['momentum'] - zeropower_backend = zeropower_backends[group['backend']] + lr = group["lr"] + momentum = group["momentum"] + zeropower_backend = zeropower_backends[group["backend"]] - for p in group['params']: + for p in group["params"]: g = p.grad if g is None: continue state = self.state[p] - if 'momentum_buffer' not in state: - state['momentum_buffer'] = torch.zeros_like(g) - buf = state['momentum_buffer'] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] buf.mul_(momentum).add_(g) - if group['nesterov']: + if group["nesterov"]: g = g.add(buf, alpha=momentum) - g = zeropower_backend(g, steps=group['backend_steps']) - g *= max(g.size(0), g.size(1)) ** 0.5 # scale to have update.square().mean() == 1 + g = zeropower_backend(g, steps=group["backend_steps"]) + g *= ( + max(g.size(0), g.size(1)) ** 0.5 + ) # scale to have update.square().mean() == 1 p.data.add_(g, alpha=-lr) return loss From 3d8b24a14703282e0f4ddc8ec1402594f69c9800 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 16 Oct 2024 19:55:13 +0300 Subject: [PATCH 05/58] AdEMAMix and Lion is here, Muon still TODO --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2a5b334..a8f5d04 100755 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ parser.add_argument("--correct_bias", default=True, type=bool) parser.add_argument("--nesterov", default=True, type=bool) # whether to use Nesterov-style momentum parser.add_argument("--muon_backend", default="newtonschulz5", type=str) # the chosen backend for the orthogonalization step parser.add_argument("--muon_backend_steps", default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative -parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta2 in AdEMAMix +parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta3 in AdEMAMix parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) From a04778605c684d80ac8596ef22419922a2b27b5b Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 11:52:41 +0300 Subject: [PATCH 06/58] Schedule-Free AdamW is here, warmup_percent -> warmup_steps / iterations --- README.md | 7 +- src/config/base.py | 7 +- src/main.py | 23 +++- src/optim/ademamix.py | 5 +- src/optim/lion.py | 5 +- src/optim/muon.py | 5 +- src/optim/schedulefree.py | 219 ++++++++++++++++++++++++++++++++++++++ src/optim/soap.py | 5 +- 8 files changed, 266 insertions(+), 10 deletions(-) create mode 100644 src/optim/schedulefree.py diff --git a/README.md b/README.md index a8f5d04..8e64474 100755 --- a/README.md +++ b/README.md @@ -38,13 +38,14 @@ parser.add_argument('--seed', default=0, type=int) # random seed for the paramet parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations +parser.add_argument("--warmup_steps", default=300, type=int) # it was only warmup_percent before parser.add_argument('--lr', default=1e-3, type=float) -parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup steps is iterations * warmup_percent +parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup_steps is iterations * warmup_percent parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion', 'sf-adamw']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -64,6 +65,8 @@ parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta3 in AdEMAMi parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) +parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulfree hyperparameter +parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulfree hyperparameter # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) parser.add_argument('--vocab_size', default=50304, type=int) diff --git a/src/config/base.py b/src/config/base.py index 6c11143..0cdc013 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -19,8 +19,9 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--data_seed", default=1337, type=int) parser.add_argument("--device", default="cuda:0", type=str) parser.add_argument("--iterations", default=25000, type=int) + parser.add_argument("--warmup_steps", default=300, type=int) # it was only warmup_percent before parser.add_argument("--lr", default=1e-3, type=float) - parser.add_argument("--warmup_percent", default=0.05, type=float) + parser.add_argument("--warmup_percent", default=0.05, type=float) # leave it anyway, warmup_steps / iterations parser.add_argument("--weight_decay", default=0.1, type=float) parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.95, type=float) @@ -28,7 +29,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument( "--opt", default="adamw", - choices=["adamw", "sgd", "muon", "soap", "ademamix", "lion"], + choices=["adamw", "sgd", "muon", "soap", "ademamix", "lion", "sf-adamw"], ) parser.add_argument("--eval_freq", default=200, type=int) # in iterations parser.add_argument("--results_base_folder", default="./exps", type=str) @@ -51,6 +52,8 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--adema_alpha", default=2.0, type=float) parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) + parser.add_argument("--schedulefree_r", default=0.0, type=float) + parser.add_argument("--weight_lr_power", default=2.0, type=float) # Dataset params parser.add_argument( "--dataset", diff --git a/src/main.py b/src/main.py index 12ca6e6..8a658b5 100755 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ from optim.base import train_base from optim.lion import Lion from optim.muon import Muon, zeropower_backends +from optim.schedulefree import AdamWScheduleFree from optim.soap import SOAP @@ -139,6 +140,16 @@ def main(args): betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, ) + elif args.opt == "sf-adamw": + opt = AdamWScheduleFree( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + warmup_steps=args.warmup_steps, + r=args.schedulefree_r, + weight_lr_power=args.weight_lr_power, + ) # without foreach argument else: opt = torch.optim.SGD( group_specs, @@ -148,12 +159,20 @@ def main(args): ) if args.scheduler != "none": + assert ( + args.warmup_steps < args.iterations + ), "Warmup steps must be < iterations." # from schedules-and-scaling if args.scheduler in ["cos", "linear"]: + # initial lr is args.lr / div_factor + # final lr is initial_lr/final_div_factor = args.lr / div_factor / final_div_factor scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=opt, - max_lr=args.lr, + max_lr=[ + group.get("lr", args.lr) for group in group_specs + ], # it was args.lr total_steps=args.iterations, - pct_start=args.warmup_percent, + pct_start=args.warmup_steps + / args.iterations, # it was args.warmup_percent anneal_strategy=args.scheduler, cycle_momentum=False, div_factor=1e2, diff --git a/src/optim/ademamix.py b/src/optim/ademamix.py index e339d3a..9a633e3 100644 --- a/src/optim/ademamix.py +++ b/src/optim/ademamix.py @@ -1,4 +1,7 @@ -"""Here is an original implementation of AdEMAMix. Source: https://github.com/apple/ml-ademamix""" +""" +Here is an original implementation of AdEMAMix. +Source: https://github.com/apple/ml-ademamix +""" import math diff --git a/src/optim/lion.py b/src/optim/lion.py index cfcc26d..2c0c59a 100644 --- a/src/optim/lion.py +++ b/src/optim/lion.py @@ -1,4 +1,7 @@ -"""Here is an original implementation of Lion. Source: https://github.com/google/automl/tree/master/lion""" +""" +Here is an original implementation of Lion. +Source: https://github.com/google/automl/tree/master/lion +""" import torch diff --git a/src/optim/muon.py b/src/optim/muon.py index 7a3aa05..ce31081 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -1,4 +1,7 @@ -"""Here is an original implementation of MUON. Source: https://github.com/KellerJordan/modded-nanogpt""" +""" +Here is an original implementation of Muon. +Source: https://github.com/KellerJordan/modded-nanogpt +""" import torch import torch.distributed as dist diff --git a/src/optim/schedulefree.py b/src/optim/schedulefree.py new file mode 100644 index 0000000..5cfa5b9 --- /dev/null +++ b/src/optim/schedulefree.py @@ -0,0 +1,219 @@ +""" +Here is an original implementation of Schedule-Free AdamW and SGD. +Source: https://github.com/facebookresearch/schedule_free +""" + +import math +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +import torch +from typing_extensions import TypeAlias + +try: + from torch.optim.optimizer import ParamsT +except ImportError: + ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] + + +class AdamWScheduleFree(torch.optim.Optimizer): + r""" + Schedule-Free AdamW + As the name suggests, no scheduler is needed with this optimizer. + To add warmup, rather than using a learning rate schedule you can just + set the warmup_steps parameter. + + This optimizer requires that .train() and .eval() be called before the + beginning of training and evaluation respectively. The optimizer should + also be placed in eval mode when saving checkpoints. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): + Learning rate parameter (default 0.0025) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)). + eps (float): + Term added to the denominator outside of the root operation to + improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + warmup_steps (int): Enables a linear learning rate warmup (default 0). + r (float): Use polynomial weighting in the average + with power r (default 0). + weight_lr_power (float): During warmup, the weights in the average will + be equal to lr raised to this power. Set to 0 for no weighting + (default 2.0). + foreach (bool): Use a foreach-backed implementation of the optimizer. + Should be significantly faster, but will have higher peak memory + usage (default True if supported in your PyTorch version). + """ + + def __init__( + self, + params: ParamsT, + lr: Union[float, torch.Tensor] = 0.0025, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + warmup_steps: int = 0, + r: float = 0.0, + weight_lr_power: float = 2.0, + foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), + ): + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + r=r, + k=0, + warmup_steps=warmup_steps, + train_mode=True, + weight_sum=0.0, + lr_max=-1.0, + weight_lr_power=weight_lr_power, + weight_decay=weight_decay, + foreach=foreach, + ) + super().__init__(params, defaults) + + def eval(self): + for group in self.param_groups: + train_mode = group["train_mode"] + beta1, _ = group["betas"] + if train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p.data to x + p.data.lerp_( + end=state["z"].to(p.data.device), weight=1 - 1 / beta1 + ) + group["train_mode"] = False + + def train(self): + for group in self.param_groups: + train_mode = group["train_mode"] + beta1, _ = group["betas"] + if not train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p.data to y + p.data.lerp_(end=state["z"].to(p.data.device), weight=1 - beta1) + group["train_mode"] = True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + eps = group["eps"] + beta1, beta2 = group["betas"] + decay = group["weight_decay"] + k = group["k"] + r = group["r"] + warmup_steps = group["warmup_steps"] + weight_lr_power = group["weight_lr_power"] + + if k < warmup_steps: + sched = (k + 1) / warmup_steps + else: + sched = 1.0 + + bias_correction2 = 1 - beta2 ** (k + 1) + lr = group["lr"] * sched * math.sqrt(bias_correction2) + + lr_max = group["lr_max"] = max(lr, group["lr_max"]) + + weight = ((k + 1) ** r) * (lr_max**weight_lr_power) + weight_sum = group["weight_sum"] = group["weight_sum"] + weight + + try: + ckp1 = weight / weight_sum + except ZeroDivisionError: + ckp1 = 0 + + if not group["train_mode"]: + raise Exception("Not in train mode!") + + active_p = [p for p in group["params"] if p.grad is not None] + + for p in active_p: + if "z" not in self.state[p]: + self.state[p]["z"] = torch.clone(p.data) + self.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) + + if group["foreach"] and len(active_p) > 0: + y, grad, exp_avg_sq, z = zip( + *[ + ( + p.data, + p.grad, + self.state[p]["exp_avg_sq"], + self.state[p]["z"], + ) + for p in active_p + ] + ) + + # Decay the first and second moment running average coefficient + torch._foreach_mul_(exp_avg_sq, beta2) + torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2) + denom = torch._foreach_sqrt(exp_avg_sq) + torch._foreach_add_(denom, eps) + + # Normalize grad in-place for memory efficiency + torch._foreach_div_(grad, denom) + + # Weight decay calculated at y + if decay != 0: + torch._foreach_add_(grad, y, alpha=decay) + + # These operations update y in-place, + # without computing x explicitly. + torch._foreach_lerp_(y, z, weight=ckp1) + torch._foreach_add_(y, grad, alpha=lr * (beta1 * (1 - ckp1) - 1)) + + # z step + torch._foreach_sub_(z, grad, alpha=lr) + else: + for p in active_p: + y = p.data # Notation to match theory + grad = p.grad.data + + state = self.state[p] + + z = state["z"] + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = exp_avg_sq.sqrt().add_(eps) + + # Reuse grad buffer for memory efficiency + grad_normalized = grad.div_(denom) + + # Weight decay calculated at y + if decay != 0: + grad_normalized.add_(y, alpha=decay) + + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.add_(grad_normalized, alpha=lr * (beta1 * (1 - ckp1) - 1)) + + # z step + z.sub_(grad_normalized, alpha=lr) + + group["k"] = k + 1 + return loss diff --git a/src/optim/soap.py b/src/optim/soap.py index 3762096..70a60cf 100644 --- a/src/optim/soap.py +++ b/src/optim/soap.py @@ -1,4 +1,7 @@ -"""Here is an original implementation of SOAP. Source: https://github.com/nikhilvyas/SOAP""" +""" +Here is an original implementation of SOAP. +Source: https://github.com/nikhilvyas/SOAP +""" from itertools import chain From 66b77f5c5b3dda5515536e304b8d17b2fb1e7bcd Mon Sep 17 00:00:00 2001 From: mpagli Date: Thu, 17 Oct 2024 12:38:42 +0000 Subject: [PATCH 07/58] eval on a fix subset + better lr decay --- src/config/base.py | 4 ++-- src/main.py | 2 +- src/optim/base.py | 1 + src/optim/utils.py | 6 +++++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/config/base.py b/src/config/base.py index 6c11143..e13064f 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -49,8 +49,8 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--muon_backend_steps", default=5, type=int) parser.add_argument("--adema_beta3", default=0.9, type=float) parser.add_argument("--adema_alpha", default=2.0, type=float) - parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) - parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) + parser.add_argument("--adema_beta3_warmup", default=None, type=int) + parser.add_argument("--adema_alpha_warmup", default=None, type=int) # Dataset params parser.add_argument( "--dataset", diff --git a/src/main.py b/src/main.py index 12ca6e6..374b8c0 100755 --- a/src/main.py +++ b/src/main.py @@ -157,7 +157,7 @@ def main(args): anneal_strategy=args.scheduler, cycle_momentum=False, div_factor=1e2, - final_div_factor=0.1, + final_div_factor=1, ) else: raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") diff --git a/src/optim/base.py b/src/optim/base.py index 486f94b..ccd1bd1 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -149,6 +149,7 @@ def train_base( extra_args.device, max_num_batches=eval_steps, ctx=type_ctx, + reset_iterator=True, ) print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}" diff --git a/src/optim/utils.py b/src/optim/utils.py index 92f690a..6a4ee51 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -1,4 +1,5 @@ from contextlib import ExitStack, contextmanager, nullcontext +import itertools import numpy as np import torch @@ -18,9 +19,12 @@ def get_batch(dataloader, device="cpu"): @torch.no_grad() -def eval(model, data_val_iter, device="cpu", max_num_batches=24, ctx=nullcontext()): +def eval(model, data_val_iter, device="cpu", max_num_batches=24, ctx=nullcontext(), reset_iterator=False): assert model.training == False + if reset_iterator: # ensure that we always eval on the same batches + data_val_iter = itertools.cycle(data_val_iter) + loss_list_val, acc_list = [], [] for _ in range(max_num_batches): From 7f158a3633543705b85a48f0a2fe29b0731a954c Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 15:48:49 +0300 Subject: [PATCH 08/58] Schedule-Free SGD + AdamW are here --- README.md | 7 +- src/config/base.py | 19 +++- src/main.py | 12 ++- src/optim/schedulefree.py | 182 +++++++++++++++++++++++++++++++++++++- 4 files changed, 212 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8e64474..d90de9d 100755 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion', 'sf-adamw']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion', 'sf-adamw', 'sf-sgd']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -63,8 +63,8 @@ parser.add_argument("--muon_backend", default="newtonschulz5", type=str) # the c parser.add_argument("--muon_backend_steps", default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta3 in AdEMAMix parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix -parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) -parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) +parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) # AdEMAMix hyperparameter +parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) # AdEMAMix hyperparameter parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulfree hyperparameter parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulfree hyperparameter # Dataset params @@ -130,6 +130,7 @@ src/ optim/ utils.py # contains eval and get_batch functions base.py # training function for the base and llama models + ... distributed/ # code to enable simple distributed training ``` diff --git a/src/config/base.py b/src/config/base.py index 0cdc013..443eae4 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -19,9 +19,13 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--data_seed", default=1337, type=int) parser.add_argument("--device", default="cuda:0", type=str) parser.add_argument("--iterations", default=25000, type=int) - parser.add_argument("--warmup_steps", default=300, type=int) # it was only warmup_percent before + parser.add_argument( + "--warmup_steps", default=300, type=int + ) # it was only warmup_percent before parser.add_argument("--lr", default=1e-3, type=float) - parser.add_argument("--warmup_percent", default=0.05, type=float) # leave it anyway, warmup_steps / iterations + parser.add_argument( + "--warmup_percent", default=0.05, type=float + ) # leave it anyway, warmup_steps / iterations parser.add_argument("--weight_decay", default=0.1, type=float) parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.95, type=float) @@ -29,7 +33,16 @@ def parse_args(base_parser, args, namespace): parser.add_argument( "--opt", default="adamw", - choices=["adamw", "sgd", "muon", "soap", "ademamix", "lion", "sf-adamw"], + choices=[ + "adamw", + "sgd", + "muon", + "soap", + "ademamix", + "lion", + "sf-adamw", + "sf-sgd", + ], ) parser.add_argument("--eval_freq", default=200, type=int) # in iterations parser.add_argument("--results_base_folder", default="./exps", type=str) diff --git a/src/main.py b/src/main.py index 8a658b5..a3b7005 100755 --- a/src/main.py +++ b/src/main.py @@ -18,7 +18,7 @@ from optim.base import train_base from optim.lion import Lion from optim.muon import Muon, zeropower_backends -from optim.schedulefree import AdamWScheduleFree +from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree from optim.soap import SOAP @@ -150,6 +150,16 @@ def main(args): r=args.schedulefree_r, weight_lr_power=args.weight_lr_power, ) # without foreach argument + elif args.opt == "sf-sgd": + opt = SGDScheduleFree( + group_specs, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + warmup_steps=args.warmup_steps, + r=args.schedulefree_r, + weight_lr_power=args.weight_lr_power, + ) # without foreach argument else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/schedulefree.py b/src/optim/schedulefree.py index 5cfa5b9..266f56a 100644 --- a/src/optim/schedulefree.py +++ b/src/optim/schedulefree.py @@ -1,5 +1,5 @@ """ -Here is an original implementation of Schedule-Free AdamW and SGD. +Here is an original implementation of Schedule-Free AdamW and Schedule-Free SGD. Source: https://github.com/facebookresearch/schedule_free """ @@ -217,3 +217,183 @@ def step(self, closure=None): group["k"] = k + 1 return loss + + +class SGDScheduleFree(torch.optim.Optimizer): + r""" + Schedule-Free SGD + As the name suggests, no scheduler is needed with this optimizer. + To add warmup, rather than using a learning rate schedule you can just + set the warmup_steps parameter. + + This optimizer requires that .train() and .eval() be called before the + beginning of training and evaluation respectively. The optimizer should + also be placed in eval mode when saving checkpoints. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): + Learning rate parameter (default 1.0) + momentum (float): momentum factor, must be between 0 and 1 exclusive + (default: 0.9) + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + warmup_steps (int): Enables a linear learning rate warmup (default 0). + r (float): Use polynomial weighting in the average + with power r (default 0). + weight_lr_power (float): During warmup, the weights in the average will + be equal to lr raised to this power. Set to 0 for no weighting + (default 2.0). + foreach (bool): Use a foreach-backed implementation of the optimizer. + Should be significantly faster, but will have higher peak memory + usage (default True if supported in your PyTorch version). + """ + + def __init__( + self, + params: ParamsT, + lr: Union[float, torch.Tensor] = 1.0, + momentum: float = 0.9, + weight_decay: float = 0, + warmup_steps: int = 0, + r: float = 0.0, + weight_lr_power: float = 2, + foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if momentum <= 0 or momentum >= 1: + raise ValueError( + "Momentum must be between 0 and 1 exclusive: {}".format(momentum) + ) + + defaults = dict( + lr=lr, + momentum=momentum, + r=r, + k=0, + warmup_steps=warmup_steps, + train_mode=True, + weight_sum=0.0, + lr_max=-1.0, + weight_lr_power=weight_lr_power, + weight_decay=weight_decay, + foreach=foreach, + ) + super().__init__(params, defaults) + + def eval(self): + for group in self.param_groups: + train_mode = group["train_mode"] + momentum = group["momentum"] + if train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p.data to x + p.data.lerp_( + end=state["z"].to(p.data.device), weight=1 - 1 / momentum + ) + group["train_mode"] = False + + def train(self): + for group in self.param_groups: + train_mode = group["train_mode"] + momentum = group["momentum"] + if not train_mode: + for p in group["params"]: + state = self.state[p] + if "z" in state: + # Set p.data to y + p.data.lerp_( + end=state["z"].to(p.data.device), weight=1 - momentum + ) + group["train_mode"] = True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + lr = group["lr"] + weight_decay = group["weight_decay"] + k = group["k"] + warmup_steps = group["warmup_steps"] + + if k < warmup_steps: + sched = (k + 1) / warmup_steps + else: + sched = 1.0 + lr = group["lr"] * sched + + weight_lr_power = group["weight_lr_power"] + + r = group["r"] + lr_max = group["lr_max"] = max(lr, group["lr_max"]) + + weight = ((k + 1) ** r) * (lr_max**weight_lr_power) + weight_sum = group["weight_sum"] = group["weight_sum"] + weight + + try: + ckp1 = weight / weight_sum + except ZeroDivisionError: + ckp1 = 0 + + if not group["train_mode"]: + raise Exception("Not in train mode!") + + active_p = [p for p in group["params"] if p.grad is not None] + + for p in active_p: + if "z" not in self.state[p]: + self.state[p]["z"] = torch.clone(p.data) + + if group["foreach"] and len(active_p) > 0: + y, grad, z = zip( + *[(p.data, p.grad, self.state[p]["z"]) for p in active_p] + ) + + # Apply weight decay + if weight_decay != 0: + torch._foreach_add_(grad, y, alpha=weight_decay) + + # These operations update y in-place, + # without computing x explicitly. + torch._foreach_lerp_(y, z, weight=ckp1) + torch._foreach_add_(y, grad, alpha=lr * (momentum * (1 - ckp1) - 1)) + + # SGD step + torch._foreach_sub_(z, grad, alpha=lr) + else: + for p in active_p: + y = p.data # Notation to match theory + grad = p.grad.data + z = self.state[p]["z"] + + # Apply weight decay + if weight_decay != 0: + grad.add_(y, alpha=weight_decay) + + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.add_(grad, alpha=lr * (momentum * (1 - ckp1) - 1)) + + # SGD step + z.sub_(grad, alpha=lr) + + group["k"] = k + 1 + return loss From 1f67f6db9621562a81647836b1f6f4baa08ea566 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 15:50:42 +0300 Subject: [PATCH 09/58] Schedule-Free SGD + AdamW are here --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d90de9d..b8641ab 100755 --- a/README.md +++ b/README.md @@ -65,8 +65,8 @@ parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta3 in AdEMAMi parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) # AdEMAMix hyperparameter parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) # AdEMAMix hyperparameter -parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulfree hyperparameter -parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulfree hyperparameter +parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulefree hyperparameter +parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulefree hyperparameter # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) parser.add_argument('--vocab_size', default=50304, type=int) From 7378449742b443be2038f62f5aff58e3ad3de007 Mon Sep 17 00:00:00 2001 From: mpagli Date: Thu, 17 Oct 2024 13:48:56 +0000 Subject: [PATCH 10/58] push to wandb team + display grad norm --- src/config/base.py | 1 + src/main.py | 2 +- src/optim/base.py | 7 ++++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/config/base.py b/src/config/base.py index e13064f..9004ffb 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -99,6 +99,7 @@ def parse_args(base_parser, args, namespace): # logging params (WandB) parser.add_argument("--wandb", action="store_true") # whether to use wandb or not parser.add_argument("--wandb_project", default="my-project", type=str) + parser.add_argument('--wandb_entity', default=None, type=none_or_str) parser.add_argument( "--wandb_run_prefix", default="none", type=str ) # is added before the autogenerated experiment name diff --git a/src/main.py b/src/main.py index 374b8c0..21c6d5e 100755 --- a/src/main.py +++ b/src/main.py @@ -169,7 +169,7 @@ def main(args): if distributed_backend.is_master_process() and args.wandb: params_copy = copy.deepcopy(vars(args)) del params_copy["device"] - wandb.init(project=args.wandb_project, name=exp_name, config=params_copy) + wandb.init(project=args.wandb_project, name=exp_name, config=params_copy, entity=args.wandb_entity) ckpt_path = os.path.join( args.results_base_folder, args.dataset, args.model, exp_name diff --git a/src/optim/base.py b/src/optim/base.py index ccd1bd1..ebbe6ff 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -80,6 +80,8 @@ def train_base( print(f"Compiling model ...") model = torch.compile(model) # requires pytorch 2.0+ + grad_norms = [] + model.train() t0 = time.time() @@ -120,7 +122,8 @@ def train_base( data_train_iter = iter(data["train"]) if extra_args.grad_clip != 0.0: - torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) + grad_norms.append(grad_norm) opt.step() scheduler.step() opt.zero_grad(set_to_none=True) @@ -162,6 +165,8 @@ def train_base( logs = { "iter": itr, "train/loss": train_loss, + "train/max_grad_norm": max(grad_norms).item() if grad_norms else 0, + "train/mean_grad_norm": torch.tensor(grad_norms).mean().item() if grad_norms else 0, "val/loss": val_loss, "val/perplexity": val_perplexity, "val/acc": val_acc, From 2b0e71f7917c883ea8f89348d4a3db7e39f62689 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 17:07:49 +0300 Subject: [PATCH 11/58] a code for schedules is here --- README.md | 8 ++- src/config/base.py | 16 +++++- src/main.py | 23 ++++++++- src/optim/utils.py | 120 +++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 158 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b8641ab..2dc47e2 100755 --- a/README.md +++ b/README.md @@ -39,12 +39,16 @@ parser.add_argument('--data_seed', default=1337, type=int) # random seed definin parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations parser.add_argument("--warmup_steps", default=300, type=int) # it was only warmup_percent before -parser.add_argument('--lr', default=1e-3, type=float) +parser.add_argument('--lr', default=1e-3, type=float) +parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float) # wsd scheduler +parser.add_argument("--wsd_fract_decay", default=0.1, type=float) # wsd scheduler +parser.add_argument("--decay_type", default="linear", choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"]) parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup_steps is iterations * warmup_percent parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter -parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) +parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) +parser.add_argument("--cos_inf_steps", default=0, type=int) # cos_inf scheduler parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion', 'sf-adamw', 'sf-sgd']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved diff --git a/src/config/base.py b/src/config/base.py index 3fe8a94..2ea4c3a 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -23,13 +23,27 @@ def parse_args(base_parser, args, namespace): "--warmup_steps", default=300, type=int ) # it was only warmup_percent before parser.add_argument("--lr", default=1e-3, type=float) + parser.add_argument( + "--wsd_final_lr_scale", default=0.0, type=float + ) # wsd scheduler + parser.add_argument("--wsd_fract_decay", default=0.1, type=float) # wsd scheduler + parser.add_argument( + "--decay_type", + default="linear", + choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], + ) parser.add_argument( "--warmup_percent", default=0.05, type=float ) # leave it anyway, warmup_steps / iterations parser.add_argument("--weight_decay", default=0.1, type=float) parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.95, type=float) - parser.add_argument("--scheduler", default="cos", choices=["linear", "cos", "none"]) + parser.add_argument( + "--scheduler", + default="cos", + choices=["linear", "cos", "wsd", "cos_inf", "none"], + ) + parser.add_argument("--cos_inf_steps", default=0, type=int) parser.add_argument( "--opt", default="adamw", diff --git a/src/main.py b/src/main.py index fe49690..e1326b6 100755 --- a/src/main.py +++ b/src/main.py @@ -13,7 +13,7 @@ import config import distributed from data.utils import get_dataset -from models.utils import get_model +from models.utils import cos_inf_schedule, get_model, wsd_schedule from optim.ademamix import AdEMAMix from optim.base import train_base from optim.lion import Lion @@ -186,8 +186,27 @@ def main(args): anneal_strategy=args.scheduler, cycle_momentum=False, div_factor=1e2, - final_div_factor=1, + final_div_factor=0.1, ) + elif args.scheduler == "cos_inf": + lambda_schedule = cos_inf_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + n_inf=args.cos_inf_steps, + div_factor=1e2, + final_div_factor=0.1, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + elif args.scheduler == "wsd": + lambda_schedule = wsd_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + fract_decay=args.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=args.wsd_final_lr_scale, # should be 0 here + decay_type=args.decay_type, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) else: raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") else: diff --git a/src/optim/utils.py b/src/optim/utils.py index 6a4ee51..cef37b4 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -1,5 +1,6 @@ -from contextlib import ExitStack, contextmanager, nullcontext import itertools +import math +from contextlib import ExitStack, contextmanager, nullcontext import numpy as np import torch @@ -18,11 +19,122 @@ def get_batch(dataloader, device="cpu"): return x, y +def cos_inf_schedule(n_iterations, n_warmup, div_factor, final_div_factor, n_inf): + """Cosine annealing with warmup and _constant_ final_lr after cycle ended. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + div_factor: initial division factor for warmup + final_div_factor: final division factor for final lr + n_inf: number of iterations for the final lr (constant lr after cycle ended) + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + max_lr = 1.0 + base_lr = max_lr / div_factor + final_lr = base_lr / final_div_factor + + n_anneal_steps = n_iterations - n_inf + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / div_factor + elif step < n_anneal_steps: + t = (step - n_warmup) / (n_anneal_steps - n_warmup) + lr = final_lr + 0.5 * (max_lr - final_lr) * (1 + np.cos(np.pi * t)) + return lr + else: + return final_lr + + return schedule + + +def wsd_schedule( + n_iterations, + final_lr_factor=0.0, + n_warmup=1000, + init_div_factor=100, + fract_decay=0.1, + decay_type="linear", +): + """Warmup, hold, and decay schedule. + Args: + n_iterations: total number of iterations + final_lr_factor: factor by which to reduce max_lr at the end + warmup_fract: fraction of iterations used for warmup + init_div_factor: initial division factor for warmup + fract_decay: fraction of iterations used for decay + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + n_anneal_steps = int(fract_decay * n_iterations) + n_hold = n_iterations - n_anneal_steps + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor + elif step < n_hold: + return 1.0 + elif step < n_iterations: + if decay_type == "linear": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - (step - n_hold) / n_anneal_steps + ) + elif decay_type == "exp": + return final_lr_factor ** ((step - n_hold) / n_anneal_steps) + elif decay_type == "cosine": + return ( + final_lr_factor + + (1 - final_lr_factor) + * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) + * 0.5 + ) + elif decay_type == "miror_cosine": + cosine_value = ( + final_lr_factor + + (1 - final_lr_factor) + * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) + * 0.5 + ) + linear_value = final_lr_factor + (1 - final_lr_factor) * ( + 1 - (step - n_hold) / n_anneal_steps + ) + return linear_value * 2 - cosine_value + elif decay_type == "square": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - ((step - n_hold) / n_anneal_steps) ** 2 + ) + + elif decay_type == "sqrt": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - math.sqrt((step - n_hold) / n_anneal_steps) + ) + + else: + raise ValueError( + f"decay type {decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + + else: + return final_lr_factor + + return schedule + + @torch.no_grad() -def eval(model, data_val_iter, device="cpu", max_num_batches=24, ctx=nullcontext(), reset_iterator=False): +def eval( + model, + data_val_iter, + device="cpu", + max_num_batches=24, + ctx=nullcontext(), + reset_iterator=False, +): assert model.training == False - if reset_iterator: # ensure that we always eval on the same batches + if reset_iterator: # ensure that we always eval on the same batches data_val_iter = itertools.cycle(data_val_iter) loss_list_val, acc_list = [], [] @@ -53,7 +165,7 @@ def save_checkpoint( "scheduler": scheduler.state_dict(), "itr": itr, }, - **extra_args + **extra_args, ) torch.save(checkpoint, ckpt_path) From a908d963f1958b98afa6d1cc172616ed3954824f Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 17:23:52 +0300 Subject: [PATCH 12/58] cos_inf and wsd schedules are here --- README.md | 1 + src/config/base.py | 5 +++-- src/distributed/ddp.py | 8 ++------ src/main.py | 12 +++++++++--- src/optim/base.py | 13 ++++++++++--- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 2dc47e2..d5c3b6a 100755 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ parser.add_argument('--multiple_of', default=256, type=int) # used by the llama # logging params (WandB) parser.add_argument('--wandb', action='store_true') # whether to use wandb or not parser.add_argument('--wandb_project', default="my-project", type=str) +parser.add_argument('--wandb_entity', default=None, type=none_or_str) # for the team projects parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences # Distributed args diff --git a/src/config/base.py b/src/config/base.py index 3a6291f..2de272c 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -1,8 +1,9 @@ from typing import Optional -import distributed import torch +import distributed + def none_or_str(value): if value == "None": @@ -129,7 +130,7 @@ def parse_args(base_parser, args, namespace): # logging params (WandB) parser.add_argument("--wandb", action="store_true") # whether to use wandb or not parser.add_argument("--wandb_project", default="my-project", type=str) - parser.add_argument('--wandb_entity', default=None, type=none_or_str) + parser.add_argument("--wandb_entity", default=None, type=none_or_str) parser.add_argument( "--wandb_run_prefix", default="none", type=str ) # is added before the autogenerated experiment name diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index 951dbf3..bf47d25 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -2,12 +2,8 @@ import os from contextlib import contextmanager -from torch.distributed import ( - barrier, - destroy_process_group, - get_world_size, - init_process_group, -) +from torch.distributed import (barrier, destroy_process_group, get_world_size, + init_process_group) from torch.nn.parallel import DistributedDataParallel as DDP from .backend import DistributedBackend diff --git a/src/main.py b/src/main.py index d1fd1b5..f043435 100755 --- a/src/main.py +++ b/src/main.py @@ -8,18 +8,19 @@ import numpy as np import torch -import wandb import config import distributed +import wandb from data.utils import get_dataset -from models.utils import cos_inf_schedule, get_model, wsd_schedule +from models.utils import get_model from optim.ademamix import AdEMAMix from optim.base import train_base from optim.lion import Lion from optim.muon import Muon, zeropower_backends from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree from optim.soap import SOAP +from optim.utils import cos_inf_schedule, wsd_schedule def get_args(): @@ -217,7 +218,12 @@ def main(args): if distributed_backend.is_master_process() and args.wandb: params_copy = copy.deepcopy(vars(args)) del params_copy["device"] - wandb.init(project=args.wandb_project, name=exp_name, config=params_copy, entity=args.wandb_entity) + wandb.init( + project=args.wandb_project, + name=exp_name, + config=params_copy, + entity=args.wandb_entity, + ) ckpt_path = os.path.join( args.results_base_folder, args.dataset, args.model, exp_name diff --git a/src/optim/base.py b/src/optim/base.py index ebbe6ff..7bc3f21 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -8,6 +8,7 @@ import numpy as np import torch import torch.nn.functional as F + import wandb from data.utils import get_dataloader @@ -122,7 +123,9 @@ def train_base( data_train_iter = iter(data["train"]) if extra_args.grad_clip != 0.0: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), extra_args.grad_clip + ) grad_norms.append(grad_norm) opt.step() scheduler.step() @@ -165,8 +168,12 @@ def train_base( logs = { "iter": itr, "train/loss": train_loss, - "train/max_grad_norm": max(grad_norms).item() if grad_norms else 0, - "train/mean_grad_norm": torch.tensor(grad_norms).mean().item() if grad_norms else 0, + "train/max_grad_norm": ( + max(grad_norms).item() if grad_norms else 0 + ), + "train/mean_grad_norm": ( + torch.tensor(grad_norms).mean().item() if grad_norms else 0 + ), "val/loss": val_loss, "val/perplexity": val_perplexity, "val/acc": val_acc, From 2e4bc4ecb0753548b90a837837264fa43f0f9e14 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 17:38:58 +0300 Subject: [PATCH 13/58] codestyle --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d5c3b6a..d844ab5 100755 --- a/README.md +++ b/README.md @@ -142,6 +142,8 @@ src/ Given the above structure, to add your own model, you can just fork the `./src/models/base.py` file, do your modifications, then if necessary fork the `./src/optim/base.py` in case you need some custom training loop or evaluation. You also need to fork the `./src/config/base.py` file to add your own parameters, which imply adding your new config to the mapping `CONFIG_FORMAT_TO_MODULE_MAP` in `./src/config/__init__.py`. To add a new dataset, create a new file in the `data` folder, check `wikitext.py` for the expected format. +**Note:** we use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests. Before committing your code, simply run ```black . && isort .``` and you will be fine. + ## Multi-GPU training Given a multi-GPU machine with e.g. 4 GPUs, one can distribute the training using data-parallelism: From c5345d691b72a0c753d51e772cefafaa3ad79d5d Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 18:02:29 +0300 Subject: [PATCH 14/58] --fix in readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d844ab5..b9dd667 100755 --- a/README.md +++ b/README.md @@ -67,8 +67,8 @@ parser.add_argument("--muon_backend", default="newtonschulz5", type=str) # the c parser.add_argument("--muon_backend_steps", default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta3 in AdEMAMix parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix -parser.add_argument("--adema_beta3_warmup", default=None, type=Optional[int]) # AdEMAMix hyperparameter -parser.add_argument("--adema_alpha_warmup", default=None, type=Optional[int]) # AdEMAMix hyperparameter +parser.add_argument("--adema_beta3_warmup", default=None, type=int) # AdEMAMix hyperparameter +parser.add_argument("--adema_alpha_warmup", default=None, type=int) # AdEMAMix hyperparameter parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulefree hyperparameter parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulefree hyperparameter # Dataset params From 9b2c2e6c054d7d2889a1a9cda43e24d53d162b6e Mon Sep 17 00:00:00 2001 From: mpagli Date: Thu, 17 Oct 2024 16:12:00 +0000 Subject: [PATCH 15/58] fix grad norm display --- src/optim/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/optim/base.py b/src/optim/base.py index 7bc3f21..4daf4f8 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -207,6 +207,8 @@ def train_base( {f"generated-text-{wandb.run.name}": copy.copy(text_table)} ) + grad_norms = [] + model.train() t0 = time.time() if distributed_backend.is_master_process(): From e142dfcd457b7da9dacb4f1895dcd9ad23c95a24 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 17 Oct 2024 20:28:49 +0300 Subject: [PATCH 16/58] removed warmup_percent argument --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index b9dd667..96775b5 100755 --- a/README.md +++ b/README.md @@ -38,12 +38,11 @@ parser.add_argument('--seed', default=0, type=int) # random seed for the paramet parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations -parser.add_argument("--warmup_steps", default=300, type=int) # it was only warmup_percent before +parser.add_argument("--warmup_steps", default=300, type=int) parser.add_argument('--lr', default=1e-3, type=float) parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float) # wsd scheduler parser.add_argument("--wsd_fract_decay", default=0.1, type=float) # wsd scheduler parser.add_argument("--decay_type", default="linear", choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"]) -parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup_steps is iterations * warmup_percent parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter From 513b902264608ce9c1e2ce6934625b0a495c0715 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sat, 19 Oct 2024 01:39:13 +0300 Subject: [PATCH 17/58] schedule-free fix, added scheduer check in optim/base.py --- src/config/base.py | 5 +++++ src/optim/base.py | 7 ++++++- src/optim/muon.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/config/base.py b/src/config/base.py index 2de272c..5b39eef 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -13,6 +13,7 @@ def none_or_str(value): def parse_args(base_parser, args, namespace): parser = base_parser + # General training params parser.add_argument("--batch_size", default=32, type=int) parser.add_argument("--acc_steps", default=4, type=int) @@ -82,6 +83,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--adema_alpha_warmup", default=None, type=int) parser.add_argument("--schedulefree_r", default=0.0, type=float) parser.add_argument("--weight_lr_power", default=2.0, type=float) + # Dataset params parser.add_argument( "--dataset", @@ -100,6 +102,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument( "--data_in_ram", action="store_true" ) # force the data to RAM, mostly useless except for openwebtext2 + # Model params parser.add_argument("--model", default="base", choices=["base", "llama2"]) parser.add_argument( @@ -127,6 +130,7 @@ def parse_args(base_parser, args, namespace): "--run_prefix", default=None, type=str, required=False ) # is added before the autogenerated experiment name parser.add_argument("--exp_name", default=None, type=str, required=False) + # logging params (WandB) parser.add_argument("--wandb", action="store_true") # whether to use wandb or not parser.add_argument("--wandb_project", default="my-project", type=str) @@ -137,6 +141,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument( "--eval_seq_prefix", default="Once upon a time", type=str ) # prefix used to generate sequences + # Distributed args parser.add_argument( "--distributed_backend", diff --git a/src/optim/base.py b/src/optim/base.py index 4daf4f8..b8a5963 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -127,8 +127,11 @@ def train_base( model.parameters(), extra_args.grad_clip ) grad_norms.append(grad_norm) + if extra_args.opt == "sf-sgd" or extra_args.opt == "sf-adamw": + opt.train() opt.step() - scheduler.step() + if extra_args.scheduler != "none": + scheduler.step() opt.zero_grad(set_to_none=True) itr += 1 @@ -141,6 +144,8 @@ def train_base( epoch = substep // num_substeps_per_epoch model.eval() + if extra_args.opt == "sf-sgd" or extra_args.opt == "sf-adamw": + opt.eval() train_loss = loss.detach().cpu().item() * acc_steps current_lr = ( scheduler.get_last_lr()[0] diff --git a/src/optim/muon.py b/src/optim/muon.py index ce31081..e518e87 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -12,7 +12,7 @@ def zeropower_via_svd(G, steps=None): return U @ V.T -# @torch.compile +@torch.compile def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a From 8e2aa6df2b51fb7859c71232ff60b6ef7cbe50ac Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sat, 19 Oct 2024 11:43:49 +0300 Subject: [PATCH 18/58] --fix saving of a checkpoint if scheduler==none --- src/optim/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optim/utils.py b/src/optim/utils.py index cef37b4..d90762e 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -162,7 +162,7 @@ def save_checkpoint( { "model": distributed_backend.get_raw_model(model).state_dict(), "optimizer": opt.state_dict(), - "scheduler": scheduler.state_dict(), + "scheduler": scheduler.state_dict() if scheduler is not None else None, "itr": itr, }, **extra_args, From eda40bfc96c4bc91d4b8b16691b9c93eacda7e24 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sun, 20 Oct 2024 19:00:48 +0000 Subject: [PATCH 19/58] fix requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0153e35..ac0ff91 100755 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ tqdm==4.65.0 transformers wandb datasets -zstandard \ No newline at end of file +zstandard +numpy==1.22.4 \ No newline at end of file From 3649936e44c6e21aff4d29e547ec10445dc46c8a Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sun, 20 Oct 2024 23:34:37 +0300 Subject: [PATCH 20/58] Adam-mini is here --- README.md | 3 + src/config/base.py | 4 + src/main.py | 18 +- src/optim/adammini.py | 452 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 476 insertions(+), 1 deletion(-) create mode 100644 src/optim/adammini.py diff --git a/README.md b/README.md index 96775b5..7a292fb 100755 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ parser.add_argument("--adema_beta3_warmup", default=None, type=int) # AdEMAMix h parser.add_argument("--adema_alpha_warmup", default=None, type=int) # AdEMAMix hyperparameter parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulefree hyperparameter parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulefree hyperparameter +parser.add_argument("--model_sharding", default=None, type=bool) # Adam-mini +parser.add_argument("--adam_mini_verbose", default=False, type=bool) # print all the logs if true # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) parser.add_argument('--vocab_size', default=50304, type=int) @@ -87,6 +89,7 @@ parser.add_argument('--bias', default=False, type=bool) parser.add_argument('--compile', action='store_true') # if true then model is compiled parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2 +parser.add_argument('--n_kv_head', default=None, type=Optional[int]) # logging params (WandB) parser.add_argument('--wandb', action='store_true') # whether to use wandb or not parser.add_argument('--wandb_project', default="my-project", type=str) diff --git a/src/config/base.py b/src/config/base.py index 5b39eef..51a6684 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -58,6 +58,7 @@ def parse_args(base_parser, args, namespace): "lion", "sf-adamw", "sf-sgd", + "adam-mini", ], ) parser.add_argument("--eval_freq", default=200, type=int) # in iterations @@ -83,6 +84,8 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--adema_alpha_warmup", default=None, type=int) parser.add_argument("--schedulefree_r", default=0.0, type=float) parser.add_argument("--weight_lr_power", default=2.0, type=float) + parser.add_argument("--model_sharding", default=None, type=bool) + parser.add_argument("--adam_mini_verbose", default=False, type=bool) # Dataset params parser.add_argument( @@ -126,6 +129,7 @@ def parse_args(base_parser, args, namespace): default=256, type=int, ) + parser.add_argument("--n_kv_head", default=None, type=Optional[int]) parser.add_argument( "--run_prefix", default=None, type=str, required=False ) # is added before the autogenerated experiment name diff --git a/src/main.py b/src/main.py index f043435..8b8ff08 100755 --- a/src/main.py +++ b/src/main.py @@ -14,6 +14,7 @@ import wandb from data.utils import get_dataset from models.utils import get_model +from optim.adammini import Adam_mini from optim.ademamix import AdEMAMix from optim.base import train_base from optim.lion import Lion @@ -85,6 +86,7 @@ def main(args): g["params"] = params optimized_params_cnt += sum([p.numel() for p in g["params"]]) print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,)) + args.world_size = distributed_backend.get_world_size() if args.opt == "adamw": use_fused = (device_type == "cuda") and ( "fused" in inspect.signature(torch.optim.AdamW).parameters @@ -161,6 +163,20 @@ def main(args): r=args.schedulefree_r, weight_lr_power=args.weight_lr_power, ) # without foreach argument + elif args.opt == "adam-mini": + opt = Adam_mini( + device=args.device, + world_size=args.world_size, + named_parameters=model.named_parameters(), # check + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + model_sharding=args.model_sharding, + dim=args.n_embd, + n_heads=args.n_head, + n_kv_heads=args.n_kv_head, + verbose=args.adam_mini_verbose, + ) else: opt = torch.optim.SGD( group_specs, @@ -213,7 +229,7 @@ def main(args): else: scheduler = None - args.world_size = distributed_backend.get_world_size() + # args.world_size = distributed_backend.get_world_size() exp_name = args.exp_name if distributed_backend.is_master_process() and args.wandb: params_copy = copy.deepcopy(vars(args)) diff --git a/src/optim/adammini.py b/src/optim/adammini.py new file mode 100644 index 0000000..c340b31 --- /dev/null +++ b/src/optim/adammini.py @@ -0,0 +1,452 @@ +""" +Here is an original implementation of Adam-mini. +Source: https://github.com/zyushun/Adam-mini +""" + +import math +from typing import Iterable, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._tensor import Replicate + + +class Adam_mini(torch.optim.Optimizer): + def __init__( + self, + device, + world_size, + named_parameters: Iterable[Tuple[str, nn.Parameter]], + lr: Union[float, torch.Tensor] = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + model_sharding: bool = None, + dim: int = 2048, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + verbose=True, + ): + """ + This is the official implementation of Adam-mini (version 1.1.0). + + Paper: [Adam-mini: Use Fewer Learning Rates To Gain More](https://arxiv.org/abs/2406.16793). + + Github repo: https://github.com/zyushun/Adam-mini + + Arguments: + named_parameters ('Iterable[Tuple[str, nn.Parameter]]'): Iterable of named parameters to optimize or dictionaries defining parameter groups. Usually set to model.named_parameters() + + lr (`float`, *optional*, defaults to 0.001): The learning rate to use. + + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): Same as Adam's betas parameters (b1, b2). + + eps (`float`, *optional*, defaults to 1e-06): Same as Adam's epsilon for numerical stability. + + weight_decay (`float`, *optional*, defaults to 0.0): Decoupled weight decay to apply. + + model_sharding (`bool`, *optional*, defaults to None): Set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if otherwise. Due to the historical reason, this argument is deprecated since version 1.0.2. We will assume that model parallelism is always used. We will remove this argument in the future version. + + dim (`int`, *optional*, defaults to 2048): Dimension for hidden features. Can be left unspecified if training non-transformer models. + + n_heads (`int`, *optional*, defaults to 32): Number of attention heads. Can be left unspecified if training non-transformer models. + + n_kv_heads (`int`, *optional*, defaults to None): Number of heads for Key and Value. Or equivalently, number of query groups in Group Query Attention. Also known as "n_query_groups". If not specified, it will be equal to n_head. Can be left unspecified if training non-transformer models. + + verbose (`bool`, *optional*, defaults to True): Print all the logs if true. + Example: + + ```python + optimizer = Adam_mini( + named_parameters = model.named_parameters(), + lr = lr, + betas = (beta1,beta2), + eps = eps, + weight_decay = weight_decay, + dim = model_config.dim, + n_heads = model_config.n_heads, + n_kv_heads = model_config.n_kv_heads, + ) + ``` + + """ + self.named_parameters = named_parameters + self.dim = dim + self.n_heads = n_heads + if n_kv_heads is not None: + assert n_heads % n_kv_heads == 0, f"{n_heads} {n_kv_heads}" + self.n_kv_heads = n_kv_heads + else: + self.n_kv_heads = n_heads + + self.device = device + self.world_size = world_size # torch.cuda.device_count() + self.verbose = verbose + self.check_block_name = True + self.head_numel = self.dim * self.dim // self.n_heads + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not self.dim == int(self.dim): + raise ValueError("Invalid dim value: {}".format(self.dim)) + if not self.n_heads == int(self.n_heads): + raise ValueError("Invalid n_heads value: {}".format(self.n_heads)) + if not self.n_kv_heads == int(self.n_kv_heads): + raise ValueError("Invalid n_kv_heads value: {}".format(self.n_kv_heads)) + + if model_sharding is not None and verbose: + print( + "Warning by Adam-mini: model_sharding is deprecated since version 1.0.2. This argument is always set True. We will remove this argument in the future version." + ) + + # Embedding layer. Use one lr per token + self.embd_names = {"embed", "embd", "wte"} # move to mlp + # Output layers. Use one lr per token + self.output_names = {"lm_head.weight", "output.weight"} # move output to mlp + # Query and Keys. User one lr per head + self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"} + # Values. Use one lr per neuron + # it is okay to set self.wv_names to be empty and use a single lr for the whole v. But this will bring extra all_reduce operations + self.wv_names = {"v_proj.weight", "wv.weight"} + # attn_proj. Use one lr per neuron + self.attn_proj_names = {"o_proj.weight", "wo.weight", "attn.proj.weight"} + # MLPs. Use one lr per neuron + self.mlp_names = {"feed_forward", "linear", "mlp"} + # Blocks that use Adam. For old versions before v.1.1.0, this is for embedding layer and output layer. For the current version, this is empty + self.adam_block_names = {} + + optim_groups = [] + # count_embd = count_output = count_wqk = 0 + for param_name, param in named_parameters: + if not param.requires_grad: + continue + if verbose: + print( + "Adam-mini found the param block with name:", + param_name, + param.size(), + ) + state = {} + state["name"] = param_name + state["params"] = param + if "norm" in param_name or "ln_f" in param_name: + state["weight_decay"] = 0.0 + else: + state["weight_decay"] = weight_decay + + optim_groups.append(state) + + defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], eps=eps) + super().__init__(optim_groups, defaults) + + def count_block(self): + count_embd = 0 + count_output = 0 + count_wqk = 0 + count_wv = 0 + count_attn_proj = 0 + count_mlp = 0 + for param_name, param in self.named_parameters: + if not param.requires_grad: + continue + if any(embd_name in param_name for embd_name in self.embd_names): + count_embd += 1 + if any(output_name in param_name for output_name in self.output_names): + count_output += 1 + if any(wqk_name in param_name for wqk_name in self.wqk_names): + count_wqk += 1 + assert ( + self.dim * self.dim + ) % self.n_heads == 0, f"{self.dim} {self.n_heads}" + if any(wv_name in param_name for wv_name in self.wv_names): + count_wv += 1 + if any( + attn_proj_name in param_name for attn_proj_name in self.attn_proj_names + ): + count_attn_proj += 1 + if any(mlp_name in param_name for mlp_name in self.mlp_names): + count_mlp += 1 + if self.verbose: + print( + f"Adam-mini found {count_embd} embedding layers, {count_output} output layers; {count_wqk} Querys and Keys; {count_wv} Values; {count_attn_proj} attn_proj; {count_mlp} MLPs;" + ) + + if count_embd == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No embedding layer found. If you are training Transformers, please check the name of your embedding layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the keywords in the name of your embedding layer'). " + ) + if count_output == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No output layer found. If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.output_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.output_names.add('the keywords in the name of your output layer'). Please ignore this warning if you are using weight-tying." + ) + if count_wqk == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No Query or Key found. If you are training Transformers, please check the name of your Query and Key in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding two additional lines of code: optimizer.wqk_names.add('the keywords in the name of your Query' ); optimizer.wqk_names.add('the keywords in the name of your Key'). " + ) + + if count_wv == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No Value found. If you are training Transformers, please check the name of your Value in attention blocks and manually add them to 'self.wv_names' of Adam-mini. You can do this by adding an additional lines of code: optimizer.wv_names.add('the keywords in the name of your Value' ). " + ) + + if count_attn_proj == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No attn_proj found. If you are training Transformers, please check the name of your attn_proj in attention blocks and manually add them to 'self.attn_proj_names' of Adam-mini. You can do this by adding an additional lines of code: optimizer.attn_proj_names.add('the keywords in the name of your attn_proj' ). " + ) + + if count_mlp == 0 and self.verbose: + # warning + print( + "=====>>> Warning by Adam-mini: No MLP found. If you are training Transformers, please check the name of your MLP in attention blocks and manually add them to 'self.mlp_names' of Adam-mini. You can do this by adding an additional lines of code: optimizer.attn_proj_names.add('the keywords in the name of your MLP' ). " + ) + + if ( + count_output + + count_embd + + count_wqk + + count_wv + + count_attn_proj + + count_mlp + == 0 + ) and self.verbose: + print( + "=====>>> Warning by Adam-mini: you are using default PyTorch partition for Adam-mini. It can cause training instability on large-scale Transformers." + ) + + @torch.no_grad() + def step(self, closure=None): + if self.check_block_name: + self.count_block() + self.check_block_name = False + + loss = None + device = self.device + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + beta1 = group["beta1"] + beta2 = group["beta2"] + lr = group["lr"] + name = group["name"] + eps = group["eps"] + + for p in group["params"]: + state = self.state[p] + if any( + adam_block_name in name for adam_block_name in self.adam_block_names + ): # For v.1.1.0, we will not enter here + if p.grad is None: + continue + if len(state) == 0: + state["m"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["step"] = 0 + state["v"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + grad = p.grad + state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = lr / bias_correction_1 + p.addcdiv_(state["m"], h, value=-stepsize) + elif any( + wqk_name in name for wqk_name in self.wqk_names + ): # this is for query and key + if p.grad is None: + continue + head_numel = self.head_numel # group["head_numel"] + if len(state) == 0: + m = torch.zeros_like(p, memory_format=torch.preserve_format) + state["m"] = m.view(-1, head_numel) + state["head_per_gpu"] = state["m"].size( + 0 + ) # this is head per gpu + state["step"] = 0 + # NOTE: We must use `zeros_like` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # the following line is equivalent to: state["vmean"] = torch.zeros(state["head"]) + state["vmean"] = torch.zeros_like( + state["m"][0 : state["head_per_gpu"], 0:1], + memory_format=torch.preserve_format, + ) + + grad = p.grad # .to(torch.float32) + head_per_gpu = state["head_per_gpu"] + grad = grad.view(head_per_gpu, head_numel) + tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) + + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = ((1 / bias_correction_1) / h).view(head_per_gpu, 1) + update = (state["m"] * stepsize).view(p.size()) + update.mul_(lr) + p.add_(-update) + elif ( + any(embd_name in name for embd_name in self.embd_names) + or any(output_name in name for output_name in self.output_names) + or any(wv_name in name for wv_name in self.wv_names) + or any(mlp_name in name for mlp_name in self.mlp_names) + or any( + attn_proj_name in name + for attn_proj_name in self.attn_proj_names + ) + ): + if p.grad is None: + continue + # neuron_numel = group["neuron_numel"] # assume grad is a matrix by default, so do not need this + if len(state) == 0: + state["m"] = torch.zeros_like( + p.grad, memory_format=torch.preserve_format + ) # assume grad is a matrix by default, no need to view + # state["m"] = torch.zeros_like(p, memory_format=torch.preserve_format).view(-1, neuron_numel) + state["step"] = 0 + state["neuron_per_gpu"] = state["m"].size( + 0 + ) # this is neuron per gpu + # NOTE: We must use `new_zeros` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # for standard tensor: state["vmean"] = torch.zeros(1, device=p.device) + # for DTensor: state["vmean"] = p.new_zeros(1) + # the following implementation unifies the above two lines + state["vmean"] = torch.zeros_like( + state["m"][0 : state["neuron_per_gpu"], 0:1], + memory_format=torch.preserve_format, + ) + + grad = p.grad # .to(torch.float32) + neuron_per_gpu = state["neuron_per_gpu"] + # grad = grad.view(neuron_per_gpu, neuron_numel) # assume grad is a matrix by default, so no need to reshape + tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) + + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + state["step"] += 1 + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = ((1 / bias_correction_1) / h).view(neuron_per_gpu, 1) + update = (state["m"] * stepsize).view(p.size()) + update.mul_(lr) + p.add_(-update) + + else: # other blocks. By default, this is for LayerNorms. Sometimes it is also fine to put Value here + if len(state) == 0: + block_numel = ( + torch.tensor(p.numel()).to(torch.float32).to(device) + ) + reduced = False + if self.world_size > 1: + tensor_list = [ + torch.zeros_like(block_numel) + for _ in range(self.world_size) + ] + + dist.all_gather(tensor_list, block_numel) + s = 0 + block_numel = 0 + for d in tensor_list: + if d > 0: + s = s + 1 + block_numel = block_numel + d + if s >= 2: + reduced = True + + state["m"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["step"] = 0 + state["reduced"] = reduced + # NOTE: We must use `new_zeros` for vmean to be a + # DTensor (not `torch.Tensor`) for DTensor parameters. + # For standard tensor: state["vmean"] = torch.zeros(1, device=p.device) + # For DTensor: state["vmean"] = p.new_zeros(1) + # the following implementation unifies the above two lines + state["vmean"] = torch.zeros_like( + torch.sum(p * p), memory_format=torch.preserve_format + ) + state["block_numel"] = block_numel.item() + if p.grad is None: + tmp_lr = torch.zeros_like(torch.sum(p * p)) + else: + grad = p.grad # .to(torch.float32) + tmp_lr = torch.sum(grad * grad) + + if state["reduced"]: + # Force communication over GPUs when GPUs are available + if tmp_lr.device.type == "cpu": + # Move the tensor to the current GPU device + tmp_lr_gpu = tmp_lr.to(torch.cuda.current_device()) + + if "device_mesh" in dir(tmp_lr): + # when tmp_lr is a DTensor in TorchTitan + lr_local = tmp_lr.to_local() + dist.all_reduce(lr_local, op=dist.ReduceOp.SUM) + tmp_lr.redistribute(placements=[Replicate()]) + else: + # when tmp_lr is a standard tensor + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + # Move the result back to the CPU tensor + tmp_lr.copy_(tmp_lr_gpu.cpu()) + else: + # Tensor is already on GPU, use NCCL backend + if "device_mesh" in dir(tmp_lr): + # when tmp_lr is a DTensor in TorchTitan + lr_local = tmp_lr.to_local() + dist.all_reduce(lr_local, op=dist.ReduceOp.SUM) + tmp_lr.redistribute(placements=[Replicate()]) + else: + # when tmp_lr is a standard tensor + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) + + if p.grad is None: + continue + tmp_lr = tmp_lr / state["block_numel"] + + if group["weight_decay"] > 0.0: + p.mul_(1 - lr * group["weight_decay"]) + state["step"] += 1 + state["m"].lerp_(grad, 1 - beta1) + bias_correction_1 = 1 - beta1 ** state["step"] + bias_correction_2 = 1 - beta2 ** state["step"] + bias_correction_2_sqrt = math.sqrt(bias_correction_2) + state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) + h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) + stepsize = (1 / bias_correction_1) / h + update = state["m"] * (stepsize.to(state["m"].device)) + update.mul_(lr) + p.add_(-update) + + return loss From e75b3be70877f871dd9b6783d3981f9153dcde6e Mon Sep 17 00:00:00 2001 From: mpagli Date: Mon, 21 Oct 2024 11:08:28 +0000 Subject: [PATCH 21/58] refactoring --- requirements.txt | 8 +- src/config/base.py | 201 ++++++++--------- src/data/arxiv.py | 21 +- src/data/openwebtext2.py | 20 +- src/data/redpajama.py | 124 +++++++++++ src/data/shakespeare.py | 14 +- src/data/slimpajama.py | 39 ++-- src/data/utils.py | 188 +++++++++++----- src/data/wikitext.py | 20 +- src/distributed/backend.py | 5 +- src/distributed/ddp.py | 17 +- src/distributed/single.py | 6 +- src/main.py | 316 +++++++++++++++++---------- src/models/base.py | 107 +++++---- src/models/llama.py | 61 +++--- src/models/test.py | 232 ++++++++++++++++++++ src/models/utils.py | 24 +- src/optim/ademamix2.py | 187 ++++++++++++++++ src/optim/base.py | 433 +++++++++++++++++++------------------ src/optim/utils.py | 165 ++++++++++++-- 20 files changed, 1525 insertions(+), 663 deletions(-) create mode 100644 src/data/redpajama.py create mode 100644 src/models/test.py create mode 100644 src/optim/ademamix2.py diff --git a/requirements.txt b/requirements.txt index 0153e35..b18b519 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ tiktoken --find-links https://download.pytorch.org/whl/torch_stable.html -torch==2.0.0+cu118 -torchaudio==2.0.0+cu118 -torchvision==0.15.0+cu118 -tqdm==4.65.0 +torch +torchaudio +torchvision +tqdm transformers wandb datasets diff --git a/src/config/base.py b/src/config/base.py index 2de272c..3075874 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -1,7 +1,3 @@ -from typing import Optional - -import torch - import distributed @@ -14,37 +10,69 @@ def none_or_str(value): def parse_args(base_parser, args, namespace): parser = base_parser # General training params - parser.add_argument("--batch_size", default=32, type=int) - parser.add_argument("--acc_steps", default=4, type=int) + parser.add_argument("--run_prefix", default=None, type=str) + parser.add_argument("--experiment_name", default=None, type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--data_seed", default=1337, type=int) + parser.add_argument("--eval_interval", default=200, type=int) + parser.add_argument("--full_eval_at", nargs="+", type=int) + parser.add_argument("--eval_batches", default=32, type=int) parser.add_argument("--device", default="cuda:0", type=str) - parser.add_argument("--iterations", default=25000, type=int) parser.add_argument( - "--warmup_steps", default=300, type=int - ) # it was only warmup_percent before - parser.add_argument("--lr", default=1e-3, type=float) + "--distributed_backend", + default=None, + type=str, + required=False, + choices=distributed.registered_backends(), + ) + parser.add_argument("--log_interval", default=50, type=int) + + # Checkpointing + parser.add_argument("--results_base_folder", default="./exps", type=str) + parser.add_argument("--permanent_ckpt_interval", default=0, type=int) + parser.add_argument("--latest_ckpt_interval", default=0, type=int) + parser.add_argument("--resume_from", default=None, type=str) + parser.add_argument("--resume_from_swa", default=None, type=str) + + parser.add_argument("--auto_resume", default=True) + + # logging params (WandB) + parser.add_argument("--wandb", action="store_true") # whether to use wandb or not + parser.add_argument("--wandb_project", default="my-project", type=str) parser.add_argument( - "--wsd_final_lr_scale", default=0.0, type=float - ) # wsd scheduler - parser.add_argument("--wsd_fract_decay", default=0.1, type=float) # wsd scheduler + "--wandb_run_prefix", default="none", type=str + ) # is added before the autogenerated experiment name parser.add_argument( - "--decay_type", - default="linear", - choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], - ) + "--eval_seq_prefix", default="none", type=str + ) # prefix used to generate sequences + parser.add_argument("--log_dynamics", action="store_true") parser.add_argument( - "--warmup_percent", default=0.05, type=float - ) # leave it anyway, warmup_steps / iterations - parser.add_argument("--weight_decay", default=0.1, type=float) - parser.add_argument("--beta1", default=0.9, type=float) - parser.add_argument("--beta2", default=0.95, type=float) + "--dynamics_logger_cfg", default="./src/logger/rotational_logger.yaml", type=str + ) + parser.add_argument("--wandb_entity", default=None, type=none_or_str) + + # Schedule parser.add_argument( "--scheduler", default="cos", - choices=["linear", "cos", "wsd", "cos_inf", "none"], + choices=["linear", "cos", "wsd", "none", "cos_inf"], ) parser.add_argument("--cos_inf_steps", default=0, type=int) + # parser.add_argument("--cos-final-lr", default=1e-6, type=float) + parser.add_argument("--iterations", default=15000, type=int) + parser.add_argument("--warmup_steps", default=3000, type=int) + parser.add_argument("--lr", default=1e-3, type=float) + # wsd + parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float) + parser.add_argument("--wsd_fract_decay", default=0.1, type=float) + # parser.add_argument("--wsd-exponential-decay", action="store_true") + parser.add_argument( + "--decay_type", + default="linear", + choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], + ) + + # Optimization parser.add_argument( "--opt", default="adamw", @@ -54,15 +82,19 @@ def parse_args(base_parser, args, namespace): "muon", "soap", "ademamix", + "ademamix2", "lion", "sf-adamw", "sf-sgd", ], ) - parser.add_argument("--eval_freq", default=200, type=int) # in iterations - parser.add_argument("--results_base_folder", default="./exps", type=str) + parser.add_argument("--batch_size", default=50, type=int) + parser.add_argument("--acc_steps", default=1, type=int) + parser.add_argument("--weight_decay", default=1e-1, type=float) + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.95, type=float) parser.add_argument( - "--grad_clip", default=0.0, type=float + "--grad_clip", default=1.0, type=float ) # default value is 1.0 in NanoGPT parser.add_argument("--momentum", default=0.9, type=float) parser.add_argument("--shampoo_beta", default=-1.0, type=float) @@ -82,119 +114,70 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--adema_alpha_warmup", default=None, type=int) parser.add_argument("--schedulefree_r", default=0.0, type=float) parser.add_argument("--weight_lr_power", default=2.0, type=float) + # Dataset params + parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") parser.add_argument( "--dataset", default="slimpajama", choices=[ - "slimpajama", "wikitext", "shakespeare-char", "arxiv", "arxiv2000", "arxiv+wiki", "openwebtext2", + "redpajama", + "slimpajama", + "slimpajama_chunk1", + "redpajamav2", ], ) + parser.add_argument( + "--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"] + ) parser.add_argument("--vocab_size", default=50304, type=int) parser.add_argument( "--data_in_ram", action="store_true" ) # force the data to RAM, mostly useless except for openwebtext2 + # Model params - parser.add_argument("--model", default="base", choices=["base", "llama2"]) parser.add_argument( - "--use_pretrained", default="auto", type=none_or_str + "--model", + default="llama", + choices=[ + "base", + "llama", + "test", + ], + ) + parser.add_argument("--parallel_block", action="store_true") + parser.add_argument( + "--use_pretrained", default="none", type=str ) # 'none', 'gpt-2' or a path to the pretraind model + parser.add_argument("--from_dense", action="store_true") + parser.add_argument("--init_std", default=0.02, type=float) parser.add_argument("--dropout", default=0.0, type=float) parser.add_argument("--n_head", default=12, type=int) - parser.add_argument("--n_layer", default=12, type=int) # depths in att + ff blocks - parser.add_argument( - "--n_embd", default=768, type=int - ) # embedding size / hidden size ... + parser.add_argument("--n_layer", default=24, type=int) # depths in att + ff blocks parser.add_argument("--sequence_length", default=512, type=int) - parser.add_argument("--dtype", default=torch.bfloat16, type=torch.dtype) - parser.add_argument("--bias", default=False, type=bool) parser.add_argument( - "--compile", action="store_true" - ) # if true then model is compiled - parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) + "--n_embd", default=768, type=int # embedding size / hidden size ... + ) parser.add_argument( "--multiple_of", # make SwiGLU hidden layer size multiple of large power of 2 default=256, type=int, ) + parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) parser.add_argument( - "--run_prefix", default=None, type=str, required=False - ) # is added before the autogenerated experiment name - parser.add_argument("--exp_name", default=None, type=str, required=False) - # logging params (WandB) - parser.add_argument("--wandb", action="store_true") # whether to use wandb or not - parser.add_argument("--wandb_project", default="my-project", type=str) - parser.add_argument("--wandb_entity", default=None, type=none_or_str) - parser.add_argument( - "--wandb_run_prefix", default="none", type=str - ) # is added before the autogenerated experiment name - parser.add_argument( - "--eval_seq_prefix", default="Once upon a time", type=str - ) # prefix used to generate sequences - # Distributed args - parser.add_argument( - "--distributed_backend", - default=None, + "--dtype", + default="bfloat16", type=str, - required=False, - choices=distributed.registered_backends(), - ) # distributed backend type - parser.add_argument( - "--save_checkpoint_freq", default=None, type=int, required=False + choices=["float32", "float16", "bfloat16"], ) + parser.add_argument("--bias", default=False, type=bool) + parser.add_argument("--compile", action="store_true") + parser.add_argument("--mlp_dim_exp_factor", default=1.0, type=float) - args = parser.parse_args(args, namespace) - - if args.exp_name is None: - special_name_handle_fields = { - "model", - "lr", - "batch_size", - "acc_steps", - "seed", - "exp_name", - "wandb", - "wandb_project", - "eval_seq_prefix", - "run_prefix", - "distributed_backend", - "config_format", - "sequence_length", - } - overriden_values = [] - for key in vars(args): - if key in special_name_handle_fields: - continue - if getattr(args, key) != parser.get_default(key): - overriden_values.append((key, getattr(args, key))) - chunk_len = 10 - overriden_values_str_parts = [] - for chunk_id in range(0, len(overriden_values), chunk_len): - overriden_values_str = "_".join( - [ - "{}={}".format(key, value) - for key, value in overriden_values[chunk_id : chunk_id + chunk_len] - ] - ) - overriden_values_str_parts.append(overriden_values_str) - overriden_values_str = "/".join(overriden_values_str_parts) - exp_name = "" - if args.run_prefix is not None: - exp_name += f"{args.run_prefix}_" - exp_name += f"{args.model}_lr{args.lr}_bs{args.batch_size}x{args.acc_steps}_seqlen{args.sequence_length}/{overriden_values_str}_seed={args.seed}" - args.exp_name = exp_name - - if args.dtype == "torch.bfloat16": - args.dtype = torch.bfloat16 - elif args.dtype == "torch.float16": - args.dtype = torch.float16 - elif args.dtype == "torch.float32": - args.dtype = torch.float32 - - return args + return parser.parse_args(args, namespace) diff --git a/src/data/arxiv.py b/src/data/arxiv.py index e9de234..91f9a71 100644 --- a/src/data/arxiv.py +++ b/src/data/arxiv.py @@ -1,16 +1,17 @@ -import logging import os import tarfile -from multiprocessing import Pool +import logging from pathlib import Path -from subprocess import PIPE, Popen, TimeoutExpired +from typing import Optional +from multiprocessing import Pool from tempfile import NamedTemporaryFile -from typing import List, Optional, Tuple +from subprocess import Popen, TimeoutExpired, PIPE +from typing import Tuple, List import numpy as np import requests -import tiktoken from tqdm.auto import tqdm +import tiktoken def convert_to_markdown(args: Tuple[Path, Path]): @@ -82,7 +83,7 @@ def tokenize_arxiv(root: Path, year: int): def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): - all_years = list(range(1992, 2004)) + all_years = list(range(1993, 2004)) # 1992 seems to give some problem if years is None: years = all_years assert set(years) <= set(all_years) @@ -108,9 +109,9 @@ def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): return ret -def get_arxiv_2000(): - return load_arxiv(Path(os.path.dirname(__file__)) / "datasets", [2000]) +def get_arxiv_2000(datasets_base_dir): + return load_arxiv(Path(datasets_base_dir), [2000]) -def get_arxiv_full(): - return load_arxiv(Path(os.path.dirname(__file__)) / "datasets") +def get_arxiv_full(datasets_base_dir): + return load_arxiv(Path(datasets_base_dir)) \ No newline at end of file diff --git a/src/data/openwebtext2.py b/src/data/openwebtext2.py index 65eea73..d30e7cc 100644 --- a/src/data/openwebtext2.py +++ b/src/data/openwebtext2.py @@ -1,16 +1,16 @@ import os - +from tqdm import tqdm import numpy as np import tiktoken from datasets import load_dataset -from tqdm import tqdm -OWT2_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/openwebtext2/") + tknzr = tiktoken.get_encoding("gpt2") -def get_openwebtext2_data(num_proc=40): +def get_openwebtext2_data(datasets_base_dir, num_proc=40): """https://openwebtext2.readthedocs.io/en/latest/""" + OWT2_DATA_PATH = os.path.join(datasets_base_dir, "openwebtext2/") if not os.path.exists(os.path.join(OWT2_DATA_PATH, "train.bin")): os.makedirs(OWT2_DATA_PATH, exist_ok=True) dataset = load_dataset("the_pile_openwebtext2") @@ -59,11 +59,7 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap( - os.path.join(OWT2_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(OWT2_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) - - return {"train": train_data, "val": val_data} + return { + "train": os.path.join(OWT2_DATA_PATH, "train.bin"), + "val": os.path.join(OWT2_DATA_PATH, "val.bin"), + } \ No newline at end of file diff --git a/src/data/redpajama.py b/src/data/redpajama.py new file mode 100644 index 0000000..93a70f3 --- /dev/null +++ b/src/data/redpajama.py @@ -0,0 +1,124 @@ +import os +from tqdm import tqdm +import numpy as np +import tiktoken +from datasets import load_dataset + + +tknzr = tiktoken.get_encoding("gpt2") + + +def get_redpajama_data(datasets_dir, num_proc=40): + RPJ_DATA_PATH = os.path.join(datasets_dir, "redpajama1Tsample/") + if not os.path.exists(os.path.join(RPJ_DATA_PATH, "train.bin")): + os.makedirs(RPJ_DATA_PATH, exist_ok=True) + dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample") + + split_dataset = dataset["train"].train_test_split( + test_size=0.0005, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(RPJ_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + train_data = np.memmap( + os.path.join(RPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" + ) + val_data = np.memmap( + os.path.join(RPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" + ) + + return {"train": train_data, "val": val_data} + + +def get_redpajamav2_data(datasets_dir, num_proc=40): + """https://openwebtext2.readthedocs.io/en/latest/""" + RPJ_V2_DATA_PATH = os.path.join(datasets_dir, "redpajamaV2sample/") + if not os.path.exists(os.path.join(RPJ_V2_DATA_PATH, "train.bin")): + os.makedirs(RPJ_V2_DATA_PATH, exist_ok=True) + dataset = load_dataset("togethercomputer/RedPajama-Data-V2", name="sample") + + split_dataset = dataset["train"].train_test_split( + test_size=0.0005, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["raw_content"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["raw_content"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(RPJ_V2_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + return { + "train": os.path.join(RPJ_V2_DATA_PATH, "train.bin"), + "val": os.path.join(RPJ_V2_DATA_PATH, "val.bin"), + } \ No newline at end of file diff --git a/src/data/shakespeare.py b/src/data/shakespeare.py index 21c607e..87ce7e6 100644 --- a/src/data/shakespeare.py +++ b/src/data/shakespeare.py @@ -4,6 +4,7 @@ import numpy as np import requests + _char_decode = dict( enumerate(sorted(set(ascii_letters + digits + punctuation + " \n"))) ) @@ -14,11 +15,9 @@ def char_tknzr(txt: str): return [_char_encode[char] for char in txt if char in _char_encode] -DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets", "shakespeare") - - -def get_shakespeare_data(): +def get_shakespeare_data(datasets_dir): """Inspired from https://github.com/karpathy/nanoGPT/""" + DATA_PATH = os.path.join(datasets_dir, "shakespeare") raw_path = os.path.join(DATA_PATH, "raw.txt") train_path = os.path.join(DATA_PATH, f"train.npy") test_path = os.path.join(DATA_PATH, f"test.npy") @@ -48,8 +47,7 @@ def get_shakespeare_data(): mem = np.memmap(test_path, dtype=np.uint16, mode="w+", shape=x_test.shape) mem[:] = x_test - # at this point we know that the binfile was properly created so we load it return { - "train": np.memmap(train_path, dtype=np.uint16, mode="r"), - "val": np.memmap(test_path, dtype=np.uint16, mode="r"), - } + "train": train_path, + "val": test_path, + } \ No newline at end of file diff --git a/src/data/slimpajama.py b/src/data/slimpajama.py index 5bd0600..4126f6c 100644 --- a/src/data/slimpajama.py +++ b/src/data/slimpajama.py @@ -1,18 +1,15 @@ -import os - +from tqdm import tqdm import numpy as np import tiktoken from datasets import load_dataset -from tqdm import tqdm - -SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/") -SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") +import os tknzr = tiktoken.get_encoding("gpt2") -def get_slimpajama_data(num_proc=40): +def get_slimpajama_data(datasets_dir, num_proc=40): + SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/") if not os.path.exists(os.path.join(SPJ_DATA_PATH, "train.bin")): os.makedirs(SPJ_DATA_PATH, exist_ok=True) dataset = load_dataset("DKYoon/SlimPajama-6B") @@ -60,17 +57,15 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) + return { + "train": os.path.join(SPJ_DATA_PATH, "train.bin"), + "val": os.path.join(SPJ_DATA_PATH, "val.bin"), + } - return {"train": train_data, "val": val_data} - -def get_slimpajama_chunk1(num_proc=40): +def get_slimpajama_chunk1(datasets_dir, num_proc=40): + SPJ_DATA_PATH = os.path.join(datasets_dir, "slimpajama6B/") + SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")): os.makedirs(SPJ_DATA_PATH, exist_ok=True) dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1") @@ -118,11 +113,7 @@ def process(example): idx += len(arr_batch) arr.flush() - train_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) - - return {"train": train_data, "val": val_data} + return { + "train": os.path.join(SPJ_DATA_PATH, "train.bin"), + "val": os.path.join(SPJ_DATA_PATH, "val.bin"), + } \ No newline at end of file diff --git a/src/data/utils.py b/src/data/utils.py index ec5604b..bbeee55 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -1,13 +1,15 @@ -from typing import Dict - +from pathlib import Path import numpy as np +from typing import Dict import torch +import torch.distributed as dist +from .shakespeare import get_shakespeare_data +from .wikitext import get_wikitext_data from .arxiv import get_arxiv_2000, get_arxiv_full from .openwebtext2 import get_openwebtext2_data -from .shakespeare import get_shakespeare_data +from .redpajama import get_redpajama_data, get_redpajamav2_data from .slimpajama import get_slimpajama_data -from .wikitext import get_wikitext_data def get_dataset(args) -> Dict[str, np.ndarray]: @@ -16,76 +18,160 @@ def get_dataset(args) -> Dict[str, np.ndarray]: containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ if args.dataset == "wikitext": - return get_wikitext_data() + return get_wikitext_data(args.datasets_dir) if args.dataset == "shakespeare-char": - return get_shakespeare_data() + return get_shakespeare_data(args.datasets_dir) if args.dataset == "arxiv2000": - return get_arxiv_2000() + return get_arxiv_2000(args.datasets_dir) if args.dataset == "arxiv": - return get_arxiv_full() + return get_arxiv_full(args.datasets_dir) if args.dataset == "arxiv+wiki": - arxiv_data = get_arxiv_full() - wiki_data = get_wikitext_data() + arxiv_data = get_arxiv_full(args.datasets_dir) + wiki_data = get_wikitext_data(args.datasets_dir) train_data = np.concatenate((arxiv_data["train"], wiki_data["train"])) val_data = np.concatenate((arxiv_data["val"], wiki_data["val"])) return {"train": train_data, "val": val_data} if args.dataset == "openwebtext2": - return get_openwebtext2_data() + return get_openwebtext2_data(args.datasets_dir) + if args.dataset == "redpajama": + return get_redpajama_data(args.datasets_dir) + if args.dataset == "redpajamav2": + return get_redpajamav2_data(args.datasets_dir) if args.dataset == "slimpajama": - return get_slimpajama_data() + return get_slimpajama_data(args.datasets_dir) else: raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") -class Dataset(torch.utils.data.Dataset): - def __init__(self, data, sequence_length): - super().__init__() - self.data = data +class DataReader: + def __init__( + self, + data_src, + batch_size, + sequence_length, + seed=1337, + with_replacement=False, + auto_shard=True, + keep_in_ram=False, + ): + if isinstance(data_src, (str, Path)): + self.data_path = Path(data_src) + self.keep_in_ram = keep_in_ram + if keep_in_ram: + self.data = np.array( + np.memmap(self.data_path, dtype=np.uint16, mode="r") + ) + else: + self.data = None + elif isinstance(data_src, (np.ndarray, np.memmap)): + self.data_path = None + self.data = data_src + self.keep_in_ram = True + + self.batch_size = batch_size self.sequence_length = sequence_length + self.seed = seed + self.with_replacement = with_replacement + + self.num_tokens = len(self._get_data()) + + if auto_shard and dist.is_initialized(): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + print( + f"Distributed DataReader Initialized for Worker {self.rank}/{self.world_size}" + ) + else: + self.world_size = 1 + self.rank = 0 + + # Sampling without replacement + self.last_epoch = None + self.order = None + self.epoch_offset = None + self.step = 0 + self.num_batches_of_seqlen = 0 + if not with_replacement: + self._shuffle_epoch(0) def __len__(self): - total_length = len(self.data) - # chunk the data into sequences of length `sequence_length` - # NOTE: we discard the last remainding sequence if it's not of length `sequence_length` - return (total_length - 1) // self.sequence_length + # Length in valid start indices for a sequence + # Extra -1 to have a valid next token for the final token of the last idx + return self.num_tokens - self.sequence_length - 1 - def __getitem__(self, idx): - seq_length = self.sequence_length - idx = idx * seq_length - x = torch.from_numpy((self.data[idx : idx + seq_length]).astype(np.int64)) + def _get_data(self): + if self.data is not None: + return self.data + else: + # Construct the memmap each time to avoid a memory leak per NanoGPT + # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 + return np.memmap(self.data_path, dtype=np.uint16, mode="r") + def __getitem__(self, idx): + # Return the underlying datapoint, no random sampling, no worker sharding + assert 0 <= idx < len(self) + data = self._get_data() + x = torch.from_numpy(data[idx : idx + self.sequence_length].astype(np.int64)) y = torch.from_numpy( - (self.data[idx + 1 : idx + 1 + seq_length]).astype(np.int64) + data[idx + 1 : idx + self.sequence_length + 1].astype(torch.int64) ) return x, y + def set_step(self, step): + self.step = step -def get_dataloader(data, sequence_length, batch_size, seed=0, distributed_backend=None): - """Create a DataLoader for the given data. If distributed_backend is provided and is truly - distributed (world size > 1), the DataLoader will be created with a DistributedSampler that - splits the data across the processes (in conjunction with DDP). - Otherwise, use a RandomSampler with the specified seed. + def sample_batch(self): + data = self._get_data() - Returns both the dataloader and the sampler. - """ - dataset = Dataset(data, sequence_length=sequence_length) - if distributed_backend and distributed_backend.get_world_size() > 1: - sampler = torch.utils.data.DistributedSampler( - dataset, - shuffle=True, - seed=seed, - ) - else: - g = torch.Generator() - g.manual_seed(seed) - sampler = torch.utils.data.RandomSampler( - dataset, replacement=False, generator=g + if self.with_replacement: + idxs = self._sample_with_replacement(self.step) + else: + idxs = self._sample_without_replacement(self.step) + self.step += 1 + + xy = np.stack([data[i : i + self.sequence_length + 1] for i in idxs]).astype( + np.int64 ) + x = torch.from_numpy(xy[:, :-1]).contiguous() + y = torch.from_numpy(xy[:, 1:]).contiguous() + return x, y + + def _sample_with_replacement(self, idx): + # Return an array of token indices of length self.batch_size + # Sampled with replacement, can get repeats at any time + seed = self.seed + idx * self.world_size + self.rank + rng = np.random.default_rng(seed) + return rng.integers(len(self), self.batch_size) + + def _shuffle_epoch(self, epoch): + seed = self.seed + epoch + rng = np.random.default_rng(seed) + # Drop one sequence to allow different offsets per epoch: + self.order = rng.permutation((len(self)) // self.sequence_length - 1) + # Shift all sequences in this epoch by this amount: + self.epoch_offset = rng.integers(self.sequence_length) + self.last_epoch = epoch + self.num_batches_of_seqlen = ( + len(self.order) // self.batch_size + ) # Drops remainder batch + + def _sample_without_replacement(self, step): + # Return an array of token indices of length self.batch_size + # Sampled without replacement, cycle all sequences before potential repeats + # Sequences are randomly offset in every epoch as well + batch_idx = self.world_size * step + self.rank + epoch_length = self.num_batches_of_seqlen + + epoch = batch_idx // epoch_length + if epoch != self.last_epoch: + self._shuffle_epoch(epoch) + epoch_idx = batch_idx % epoch_length + + start = epoch_idx * self.batch_size + end = start + self.batch_size + return self.order[start:end] * self.sequence_length + self.epoch_offset - loader = torch.utils.data.DataLoader( - dataset, - sampler=sampler, - batch_size=batch_size, - num_workers=4, - ) - return loader, sampler + def num_batches(self): + if self.with_replacement: + return self.num_tokens // self.batch_size + return self.num_batches_of_seqlen \ No newline at end of file diff --git a/src/data/wikitext.py b/src/data/wikitext.py index 3bc4b03..8964fad 100755 --- a/src/data/wikitext.py +++ b/src/data/wikitext.py @@ -1,15 +1,13 @@ import os -import urllib import zipfile - +import urllib import numpy as np import tiktoken -WIKITEXT_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/wikitext/") - -def get_wikitext_data(): +def get_wikitext_data(datasets_base_dir): """Inspired from https://github.com/tysam-code/hlb-gpt""" + WIKITEXT_DATA_PATH = os.path.join(datasets_base_dir, "wikitext/") if not os.path.exists(WIKITEXT_DATA_PATH): os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True) print("downloading data and tokenizing (1-2 min)") @@ -44,11 +42,7 @@ def get_wikitext_data(): eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, "val.bin")) print("completed the tokenization process!") - train_data = np.memmap( - os.path.join(WIKITEXT_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" - ) - val_data = np.memmap( - os.path.join(WIKITEXT_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" - ) - - return {"train": train_data, "val": val_data} + return { + "train": os.path.join(WIKITEXT_DATA_PATH, "train.bin"), + "val": os.path.join(WIKITEXT_DATA_PATH, "val.bin"), + } \ No newline at end of file diff --git a/src/distributed/backend.py b/src/distributed/backend.py index bcca248..06d37e8 100644 --- a/src/distributed/backend.py +++ b/src/distributed/backend.py @@ -30,7 +30,4 @@ def get_world_size(self): raise NotImplementedError def finalize(self): - pass - - def sync(self): - pass + pass \ No newline at end of file diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index bf47d25..664f069 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -1,10 +1,9 @@ -import math import os +import math from contextlib import contextmanager -from torch.distributed import (barrier, destroy_process_group, get_world_size, - init_process_group) from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group, get_world_size from .backend import DistributedBackend @@ -21,12 +20,6 @@ def __init__(self, args): def get_adjusted_args_for_process(self, args): effective_batch_size = args.batch_size * args.acc_steps world_size = self.get_world_size() - if args.acc_steps % world_size != 0: - raise ValueError( - f"Number of accumulation steps " - "{args.acc_steps} is not divisible " - "by the world size {world_size}." - ) if effective_batch_size % world_size != 0: raise ValueError( f"Effective batch size " @@ -38,6 +31,7 @@ def get_adjusted_args_for_process(self, args): args.batch_size = args.batch_size // (world_size // acc_steps_div) args.device = f"cuda:{self.local_rank}" args.seed = args.seed + self.local_rank + args.data_seed = args.data_seed return args def transform_model(self, model): @@ -65,7 +59,4 @@ def get_world_size(self): return get_world_size() def finalize(self): - destroy_process_group() - - def sync(self): - barrier() + destroy_process_group() \ No newline at end of file diff --git a/src/distributed/single.py b/src/distributed/single.py index b852988..7e1f8d5 100644 --- a/src/distributed/single.py +++ b/src/distributed/single.py @@ -5,6 +5,10 @@ class SinlgeNodeBackend(DistributedBackend): + def __init__(self, args): + super().__init__(args) + self.rank = 0 + def transform_model(self, model): return model @@ -24,4 +28,4 @@ def get_world_size(self): return 1 def translate_model_parameter_name_for_node(self, parameter_name): - return [parameter_name] + return [parameter_name] \ No newline at end of file diff --git a/src/main.py b/src/main.py index f043435..88bd4e3 100755 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ import os import random import sys +from pathlib import Path import numpy as np import torch @@ -12,10 +13,11 @@ import config import distributed import wandb -from data.utils import get_dataset +from data.utils import DataReader, get_dataset from models.utils import get_model from optim.ademamix import AdEMAMix -from optim.base import train_base +from optim.ademamix2 import AdEMAMix2 +from optim.base import train from optim.lion import Lion from optim.muon import Muon, zeropower_backends from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree @@ -31,44 +33,57 @@ def get_args(): args, rem_args = parser.parse_known_args() - return config.parse_args_with_format( + final_args = config.parse_args_with_format( format=args.config_format, base_parser=parser, args=rem_args, namespace=args ) + return final_args, parser -def main(args): - torch.backends.cuda.matmul.allow_tf32 = ( - True # allows us to make sure we're able to use tensorfloat32 during training - ) - torch.backends.cudnn.allow_tf32 = True +def main(args, parser): distributed_backend = distributed.make_backend_from_args(args) args = distributed_backend.get_adjusted_args_for_process(args) + args.world_size = distributed_backend.get_world_size() - args.device = torch.device(args.device) - device_type = "cuda" if "cuda" in str(args.device) else "cpu" - if device_type == "cuda": - torch.cuda.set_device(args.device) + if args.full_eval_at is None: + args.full_eval_at = [] + # NOTE args.seed is offset per worker in get_adjusted_args_for_process + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) + if "cuda" in args.device: + torch.cuda.set_device(torch.device(args.device)) + # torch.use_deterministic_algorithms(True) # CUBLAS_WORKSPACE_CONFIG=:4096:8 - print(f"Loading dataset '{args.dataset}'") + exp_name = get_exp_name(args, parser, distributed_backend) + exp_dir = Path(args.results_base_folder) / exp_name + if distributed_backend.is_master_process() and args.wandb: + wandb.init( + project=args.wandb_project, + name=exp_name, + config=vars(args), + entity=args.wandb_entity, + ) + wandb.define_metric("iter") + wandb.define_metric("train/*", step_metric="iter") + wandb.define_metric("val/*", step_metric="iter") + wandb.define_metric("lr", step_metric="iter") - data = get_dataset( - args - ) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} - if args.data_in_ram: - data = {"train": np.array(data["train"]), "val": np.array(data["val"])} + print(f"Starting Experiment: {exp_name}") + print(f"Experiment Directory: {exp_dir}") + print(f"Config:\n{vars(args)}\n") - print(f"Num training tokens: {len(data['train'])}") - print(f"Num validation tokens: {len(data['val'])}") + print(f"Loading dataset: '{args.dataset}'") + datareaders = get_data_readers(args) model = get_model(args).to( args.device ) # todo: take care of initializing the model if args.use_pretrained != 'none' + print(f"\nModel:\n{model}") model = distributed_backend.transform_model(model) @@ -84,8 +99,16 @@ def main(args): params += [param_name_mapping[p_name] for p_name in translated_p_names] g["params"] = params optimized_params_cnt += sum([p.numel() for p in g["params"]]) + params_cnt = distributed_backend.get_raw_model(model).get_num_params() + print("number of parameters: %.2fM" % (params_cnt / 1e6,)) print("number of optimized parameters: %.2fM" % (optimized_params_cnt / 1e6,)) + if args.wandb and distributed_backend.is_master_process(): + wandb.log( + {"parameters": params_cnt, "optimized_parameters": optimized_params_cnt} + ) + if args.opt == "adamw": + device_type = 'cuda' if 'cuda' in args.device else 'cpu' use_fused = (device_type == "cuda") and ( "fused" in inspect.signature(torch.optim.AdamW).parameters ) @@ -134,6 +157,16 @@ def main(args): alpha_warmup=args.adema_alpha_warmup, weight_decay=args.weight_decay, ) + elif args.opt == "ademamix2": + opt = AdEMAMix2( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2, args.adema_beta3), + alpha=args.adema_alpha, + beta3_warmup=args.adema_beta3_warmup, + alpha_warmup=args.adema_alpha_warmup, + weight_decay=args.weight_decay, + ) elif args.opt == "lion": opt = Lion( group_specs, @@ -168,6 +201,7 @@ def main(args): momentum=args.momentum, weight_decay=args.weight_decay, ) + print(f"\nOptimizer:\n{opt}") if args.scheduler != "none": assert ( @@ -187,7 +221,7 @@ def main(args): anneal_strategy=args.scheduler, cycle_momentum=False, div_factor=1e2, - final_div_factor=0.1, + final_div_factor=1, ) elif args.scheduler == "cos_inf": lambda_schedule = cos_inf_schedule( @@ -213,107 +247,165 @@ def main(args): else: scheduler = None - args.world_size = distributed_backend.get_world_size() - exp_name = args.exp_name - if distributed_backend.is_master_process() and args.wandb: - params_copy = copy.deepcopy(vars(args)) - del params_copy["device"] - wandb.init( - project=args.wandb_project, - name=exp_name, - config=params_copy, - entity=args.wandb_entity, - ) - - ckpt_path = os.path.join( - args.results_base_folder, args.dataset, args.model, exp_name - ) - if not os.path.exists(ckpt_path): - if distributed_backend.is_master_process(): - os.makedirs(ckpt_path) - distributed_backend.sync() - elif os.path.isfile( - os.path.join(ckpt_path, "summary.json") - ): # the experiment was already completed - print(f"Already found experiment '{ckpt_path}'.\nSkipping.") - sys.exit(0) - - itr = 0 - rng_state_dict = None - if args.use_pretrained == "auto": - checkpoints = [file for file in os.listdir(ckpt_path) if "ckpt_" in file] - if checkpoints: - args.use_pretrained = sorted(checkpoints)[-1] + if (exp_dir / "ckpts" / "latest" / "main.pt").exists(): + if not args.auto_resume: + raise ValueError( + f"The experiment dir {exp_dir} already exists. " + + "To resume training, set auto_resume=True. " + + "Otherwise, specify a different experiment name. " + ) else: - args.use_pretrained = None - - if args.use_pretrained is not None: - last_ckpt_path = args.use_pretrained - print(f"Resuming from {last_ckpt_path}") - checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path)) - model_state_dict = { - distributed_backend.translate_model_parameter_name_for_node( - k.replace("_orig_mod.", "") - )[0]: v - for k, v in checkpoint["model"].items() - } - # FIXME checkpoints from compiled model have _orig_mod keyword - - optimizer_state_dict = checkpoint["optimizer"] - rng_state_dict = { - module: checkpoint[module] - for module in [ - "cpu_rng_state", - "gpu_rng_state", - "numpy_rng_state", - "py_rng_state", - "train_sampler_state", - ] - } - - model.load_state_dict(model_state_dict) - opt.load_state_dict(optimizer_state_dict) - itr = checkpoint["itr"] - if scheduler is not None: - scheduler_state_dict = checkpoint["scheduler"] - scheduler.load_state_dict(scheduler_state_dict) - - if args.model in ["base", "llama2"]: # all train functions have the same interface - train = train_base - else: - raise NotImplementedError( - f"No training method implemented for model type '{args.model}'." - ) - - print(f"\nTraining model={args.model} \n{vars(args)}\n") + # Auto resume overwrites resume_from + args.resume_from = str(exp_dir / "ckpts" / "latest") + elif distributed_backend.is_master_process(): + exp_dir.mkdir(parents=True, exist_ok=True) stats = train( - model, - opt, - data, - args.data_seed, - scheduler, - args.iterations, - args.acc_steps, - args.batch_size, - args.sequence_length, - eval_freq=args.eval_freq, + model=model, + opt=opt, + datareaders=datareaders, + scheduler=scheduler, + exp_dir=exp_dir, distributed_backend=distributed_backend, - ckpt_path=f"{ckpt_path}/ckpt.pt", - itr=itr, - rng_state_dict=rng_state_dict, - extra_args=args, + cfg=args, ) - args.device = None - args.dtype = None stats["args"] = vars(args) if distributed_backend.is_master_process(): - with open(f"{ckpt_path}/summary.json", "w") as fs: + with open(exp_dir / "summary.json", "w") as fs: json.dump(stats, fs) distributed_backend.finalize() +def get_data_readers(args, verbose=True): + data_srcs = get_dataset(args) + train_reader = DataReader( + data_src=data_srcs["train"], + batch_size=args.batch_size, + sequence_length=args.sequence_length, + seed=args.data_seed, + with_replacement=False, + auto_shard=True, + keep_in_ram=args.data_in_ram, + ) + val_reader = DataReader( + data_src=data_srcs["val"], + batch_size=args.batch_size, + sequence_length=args.sequence_length, + seed=args.data_seed, + with_replacement=False, + auto_shard=False, # NOTE Identical Per Rank + keep_in_ram=args.data_in_ram, + ) + + if verbose: + print(f"Num training tokens: {train_reader.num_tokens}") + print(f"Num validation tokens: {val_reader.num_tokens}") + + return { + "train": train_reader, + "val": val_reader, + } + + +def get_exp_name(args, parser, distributed_backend, key_args=['model', 'dataset', 'opt'], ignore_args=['eval_interval', 'full_eval_at', 'distributed_backend', 'latest_ckpt_interval', 'wandb', 'wandb_project', 'wandb_entity', 'batch_size', 'acc_steps', 'results_base_folder', 'run_prefix', 'wandb_run_prefix']): + # Get the default values + defaults = vars(parser.parse_args([])) + + rank = distributed_backend.rank + + # Generate the prefix with key arguments + prefix_parts = [] + for key in key_args: + if hasattr(args, key): + value = getattr(args, key) + prefix_parts.append(f"{key}-{value}") + + prefix = "_".join(prefix_parts) + prefix = f"{args.batch_size}x{args.acc_steps}(rank={rank})_" + prefix + + # Generate the rest of the string with non-default arguments + non_default_parts = [] + for key, value in vars(args).items(): + if key in ignore_args: + continue + if key not in defaults: + print(f"Warning: {key} not in defaults") + continue + if key not in key_args and value != defaults[key]: + non_default_parts.append(f"{key}-{value}") + + non_default_string = "_".join(non_default_parts) + + if args.run_prefix is not None: + prefix = args.run_prefix + "_" + prefix + + # Combine prefix and non-default string + if non_default_string: + return f"{prefix}__{non_default_string}" + else: + return prefix + + +# def get_exp_name(args, distributed_backend): +# """Returns the name of the experiment, used for saving models and wandb.""" +# if args.experiment_name is not None: +# return args.experiment_name + +# rank = distributed_backend.rank + +# exp_name = ( +# f"{args.dataset}_{args.model}_nlayers{args.n_layer}" +# f"_nhead{args.n_head}_lr{args.lr}" +# f"_sched_{args.scheduler}_warmup{args.warmup_steps}" +# f"_decay_{args.decay_type}_{args.wsd_fract_decay}" +# f"_iter{args.iterations}" +# f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}" +# ) +# # for mup +# if args.model == "mup_noam": +# exp_name = ( +# f"{args.dataset}_{args.model}" +# f"_opt{args.opt}" +# f"_nlayers{args.n_layer}" +# # f"_nhead{args.n_head}" +# f"_lr{args.lr}" +# f"_sched_{args.scheduler}" +# f"_decay_{args.decay_type}" +# # f"_warmup{args.warmup_steps}" +# f"_iter{args.iterations}" +# f"_init{args.init_std}_sce{args.scale_emb}" +# f"_scd{args.scale_depth}" +# # f"_bs{args.batch_size}x{args.acc_steps}_ws{args.world_size}" +# ) +# if args.run_prefix is not None: +# exp_name = args.run_prefix + "_" + exp_name +# if args.wandb_run_prefix != "none": +# exp_name = args.wandb_run_prefix + "_" + exp_name +# exp_name += f"_seed{args.seed - rank}" +# exp_name += f"_data_seed{args.data_seed}" + +# if args.opt == "SFAdamW": +# exp_name += f"_beta1_{args.beta1}" +# exp_name += f"_beta2_{args.beta2}" + +# if args.opt == "ademamix": +# exp_name += f"_beta3_{args.adema_beta3}" +# exp_name += f"_alpha_{args.adema_alpha}" +# exp_name += f"_beta3_warmup_{args.adema_beta3_warmup}" +# exp_name += f"_alpha_warmup_{args.adema_alpha_warmup}" + +# if args.opt == "lion": +# exp_name += f"_beta1_{args.beta1}" +# exp_name += f"_beta2_{args.beta2}" + +# if args.opt == "adamw": +# exp_name += f"_beta1_{args.beta1}" +# exp_name += f"_beta2_{args.beta2}" + +# return exp_name + + if __name__ == "__main__": - args = get_args() - main(args) + args, parser = get_args() + main(args, parser) diff --git a/src/models/base.py b/src/models/base.py index dc2258a..3fd01f6 100755 --- a/src/models/base.py +++ b/src/models/base.py @@ -7,7 +7,6 @@ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py """ -import inspect import math import tiktoken @@ -29,7 +28,6 @@ def forward(self, input): class CausalSelfAttention(nn.Module): - def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 @@ -58,21 +56,21 @@ def __init__(self, config): ) def forward(self, x): - B, T, C = ( - x.size() - ) # batch size, sequence length, embedding dimensionality (n_embd) + # batch size, sequence length, embedding dimensionality (n_embd) + ( + B, + T, + C, + ) = x.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) + # (B, T, nh, hs) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + + # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: @@ -97,11 +95,16 @@ def forward(self, x): class MLP(nn.Module): - - def __init__(self, config): + def __init__(self, config, exp_factor=1.0): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dim_exp_factor = exp_factor * 4 + + self.c_fc = nn.Linear( + config.n_embd, int(self.dim_exp_factor * config.n_embd), bias=config.bias + ) + self.c_proj = nn.Linear( + int(self.dim_exp_factor * config.n_embd), config.n_embd, bias=config.bias + ) self.dropout = nn.Dropout(config.dropout) self.activation = nn.GELU() @@ -114,22 +117,30 @@ def forward(self, x): class Block(nn.Module): - def __init__(self, config): super().__init__() self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.parallel = config.parallel_block + if not self.parallel: + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) - def forward(self, x): - x = x + self.attn(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) + def forward(self, x, *args, **kwargs): + if self.parallel: + # from GPT-J 6B https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L299 + x_ln = self.ln_1(x, *args, **kwargs) + x_attn = self.attn(x_ln) + x_ffn = self.mlp(x_ln) + x = x + x_attn + x_ffn + else: + x = x + self.attn(self.ln_1(x, *args, **kwargs)) + x_ = self.mlp(self.ln_2(x, *args, **kwargs)) + x = x + x_ return x class GPTBase(nn.Module): - def __init__(self, config): super().__init__() assert config.vocab_size is not None @@ -162,12 +173,11 @@ def __init__(self, config): for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + p, + mean=0.0, + std=self.config.init_std / math.sqrt(2 * config.n_layer), ) - # report number of parameters - print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) - def get_num_params(self, non_embedding=True): """ Return the number of parameters in the model. @@ -182,11 +192,11 @@ def get_num_params(self, non_embedding=True): def _init_weights(self, module): if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) def forward(self, idx, targets=None, get_logits=False): device = idx.device @@ -194,9 +204,8 @@ def forward(self, idx, targets=None, get_logits=False): assert ( t <= self.config.sequence_length ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( - 0 - ) # shape (1, t) + # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) @@ -204,6 +213,8 @@ def forward(self, idx, targets=None, get_logits=False): pos ) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) + + # forward pass through all the transformer blocks for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) @@ -214,6 +225,7 @@ def forward(self, idx, targets=None, get_logits=False): loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 ) + else: # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head( @@ -221,7 +233,10 @@ def forward(self, idx, targets=None, get_logits=False): ) # note: using list [-1] to preserve the time dim loss = None logits = logits if get_logits else None - return {"logits": logits, "loss": loss} + return { + "logits": logits, + "loss": loss, + } def crop_sequence_length(self, sequence_length): # model surgery to decrease the block size if necessary @@ -235,10 +250,24 @@ def crop_sequence_length(self, sequence_length): for block in self.transformer.h: block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] - @classmethod - def from_pretrained(cls, model_type, override_args=None): - # TODO - pass + def from_pretrained( + self, + model_path, + ): + paths = model_path.split(",") + if len(paths) == 1: + # TODO: with distributed? + loaded_state = torch.load( + str(model_path + "/ckpt.pt"), + map_location=torch.device(self.config.device), + ) + state_to_load = loaded_state["model"] + + # load the sparse model + state_to_load = { + ".".join(k.split(".")[1:]): v # drop _orig_mod from keys + for k, v in state_to_load.items() + } def get_parameter_group_specs(self): """ @@ -344,4 +373,4 @@ def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=No .to("cpu") .numpy() ) - return self.tokenizer.decode(out_idx) + return self.tokenizer.decode(out_idx) \ No newline at end of file diff --git a/src/models/llama.py b/src/models/llama.py index 34996ba..9ea78c3 100644 --- a/src/models/llama.py +++ b/src/models/llama.py @@ -1,17 +1,6 @@ """ -Llama style Language Model. -References: -1) Llama inference code: -https://github.com/facebookresearch/llama/blob/main/llama/model.py -2) Mistral one file ref: -https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py -3) Llama paper: -https://arxiv.org/pdf/2302.13971.pdf - -Main differences from GPT2: -* Uses RMSNorm instead of LayerNorm -* Uses a slightly different MLP (SwiGLU) -* rotary embeddings (RoPE) +Llama style Language Model that is +compilable (avoids torch complex) """ import math @@ -20,7 +9,6 @@ import torch import torch.nn as nn from torch.nn import functional as F - from models.base import CausalSelfAttention, GPTBase @@ -28,7 +16,10 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Te freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore - return torch.polar(torch.ones_like(freqs), freqs) # complex64 + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + # Stack the cos and sin parts in the last dimension to simulate complex numbers + return torch.stack((cos_freqs, sin_freqs), dim=-1) def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: @@ -38,11 +29,11 @@ def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Te """ ndim = x.ndim assert 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( - freqs_cis.shape, - (x.shape[1], x.shape[-1]), - ) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + assert freqs_cis.shape[:-1] == (x.shape[1], x.shape[-2]) + # New shape for broadcasting + shape = [ + 1 if i != 1 and i != ndim - 2 else d for i, d in enumerate(x.shape[:-1]) + ] + [2] return freqs_cis.view(*shape) @@ -50,12 +41,22 @@ def apply_rotary_emb(q, k, freqs_cis): # q, k: (B, T, nh, hs) # freq_cis: (T, hs) # return: (B, T, nh, hs), (B, T, nh, hs) - q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) - k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) - freqs_cis = _reshape_for_broadcast(freqs_cis, q_) - xq_out = torch.view_as_real(q_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(k_ * freqs_cis).flatten(3) - return xq_out.type_as(q), xk_out.type_as(k) + q = q.float().reshape(*q.shape[:-1], -1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 2) + + freqs_cis = _reshape_for_broadcast(freqs_cis, q) + + # Perform manual "complex" multiplication + q_cos = q[..., 0] * freqs_cis[..., 0] - q[..., 1] * freqs_cis[..., 1] + q_sin = q[..., 0] * freqs_cis[..., 1] + q[..., 1] * freqs_cis[..., 0] + k_cos = k[..., 0] * freqs_cis[..., 0] - k[..., 1] * freqs_cis[..., 1] + k_sin = k[..., 0] * freqs_cis[..., 1] + k[..., 1] * freqs_cis[..., 0] + + # Combine the results back into the interleaved format expected by q and k + q_out = torch.stack((q_cos, q_sin), dim=-1).reshape(q.shape).flatten(3) + k_out = torch.stack((k_cos, k_sin), dim=-1).reshape(k.shape).flatten(3) + + return q_out, k_out class RMSNorm(nn.Module): @@ -91,6 +92,7 @@ def forward(self, x): class LlamaAttention(CausalSelfAttention): + def forward(self, x, freqs_cis): # batch size, sequence length, embedding dimensionality (n_embd) ( @@ -143,7 +145,8 @@ def __init__(self, config): def forward(self, x, freqs_cis): x = x + self.attn(self.ln_1(x), freqs_cis) - x = x + self.mlp(self.ln_2(x)) + x_ = self.mlp(self.ln_2(x)) + x = x + x_ return x @@ -211,7 +214,7 @@ def forward(self, idx, targets=None, get_logits=False): x = self.transformer.drop(tok_emb) freqs_cis = self.freqs_cis.to(x.device)[pos] - for block_idx, block in enumerate(self.transformer.h): + for block in self.transformer.h: x = block(x, freqs_cis=freqs_cis) x = self.transformer.ln_f(x) @@ -233,4 +236,4 @@ def forward(self, idx, targets=None, get_logits=False): return { "logits": logits, "loss": loss, - } + } \ No newline at end of file diff --git a/src/models/test.py b/src/models/test.py new file mode 100644 index 0000000..bd7b1d6 --- /dev/null +++ b/src/models/test.py @@ -0,0 +1,232 @@ +""" +Llama style Language Model that is +compilable (avoids torch complex) +""" + +import math + +import tiktoken +import torch +import torch.nn as nn +from torch.nn import functional as F +from models.base import CausalSelfAttention, GPTBase + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + # Stack the cos and sin parts in the last dimension to simulate complex numbers + return torch.stack((cos_freqs, sin_freqs), dim=-1) + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape[:-1] == (x.shape[1], x.shape[-2]) + # New shape for broadcasting + shape = [ + 1 if i != 1 and i != ndim - 2 else d for i, d in enumerate(x.shape[:-1]) + ] + [2] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(q, k, freqs_cis): + # q, k: (B, T, nh, hs) + # freq_cis: (T, hs) + # return: (B, T, nh, hs), (B, T, nh, hs) + q = q.float().reshape(*q.shape[:-1], -1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 2) + + freqs_cis = _reshape_for_broadcast(freqs_cis, q) + + # Perform manual "complex" multiplication + q_cos = q[..., 0] * freqs_cis[..., 0] - q[..., 1] * freqs_cis[..., 1] + q_sin = q[..., 0] * freqs_cis[..., 1] + q[..., 1] * freqs_cis[..., 0] + k_cos = k[..., 0] * freqs_cis[..., 0] - k[..., 1] * freqs_cis[..., 1] + k_sin = k[..., 0] * freqs_cis[..., 1] + k[..., 1] * freqs_cis[..., 0] + + # Combine the results back into the interleaved format expected by q and k + q_out = torch.stack((q_cos, q_sin), dim=-1).reshape(q.shape).flatten(3) + k_out = torch.stack((k_cos, k_sin), dim=-1).reshape(k.shape).flatten(3) + + return q_out, k_out + + +class RMSNorm2(nn.Module): + def __init__(self, ndim, eps=1e-5, bias=False): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + self.eps = eps + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, self.eps) + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + + hidden_dim = config.n_embd * 4 + + self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) + self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) + + def forward(self, x): + return self.c_proj(nn.functional.gelu(self.w1(x))) + + +class LlamaAttention(CausalSelfAttention): + + def forward(self, x, freqs_cis): + # batch size, sequence length, embedding dimensionality (n_embd) + ( + B, + T, + C, + ) = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + # (B, T, nh, hs) + k = k.view(B, T, self.n_head, C // self.n_head) + q = q.view(B, T, self.n_head, C // self.n_head) + q, k = apply_rotary_emb(q, k, freqs_cis) + # (B, nh, T, hs) + q, k = q.transpose(1, 2), k.transpose(1, 2) + + # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class LlamaBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = RMSNorm2(config.n_embd, eps=config.rmsnorm_eps) + self.attn = LlamaAttention(config) + self.ln_2 = RMSNorm2(config.n_embd, eps=config.rmsnorm_eps) + self.mlp = LlamaMLP(config) + + def forward(self, x, freqs_cis): + x = x + self.attn(self.ln_1(x), freqs_cis) + x_ = self.mlp(self.ln_2(x)) + x = x + x_ + return x + + +class Test(GPTBase): + def __init__(self, config): + super().__init__(config) + assert config.vocab_size is not None + assert config.sequence_length is not None + self.config = config + self.tokenizer = tiktoken.get_encoding("gpt2") + + # create the token and position embeddings + self.head_dim = config.n_embd // config.n_head + self.freqs_cis = precompute_freqs_cis(self.head_dim, config.sequence_length) + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([LlamaBlock(config) for _ in range(config.n_layer)]), + ln_f=RMSNorm2(config.n_embd, eps=config.rmsnorm_eps), + ) + ) + + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate + self.transformer.wte.weight = ( + self.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default) + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def forward(self, idx, targets=None, get_logits=False): + device = idx.device + b, t = idx.size() + assert ( + t <= self.config.sequence_length + ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" + # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device) + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + # pos_emb = self.transformer.wpe(pos) + + x = self.transformer.drop(tok_emb) # + pos_emb) + freqs_cis = self.freqs_cis.to(x.device)[pos] + + for block in self.transformer.h: + x = block(x, freqs_cis=freqs_cis) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head( + x[:, [-1], :] + ) # note: using list [-1] to preserve the time dim + loss = None + + logits = logits if get_logits else None + + return { + "logits": logits, + "loss": loss, + } \ No newline at end of file diff --git a/src/models/utils.py b/src/models/utils.py index 2d68c05..35003e8 100755 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -1,12 +1,13 @@ -import torch - -from .base import GPTBase, LayerNorm from .llama import Llama, RMSNorm +from .base import GPTBase, LayerNorm +from .test import Test, RMSNorm2 +import torch BLACKLIST_WEIGHT_MODULES = ( torch.nn.LayerNorm, LayerNorm, RMSNorm, + RMSNorm2, torch.nn.Embedding, ) @@ -15,9 +16,22 @@ def get_model(args): """Return the right model""" if args.model == "base": model = GPTBase(args) + if args.use_pretrained != "none": + model.from_pretrained(args.use_pretrained) return model - elif args.model == "llama2": + elif args.model == "llama": model = Llama(args) + if args.use_pretrained != "none": + raise NotImplementedError( + f"Loading of pretrained models not yet implemented for model '{args.model}'." + ) + return model + elif args.model == "test": + model = Test(args) + if args.use_pretrained != "none": + raise NotImplementedError( + f"Loading of pretrained models not yet implemented for model '{args.model}'." + ) return model else: - raise KeyError(f"Unknown model '{args.model}'.") + raise KeyError(f"Unknown model '{args.model}'.") \ No newline at end of file diff --git a/src/optim/ademamix2.py b/src/optim/ademamix2.py new file mode 100644 index 0000000..df901f5 --- /dev/null +++ b/src/optim/ademamix2.py @@ -0,0 +1,187 @@ +""" +Here is an original implementation of AdEMAMix. +Source: https://github.com/apple/ml-ademamix +""" + +import math + +import torch + + +def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): + if step < warmup: + a = step / float(warmup) + return (1.0 - a) * alpha_start + a * alpha_end + return alpha_end + + +def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): + + def f(beta, eps=1e-8): + return math.log(0.5) / math.log(beta + eps) - 1 + + def f_inv(t): + return math.pow(0.5, 1 / (t + 1)) + + if step < warmup: + a = step / float(warmup) + return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end)) + return beta_end + + +class AdEMAMix2(torch.optim.Optimizer): + r"""Implements the AdEMAMix algorithm. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) + corresponding to beta_1, beta_2, beta_3 in AdEMAMix + alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) + beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) + alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay as in AdamW (default: 0) + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999, 0.9999), + alpha=2.0, + beta3_warmup=None, + alpha_warmup=None, + eps=1e-8, + weight_decay=0, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + alpha=alpha, + beta3_warmup=beta3_warmup, + alpha_warmup=alpha_warmup, + weight_decay=weight_decay, + ) + super(AdEMAMix2, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdEMAMix2, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + lr = group["lr"] + lmbda = group["weight_decay"] + eps = group["eps"] + beta1, beta2, beta3_final = group["betas"] + beta3_warmup = group["beta3_warmup"] + alpha_final = group["alpha"] + alpha_warmup = group["alpha_warmup"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdEMAMix does not support sparse gradients.") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + if beta1 != 0.0: # save memory in case beta1 is 0.0 + state["exp_avg_fast"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + else: + state["exp_avg_fast"] = None + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg_fast, exp_avg_slow, exp_avg_sq = ( + state["exp_avg_fast"], + state["exp_avg_slow"], + state["exp_avg_sq"], + ) + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Compute the effective alpha and beta3 in case warmup is used + if alpha_warmup is not None: + alpha = linear_warmup_scheduler( + state["step"], + alpha_end=alpha_final, + alpha_start=0, + warmup=alpha_warmup, + ) + else: + alpha = alpha_final + + if beta3_warmup is not None: + beta3 = linear_hl_warmup_scheduler( + state["step"], + beta_end=beta3_final, + beta_start=beta1, + warmup=beta3_warmup, + ) + else: + beta3 = beta3_final + + # Decay the first and second moment running average coefficient + if beta1 != 0.0: + exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) + else: + exp_avg_fast = grad + exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + update = exp_avg_fast.div(bias_correction1) / denom + alpha * exp_avg_slow / exp_avg_slow.norm() + + # decay + update.add_(p, alpha=lmbda) + + p.add_(-lr * update) + + return loss diff --git a/src/optim/base.py b/src/optim/base.py index 4daf4f8..8b4d570 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -1,249 +1,272 @@ +from contextlib import nullcontext import copy -import itertools -import os -import random +from pathlib import Path import time -from contextlib import nullcontext +import yaml -import numpy as np import torch -import torch.nn.functional as F - import wandb -from data.utils import get_dataloader -from .utils import eval, get_batch, save_checkpoint +# from logger.logger import DynamicsLogger +from .utils import ( + eval, + get_batch, + load_checkpoint, + load_worker_state, + save_checkpoint, + save_worker_state, +) + +def compute_gradient_norm(model): + total_norm = 0.0 + for param in model.parameters(): + if param.grad is not None: + param_norm = param.grad.data.norm(2) + total_norm += param_norm ** 2 + return torch.sqrt(torch.tensor([total_norm])) -def train_base( + +def train( model, opt, - data, - data_seed, + datareaders, scheduler, - iterations, - acc_steps, - batch_size, - sequence_length, - eval_freq, - ckpt_path, + exp_dir, distributed_backend, - extra_args, - itr=0, - rng_state_dict=None, + cfg, ): - device_type = "cuda" if "cuda" in str(extra_args.device) else "cpu" - type_ctx = ( - nullcontext() - if device_type == "cpu" - else torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) - ) # extra_args.dtype) - best_val_loss, text_table = ( - float("inf"), - None, - ) # best_val_loss not used atm, early stopping not recommended but possible - substep = itr * acc_steps - data["train"], train_sampler = get_dataloader( - data["train"], - sequence_length=sequence_length, - batch_size=batch_size, - seed=data_seed, - distributed_backend=distributed_backend, - ) - - data["val"], val_sampler = get_dataloader( - data["val"], - sequence_length=sequence_length, - batch_size=batch_size, - seed=data_seed, - ) - - num_substeps_per_epoch = len(data["train"]) - train_epochs = substep // num_substeps_per_epoch - - if ( - rng_state_dict is not None - and rng_state_dict.get("train_sampler_state", None) is not None - ): - train_sampler.generator.set_state(rng_state_dict["train_sampler_state"]) - if hasattr(train_sampler, "set_epoch"): - train_sampler.set_epoch(train_epochs) + not_compiled_model = model + if cfg.compile: + print(f"Compiling model ...") + model = torch.compile(model) + + if "cuda" in cfg.device: + type_ctx = torch.amp.autocast( + device_type="cuda", + dtype={ + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[cfg.dtype], + ) else: - sampler_state_before_iter = train_sampler.generator.get_state() - data_train_iter = iter(data["train"]) - - # for val data we don't care about epochs? just cycle through (no need to set_epoch to reshuffle) - data_val_iter = itertools.cycle(data["val"]) + type_ctx = nullcontext() + + if cfg.resume_from: + # This is a full resume including the model weights, optimizer, state + # dataloader state, random seed, etc. Not indended for fine tuning or + # other scenarios where some of these should change. + print(f"\nResuming Training From {cfg.resume_from}") + ckpt_dir = Path(cfg.resume_from) + curr_iter = load_checkpoint( + model, + opt, + scheduler, + ckpt_dir / "main.pt", + cfg.device, + ) + load_worker_state(ckpt_dir) + else: + curr_iter = 0 - stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} + # if distributed_backend.is_master_process() and cfg.log_dynamics: + # with open(cfg.dynamics_logger_cfg, "r") as f: + # dlcfg = yaml.safe_load(f) - if extra_args.compile: - print(f"Compiling model ...") - model = torch.compile(model) # requires pytorch 2.0+ + # # Hooks into optimizer + # dlogger = DynamicsLogger( + # model, opt, dlcfg, cfg.results_base_folder, wandb=cfg.wandb + # ) + # dlogger.iteration = curr_iter + substep = curr_iter * cfg.acc_steps + train_reader, val_reader = datareaders["train"], datareaders["val"] + train_reader.set_step(substep) + stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} grad_norms = [] - model.train() - t0 = time.time() - - if rng_state_dict is not None: - torch.set_rng_state(rng_state_dict["cpu_rng_state"]) - torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) - np.random.set_state(rng_state_dict["numpy_rng_state"]) - random.setstate(rng_state_dict["py_rng_state"]) - for _ in range(substep % num_substeps_per_epoch): - get_batch(data_train_iter, device=extra_args.device) - - while itr < iterations: + while curr_iter <= cfg.iterations: + # Save permanent checkpoint + if cfg.permanent_ckpt_interval > 0: + if curr_iter % cfg.permanent_ckpt_interval == 0: + ckpt_dir = exp_dir / "ckpts" / str(curr_iter) + if distributed_backend.is_master_process(): + save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) + save_worker_state(ckpt_dir) + + # Save temporary checkpoint for resuming training + if cfg.latest_ckpt_interval > 0: + if curr_iter % cfg.latest_ckpt_interval == 0 or curr_iter == cfg.iterations: + ckpt_dir = exp_dir / "ckpts" / "latest" + if distributed_backend.is_master_process(): + save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) + save_worker_state(ckpt_dir) + + ws = distributed_backend.get_world_size() + tokens = ws * substep * cfg.sequence_length * cfg.batch_size + epoch = tokens / train_reader.num_tokens + if ( + curr_iter % cfg.eval_interval == 0 + or curr_iter == cfg.iterations + or (curr_iter in cfg.full_eval_at) + ): + eval_and_log( + curr_iter, + epoch, + model, + val_reader, + type_ctx, + distributed_backend, + cfg, + opt, + full_eval=(curr_iter in cfg.full_eval_at), + ) - for microstep_idx in range(acc_steps): # gradient accumulation - x, y = get_batch(data_train_iter, device=extra_args.device) + if curr_iter == cfg.iterations: + # Save checkpoints and evaluate at final iteration, but no need to train further + break + # Train model + t_start = time.perf_counter_ns() + for microstep_idx in range(cfg.acc_steps): # gradient accumulation + x, y = get_batch(train_reader, device=cfg.device) with type_ctx: with distributed_backend.get_context_for_microstep_forward( model=model, microstep_idx=microstep_idx, - gradient_accumulation_steps=acc_steps, + gradient_accumulation_steps=cfg.acc_steps, ): outputs = model(x, targets=y) - loss = outputs["loss"] / acc_steps + loss = outputs["loss"] / cfg.acc_steps loss.backward() substep += 1 - if substep % len(data["train"]) == 0: - train_epochs += 1 - print(f"Train epoch {train_epochs} done (full pass over training data)") - if hasattr(train_sampler, "set_epoch"): - # set epoch for reshuffling between epochs - train_sampler.set_epoch(train_epochs) - sampler_state_before_iter = None - else: - sampler_state_before_iter = train_sampler.generator.get_state() - data_train_iter = iter(data["train"]) - - if extra_args.grad_clip != 0.0: - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), extra_args.grad_clip - ) + + if cfg.grad_clip != 0.0: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + grad_norm = torch.nn.utils.clip_grad_norm_(model.module.parameters(), cfg.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) grad_norms.append(grad_norm) + # grad_norms.append(compute_gradient_norm(model)) + if cfg.opt == "SFAdamW": + opt.train() opt.step() scheduler.step() opt.zero_grad(set_to_none=True) - itr += 1 + dt = (time.perf_counter_ns() - t_start) / 1e9 - if ( - itr % eval_freq == 0 or itr == iterations - ): # from here it's only evaluation code, all the training is above - if distributed_backend.is_master_process(): - t1 = time.time() - dt = t1 - t0 - epoch = substep // num_substeps_per_epoch - - model.eval() - train_loss = loss.detach().cpu().item() * acc_steps - current_lr = ( - scheduler.get_last_lr()[0] - if scheduler is not None - else extra_args.lr - ) - - eval_steps = 24 if itr < iterations else len(data["val"]) - val_acc, val_loss, val_perplexity = eval( - model, - data_val_iter, - extra_args.device, - max_num_batches=eval_steps, - ctx=type_ctx, - reset_iterator=True, - ) + curr_iter += 1 - print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}" - print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms" - if scheduler is not None: - print_string += f" [lr] {current_lr:.5f}" - print(print_string) + if ( + cfg.log_interval + and curr_iter % cfg.log_interval == 0 + and distributed_backend.is_master_process() # Only log on master rank + ): + train_loss = loss.detach().cpu().item() * cfg.acc_steps + + current_lrs = [param_group["lr"] for param_group in opt.param_groups] + + print( + f"Train: Iter={curr_iter} ({epoch:0.3f} epochs) " + f"train_loss={train_loss:.3f} iter_dt={dt:.2e}s " + f"lr={current_lrs[0]:.2e}" + ) - if extra_args.wandb: - logs = { - "iter": itr, + if cfg.wandb: + wandb.log( + { + "iter": curr_iter, "train/loss": train_loss, - "train/max_grad_norm": ( - max(grad_norms).item() if grad_norms else 0 - ), - "train/mean_grad_norm": ( - torch.tensor(grad_norms).mean().item() if grad_norms else 0 - ), - "val/loss": val_loss, - "val/perplexity": val_perplexity, - "val/acc": val_acc, - "lr": current_lr, + "train/perplexity": 2.71828**train_loss, + "lr": current_lrs[0], + "iter_dt": dt, + "max_grad_norm": max(grad_norms).item() if grad_norms else 0, + "mean_grad_norm": torch.tensor(grad_norms).mean().item() if grad_norms else 0, } - - if itr == iterations: - logs["val/final-ppl"] = val_perplexity - logs["val/final-acc"] = val_acc - logs["val/final-loss"] = val_loss - - wandb.log(logs) - - if extra_args.eval_seq_prefix != "none" and ( - itr % (eval_freq * 5) == 0 or itr == iterations - ): - if text_table is None: - text_table = wandb.Table(columns=["itr", "val-pp", "text"]) - - out_str = distributed_backend.get_raw_model( - model - ).generate_from_string( - extra_args.eval_seq_prefix, - max_new_tokens=40, - temperature=0.9, - top_k=None, - ) - text_table.add_data(itr, val_perplexity, out_str) - # why a copy? see github.com/wandb/wandb/issues/2981 - wandb.log( - {f"generated-text-{wandb.run.name}": copy.copy(text_table)} - ) - - grad_norms = [] - - model.train() - t0 = time.time() - if distributed_backend.is_master_process(): - if ( - extra_args.save_checkpoint_freq is not None - and itr % extra_args.save_checkpoint_freq == 0 - ): - print( - f"saving checkpoint to {os.path.dirname(ckpt_path)}/ckpt_{itr}.pt" - ) - save_checkpoint( - distributed_backend=distributed_backend, - model=model, - opt=opt, - scheduler=scheduler, - itr=itr, - cpu_rng_state=torch.get_rng_state(), - gpu_rng_state=torch.cuda.get_rng_state(), - numpy_rng_state=np.random.get_state(), - py_rng_state=random.getstate(), - train_sampler_state=sampler_state_before_iter, - ckpt_path=os.path.join( - os.path.dirname(ckpt_path), f"ckpt_{itr}.pt" - ), ) + + grad_norms = [] - if distributed_backend.is_master_process(): - print(f"saving checkpoint to {ckpt_path}") - save_checkpoint( - distributed_backend=distributed_backend, - model=model, - opt=opt, - scheduler=scheduler, - itr=itr, - ckpt_path=ckpt_path, - ) return stats + + +def eval_and_log( + curr_iter, + epoch, + model, + val_reader, + type_ctx, + distributed_backend, + cfg, + opt, + full_eval=False, +): + if not distributed_backend.is_master_process(): + # Only evaluate and log on master rank + return + + model.eval() + if cfg.opt == "SFAdamW": + opt.eval() + + if curr_iter == cfg.iterations or full_eval: + max_num_batches = val_reader.num_batches() + else: + max_num_batches = cfg.eval_batches + + # to make sure we start from the beginning of the validation set, + # i.e. repeat the same batches + val_reader.set_step(0) + val_acc, val_loss, val_perplexity = eval( + model, + val_reader, + cfg.device, + max_num_batches=max_num_batches, + ctx=type_ctx, + cfg=cfg, + ) + + print( + f">Eval: Iter={curr_iter} ({epoch:0.3f} epochs) " + f"val_loss={val_loss:.3f} " + f"val_pp={val_perplexity:.3f} " + f"val_acc={val_acc:3f}" + ) + + if cfg.wandb: + if curr_iter == cfg.iterations or full_eval: + logs = { + "iter": curr_iter, + "final-val/loss": val_loss, + "final-val/perplexity": val_perplexity, + "final-val/acc": val_acc, + } + else: + logs = { + "iter": curr_iter, + "val/loss": val_loss, + "val/perplexity": val_perplexity, + "val/acc": val_acc, + } + + wandb.log(logs) + if cfg.eval_seq_prefix != "none" and ( + curr_iter % (cfg.eval_interval * 5) == 0 or curr_iter == cfg.iterations + ): + text_table = wandb.Table(columns=["itr", "val-pp", "text"]) + + out_str = distributed_backend.get_raw_model(model).generate_from_string( + cfg.eval_seq_prefix, + max_new_tokens=40, + temperature=0.9, + top_k=None, + ) + text_table.add_data(curr_iter, val_perplexity, out_str) + # why a copy? see github.com/wandb/wandb/issues/2981 + wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) + model.train() + diff --git a/src/optim/utils.py b/src/optim/utils.py index cef37b4..87c8c10 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -1,14 +1,16 @@ -import itertools -import math -from contextlib import ExitStack, contextmanager, nullcontext - +from pathlib import Path +import random import numpy as np import torch import torch.nn.functional as F +from contextlib import nullcontext +import torch.distributed as dist +import math +import wandb -def get_batch(dataloader, device="cpu"): - x, y = next(dataloader) +def get_batch(datareader, device="cpu"): + x, y = datareader.sample_batch() if "cuda" in torch.device(device).type: # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) x = x.pin_memory().to(device, non_blocking=True) @@ -126,24 +128,22 @@ def schedule(step): @torch.no_grad() def eval( model, - data_val_iter, + reader, device="cpu", max_num_batches=24, ctx=nullcontext(), - reset_iterator=False, + cfg=None, ): assert model.training == False - if reset_iterator: # ensure that we always eval on the same batches - data_val_iter = itertools.cycle(data_val_iter) - loss_list_val, acc_list = [], [] - for _ in range(max_num_batches): - x, y = get_batch(data_val_iter, device=device) + for idx in range(max_num_batches): + x, y = get_batch(reader, device=device) with ctx: outputs = model(x, targets=y, get_logits=True) val_loss = outputs["loss"] + loss_list_val.append(val_loss) acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) @@ -154,18 +154,135 @@ def eval( return val_acc, val_loss, val_perplexity -def save_checkpoint( - distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args +@torch.no_grad() +def eval_sweep_dropk( + model, + data_tensor, + sequence_length, + batch_size, + n_heads, + device="cpu", + max_num_batches=24, + ctx=nullcontext(), +): + assert model.training == False + + x_axis, y_axis_pp, y_axis_acc, y_axis_loss = ( + torch.linspace(0.0, 0.95, 15), + [], + [], + [], + ) + loss_list_val, acc_list = [], [] + + for frac in x_axis: + drop_k = int(sequence_length * frac * n_heads) + for _ in range(max_num_batches): + x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) + with ctx: + outputs = model( + x, targets=y, alpha_th=None, drop_k=drop_k, get_logits=True + ) + loss_list_val.append(outputs["ce_loss"]) + acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) + + y_axis_acc.append(torch.stack(acc_list).mean().item()) + y_axis_loss.append(np.mean(loss_list_val)) + y_axis_pp.append(2.71828 ** y_axis_loss[-1]) + + return x_axis, y_axis_acc, y_axis_pp, y_axis_loss + + +@torch.no_grad() +def eval_sweep_alphath( + model, + data_tensor, + sequence_length, + batch_size, + device="cpu", + max_num_batches=24, + ctx=nullcontext(), ): + assert model.training == False - checkpoint = dict( - { - "model": distributed_backend.get_raw_model(model).state_dict(), - "optimizer": opt.state_dict(), - "scheduler": scheduler.state_dict(), - "itr": itr, - }, - **extra_args, + alpha_ths, y_axis_pp, y_axis_acc, y_axis_loss = ( + [0, 1e-4, 1e-3, 1e-2, 1e-1, 2e-1, 3e-1, 4e-1, 5e-1], + [], + [], + [], ) + loss_list_val, acc_list, x_axis = [], [], [] + + for alpha_th in alpha_ths: + frac_heads_pruned_list = [] + for _ in range(max_num_batches): + x, y = get_batch(data_tensor, sequence_length, batch_size, device=device) + with ctx: + outputs = model( + x, targets=y, alpha_th=alpha_th, drop_k=None, get_logits=True + ) + nph, nh = ( + outputs["num_head_pruned_per_layer"], + outputs["num_heads_per_layer"], + ) + frac_heads_pruned = np.sum(nph) / np.sum( + nh + ) # fractions of heads removed given alpha_th + frac_heads_pruned_list.append(frac_heads_pruned) + loss_list_val.append(outputs["ce_loss"]) + acc_list.append((outputs["logits"].argmax(-1) == y).float().mean()) + + x_axis.append(np.mean(frac_heads_pruned_list)) + y_axis_acc.append(torch.stack(acc_list).mean().item()) + y_axis_loss.append(np.mean(loss_list_val)) + y_axis_pp.append(2.71828 ** y_axis_loss[-1]) + + return x_axis, y_axis_acc, y_axis_pp, y_axis_loss + + +def save_checkpoint(model, opt, scheduler, itr, ckpt_dir: Path): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer": opt.state_dict(), + "scheduler": scheduler.state_dict(), + "itr": itr, + } + ckpt_dir.mkdir(exist_ok=True, parents=True) + torch.save(checkpoint, ckpt_dir / "main.pt") + + +def load_checkpoint(model, opt, scheduler, ckpt_path, device): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + ckpt = torch.load(ckpt_path, map_location=device) + model.load_state_dict(ckpt["model"]) + opt.load_state_dict(ckpt["optimizer"]) + scheduler.load_state_dict(ckpt["scheduler"]) + itr = ckpt["itr"] + return itr + + +def save_worker_state(ckpt_dir: Path): + # Dataloader, rng states + worker_state = { + "rng_torch_cpu": torch.random.get_rng_state(), + "rng_torch_gpu": torch.cuda.get_rng_state(), + "rng_np": np.random.get_state(), + "rng_python": random.getstate(), + } + rank = 0 if not dist.is_initialized() else dist.get_rank() + ckpt_dir.mkdir(exist_ok=True, parents=True) + torch.save(worker_state, ckpt_dir / f"worker_{rank}.pt") + - torch.save(checkpoint, ckpt_path) +def load_worker_state(ckpt_dir: Path): + rank = 0 if not dist.is_initialized() else dist.get_rank() + worker_state = torch.load(ckpt_dir / f"worker_{rank}.pt") + torch.random.set_rng_state(worker_state["rng_torch_cpu"]) + torch.cuda.set_rng_state(worker_state["rng_torch_gpu"]) + np.random.set_state(worker_state["rng_np"]) + random.setstate(worker_state["rng_python"]) \ No newline at end of file From 5876db9325bad5ee12a3656bd84430b14de7e017 Mon Sep 17 00:00:00 2001 From: Andrei Semenov <67924720+Andron00e@users.noreply.github.com> Date: Mon, 21 Oct 2024 13:32:28 +0000 Subject: [PATCH 22/58] reviewed --- README.md | 28 ++++++++++++++++++++++------ src/config/base.py | 6 +++--- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 7a292fb..5297bad 100755 --- a/README.md +++ b/README.md @@ -36,6 +36,9 @@ parser.add_argument('--batch_size', default=32, type=int) parser.add_argument('--acc_steps', default=4, type=int) parser.add_argument('--seed', default=0, type=int) # random seed for the parameters parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering +parser.add_argument("--eval_interval", default=200, type=int) +parser.add_argument("--full_eval_at", nargs="+", type=int) +parser.add_argument("--eval_batches", default=32, type=int) parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations parser.add_argument("--warmup_steps", default=300, type=int) @@ -48,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument("--cos_inf_steps", default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'lion', 'sf-adamw', 'sf-sgd']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -72,34 +75,47 @@ parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulefree parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulefree hyperparameter parser.add_argument("--model_sharding", default=None, type=bool) # Adam-mini parser.add_argument("--adam_mini_verbose", default=False, type=bool) # print all the logs if true +parser.add_argument("--log_interval", default=50, type=int) # Dataset params -parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) +parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) +parser.add_argument("--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"]) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this # Model params -parser.add_argument('--model', default='base', choices=['base', 'llama2']) -parser.add_argument('--use_pretrained', default="none", type=str) # 'none', 'gpt-2' or a path to the pretraind model +parser.add_argument('--model', default='base', choices=['base', 'llama', 'test']) +parser.add_argument("--parallel_block", action="store_true") +parser.add_argument('--use_pretrained', default="none", type=str) # 'none', 'gpt2' or a path to the pretraind model +parser.add_argument("--from_dense", action="store_true") +parser.add_argument("--init_std", default=0.02, type=float) parser.add_argument('--dropout', default=0.0, type=float) # keep to 0 unless in low data regime (e.g. wikitext) parser.add_argument('--n_head', default=12, type=int) parser.add_argument('--n_layer', default=12, type=int) # depth in (att + ff) blocks parser.add_argument('--n_embd', default=768, type=int) # hidden size ... parser.add_argument('--sequence_length', default=512, type=int) -parser.add_argument('--dtype', default=torch.bfloat16, type=torch.dtype) +parser.add_argument("--dtype", default="bfloat16", type=str, choices=["float32", "float16", "bfloat16"],) parser.add_argument('--bias', default=False, type=bool) parser.add_argument('--compile', action='store_true') # if true then model is compiled parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2 parser.add_argument('--n_kv_head', default=None, type=Optional[int]) +# Checkpointing +parser.add_argument("--results_base_folder", default="./exps", type=str) +parser.add_argument("--permanent_ckpt_interval", default=0, type=int) +parser.add_argument("--latest_ckpt_interval", default=0, type=int) +parser.add_argument("--resume_from", default=None, type=str) +parser.add_argument("--resume_from_swa", default=None, type=str) +parser.add_argument("--auto_resume", default=True) # logging params (WandB) parser.add_argument('--wandb', action='store_true') # whether to use wandb or not parser.add_argument('--wandb_project', default="my-project", type=str) parser.add_argument('--wandb_entity', default=None, type=none_or_str) # for the team projects parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences +parser.add_argument("--log_dynamics", action="store_true") # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type (e.g. nccl) -parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) +# parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) ``` ## Using WandB diff --git a/src/config/base.py b/src/config/base.py index e016be1..90fd4e2 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -47,9 +47,9 @@ def parse_args(base_parser, args, namespace): "--eval_seq_prefix", default="none", type=str ) # prefix used to generate sequences parser.add_argument("--log_dynamics", action="store_true") - parser.add_argument( - "--dynamics_logger_cfg", default="./src/logger/rotational_logger.yaml", type=str - ) + # parser.add_argument( + # "--dynamics_logger_cfg", default="./src/logger/rotational_logger.yaml", type=str + # ) parser.add_argument("--wandb_entity", default=None, type=none_or_str) # Schedule From 94ca273f24aac5bf9625d5fee03e8dc623d28d09 Mon Sep 17 00:00:00 2001 From: Andrei Semenov <67924720+Andron00e@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:49:38 +0200 Subject: [PATCH 23/58] extra_args removed --- src/optim/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optim/base.py b/src/optim/base.py index e691069..84de093 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -142,10 +142,10 @@ def train( grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) grad_norms.append(grad_norm) - if extra_args.opt == "sf-sgd" or extra_args.opt == "sf-adamw": + if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": opt.train() opt.step() - if extra_args.scheduler != "none": + if cfg.scheduler != "none": scheduler.step() opt.zero_grad(set_to_none=True) dt = (time.perf_counter_ns() - t_start) / 1e9 From a1ce2f436c3fc385259e86a3a5ffbb3c0d9195e4 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 24 Oct 2024 15:32:53 +0300 Subject: [PATCH 24/58] --fixed scheduler ckpt again, signSGD and Signum are here, codestyle --- README.md | 90 ++++++------- src/config/base.py | 3 + src/data/arxiv.py | 13 +- src/data/openwebtext2.py | 6 +- src/data/redpajama.py | 6 +- src/data/shakespeare.py | 3 +- src/data/slimpajama.py | 8 +- src/data/utils.py | 9 +- src/data/wikitext.py | 5 +- src/distributed/backend.py | 2 +- src/distributed/ddp.py | 7 +- src/distributed/single.py | 2 +- src/main.py | 50 +++++-- src/models/base.py | 2 +- src/models/llama.py | 3 +- src/models/test.py | 5 +- src/models/utils.py | 9 +- src/optim/ademamix2.py | 5 +- src/optim/base.py | 264 ------------------------------------- src/optim/sign.py | 71 ++++++++++ src/optim/utils.py | 17 ++- 21 files changed, 217 insertions(+), 363 deletions(-) create mode 100644 src/optim/sign.py diff --git a/README.md b/README.md index 5297bad..136c53e 100755 --- a/README.md +++ b/README.md @@ -36,82 +36,82 @@ parser.add_argument('--batch_size', default=32, type=int) parser.add_argument('--acc_steps', default=4, type=int) parser.add_argument('--seed', default=0, type=int) # random seed for the parameters parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering -parser.add_argument("--eval_interval", default=200, type=int) -parser.add_argument("--full_eval_at", nargs="+", type=int) -parser.add_argument("--eval_batches", default=32, type=int) +parser.add_argument('--eval_interval', default=200, type=int) +parser.add_argument('--full_eval_at', nargs="+", type=int) +parser.add_argument('--eval_batches', default=32, type=int) parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations -parser.add_argument("--warmup_steps", default=300, type=int) +parser.add_argument('--warmup_steps', default=300, type=int) parser.add_argument('--lr', default=1e-3, type=float) -parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float) # wsd scheduler -parser.add_argument("--wsd_fract_decay", default=0.1, type=float) # wsd scheduler -parser.add_argument("--decay_type", default="linear", choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"]) +parser.add_argument('--wsd_final_lr_scale', default=0.0, type=float) # wsd scheduler +parser.add_argument('--wsd_fract_decay', default=0.1, type=float) # wsd scheduler +parser.add_argument('--decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) -parser.add_argument("--cos_inf_steps", default=0, type=int) # cos_inf scheduler +parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT -parser.add_argument("--momentum", default=0.9, type=float) +parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--shampoo_beta', default=-1.0, type=float) -parser.add_argument("--precondition_frequency", default=10, type=int) -parser.add_argument("--max_precond_dim", default=10000, type=int) -parser.add_argument("--merge_dims", default=False, type=bool) # merge dimensions till the product of the dimensions is less than or equal to max_precond_dim -parser.add_argument("--precondition_1d", default=False, type=bool) -parser.add_argument("--normalize_grads", default=False, type=bool) -parser.add_argument("--soap_data_format", default="channels_first", type=str) -parser.add_argument("--correct_bias", default=True, type=bool) -parser.add_argument("--nesterov", default=True, type=bool) # whether to use Nesterov-style momentum -parser.add_argument("--muon_backend", default="newtonschulz5", type=str) # the chosen backend for the orthogonalization step -parser.add_argument("--muon_backend_steps", default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative -parser.add_argmunet("--adema_beta3", default=0.9, type=float) # beta3 in AdEMAMix -parser.add_argument("--adema_alpha", default=2.0, type=float) # alpha in AdEMAMix -parser.add_argument("--adema_beta3_warmup", default=None, type=int) # AdEMAMix hyperparameter -parser.add_argument("--adema_alpha_warmup", default=None, type=int) # AdEMAMix hyperparameter -parser.add_argument("--schedulefree_r", defalut=0.0, type=float) # schedulefree hyperparameter -parser.add_argument("--weight_lr_power", default=2.0, type=float) # schedulefree hyperparameter -parser.add_argument("--model_sharding", default=None, type=bool) # Adam-mini -parser.add_argument("--adam_mini_verbose", default=False, type=bool) # print all the logs if true -parser.add_argument("--log_interval", default=50, type=int) +parser.add_argument('--precondition_frequency', default=10, type=int) +parser.add_argument('--max_precond_dim', default=10000, type=int) +parser.add_argument('--merge_dims', default=False, type=bool) # merge dimensions till the product of the dimensions is less than or equal to max_precond_dim +parser.add_argument('--precondition_1d', default=False, type=bool) +parser.add_argument('--normalize_grads', default=False, type=bool) +parser.add_argument('--soap_data_format', default='channels_first', type=str) +parser.add_argument('--correct_bias', default=True, type=bool) +parser.add_argument('--nesterov', default=True, type=bool) # whether to use Nesterov-style momentum +parser.add_argument('--muon_backend', default='newtonschulz5', type=str) # the chosen backend for the orthogonalization step +parser.add_argument('--muon_backend_steps', default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative +parser.add_argmunet('--adema_beta3', default=0.9, type=float) # beta3 in AdEMAMix +parser.add_argument('--adema_alpha', default=2.0, type=float) # alpha in AdEMAMix +parser.add_argument('--adema_beta3_warmup', default=None, type=int) # AdEMAMix hyperparameter +parser.add_argument('--adema_alpha_warmup', default=None, type=int) # AdEMAMix hyperparameter +parser.add_argument('--schedulefree_r', defalut=0.0, type=float) # schedulefree hyperparameter +parser.add_argument('--weight_lr_power', default=2.0, type=float) # schedulefree hyperparameter +parser.add_argument('--model_sharding', default=None, type=bool) # Adam-mini +parser.add_argument('--adam_mini_verbose', default=False, type=bool) # print all the logs if true +parser.add_argument('--log_interval', default=50, type=int) # Dataset params -parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) -parser.add_argument("--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"]) +parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) +parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this # Model params parser.add_argument('--model', default='base', choices=['base', 'llama', 'test']) -parser.add_argument("--parallel_block", action="store_true") -parser.add_argument('--use_pretrained', default="none", type=str) # 'none', 'gpt2' or a path to the pretraind model -parser.add_argument("--from_dense", action="store_true") -parser.add_argument("--init_std", default=0.02, type=float) +parser.add_argument('--parallel_block', action='store_true') +parser.add_argument('--use_pretrained', default='none', type=str) # 'none', 'gpt2' or a path to the pretraind model +parser.add_argument('--from_dense', action='store_true') +parser.add_argument('--init_std', default=0.02, type=float) parser.add_argument('--dropout', default=0.0, type=float) # keep to 0 unless in low data regime (e.g. wikitext) parser.add_argument('--n_head', default=12, type=int) parser.add_argument('--n_layer', default=12, type=int) # depth in (att + ff) blocks parser.add_argument('--n_embd', default=768, type=int) # hidden size ... parser.add_argument('--sequence_length', default=512, type=int) -parser.add_argument("--dtype", default="bfloat16", type=str, choices=["float32", "float16", "bfloat16"],) +parser.add_argument('--dtype', default='bfloat16', type=str, choices=['float32', 'float16', 'bfloat16'],) parser.add_argument('--bias', default=False, type=bool) parser.add_argument('--compile', action='store_true') # if true then model is compiled parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2 -parser.add_argument('--n_kv_head', default=None, type=Optional[int]) +parser.add_argument('--n_kv_head', default=None, type=int) # for Adam-mini # Checkpointing -parser.add_argument("--results_base_folder", default="./exps", type=str) -parser.add_argument("--permanent_ckpt_interval", default=0, type=int) -parser.add_argument("--latest_ckpt_interval", default=0, type=int) -parser.add_argument("--resume_from", default=None, type=str) -parser.add_argument("--resume_from_swa", default=None, type=str) -parser.add_argument("--auto_resume", default=True) +parser.add_argument('--results_base_folder', default='./exps', type=str) +parser.add_argument('--permanent_ckpt_interval', default=0, type=int) +parser.add_argument('--latest_ckpt_interval', default=0, type=int) +parser.add_argument('--resume_from', default=None, type=str) +parser.add_argument('--resume_from_swa', default=None, type=str) +parser.add_argument('--auto_resume', default=True) # logging params (WandB) parser.add_argument('--wandb', action='store_true') # whether to use wandb or not -parser.add_argument('--wandb_project', default="my-project", type=str) +parser.add_argument('--wandb_project', default='my-project', type=str) parser.add_argument('--wandb_entity', default=None, type=none_or_str) # for the team projects -parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name +parser.add_argument('--wandb_run_prefix', default='none', type=str) # is added before the autogenerated experiment name parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences -parser.add_argument("--log_dynamics", action="store_true") +parser.add_argument('--log_dynamics', action='store_true') # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type (e.g. nccl) diff --git a/src/config/base.py b/src/config/base.py index 90fd4e2..3d27e70 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -88,6 +88,8 @@ def parse_args(base_parser, args, namespace): "sf-adamw", "sf-sgd", "adam-mini", + "signsgd", + "signum", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -173,6 +175,7 @@ def parse_args(base_parser, args, namespace): default=256, type=int, ) + parser.add_argument("--n_kv_head", default=None, type=int) # for Adam-mini parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) parser.add_argument( "--dtype", diff --git a/src/data/arxiv.py b/src/data/arxiv.py index 91f9a71..71a7fe0 100644 --- a/src/data/arxiv.py +++ b/src/data/arxiv.py @@ -1,17 +1,16 @@ +import logging import os import tarfile -import logging -from pathlib import Path -from typing import Optional from multiprocessing import Pool +from pathlib import Path +from subprocess import PIPE, Popen, TimeoutExpired from tempfile import NamedTemporaryFile -from subprocess import Popen, TimeoutExpired, PIPE -from typing import Tuple, List +from typing import List, Optional, Tuple import numpy as np import requests -from tqdm.auto import tqdm import tiktoken +from tqdm.auto import tqdm def convert_to_markdown(args: Tuple[Path, Path]): @@ -114,4 +113,4 @@ def get_arxiv_2000(datasets_base_dir): def get_arxiv_full(datasets_base_dir): - return load_arxiv(Path(datasets_base_dir)) \ No newline at end of file + return load_arxiv(Path(datasets_base_dir)) diff --git a/src/data/openwebtext2.py b/src/data/openwebtext2.py index d30e7cc..01fb581 100644 --- a/src/data/openwebtext2.py +++ b/src/data/openwebtext2.py @@ -1,9 +1,9 @@ import os -from tqdm import tqdm + import numpy as np import tiktoken from datasets import load_dataset - +from tqdm import tqdm tknzr = tiktoken.get_encoding("gpt2") @@ -62,4 +62,4 @@ def process(example): return { "train": os.path.join(OWT2_DATA_PATH, "train.bin"), "val": os.path.join(OWT2_DATA_PATH, "val.bin"), - } \ No newline at end of file + } diff --git a/src/data/redpajama.py b/src/data/redpajama.py index 93a70f3..0901719 100644 --- a/src/data/redpajama.py +++ b/src/data/redpajama.py @@ -1,9 +1,9 @@ import os -from tqdm import tqdm + import numpy as np import tiktoken from datasets import load_dataset - +from tqdm import tqdm tknzr = tiktoken.get_encoding("gpt2") @@ -121,4 +121,4 @@ def process(example): return { "train": os.path.join(RPJ_V2_DATA_PATH, "train.bin"), "val": os.path.join(RPJ_V2_DATA_PATH, "val.bin"), - } \ No newline at end of file + } diff --git a/src/data/shakespeare.py b/src/data/shakespeare.py index 87ce7e6..cb13f94 100644 --- a/src/data/shakespeare.py +++ b/src/data/shakespeare.py @@ -4,7 +4,6 @@ import numpy as np import requests - _char_decode = dict( enumerate(sorted(set(ascii_letters + digits + punctuation + " \n"))) ) @@ -50,4 +49,4 @@ def get_shakespeare_data(datasets_dir): return { "train": train_path, "val": test_path, - } \ No newline at end of file + } diff --git a/src/data/slimpajama.py b/src/data/slimpajama.py index 4126f6c..3811930 100644 --- a/src/data/slimpajama.py +++ b/src/data/slimpajama.py @@ -1,9 +1,9 @@ -from tqdm import tqdm +import os + import numpy as np import tiktoken from datasets import load_dataset -import os - +from tqdm import tqdm tknzr = tiktoken.get_encoding("gpt2") @@ -116,4 +116,4 @@ def process(example): return { "train": os.path.join(SPJ_DATA_PATH, "train.bin"), "val": os.path.join(SPJ_DATA_PATH, "val.bin"), - } \ No newline at end of file + } diff --git a/src/data/utils.py b/src/data/utils.py index bbeee55..ad53f5f 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -1,15 +1,16 @@ from pathlib import Path -import numpy as np from typing import Dict + +import numpy as np import torch import torch.distributed as dist -from .shakespeare import get_shakespeare_data -from .wikitext import get_wikitext_data from .arxiv import get_arxiv_2000, get_arxiv_full from .openwebtext2 import get_openwebtext2_data from .redpajama import get_redpajama_data, get_redpajamav2_data +from .shakespeare import get_shakespeare_data from .slimpajama import get_slimpajama_data +from .wikitext import get_wikitext_data def get_dataset(args) -> Dict[str, np.ndarray]: @@ -174,4 +175,4 @@ def _sample_without_replacement(self, step): def num_batches(self): if self.with_replacement: return self.num_tokens // self.batch_size - return self.num_batches_of_seqlen \ No newline at end of file + return self.num_batches_of_seqlen diff --git a/src/data/wikitext.py b/src/data/wikitext.py index 8964fad..0cd5ea5 100755 --- a/src/data/wikitext.py +++ b/src/data/wikitext.py @@ -1,6 +1,7 @@ import os -import zipfile import urllib +import zipfile + import numpy as np import tiktoken @@ -45,4 +46,4 @@ def get_wikitext_data(datasets_base_dir): return { "train": os.path.join(WIKITEXT_DATA_PATH, "train.bin"), "val": os.path.join(WIKITEXT_DATA_PATH, "val.bin"), - } \ No newline at end of file + } diff --git a/src/distributed/backend.py b/src/distributed/backend.py index 06d37e8..9fc0539 100644 --- a/src/distributed/backend.py +++ b/src/distributed/backend.py @@ -30,4 +30,4 @@ def get_world_size(self): raise NotImplementedError def finalize(self): - pass \ No newline at end of file + pass diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index 664f069..9226ff1 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -1,9 +1,10 @@ -import os import math +import os from contextlib import contextmanager +from torch.distributed import (destroy_process_group, get_world_size, + init_process_group) from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed import init_process_group, destroy_process_group, get_world_size from .backend import DistributedBackend @@ -59,4 +60,4 @@ def get_world_size(self): return get_world_size() def finalize(self): - destroy_process_group() \ No newline at end of file + destroy_process_group() diff --git a/src/distributed/single.py b/src/distributed/single.py index 7e1f8d5..8ece239 100644 --- a/src/distributed/single.py +++ b/src/distributed/single.py @@ -28,4 +28,4 @@ def get_world_size(self): return 1 def translate_model_parameter_name_for_node(self, parameter_name): - return [parameter_name] \ No newline at end of file + return [parameter_name] diff --git a/src/main.py b/src/main.py index 1577d67..727cf93 100755 --- a/src/main.py +++ b/src/main.py @@ -22,6 +22,7 @@ from optim.lion import Lion from optim.muon import Muon, zeropower_backends from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree +from optim.sign import Signum from optim.soap import SOAP from optim.utils import cos_inf_schedule, wsd_schedule @@ -111,7 +112,7 @@ def main(args, parser): args.world_size = distributed_backend.get_world_size() if args.opt == "adamw": - device_type = 'cuda' if 'cuda' in args.device else 'cpu' + device_type = "cuda" if "cuda" in args.device else "cpu" use_fused = (device_type == "cuda") and ( "fused" in inspect.signature(torch.optim.AdamW).parameters ) @@ -211,6 +212,20 @@ def main(args, parser): n_kv_heads=args.n_kv_head, verbose=args.adam_mini_verbose, ) + elif args.opt == "signsgd": + opt = Signum( + group_specs, + lr=args.lr, + momentum=0.0, + weight_decay=args.weight_decay, + ) + elif args.opt == "signum": + opt = Signum( + group_specs, + lr=args.lr, + momuntum=args.momentum, + weight_decay=args.weight_decay, + ) else: opt = torch.optim.SGD( group_specs, @@ -325,22 +340,41 @@ def get_data_readers(args, verbose=True): } -def get_exp_name(args, parser, distributed_backend, key_args=['model', 'dataset', 'opt'], ignore_args=['eval_interval', 'full_eval_at', 'distributed_backend', 'latest_ckpt_interval', 'wandb', 'wandb_project', 'wandb_entity', 'batch_size', 'acc_steps', 'results_base_folder', 'run_prefix', 'wandb_run_prefix']): +def get_exp_name( + args, + parser, + distributed_backend, + key_args=["model", "dataset", "opt"], + ignore_args=[ + "eval_interval", + "full_eval_at", + "distributed_backend", + "latest_ckpt_interval", + "wandb", + "wandb_project", + "wandb_entity", + "batch_size", + "acc_steps", + "results_base_folder", + "run_prefix", + "wandb_run_prefix", + ], +): # Get the default values defaults = vars(parser.parse_args([])) rank = distributed_backend.rank - + # Generate the prefix with key arguments prefix_parts = [] for key in key_args: if hasattr(args, key): value = getattr(args, key) prefix_parts.append(f"{key}-{value}") - + prefix = "_".join(prefix_parts) prefix = f"{args.batch_size}x{args.acc_steps}(rank={rank})_" + prefix - + # Generate the rest of the string with non-default arguments non_default_parts = [] for key, value in vars(args).items(): @@ -351,18 +385,18 @@ def get_exp_name(args, parser, distributed_backend, key_args=['model', 'dataset' continue if key not in key_args and value != defaults[key]: non_default_parts.append(f"{key}-{value}") - + non_default_string = "_".join(non_default_parts) if args.run_prefix is not None: prefix = args.run_prefix + "_" + prefix - + # Combine prefix and non-default string if non_default_string: return f"{prefix}__{non_default_string}" else: return prefix - + if __name__ == "__main__": args, parser = get_args() diff --git a/src/models/base.py b/src/models/base.py index 3fd01f6..3f2dba3 100755 --- a/src/models/base.py +++ b/src/models/base.py @@ -373,4 +373,4 @@ def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=No .to("cpu") .numpy() ) - return self.tokenizer.decode(out_idx) \ No newline at end of file + return self.tokenizer.decode(out_idx) diff --git a/src/models/llama.py b/src/models/llama.py index 9ea78c3..ebb7430 100644 --- a/src/models/llama.py +++ b/src/models/llama.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch.nn import functional as F + from models.base import CausalSelfAttention, GPTBase @@ -236,4 +237,4 @@ def forward(self, idx, targets=None, get_logits=False): return { "logits": logits, "loss": loss, - } \ No newline at end of file + } diff --git a/src/models/test.py b/src/models/test.py index bd7b1d6..d146880 100644 --- a/src/models/test.py +++ b/src/models/test.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch.nn import functional as F + from models.base import CausalSelfAttention, GPTBase @@ -204,7 +205,7 @@ def forward(self, idx, targets=None, get_logits=False): tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) # pos_emb = self.transformer.wpe(pos) - x = self.transformer.drop(tok_emb) # + pos_emb) + x = self.transformer.drop(tok_emb) # + pos_emb) freqs_cis = self.freqs_cis.to(x.device)[pos] for block in self.transformer.h: @@ -229,4 +230,4 @@ def forward(self, idx, targets=None, get_logits=False): return { "logits": logits, "loss": loss, - } \ No newline at end of file + } diff --git a/src/models/utils.py b/src/models/utils.py index 35003e8..05cd133 100755 --- a/src/models/utils.py +++ b/src/models/utils.py @@ -1,8 +1,9 @@ -from .llama import Llama, RMSNorm -from .base import GPTBase, LayerNorm -from .test import Test, RMSNorm2 import torch +from .base import GPTBase, LayerNorm +from .llama import Llama, RMSNorm +from .test import RMSNorm2, Test + BLACKLIST_WEIGHT_MODULES = ( torch.nn.LayerNorm, LayerNorm, @@ -34,4 +35,4 @@ def get_model(args): ) return model else: - raise KeyError(f"Unknown model '{args.model}'.") \ No newline at end of file + raise KeyError(f"Unknown model '{args.model}'.") diff --git a/src/optim/ademamix2.py b/src/optim/ademamix2.py index df901f5..1db0aad 100644 --- a/src/optim/ademamix2.py +++ b/src/optim/ademamix2.py @@ -177,7 +177,10 @@ def step(self, closure=None): denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - update = exp_avg_fast.div(bias_correction1) / denom + alpha * exp_avg_slow / exp_avg_slow.norm() + update = ( + exp_avg_fast.div(bias_correction1) / denom + + alpha * exp_avg_slow / exp_avg_slow.norm() + ) # decay update.add_(p, alpha=lmbda) diff --git a/src/optim/base.py b/src/optim/base.py index 84de093..e69de29 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -1,264 +0,0 @@ -from contextlib import nullcontext -import copy -from pathlib import Path -import time -import yaml - -import torch -import wandb - -# from logger.logger import DynamicsLogger -from .utils import ( - eval, - get_batch, - load_checkpoint, - load_worker_state, - save_checkpoint, - save_worker_state, -) - - -def train( - model, - opt, - datareaders, - scheduler, - exp_dir, - distributed_backend, - cfg, -): - not_compiled_model = model - if cfg.compile: - print(f"Compiling model ...") - model = torch.compile(model) - - if "cuda" in cfg.device: - type_ctx = torch.amp.autocast( - device_type="cuda", - dtype={ - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - }[cfg.dtype], - ) - else: - type_ctx = nullcontext() - - if cfg.resume_from: - # This is a full resume including the model weights, optimizer, state - # dataloader state, random seed, etc. Not indended for fine tuning or - # other scenarios where some of these should change. - print(f"\nResuming Training From {cfg.resume_from}") - ckpt_dir = Path(cfg.resume_from) - curr_iter = load_checkpoint( - model, - opt, - scheduler, - ckpt_dir / "main.pt", - cfg.device, - ) - load_worker_state(ckpt_dir) - else: - curr_iter = 0 - - # if distributed_backend.is_master_process() and cfg.log_dynamics: - # with open(cfg.dynamics_logger_cfg, "r") as f: - # dlcfg = yaml.safe_load(f) - - # # Hooks into optimizer - # dlogger = DynamicsLogger( - # model, opt, dlcfg, cfg.results_base_folder, wandb=cfg.wandb - # ) - # dlogger.iteration = curr_iter - - substep = curr_iter * cfg.acc_steps - train_reader, val_reader = datareaders["train"], datareaders["val"] - train_reader.set_step(substep) - stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} - grad_norms = [] - model.train() - - while curr_iter <= cfg.iterations: - # Save permanent checkpoint - if cfg.permanent_ckpt_interval > 0: - if curr_iter % cfg.permanent_ckpt_interval == 0: - ckpt_dir = exp_dir / "ckpts" / str(curr_iter) - if distributed_backend.is_master_process(): - save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) - save_worker_state(ckpt_dir) - - # Save temporary checkpoint for resuming training - if cfg.latest_ckpt_interval > 0: - if curr_iter % cfg.latest_ckpt_interval == 0 or curr_iter == cfg.iterations: - ckpt_dir = exp_dir / "ckpts" / "latest" - if distributed_backend.is_master_process(): - save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) - save_worker_state(ckpt_dir) - - ws = distributed_backend.get_world_size() - tokens = ws * substep * cfg.sequence_length * cfg.batch_size - epoch = tokens / train_reader.num_tokens - if ( - curr_iter % cfg.eval_interval == 0 - or curr_iter == cfg.iterations - or (curr_iter in cfg.full_eval_at) - ): - eval_and_log( - curr_iter, - epoch, - model, - val_reader, - type_ctx, - distributed_backend, - cfg, - opt, - full_eval=(curr_iter in cfg.full_eval_at), - ) - - if curr_iter == cfg.iterations: - # Save checkpoints and evaluate at final iteration, but no need to train further - break - - # Train model - t_start = time.perf_counter_ns() - for microstep_idx in range(cfg.acc_steps): # gradient accumulation - x, y = get_batch(train_reader, device=cfg.device) - with type_ctx: - with distributed_backend.get_context_for_microstep_forward( - model=model, - microstep_idx=microstep_idx, - gradient_accumulation_steps=cfg.acc_steps, - ): - outputs = model(x, targets=y) - - loss = outputs["loss"] / cfg.acc_steps - loss.backward() - substep += 1 - - if cfg.grad_clip != 0.0: - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - grad_norm = torch.nn.utils.clip_grad_norm_(model.module.parameters(), cfg.grad_clip) - else: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) - grad_norms.append(grad_norm) - - if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": - opt.train() - opt.step() - if cfg.scheduler != "none": - scheduler.step() - opt.zero_grad(set_to_none=True) - dt = (time.perf_counter_ns() - t_start) / 1e9 - - curr_iter += 1 - - if ( - cfg.log_interval - and curr_iter % cfg.log_interval == 0 - and distributed_backend.is_master_process() # Only log on master rank - ): - train_loss = loss.detach().cpu().item() * cfg.acc_steps - - current_lrs = [param_group["lr"] for param_group in opt.param_groups] - - print( - f"Train: Iter={curr_iter} ({epoch:0.3f} epochs) " - f"train_loss={train_loss:.3f} iter_dt={dt:.2e}s " - f"lr={current_lrs[0]:.2e}" - ) - - if cfg.wandb: - wandb.log( - { - "iter": curr_iter, - "train/loss": train_loss, - "train/perplexity": 2.71828**train_loss, - "lr": current_lrs[0], - "iter_dt": dt, - "max_grad_norm": max(grad_norms).item() if grad_norms else 0, - "mean_grad_norm": torch.tensor(grad_norms).mean().item() if grad_norms else 0, - } - ) - - grad_norms = [] - - - return stats - - -def eval_and_log( - curr_iter, - epoch, - model, - val_reader, - type_ctx, - distributed_backend, - cfg, - opt, - full_eval=False, -): - if not distributed_backend.is_master_process(): - # Only evaluate and log on master rank - return - - model.eval() - if cfg.opt == "SFAdamW": - opt.eval() - - if curr_iter == cfg.iterations or full_eval: - max_num_batches = val_reader.num_batches() - else: - max_num_batches = cfg.eval_batches - - # to make sure we start from the beginning of the validation set, - # i.e. repeat the same batches - val_reader.set_step(0) - val_acc, val_loss, val_perplexity = eval( - model, - val_reader, - cfg.device, - max_num_batches=max_num_batches, - ctx=type_ctx, - cfg=cfg, - ) - - print( - f">Eval: Iter={curr_iter} ({epoch:0.3f} epochs) " - f"val_loss={val_loss:.3f} " - f"val_pp={val_perplexity:.3f} " - f"val_acc={val_acc:3f}" - ) - - if cfg.wandb: - if curr_iter == cfg.iterations or full_eval: - logs = { - "iter": curr_iter, - "final-val/loss": val_loss, - "final-val/perplexity": val_perplexity, - "final-val/acc": val_acc, - } - else: - logs = { - "iter": curr_iter, - "val/loss": val_loss, - "val/perplexity": val_perplexity, - "val/acc": val_acc, - } - - wandb.log(logs) - if cfg.eval_seq_prefix != "none" and ( - curr_iter % (cfg.eval_interval * 5) == 0 or curr_iter == cfg.iterations - ): - text_table = wandb.Table(columns=["itr", "val-pp", "text"]) - - out_str = distributed_backend.get_raw_model(model).generate_from_string( - cfg.eval_seq_prefix, - max_new_tokens=40, - temperature=0.9, - top_k=None, - ) - text_table.add_data(curr_iter, val_perplexity, out_str) - # why a copy? see github.com/wandb/wandb/issues/2981 - wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) - model.train() - diff --git a/src/optim/sign.py b/src/optim/sign.py new file mode 100644 index 0000000..d5bc45e --- /dev/null +++ b/src/optim/sign.py @@ -0,0 +1,71 @@ +import torch + + +class Signum(torch.optim.Optimizer): + r"""Implements Signum optimizer that takes the sign of gradient or momentum. + + See details in the original paper at: https://arxiv.org/abs/1711.05101 + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0.9) + weight_decay (float, optional): weight decay (default: 0) + + Example: + >>> optimizer = signum.Signum(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + """ + + def __init__(self, params, lr=0.01, momentum=0.9, weight_decay=0, **kwargs): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) + + super(Signum, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Signum, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group["weight_decay"] + momentum = group["momentum"] + + for p in group["params"]: + if p.grad is None: + continue + d_p = p.grad.data + if weight_decay != 0: + d_p.add_(p.data, alpha=weight_decay) + if momentum != 0: + # Signum + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) + + else: + buf = param_state["momentum_buffer"] + + buf.mul_(momentum).add_(d_p, alpha=(1 - momentum)) + d_p = torch.sign(buf) + else: + # signSGD + d_p = torch.sign(d_p) + + p.data.add_(d_p, alpha=-group["lr"]) + + return loss diff --git a/src/optim/utils.py b/src/optim/utils.py index 87c8c10..893ea26 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -1,11 +1,13 @@ -from pathlib import Path +import math import random +from contextlib import nullcontext +from pathlib import Path + import numpy as np import torch -import torch.nn.functional as F -from contextlib import nullcontext import torch.distributed as dist -import math +import torch.nn.functional as F + import wandb @@ -247,7 +249,7 @@ def save_checkpoint(model, opt, scheduler, itr, ckpt_dir: Path): checkpoint = { "model": model.state_dict(), "optimizer": opt.state_dict(), - "scheduler": scheduler.state_dict(), + "scheduler": scheduler.state_dict() if scheduler is not None else None, "itr": itr, } ckpt_dir.mkdir(exist_ok=True, parents=True) @@ -261,7 +263,8 @@ def load_checkpoint(model, opt, scheduler, ckpt_path, device): ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt["model"]) opt.load_state_dict(ckpt["optimizer"]) - scheduler.load_state_dict(ckpt["scheduler"]) + if scheduler is not None: + scheduler.load_state_dict(ckpt["scheduler"]) itr = ckpt["itr"] return itr @@ -285,4 +288,4 @@ def load_worker_state(ckpt_dir: Path): torch.random.set_rng_state(worker_state["rng_torch_cpu"]) torch.cuda.set_rng_state(worker_state["rng_torch_gpu"]) np.random.set_state(worker_state["rng_np"]) - random.setstate(worker_state["rng_python"]) \ No newline at end of file + random.setstate(worker_state["rng_python"]) From 5ddffa52742106ed69c34ad6b4d240012e0cd2b0 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 24 Oct 2024 17:38:04 +0300 Subject: [PATCH 25/58] updated signum --- README.md | 3 +- src/config/base.py | 3 +- src/main.py | 13 ++- src/optim/base.py | 263 +++++++++++++++++++++++++++++++++++++++++++++ src/optim/sign.py | 150 ++++++++++++++++---------- 5 files changed, 373 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 136c53e..a78a996 100755 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ parser.add_argument('--precondition_1d', default=False, type=bool) parser.add_argument('--normalize_grads', default=False, type=bool) parser.add_argument('--soap_data_format', default='channels_first', type=str) parser.add_argument('--correct_bias', default=True, type=bool) -parser.add_argument('--nesterov', default=True, type=bool) # whether to use Nesterov-style momentum +parser.add_argument('--nesterov', default=False, type=bool) # whether to use Nesterov-style momentum parser.add_argument('--muon_backend', default='newtonschulz5', type=str) # the chosen backend for the orthogonalization step parser.add_argument('--muon_backend_steps', default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative parser.add_argmunet('--adema_beta3', default=0.9, type=float) # beta3 in AdEMAMix @@ -76,6 +76,7 @@ parser.add_argument('--weight_lr_power', default=2.0, type=float) # schedulefree parser.add_argument('--model_sharding', default=None, type=bool) # Adam-mini parser.add_argument('--adam_mini_verbose', default=False, type=bool) # print all the logs if true parser.add_argument('--log_interval', default=50, type=int) +parser.add_argument('--dampening', default=0.0, type=float) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) diff --git a/src/config/base.py b/src/config/base.py index 3d27e70..efb8be7 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -109,7 +109,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--normalize_grads", default=False, type=bool) parser.add_argument("--soap_data_format", default="channels_first", type=str) parser.add_argument("--correct_bias", default=True, type=bool) - parser.add_argument("--nesterov", default=True, type=bool) + parser.add_argument("--nesterov", default=False, type=bool) parser.add_argument("--muon_backend", default="newtonschulz5", type=str) parser.add_argument("--muon_backend_steps", default=5, type=int) parser.add_argument("--adema_beta3", default=0.9, type=float) @@ -120,6 +120,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--weight_lr_power", default=2.0, type=float) parser.add_argument("--model_sharding", default=None, type=bool) parser.add_argument("--adam_mini_verbose", default=False, type=bool) + parser.add_argument("--dampening", default=0.0, type=float) # Dataset params parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") diff --git a/src/main.py b/src/main.py index 727cf93..4693a5f 100755 --- a/src/main.py +++ b/src/main.py @@ -145,7 +145,7 @@ def main(args, parser): group_specs, lr=args.lr, momentum=args.momentum, - nesterov=args.nesterov, + nesterov=args.nesterov, # use True for Muon as a default backend=args.muon_backend, backend_steps=args.muon_backend_steps, # rank=args.rank, @@ -216,15 +216,21 @@ def main(args, parser): opt = Signum( group_specs, lr=args.lr, - momentum=0.0, + momentum=0.0, # always use zero momentum because its signSGD + dampening=args.dampening, weight_decay=args.weight_decay, + nesterov=args.nesterov, + sign_update=True, ) elif args.opt == "signum": opt = Signum( group_specs, lr=args.lr, - momuntum=args.momentum, + momentum=args.momentum, weight_decay=args.weight_decay, + dampening=args.dampening, + nesterov=args.nesterov, + sign_update=True, ) else: opt = torch.optim.SGD( @@ -232,6 +238,7 @@ def main(args, parser): lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + nesterov=args.nesterov, ) print(f"\nOptimizer:\n{opt}") diff --git a/src/optim/base.py b/src/optim/base.py index e69de29..4359d9d 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -0,0 +1,263 @@ +import copy +import time +from contextlib import nullcontext +from pathlib import Path + +import torch +import yaml + +import wandb + +# from logger.logger import DynamicsLogger +from .utils import (eval, get_batch, load_checkpoint, load_worker_state, + save_checkpoint, save_worker_state) + + +def train( + model, + opt, + datareaders, + scheduler, + exp_dir, + distributed_backend, + cfg, +): + not_compiled_model = model + if cfg.compile: + print(f"Compiling model ...") + model = torch.compile(model) + + if "cuda" in cfg.device: + type_ctx = torch.amp.autocast( + device_type="cuda", + dtype={ + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[cfg.dtype], + ) + else: + type_ctx = nullcontext() + + if cfg.resume_from: + # This is a full resume including the model weights, optimizer, state + # dataloader state, random seed, etc. Not indended for fine tuning or + # other scenarios where some of these should change. + print(f"\nResuming Training From {cfg.resume_from}") + ckpt_dir = Path(cfg.resume_from) + curr_iter = load_checkpoint( + model, + opt, + scheduler, + ckpt_dir / "main.pt", + cfg.device, + ) + load_worker_state(ckpt_dir) + else: + curr_iter = 0 + + # if distributed_backend.is_master_process() and cfg.log_dynamics: + # with open(cfg.dynamics_logger_cfg, "r") as f: + # dlcfg = yaml.safe_load(f) + + # # Hooks into optimizer + # dlogger = DynamicsLogger( + # model, opt, dlcfg, cfg.results_base_folder, wandb=cfg.wandb + # ) + # dlogger.iteration = curr_iter + + substep = curr_iter * cfg.acc_steps + train_reader, val_reader = datareaders["train"], datareaders["val"] + train_reader.set_step(substep) + stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} + grad_norms = [] + model.train() + + while curr_iter <= cfg.iterations: + # Save permanent checkpoint + if cfg.permanent_ckpt_interval > 0: + if curr_iter % cfg.permanent_ckpt_interval == 0: + ckpt_dir = exp_dir / "ckpts" / str(curr_iter) + if distributed_backend.is_master_process(): + save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) + save_worker_state(ckpt_dir) + + # Save temporary checkpoint for resuming training + if cfg.latest_ckpt_interval > 0: + if curr_iter % cfg.latest_ckpt_interval == 0 or curr_iter == cfg.iterations: + ckpt_dir = exp_dir / "ckpts" / "latest" + if distributed_backend.is_master_process(): + save_checkpoint(model, opt, scheduler, curr_iter, ckpt_dir) + save_worker_state(ckpt_dir) + + ws = distributed_backend.get_world_size() + tokens = ws * substep * cfg.sequence_length * cfg.batch_size + epoch = tokens / train_reader.num_tokens + if ( + curr_iter % cfg.eval_interval == 0 + or curr_iter == cfg.iterations + or (curr_iter in cfg.full_eval_at) + ): + eval_and_log( + curr_iter, + epoch, + model, + val_reader, + type_ctx, + distributed_backend, + cfg, + opt, + full_eval=(curr_iter in cfg.full_eval_at), + ) + + if curr_iter == cfg.iterations: + # Save checkpoints and evaluate at final iteration, but no need to train further + break + + # Train model + t_start = time.perf_counter_ns() + for microstep_idx in range(cfg.acc_steps): # gradient accumulation + x, y = get_batch(train_reader, device=cfg.device) + with type_ctx: + with distributed_backend.get_context_for_microstep_forward( + model=model, + microstep_idx=microstep_idx, + gradient_accumulation_steps=cfg.acc_steps, + ): + outputs = model(x, targets=y) + + loss = outputs["loss"] / cfg.acc_steps + loss.backward() + substep += 1 + + if cfg.grad_clip != 0.0: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + grad_norm = torch.nn.utils.clip_grad_norm_( + model.module.parameters(), cfg.grad_clip + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), cfg.grad_clip + ) + grad_norms.append(grad_norm) + + if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": + opt.train() + opt.step() + if cfg.scheduler != "none": + scheduler.step() + opt.zero_grad(set_to_none=True) + dt = (time.perf_counter_ns() - t_start) / 1e9 + + curr_iter += 1 + + if ( + cfg.log_interval + and curr_iter % cfg.log_interval == 0 + and distributed_backend.is_master_process() # Only log on master rank + ): + train_loss = loss.detach().cpu().item() * cfg.acc_steps + + current_lrs = [param_group["lr"] for param_group in opt.param_groups] + + print( + f"Train: Iter={curr_iter} ({epoch:0.3f} epochs) " + f"train_loss={train_loss:.3f} iter_dt={dt:.2e}s " + f"lr={current_lrs[0]:.2e}" + ) + + if cfg.wandb: + wandb.log( + { + "iter": curr_iter, + "train/loss": train_loss, + "train/perplexity": 2.71828**train_loss, + "lr": current_lrs[0], + "iter_dt": dt, + "max_grad_norm": max(grad_norms).item() if grad_norms else 0, + "mean_grad_norm": ( + torch.tensor(grad_norms).mean().item() if grad_norms else 0 + ), + } + ) + + grad_norms = [] + + return stats + + +def eval_and_log( + curr_iter, + epoch, + model, + val_reader, + type_ctx, + distributed_backend, + cfg, + opt, + full_eval=False, +): + if not distributed_backend.is_master_process(): + # Only evaluate and log on master rank + return + + model.eval() + if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": + opt.eval() + + if curr_iter == cfg.iterations or full_eval: + max_num_batches = val_reader.num_batches() + else: + max_num_batches = cfg.eval_batches + + # to make sure we start from the beginning of the validation set, + # i.e. repeat the same batches + val_reader.set_step(0) + val_acc, val_loss, val_perplexity = eval( + model, + val_reader, + cfg.device, + max_num_batches=max_num_batches, + ctx=type_ctx, + cfg=cfg, + ) + + print( + f">Eval: Iter={curr_iter} ({epoch:0.3f} epochs) " + f"val_loss={val_loss:.3f} " + f"val_pp={val_perplexity:.3f} " + f"val_acc={val_acc:3f}" + ) + + if cfg.wandb: + if curr_iter == cfg.iterations or full_eval: + logs = { + "iter": curr_iter, + "final-val/loss": val_loss, + "final-val/perplexity": val_perplexity, + "final-val/acc": val_acc, + } + else: + logs = { + "iter": curr_iter, + "val/loss": val_loss, + "val/perplexity": val_perplexity, + "val/acc": val_acc, + } + + wandb.log(logs) + if cfg.eval_seq_prefix != "none" and ( + curr_iter % (cfg.eval_interval * 5) == 0 or curr_iter == cfg.iterations + ): + text_table = wandb.Table(columns=["itr", "val-pp", "text"]) + + out_str = distributed_backend.get_raw_model(model).generate_from_string( + cfg.eval_seq_prefix, + max_new_tokens=40, + temperature=0.9, + top_k=None, + ) + text_table.add_data(curr_iter, val_perplexity, out_str) + # why a copy? see github.com/wandb/wandb/issues/2981 + wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) + model.train() diff --git a/src/optim/sign.py b/src/optim/sign.py index d5bc45e..c1ac3b8 100644 --- a/src/optim/sign.py +++ b/src/optim/sign.py @@ -1,71 +1,113 @@ +from typing import Dict + import torch class Signum(torch.optim.Optimizer): - r"""Implements Signum optimizer that takes the sign of gradient or momentum. - - See details in the original paper at: https://arxiv.org/abs/1711.05101 - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float): learning rate - momentum (float, optional): momentum factor (default: 0.9) - weight_decay (float, optional): weight decay (default: 0) - - Example: - >>> optimizer = signum.Signum(model.parameters(), lr=0.1, momentum=0.9) - >>> optimizer.zero_grad() - >>> loss_fn(model(input), target).backward() - >>> optimizer.step() - """ - - def __init__(self, params, lr=0.01, momentum=0.9, weight_decay=0, **kwargs): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= momentum: - raise ValueError("Invalid momentum value: {}".format(momentum)) - if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - - defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) - - super(Signum, self).__init__(params, defaults) + def __init__( + self, + params, + lr=1e-3, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + sign_update=True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + sign_update=sign_update, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) def __setstate__(self, state): - super(Signum, self).__setstate__(state) - + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def _init_state(self, example, state=None): + assert isinstance(example, torch.Tensor) + assert isinstance(state, Dict) or state is None + if state is None: + state = {} + state["step"] = 0 + state["momentum_buffer"] = torch.clone(example).detach() + return state + + @torch.no_grad() + def _compute_update( + self, grad, state, lr, momentum, nesterov, dampening, sign_update, **kwargs + ): + + if momentum != 0: # Signum check + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + if sign_update: + grad = grad.sign() + + return grad * (-lr) + + @torch.no_grad() def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: - weight_decay = group["weight_decay"] - momentum = group["momentum"] - for p in group["params"]: if p.grad is None: continue - d_p = p.grad.data - if weight_decay != 0: - d_p.add_(p.data, alpha=weight_decay) - if momentum != 0: - # Signum - param_state = self.state[p] - if "momentum_buffer" not in param_state: - buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) - - else: - buf = param_state["momentum_buffer"] - - buf.mul_(momentum).add_(d_p, alpha=(1 - momentum)) - d_p = torch.sign(buf) - else: - # signSGD - d_p = torch.sign(d_p) - - p.data.add_(d_p, alpha=-group["lr"]) + + grad = p.grad + state = self.state[p] + + if group["weight_decay"] != 0: + grad = grad.add(p, alpha=group["weight_decay"]) + + if len(state) == 0: + self._init_state(example=p, state=state) + if not group["momentum"]: + state.pop("momentum_buffer", None) + + state["step"] += 1 + + update = self._compute_update( + grad, + state, + group["lr"], + group["momentum"], + group["nesterov"], + group["dampening"], + group["sign_update"], + ) + + p.add_(update) return loss From 3311ca829a3504b60a1177d8b194ced92801dab1 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 31 Oct 2024 15:45:19 +0300 Subject: [PATCH 26/58] changed logic a bit, schedules moved from utils to schedules.py; cos_wsd is testing --- README.md | 2 +- src/config/base.py | 2 +- src/main.py | 14 ++- src/optim/schedule.py | 197 ++++++++++++++++++++++++++++++++++++++++++ src/optim/utils.py | 104 ---------------------- 5 files changed, 212 insertions(+), 107 deletions(-) create mode 100644 src/optim/schedule.py diff --git a/README.md b/README.md index a78a996..2345593 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT diff --git a/src/config/base.py b/src/config/base.py index efb8be7..a902761 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -56,7 +56,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument( "--scheduler", default="cos", - choices=["linear", "cos", "wsd", "none", "cos_inf"], + choices=["linear", "cos", "wsd", "none", "cos_inf", "cos_wsd"], ) parser.add_argument("--cos_inf_steps", default=0, type=int) # parser.add_argument("--cos-final-lr", default=1e-6, type=float) diff --git a/src/main.py b/src/main.py index 4693a5f..0125297 100755 --- a/src/main.py +++ b/src/main.py @@ -21,10 +21,11 @@ from optim.base import train from optim.lion import Lion from optim.muon import Muon, zeropower_backends +from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, + wsd_schedule) from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree from optim.sign import Signum from optim.soap import SOAP -from optim.utils import cos_inf_schedule, wsd_schedule def get_args(): @@ -281,6 +282,17 @@ def main(args, parser): decay_type=args.decay_type, ) scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + elif args.scheduler == "cos_wsd": + lambda_schedule = cosine_wsd_decay_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + anneal_end_factor=0.15, # 0.2 + fract_decay=args.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=0.1, # should be 0 here + decay_type=args.decay_type, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) else: raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") else: diff --git a/src/optim/schedule.py b/src/optim/schedule.py new file mode 100644 index 0000000..89538d4 --- /dev/null +++ b/src/optim/schedule.py @@ -0,0 +1,197 @@ +import math + +import numpy as np + + +def cos_inf_schedule(n_iterations, n_warmup, div_factor, final_div_factor, n_inf): + """Cosine annealing with warmup and _constant_ final_lr after cycle ended. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + div_factor: initial division factor for warmup + final_div_factor: final division factor for final lr + n_inf: number of iterations for the final lr (constant lr after cycle ended) + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + max_lr = 1.0 + base_lr = max_lr / div_factor + final_lr = base_lr / final_div_factor + + n_anneal_steps = n_iterations - n_inf + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / div_factor + elif step < n_anneal_steps: + t = (step - n_warmup) / (n_anneal_steps - n_warmup) + lr = final_lr + 0.5 * (max_lr - final_lr) * (1 + np.cos(np.pi * t)) + return lr + else: + return final_lr + + return schedule + + +def wsd_schedule( + n_iterations, + final_lr_factor=0.0, + n_warmup=1000, + init_div_factor=100, + fract_decay=0.1, + decay_type="linear", +): + """Warmup, hold, and decay schedule. + Args: + n_iterations: total number of iterations + final_lr_factor: factor by which to reduce max_lr at the end + warmup_fract: fraction of iterations used for warmup + init_div_factor: initial division factor for warmup + fract_decay: fraction of iterations used for decay + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + n_anneal_steps = int(fract_decay * n_iterations) + n_hold = n_iterations - n_anneal_steps + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor + elif step < n_hold: + return 1.0 + elif step < n_iterations: + if decay_type == "linear": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - (step - n_hold) / n_anneal_steps + ) + elif decay_type == "exp": + return final_lr_factor ** ((step - n_hold) / n_anneal_steps) + elif decay_type == "cosine": + return ( + final_lr_factor + + (1 - final_lr_factor) + * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) + * 0.5 + ) + elif decay_type == "miror_cosine": + cosine_value = ( + final_lr_factor + + (1 - final_lr_factor) + * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) + * 0.5 + ) + linear_value = final_lr_factor + (1 - final_lr_factor) * ( + 1 - (step - n_hold) / n_anneal_steps + ) + return linear_value * 2 - cosine_value + elif decay_type == "square": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - ((step - n_hold) / n_anneal_steps) ** 2 + ) + + elif decay_type == "sqrt": + return final_lr_factor + (1 - final_lr_factor) * ( + 1 - math.sqrt((step - n_hold) / n_anneal_steps) + ) + + else: + raise ValueError( + f"decay type {decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + + else: + return final_lr_factor + + return schedule + + +def cosine_wsd_decay_schedule( + n_iterations, + n_warmup=1000, + anneal_end_factor=0.15, + final_lr_factor=0.0, + init_div_factor=1e-2, + fract_decay=0.1, + decay_type="linear", +): + """Warmup, cosine, and wsd-like decay schedule. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + anneal_end_factor: factor at which cosine annealing ends + final_lr_factor: factor by which to reduce max_lr at the end + init_div_factor: initial division factor for warmup + fract_decay: fraction of iterations used for decay + decay_type: type of decay after cosine phase + ['linear', 'exp', 'cosine', 'mirror_cosine', 'square', 'sqrt'] + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + valid_decay_types = ["linear", "exp", "cosine", "mirror_cosine", "square", "sqrt"] + if decay_type not in valid_decay_types: + raise ValueError(f"decay_type {decay_type} is not in {valid_decay_types}") + + max_lr = 1.0 + base_lr = max_lr / init_div_factor + # final_lr = base_lr / final_lr_factor + n_decay_steps = int(fract_decay * n_iterations) + n_hold = n_iterations - n_decay_steps + cosine_start = n_warmup + cosine_end = n_warmup + n_hold + + def schedule(step): + if step < n_warmup: + # Warmup phase + return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor + + elif step < cosine_end: + # Cosine regime + t = (step - cosine_start) / (cosine_end - cosine_start) + return anneal_end_factor + (max_lr - anneal_end_factor) * 0.5 * ( + 1 + math.cos(math.pi * t) + ) + + elif step < n_iterations: + # Decay regime + progress = (step - cosine_end) / n_decay_steps + + if decay_type == "linear": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + 1 - progress + ) + + elif decay_type == "exp": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + final_lr_factor ** (progress) + ) + + elif decay_type == "cosine": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + (1 + math.cos(math.pi * progress)) * 0.5 + ) + + elif decay_type == "mirror_cosine": + cosine_value = final_lr_factor + ( + anneal_end_factor - final_lr_factor + ) * ((1 + math.cos(math.pi * progress)) * 0.5) + linear_value = final_lr_factor + ( + anneal_end_factor - final_lr_factor + ) * (1 - progress) + return linear_value * 2 - cosine_value + + elif decay_type == "square": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + 1 - progress**2 + ) + + elif decay_type == "sqrt": + return final_lr_factor + (anneal_end_factor - final_lr_factor) * ( + 1 - math.sqrt(progress) + ) + + return final_lr_factor + + return schedule diff --git a/src/optim/utils.py b/src/optim/utils.py index 893ea26..1bdd01e 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -23,110 +23,6 @@ def get_batch(datareader, device="cpu"): return x, y -def cos_inf_schedule(n_iterations, n_warmup, div_factor, final_div_factor, n_inf): - """Cosine annealing with warmup and _constant_ final_lr after cycle ended. - Args: - n_iterations: total number of iterations - n_warmup: number of warmup iterations - div_factor: initial division factor for warmup - final_div_factor: final division factor for final lr - n_inf: number of iterations for the final lr (constant lr after cycle ended) - Returns: - schedule: a function that takes the current iteration and - returns the multiplicative factor for the learning rate - """ - max_lr = 1.0 - base_lr = max_lr / div_factor - final_lr = base_lr / final_div_factor - - n_anneal_steps = n_iterations - n_inf - - def schedule(step): - if step < n_warmup: - return (step / n_warmup) + (1 - step / n_warmup) / div_factor - elif step < n_anneal_steps: - t = (step - n_warmup) / (n_anneal_steps - n_warmup) - lr = final_lr + 0.5 * (max_lr - final_lr) * (1 + np.cos(np.pi * t)) - return lr - else: - return final_lr - - return schedule - - -def wsd_schedule( - n_iterations, - final_lr_factor=0.0, - n_warmup=1000, - init_div_factor=100, - fract_decay=0.1, - decay_type="linear", -): - """Warmup, hold, and decay schedule. - Args: - n_iterations: total number of iterations - final_lr_factor: factor by which to reduce max_lr at the end - warmup_fract: fraction of iterations used for warmup - init_div_factor: initial division factor for warmup - fract_decay: fraction of iterations used for decay - Returns: - schedule: a function that takes the current iteration and - returns the multiplicative factor for the learning rate - """ - n_anneal_steps = int(fract_decay * n_iterations) - n_hold = n_iterations - n_anneal_steps - - def schedule(step): - if step < n_warmup: - return (step / n_warmup) + (1 - step / n_warmup) / init_div_factor - elif step < n_hold: - return 1.0 - elif step < n_iterations: - if decay_type == "linear": - return final_lr_factor + (1 - final_lr_factor) * ( - 1 - (step - n_hold) / n_anneal_steps - ) - elif decay_type == "exp": - return final_lr_factor ** ((step - n_hold) / n_anneal_steps) - elif decay_type == "cosine": - return ( - final_lr_factor - + (1 - final_lr_factor) - * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) - * 0.5 - ) - elif decay_type == "miror_cosine": - cosine_value = ( - final_lr_factor - + (1 - final_lr_factor) - * (1 + math.cos(math.pi * (step - n_hold) / n_anneal_steps)) - * 0.5 - ) - linear_value = final_lr_factor + (1 - final_lr_factor) * ( - 1 - (step - n_hold) / n_anneal_steps - ) - return linear_value * 2 - cosine_value - elif decay_type == "square": - return final_lr_factor + (1 - final_lr_factor) * ( - 1 - ((step - n_hold) / n_anneal_steps) ** 2 - ) - - elif decay_type == "sqrt": - return final_lr_factor + (1 - final_lr_factor) * ( - 1 - math.sqrt((step - n_hold) / n_anneal_steps) - ) - - else: - raise ValueError( - f"decay type {decay_type} is not in ['cosine','miror_cosine','linear','exp']" - ) - - else: - return final_lr_factor - - return schedule - - @torch.no_grad() def eval( model, From 1ca7a763da3ee1fd60100a7c23c02d102f95b8ec Mon Sep 17 00:00:00 2001 From: Andron00e Date: Mon, 4 Nov 2024 23:40:40 +0300 Subject: [PATCH 27/58] sgdf --- README.md | 2 +- src/config/base.py | 1 + src/main.py | 8 +++++ src/optim/sgdf.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 src/optim/sgdf.py diff --git a/README.md b/README.md index 2345593..82b9e44 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT diff --git a/src/config/base.py b/src/config/base.py index a902761..d6507fd 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -90,6 +90,7 @@ def parse_args(base_parser, args, namespace): "adam-mini", "signsgd", "signum", + "sgdf", ], ) parser.add_argument("--batch_size", default=50, type=int) diff --git a/src/main.py b/src/main.py index 0125297..a3236e8 100755 --- a/src/main.py +++ b/src/main.py @@ -24,6 +24,7 @@ from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule) from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree +from optim.sgdf import SGDF from optim.sign import Signum from optim.soap import SOAP @@ -233,6 +234,13 @@ def main(args, parser): nesterov=args.nesterov, sign_update=True, ) + elif args.opt == "sgdf": + opt = SGDF( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/sgdf.py b/src/optim/sgdf.py new file mode 100644 index 0000000..ab873af --- /dev/null +++ b/src/optim/sgdf.py @@ -0,0 +1,78 @@ +import torch + + +class SGDF(torch.optim.Optimizer): + def __init__(self, params, lr=1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0): + if lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if betas[0] < 0.0 or betas[1] < 0.0: + raise ValueError("Invalid beta value: {}".format(betas)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(SGDF, self).__init__(params, defaults) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + eps = group["eps"] + beta1, beta2 = group["betas"] + weight_decay = group["weight_decay"] + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad.data + if weight_decay != 0.0: + grad.add_(p.data, alpha=weight_decay) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(grad) + state["exp_var"] = torch.zeros_like(grad) + + exp_avg = state["exp_avg"] + exp_var = state["exp_var"] + + # Compute gradient 1st and 2nd + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + grad_residual = grad - exp_avg + + exp_var.mul_(beta2).addcmul_( + grad_residual, grad_residual, value=1 - beta2 + ) + + state["step"] += 1 + + # Bias correction + bias_correction1 = 1 - beta1 ** state["step"] + # bias_correction2 = 1 - beta2 ** state['step'] + bias_correction2 = ( + (1 + beta1) + * (1 - beta2 ** state["step"]) + / ((1 - beta1) * (1 - beta1 ** (2 * state["step"]))) + ) + + exp_avg_corr = exp_avg / bias_correction1 + exp_var_corr = exp_var / bias_correction2 + + # Wiener gain + K = exp_var_corr / ( + exp_var_corr + (grad - exp_avg_corr).pow(2).add_(eps) + ) + + grad_hat_residual = grad - exp_avg_corr + grad_hat = exp_avg_corr + K * grad_hat_residual + + p.data.add_(grad_hat, alpha=-lr) + + return loss From d485b7492158c96784c80b095b79878d9280fde4 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 5 Nov 2024 00:56:47 +0300 Subject: [PATCH 28/58] prodigy --- README.md | 7 +- src/config/base.py | 6 + src/main.py | 17 ++- src/optim/muon.py | 221 +++++++++++++++++------------------ src/optim/prodigy.py | 269 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 403 insertions(+), 117 deletions(-) create mode 100644 src/optim/prodigy.py diff --git a/README.md b/README.md index 82b9e44..ee0c233 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -77,6 +77,11 @@ parser.add_argument('--model_sharding', default=None, type=bool) # Adam-mini parser.add_argument('--adam_mini_verbose', default=False, type=bool) # print all the logs if true parser.add_argument('--log_interval', default=50, type=int) parser.add_argument('--dampening', default=0.0, type=float) +parser.add_argument("--prodigy_beta3", default=None, type=float) # coefficients for computing the Prodidy stepsize using running averages +parser.add_argument("--prodigy_decouple", default=True, type=bool) # Use AdamW style decoupled weight decay +parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) +parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. +parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) diff --git a/src/config/base.py b/src/config/base.py index d6507fd..9b775e0 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -91,6 +91,7 @@ def parse_args(base_parser, args, namespace): "signsgd", "signum", "sgdf", + "prodigy", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -122,6 +123,11 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--model_sharding", default=None, type=bool) parser.add_argument("--adam_mini_verbose", default=False, type=bool) parser.add_argument("--dampening", default=0.0, type=float) + parser.add_argument("--prodigy_beta3", default=None, type=float) + parser.add_argument("--prodigy_decouple", default=True, type=bool) + parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) + parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) + parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) # Dataset params parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") diff --git a/src/main.py b/src/main.py index a3236e8..ffe1d37 100755 --- a/src/main.py +++ b/src/main.py @@ -21,6 +21,7 @@ from optim.base import train from optim.lion import Lion from optim.muon import Muon, zeropower_backends +from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule) from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree @@ -150,9 +151,7 @@ def main(args, parser): nesterov=args.nesterov, # use True for Muon as a default backend=args.muon_backend, backend_steps=args.muon_backend_steps, - # rank=args.rank, - # world_size=args.world_size, - ) # i have left rank and world_size inside Muon + ) elif args.opt == "ademamix": opt = AdEMAMix( group_specs, @@ -241,6 +240,18 @@ def main(args, parser): betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, ) + elif args.opt == "prodigy": + opt = Prodigy( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + beta3=args.prodigy_beta3, + weight_decay=args.weight_decay, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + fsdp_in_use=args.prodigy_fsdp_in_use, + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/muon.py b/src/optim/muon.py index e518e87..cd4ef6b 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -3,6 +3,8 @@ Source: https://github.com/KellerJordan/modded-nanogpt """ +import os + import torch import torch.distributed as dist @@ -43,6 +45,96 @@ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): ) +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - This optimizer assumes that all parameters passed in are 2D. + - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D + parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + - We believe it is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). + + Arguments: + lr: The learning rate used by the internal SGD. + momentum: The momentum used by the internal SGD. + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') + backend_steps: The number of iteration steps to use in the backend, if it is iterative. + """ + + def __init__( + self, + params, + lr=0.02, + momentum=0.95, + nesterov=True, + backend="newtonschulz5", + backend_steps=5, + ): + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + backend=backend, + backend_steps=backend_steps, + ) + super().__init__(params, defaults) + + def step(self): + + for group in self.param_groups: + + lr = group["lr"] + momentum = group["momentum"] + zeropower_backend = zeropower_backends[group["backend"]] + + # generate weight updates in distributed fashion + total_params = sum(p.numel() for p in group["params"]) + updates_flat = torch.zeros( + total_params, device="cuda", dtype=torch.bfloat16 + ) + curr_idx = 0 + for i, p in enumerate(group["params"]): + # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs + if i % int(os.environ["WORLD_SIZE"]) == int(os.environ["RANK"]): + g = p.grad + assert g is not None + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + g = zeropower_backend(g, steps=group["backend_steps"]) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() + curr_idx += p.numel() + + # sync updates across devices. we are not memory-constrained so can do this simple deserialization + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + # deserialize and apply updates + curr_idx = 0 + for p in group["params"]: + g = ( + updates_flat[curr_idx : curr_idx + p.numel()] + .view_as(p.data) + .type_as(p.data) + ) + p.data.add_(g, alpha=-lr) + curr_idx += p.numel() + + # class Muon(torch.optim.Optimizer): # """ # Muon - MomentUm Orthogonalized by Newton-schulz @@ -77,8 +169,6 @@ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): # nesterov=True, # backend="newtonschulz5", # backend_steps=5, -# rank=0, -# world_size=1, # ): # defaults = dict( # lr=lr, @@ -88,124 +178,29 @@ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): # backend_steps=backend_steps, # ) # super().__init__(params, defaults) -# self.rank = rank -# self.world_size = world_size # def step(self): - +# loss = None # for group in self.param_groups: - # lr = group["lr"] # momentum = group["momentum"] # zeropower_backend = zeropower_backends[group["backend"]] -# # generate weight updates in distributed fashion -# total_params = sum(p.numel() for p in group["params"]) -# updates_flat = torch.zeros( -# total_params, device="cuda", dtype=torch.bfloat16 -# ) -# curr_idx = 0 -# for i, p in enumerate(group["params"]): -# # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs -# if i % self.world_size == self.rank: -# g = p.grad -# if g is None: -# continue -# state = self.state[p] -# if "momentum_buffer" not in state: -# state["momentum_buffer"] = torch.zeros_like(g) -# buf = state["momentum_buffer"] -# buf.mul_(momentum).add_(g) -# if group["nesterov"]: -# g = g.add(buf, alpha=momentum) -# g = zeropower_backend(g, steps=group["backend_steps"]) -# g *= ( -# max(g.size(0), g.size(1)) ** 0.5 -# ) # scale to have update.square().mean() == 1 -# updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() -# curr_idx += p.numel() - -# # sync updates across devices. we are not memory-constrained so can do this simple deserialization -# dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) - -# # deserialize and apply updates -# curr_idx = 0 # for p in group["params"]: -# g = ( -# updates_flat[curr_idx : curr_idx + p.numel()] -# .view_as(p.data) -# .type_as(p.data) -# ) +# g = p.grad +# if g is None: +# continue +# state = self.state[p] +# if "momentum_buffer" not in state: +# state["momentum_buffer"] = torch.zeros_like(g) +# buf = state["momentum_buffer"] +# buf.mul_(momentum).add_(g) +# if group["nesterov"]: +# g = g.add(buf, alpha=momentum) +# g = zeropower_backend(g, steps=group["backend_steps"]) +# g *= ( +# max(g.size(0), g.size(1)) ** 0.5 +# ) # scale to have update.square().mean() == 1 # p.data.add_(g, alpha=-lr) -# curr_idx += p.numel() - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - This optimizer assumes that all parameters passed in are 2D. - - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D - parameters; those should all be optimized by a standard method (e.g., AdamW). - - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. - - We believe it is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). - - Arguments: - lr: The learning rate used by the internal SGD. - momentum: The momentum used by the internal SGD. - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') - backend_steps: The number of iteration steps to use in the backend, if it is iterative. - """ - - def __init__( - self, - params, - lr=3e-4, - momentum=0.95, - nesterov=True, - backend="newtonschulz5", - backend_steps=5, - ): - defaults = dict( - lr=lr, - momentum=momentum, - nesterov=nesterov, - backend=backend, - backend_steps=backend_steps, - ) - super().__init__(params, defaults) - - def step(self): - loss = None - for group in self.param_groups: - lr = group["lr"] - momentum = group["momentum"] - zeropower_backend = zeropower_backends[group["backend"]] - - for p in group["params"]: - g = p.grad - if g is None: - continue - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - g = zeropower_backend(g, steps=group["backend_steps"]) - g *= ( - max(g.size(0), g.size(1)) ** 0.5 - ) # scale to have update.square().mean() == 1 - p.data.add_(g, alpha=-lr) - return loss +# return loss diff --git a/src/optim/prodigy.py b/src/optim/prodigy.py new file mode 100644 index 0000000..2f6c049 --- /dev/null +++ b/src/optim/prodigy.py @@ -0,0 +1,269 @@ +import math + +import torch +import torch.distributed as dist + + +class Prodigy(torch.optim.Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__( + self, + params, + lr=1.0, + betas=(0.9, 0.999), + beta3=None, + eps=1e-8, + weight_decay=0, + decouple=True, + use_bias_correction=False, + safeguard_warmup=False, + d0=1e-6, + d_coef=1.0, + growth_rate=float("inf"), + fsdp_in_use=False, + ): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict( + lr=lr, + betas=betas, + beta3=beta3, + eps=eps, + weight_decay=weight_decay, + d=d0, + d0=d0, + d_max=d0, + d_numerator=0.0, + d_coef=d_coef, + k=0, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use, + ) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group["use_bias_correction"] + beta1, beta2 = group["betas"] + beta3 = group["beta3"] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group["k"] + + d = group["d"] + d_max = group["d_max"] + d_coef = group["d_coef"] + lr = max(group["lr"] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group["growth_rate"] + decouple = group["decouple"] + fsdp_in_use = group["fsdp_in_use"] + + d_numerator = group["d_numerator"] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + group_lr = group["lr"] + d0 = group["d0"] + safeguard_warmup = group["safeguard_warmup"] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0" + ) + + for p in group["params"]: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if "step" not in state: + state["step"] = 0 + state["s"] = torch.zeros_like(p.data).detach() + state["p0"] = p.detach().clone() + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + s = state["s"] + p0 = state["p0"] + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += ( + (d / d0) + * dlr + * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + ) + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1 - beta2) + ) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + ###### + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group["d0"]: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group["d_numerator"] = global_d_numerator + group["d_denom"] = global_d_denom + group["d"] = d + group["d_max"] = d_max + group["d_hat"] = d_hat + + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-dlr) + + group["k"] = k + 1 + + return loss From 6007a953db4fcb4eb1fa3ac8bacdb2d424d72777 Mon Sep 17 00:00:00 2001 From: mpagli Date: Mon, 4 Nov 2024 23:02:51 +0000 Subject: [PATCH 29/58] add fineweb --- src/config/base.py | 1 + src/data/fineweb.py | 71 +++++++++++++++++++++++++++++++++++++++++++++ src/data/utils.py | 3 ++ 3 files changed, 75 insertions(+) create mode 100644 src/data/fineweb.py diff --git a/src/config/base.py b/src/config/base.py index a902761..fbfe66e 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -138,6 +138,7 @@ def parse_args(base_parser, args, namespace): "slimpajama", "slimpajama_chunk1", "redpajamav2", + "fineweb", ], ) parser.add_argument( diff --git a/src/data/fineweb.py b/src/data/fineweb.py new file mode 100644 index 0000000..a2d9b96 --- /dev/null +++ b/src/data/fineweb.py @@ -0,0 +1,71 @@ +import os + +import numpy as np +import tiktoken +from datasets import load_dataset +from tqdm import tqdm + +tknzr = tiktoken.get_encoding("gpt2") + + +def get_fineweb_data(datasets_dir, num_proc=40): + """ To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code. """ + FWEB_DATA_PATH = os.path.join(datasets_dir, "fineweb-100BT/") + if not os.path.exists(os.path.join(FWEB_DATA_PATH, "train.bin")): + os.makedirs(FWEB_DATA_PATH, exist_ok=True) + + dataset = load_dataset("HuggingFaceFW/fineweb", name="sample-100BT", split="train", + streaming=False, verification_mode="no_checks") + + split_dataset = dataset.train_test_split( + test_size=0.0001, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(FWEB_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + return { + "train": os.path.join(FWEB_DATA_PATH, "train.bin"), + "val": os.path.join(FWEB_DATA_PATH, "val.bin"), + } + + +if __name__ == "__main__": + get_fineweb_data("./datasets/") \ No newline at end of file diff --git a/src/data/utils.py b/src/data/utils.py index ad53f5f..bf30a1f 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -11,6 +11,7 @@ from .shakespeare import get_shakespeare_data from .slimpajama import get_slimpajama_data from .wikitext import get_wikitext_data +from .fineweb import get_fineweb_data def get_dataset(args) -> Dict[str, np.ndarray]: @@ -40,6 +41,8 @@ def get_dataset(args) -> Dict[str, np.ndarray]: return get_redpajamav2_data(args.datasets_dir) if args.dataset == "slimpajama": return get_slimpajama_data(args.datasets_dir) + if args.dataset == "fineweb": + return get_fineweb_data(args.datasets_dir) else: raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") From dc5169ccda038f15a2f80e58577a23d26f8b8392 Mon Sep 17 00:00:00 2001 From: Andrei Semenov <67924720+Andron00e@users.noreply.github.com> Date: Tue, 5 Nov 2024 00:10:39 +0100 Subject: [PATCH 30/58] -- description --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ee0c233..c2ce666 100755 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) # Dataset params -parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) +parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this From 7dd576b2c4e7ecbfe8ddcda56345f92d5d6744ba Mon Sep 17 00:00:00 2001 From: mpagli Date: Tue, 5 Nov 2024 09:19:28 +0000 Subject: [PATCH 31/58] log tokens processed --- src/optim/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/optim/base.py b/src/optim/base.py index 4359d9d..1be6921 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -99,6 +99,7 @@ def train( or (curr_iter in cfg.full_eval_at) ): eval_and_log( + tokens, curr_iter, epoch, model, @@ -169,6 +170,7 @@ def train( if cfg.wandb: wandb.log( { + "tokens": tokens, "iter": curr_iter, "train/loss": train_loss, "train/perplexity": 2.71828**train_loss, @@ -187,6 +189,7 @@ def train( def eval_and_log( + tokens, curr_iter, epoch, model, @@ -232,6 +235,7 @@ def eval_and_log( if cfg.wandb: if curr_iter == cfg.iterations or full_eval: logs = { + "tokens": tokens, "iter": curr_iter, "final-val/loss": val_loss, "final-val/perplexity": val_perplexity, @@ -239,6 +243,7 @@ def eval_and_log( } else: logs = { + "tokens": tokens, "iter": curr_iter, "val/loss": val_loss, "val/perplexity": val_perplexity, From 94a4e57108524b7dbc9936e58c2ab5425e3cfdfd Mon Sep 17 00:00:00 2001 From: mpagli Date: Tue, 5 Nov 2024 09:21:18 +0000 Subject: [PATCH 32/58] add fineweb-edu --- src/config/base.py | 3 +- src/data/fineweb_edu.py | 71 +++++++++++++++++++++++++++++++++++++++++ src/data/utils.py | 3 ++ 3 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 src/data/fineweb_edu.py diff --git a/src/config/base.py b/src/config/base.py index f39ddc9..8fb73e5 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -145,7 +145,8 @@ def parse_args(base_parser, args, namespace): "slimpajama", "slimpajama_chunk1", "redpajamav2", - "fineweb", + "fineweb", + "finewebedu", ], ) parser.add_argument( diff --git a/src/data/fineweb_edu.py b/src/data/fineweb_edu.py new file mode 100644 index 0000000..c7ed7d9 --- /dev/null +++ b/src/data/fineweb_edu.py @@ -0,0 +1,71 @@ +import os + +import numpy as np +import tiktoken +from datasets import load_dataset +from tqdm import tqdm + +tknzr = tiktoken.get_encoding("gpt2") + + +def get_fineweb_edu_data(datasets_dir, num_proc=40): + """ To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code. """ + FWEB_DATA_PATH = os.path.join(datasets_dir, "fineweb-edu-100BT/") + if not os.path.exists(os.path.join(FWEB_DATA_PATH, "train.bin")): + os.makedirs(FWEB_DATA_PATH, exist_ok=True) + + dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", split="train", + streaming=False, verification_mode="no_checks") + + split_dataset = dataset.train_test_split( + test_size=0.0001, seed=2357, shuffle=True + ) + split_dataset["val"] = split_dataset.pop("test") + + def process(example): + ids = tknzr.encode_ordinary( + example["text"] + ) # encode_ordinary ignores any special tokens + ids.append( + tknzr.eot_token + ) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + out = {"ids": ids, "len": len(ids)} + return out + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"]) + filename = os.path.join(FWEB_DATA_PATH, f"{split}.bin") + dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) + total_batches = min(1024, len(dset)) + + idx = 0 + for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + return { + "train": os.path.join(FWEB_DATA_PATH, "train.bin"), + "val": os.path.join(FWEB_DATA_PATH, "val.bin"), + } + + +if __name__ == "__main__": + get_fineweb_edu_data("./datasets/") \ No newline at end of file diff --git a/src/data/utils.py b/src/data/utils.py index bf30a1f..a592785 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -12,6 +12,7 @@ from .slimpajama import get_slimpajama_data from .wikitext import get_wikitext_data from .fineweb import get_fineweb_data +from .fineweb_edu import get_fineweb_edu_data def get_dataset(args) -> Dict[str, np.ndarray]: @@ -43,6 +44,8 @@ def get_dataset(args) -> Dict[str, np.ndarray]: return get_slimpajama_data(args.datasets_dir) if args.dataset == "fineweb": return get_fineweb_data(args.datasets_dir) + if args.dataset == "finewebedu": + return get_fineweb_edu_data(args.datasets_dir) else: raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") From c8e6fd87bbba486fb9b19a3b5337bd667e07d550 Mon Sep 17 00:00:00 2001 From: Andrei Semenov <67924720+Andron00e@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:10:11 +0100 Subject: [PATCH 33/58] finewebedu --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c2ce666..173fe21 100755 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) # Dataset params -parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb']) +parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this From 99d7af619d8f5b43da5c91c29edf41eab8404248 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 5 Nov 2024 19:05:52 +0300 Subject: [PATCH 34/58] sophia is in the process --- README.md | 13 +-- src/config/base.py | 2 + src/main.py | 10 +- src/optim/base.py | 15 ++- src/optim/sophia.py | 243 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 275 insertions(+), 8 deletions(-) create mode 100644 src/optim/sophia.py diff --git a/README.md b/README.md index ee0c233..bfd6515 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -77,11 +77,12 @@ parser.add_argument('--model_sharding', default=None, type=bool) # Adam-mini parser.add_argument('--adam_mini_verbose', default=False, type=bool) # print all the logs if true parser.add_argument('--log_interval', default=50, type=int) parser.add_argument('--dampening', default=0.0, type=float) -parser.add_argument("--prodigy_beta3", default=None, type=float) # coefficients for computing the Prodidy stepsize using running averages -parser.add_argument("--prodigy_decouple", default=True, type=bool) # Use AdamW style decoupled weight decay -parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) -parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. -parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) +parser.add_argument('--prodigy_beta3', default=None, type=float) # coefficients for computing the Prodidy stepsize using running averages +parser.add_argument('--prodigy_decouple', default=True, type=bool) # Use AdamW style decoupled weight decay +parser.add_argument('--prodigy_use_bias_correction', default=False, type=bool) +parser.add_argument('--prodigy_safeguard_warmup', default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. +parser.add_argument('--prodigy_fsdp_in_use', default=False, type=bool) +parser.add_argument('--sophia_rho', default=0.04, type=float) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) diff --git a/src/config/base.py b/src/config/base.py index 9b775e0..e5d6738 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -92,6 +92,7 @@ def parse_args(base_parser, args, namespace): "signum", "sgdf", "prodigy", + "sophiag", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -128,6 +129,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool) parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) + parser.add_argument("--sophia_rho", default=0.04, type=float) # Dataset params parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") diff --git a/src/main.py b/src/main.py index ffe1d37..189f0cb 100755 --- a/src/main.py +++ b/src/main.py @@ -28,7 +28,7 @@ from optim.sgdf import SGDF from optim.sign import Signum from optim.soap import SOAP - +from optim.sophia import SophiaG def get_args(): parser = argparse.ArgumentParser(allow_abbrev=False) @@ -252,6 +252,14 @@ def main(args, parser): safeguard_warmup=args.prodigy_safeguard_warmup, fsdp_in_use=args.prodigy_fsdp_in_use, ) + elif args.opt == "sophiag": + opt = SophiaG( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + rho = args.sophia_rho, + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/base.py b/src/optim/base.py index 4359d9d..bd08392 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -146,7 +146,20 @@ def train( opt.step() if cfg.scheduler != "none": scheduler.step() - opt.zero_grad(set_to_none=True) + if cfg.opt == "sophia": + opt.zero_grad(set_to_none=True) + if curr_iter % 10 != 10 - 1: + continue + else: + samp_dist = torch.distributions.Categorical(logits=outputs["logits"]) + y_sample = samp_dist.sample() + loss_sampled = torch.nn.functional.cross_entropy(outputs["logits"].view(-1, outputs["logits"].size(-1)), y_sample.view(-1), ignore_index=-1) + loss_sampled.backward() + opt.update_hessian() + opt.zero_grad(set_to_none=True) + model.zero_grad() + else: + opt.zero_grad(set_to_none=True) dt = (time.perf_counter_ns() - t_start) / 1e9 curr_iter += 1 diff --git a/src/optim/sophia.py b/src/optim/sophia.py new file mode 100644 index 0000000..69ceec3 --- /dev/null +++ b/src/optim/sophia.py @@ -0,0 +1,243 @@ +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class SophiaG(Optimizer): + def __init__( + self, + params, + lr=1e-4, + betas=(0.965, 0.99), + rho=0.04, + weight_decay=1e-1, + *, + maximize: bool = False, + capturable: bool = False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict( + lr=lr, + betas=betas, + rho=rho, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + ) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("maximize", False) + group.setdefault("capturable", False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]["step"] + ) + if not step_is_tensor: + for s in state_values: + s["step"] = torch.tensor(float(s["step"])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group["betas"] + for p in group["params"]: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state["step"] = ( + torch.zeros((1,), dtype=torch.float, device=p.device) + if self.defaults["capturable"] + else torch.tensor(0.0) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if "hessian" not in state.keys(): + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group["betas"] + + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError("Hero does not support sparse gradients") + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state["step"] = ( + torch.zeros((1,), dtype=torch.float, device=p.device) + if self.defaults["capturable"] + else torch.tensor(0.0) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if "hessian" not in state.keys(): + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + state_steps.append(state["step"]) + hessian.append(state["hessian"]) + + if self.defaults["capturable"]: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag( + params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group["rho"], + lr=group["lr"], + weight_decay=group["weight_decay"], + maximize=group["maximize"], + capturable=group["capturable"], + ) + + return loss + + +def sophiag( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool +): + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) + + func = _single_tensor_sophiag + + func( + params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + ) + + +def _single_tensor_sophiag( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool +): + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = -lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) From fd9aa7e6034d3d131bdbfd393fe5121d152a4cb6 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 6 Nov 2024 05:21:14 +0300 Subject: [PATCH 35/58] shampoo is here, needs major improvements (very memory- and time- consuming due to the torch.inverse) --- README.md | 2 +- src/config/base.py | 1 + src/main.py | 14 +- src/optim/base.py | 7 +- src/optim/sgdf.py | 5 + src/optim/shampoo.py | 647 +++++++++++++++++++++++++++++++++++++++++++ src/optim/sophia.py | 5 + 7 files changed, 677 insertions(+), 4 deletions(-) create mode 100644 src/optim/shampoo.py diff --git a/README.md b/README.md index bfd6515..67fe2d8 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT diff --git a/src/config/base.py b/src/config/base.py index e5d6738..6e74bfc 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -93,6 +93,7 @@ def parse_args(base_parser, args, namespace): "sgdf", "prodigy", "sophiag", + "shampoo", ], ) parser.add_argument("--batch_size", default=50, type=int) diff --git a/src/main.py b/src/main.py index 189f0cb..870abfe 100755 --- a/src/main.py +++ b/src/main.py @@ -26,9 +26,11 @@ wsd_schedule) from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree from optim.sgdf import SGDF +from optim.shampoo import DistributedShampoo from optim.sign import Signum from optim.soap import SOAP -from optim.sophia import SophiaG +from optim.sophia import SophiaG + def get_args(): parser = argparse.ArgumentParser(allow_abbrev=False) @@ -258,7 +260,15 @@ def main(args, parser): lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay, - rho = args.sophia_rho, + rho=args.sophia_rho, + ) + elif args.opt == "shampoo": + opt = DistributedShampoo( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + shampoo_decay=args.momentum, # decay rate for Shampoo preconditioners with the momentum constant + weight_decay=args.weight_decay, ) else: opt = torch.optim.SGD( diff --git a/src/optim/base.py b/src/optim/base.py index bd08392..2a41fb3 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -153,13 +153,18 @@ def train( else: samp_dist = torch.distributions.Categorical(logits=outputs["logits"]) y_sample = samp_dist.sample() - loss_sampled = torch.nn.functional.cross_entropy(outputs["logits"].view(-1, outputs["logits"].size(-1)), y_sample.view(-1), ignore_index=-1) + loss_sampled = torch.nn.functional.cross_entropy( + outputs["logits"].view(-1, outputs["logits"].size(-1)), + y_sample.view(-1), + ignore_index=-1, + ) loss_sampled.backward() opt.update_hessian() opt.zero_grad(set_to_none=True) model.zero_grad() else: opt.zero_grad(set_to_none=True) + # opt.zero_grad(set_to_none=True) dt = (time.perf_counter_ns() - t_start) / 1e9 curr_iter += 1 diff --git a/src/optim/sgdf.py b/src/optim/sgdf.py index ab873af..d3ee42a 100644 --- a/src/optim/sgdf.py +++ b/src/optim/sgdf.py @@ -1,3 +1,8 @@ +""" +Here is an original implementation of SGDF. +Source: https://arxiv.org/abs/2311.02818 +""" + import torch diff --git a/src/optim/shampoo.py b/src/optim/shampoo.py new file mode 100644 index 0000000..5119f47 --- /dev/null +++ b/src/optim/shampoo.py @@ -0,0 +1,647 @@ +import math +from typing import Callable, List, Optional, Tuple + +import torch +import torch.distributed as dist + + +class DistributedShampoo(torch.optim.Optimizer): + """ + Args: + params (iterable): Iterable of parameters to optimize. + lr (float, optional): Learning rate (default: 1e-3). + betas (Tuple[float, float, float], optional): Coefficients used for computing + running averages of gradient, squared gradient, and slow EMA (default: (0.9, 0.999, 0.9999)). + eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-8). + weight_decay (float, optional): Weight decay (L2 penalty) (default: 0). + shampoo_decay (float, optional): Decay rate for Shampoo preconditioners (default: 0.9). + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + shampoo_decay: float = 0.9, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if len(betas) != 2: + raise ValueError(f"Invalid betas length: {len(betas)}, expected 2.") + if not all(0.0 <= beta < 1.0 for beta in betas): + raise ValueError(f"Invalid betas: {betas}. Each beta must be in [0, 1).") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= shampoo_decay < 1.0: + raise ValueError( + f"Invalid shampoo_decay value: {shampoo_decay}. Must be in [0, 1)." + ) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + shampoo_decay=shampoo_decay, + ) + super(DistributedShampoo, self).__init__(params, defaults) + + def __setstate__(self, state): + super(DistributedShampoo, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """ + Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + Optional[float]: The loss if closure is provided, else None. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Iterate over parameter groups + for group in self.param_groups: + # print(group) + params_with_grad = [] + grads = [] + beta1, beta2 = self.defaults["betas"] + exp_avgs = [] + exp_avg_sqs = [] + preconditioners1 = [] + preconditioners2 = [] + state_steps = [] + + # Collect parameters and their states + for p in group["params"]: + if p.grad is None: + continue + if p.grad.is_sparse: + raise RuntimeError( + "DistributedShampoo does not support sparse gradients" + ) + if not p.requires_grad: + continue + + params_with_grad.append(p) + grad = p.grad + grads.append(grad) + + state = self.state[p] + if not state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Initialize Shampoo preconditioners as identity matrices or scalars + if p.dim() >= 2: + state["preconditioner1"] = torch.eye( + p.size(0), device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.eye( + p.size(1), device=p.device, dtype=p.dtype + ) + else: + state["preconditioner1"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + preconditioners1.append(state["preconditioner1"]) + preconditioners2.append(state["preconditioner2"]) + state_steps.append(state["step"]) + state["step"] += 1 + + if not params_with_grad: + continue # Skip if no parameters to update in this group + + # betas = group['betas'] + # beta1, beta2 = betas + beta1, beta2 = self.defaults["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + # eps = group['eps'] + eps = self.defaults["eps"] + # shampoo_decay = group['shampoo_decay'] + shampoo_decay = self.defaults["shampoo_decay"] + + # Update Shampoo preconditioners in a distributed manner + self._update_preconditioners_distributed( + preconditioners1, preconditioners2, grads, group, shampoo_decay, eps + ) + + # Update parameters using Shampoo preconditioning + self._update_distributed_shampoo( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + preconditioners1, + preconditioners2, + state_steps, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + ) + + return loss + + def _update_preconditioners_distributed( + self, + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + grads: List[torch.Tensor], + group: dict, + shampoo_decay: float, + eps: float, + ): + """ + Updates Shampoo preconditioners and synchronizes them across distributed workers. + + Args: + preconditioners1 (List[torch.Tensor]): List of first preconditioners for each parameter. + preconditioners2 (List[torch.Tensor]): List of second preconditioners for each parameter. + grads (List[torch.Tensor]): List of gradients for each parameter. + group (dict): Parameter group options. + shampoo_decay (float): Decay rate for Shampoo preconditioners. + eps (float): Small epsilon for numerical stability. + """ + for pc1, pc2, grad in zip(preconditioners1, preconditioners2, grads): + if grad.dim() >= 2: + A = grad @ grad.t() # [in_features, in_features] + B = grad.t() @ grad # [out_features, out_features] + else: + A = (grad**2).sum() + B = A.clone() # For 1D gradients, B is same as A + + # Update preconditioners with exponential moving average + pc1.mul_(shampoo_decay).add_(A, alpha=1 - shampoo_decay) + pc2.mul_(shampoo_decay).add_(B, alpha=1 - shampoo_decay) + + # Synchronize preconditioners across workers + if dist.is_initialized(): + dist.all_reduce(pc1, op=dist.ReduceOp.SUM) + dist.all_reduce(pc2, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + pc1.div_(world_size) + pc2.div_(world_size) + + def _update_distributed_shampoo( + self, + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + steps: List[int], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + ): + """ + Performs an update with Shampoo preconditioning. + + Args: + params (List[torch.Tensor]): List of parameters to update. + grads (List[torch.Tensor]): List of gradients for each parameter. + exp_avgs (List[torch.Tensor]): List of first moment estimates. + exp_avg_sqs (List[torch.Tensor]): List of second moment estimates. + preconditioners1 (List[torch.Tensor]): List of first preconditioners. + preconditioners2 (List[torch.Tensor]): List of second preconditioners. + steps (List[int]): List of step counts for each parameter. + beta1 (float): Coefficient for first moment. + beta2 (float): Coefficient for second moment. + lr (float): Learning rate. + weight_decay (float): Weight decay coefficient. + eps (float): Small epsilon for numerical stability. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + pc1 = preconditioners1[i] + pc2 = preconditioners2[i] + step = steps[i] + + # Bias corrections + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # Update biased first moment estimate + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # Update biased second raw moment estimate + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Compute bias-corrected second moment + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + # Compute step size NOTE: think about another step_size (mb like in the AdEMAMix or SOAP) + step_size = lr / (bias_correction1 if bias_correction1 > 0 else 0.01) + + # Apply weight decay + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + # Compute Shampoo preconditioned gradient + if grad.dim() >= 2: + # Safe inversion with added epsilon to diagonal + if not torch.isfinite(pc1).all() or not torch.isfinite(pc2).all(): + raise RuntimeError("Matrix contains NaN or Inf values.") + inv_pc1 = torch.inverse( + pc1 + + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps + ).sqrt() + inv_pc2 = torch.inverse( + pc2 + + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps + ).sqrt() + # inv_pc1 = torch.linalg.pinv(pc1 + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps).sqrt() + # inv_pc2 = torch.linalg.pinv(pc2 + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps).sqrt() + + # Precondition the gradient + preconditioned_grad = inv_pc1 @ grad @ inv_pc2 + else: + # For 1D gradients, use scalar preconditioning + preconditioned_grad = grad / (pc1.sqrt() + eps) + + combined_grad = (exp_avg + preconditioned_grad) / 2 # Weighted average + + # Update parameters + param.addcdiv_(combined_grad, denom, value=-step_size) + + # Optional: Gradient Clipping (Uncomment if needed) + # torch.nn.utils.clip_grad_norm_(param, max_norm=1.0) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(lr={self.defaults['lr']}, " + f"betas={self.defaults['betas']}, eps={self.defaults['eps']}, " + f"weight_decay={self.defaults['weight_decay']}" + ) + + +class AdEMAMixDistributedShampoo(torch.optim.Optimizer): + """ + AdEMAMix optimizer with Distributed Shampoo preconditioning. + + Combines the AdEMAMix optimizer with Shampoo’s second-order preconditioning. + Supports distributed training via torch.distributed. + + Args: + params (iterable): Iterable of parameters to optimize. + lr (float, optional): Learning rate (default: 1e-3). + betas (Tuple[float, float, float], optional): Coefficients used for computing + running averages of gradient, squared gradient, and slow EMA (default: (0.9, 0.999, 0.9999)). + eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-8). + weight_decay (float, optional): Weight decay (L2 penalty) (default: 0). + alpha (float, optional): Alpha parameter for AdEMAMix (default: 5.0). + T_alpha_beta3 (Optional[int], optional): Time constant for alpha and beta3 scheduling (default: None). + shampoo_decay (float, optional): Decay rate for Shampoo preconditioners (default: 0.9). + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999), + eps: float = 1e-8, + weight_decay: float = 0, + alpha: float = 5.0, + T_alpha_beta3: Optional[int] = None, + shampoo_decay: float = 0.9, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if len(betas) != 3: + raise ValueError(f"Invalid betas length: {len(betas)}, expected 3.") + if not all(0.0 <= beta < 1.0 for beta in betas): + raise ValueError(f"Invalid betas: {betas}. Each beta must be in [0, 1).") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 <= shampoo_decay < 1.0: + raise ValueError( + f"Invalid shampoo_decay value: {shampoo_decay}. Must be in [0, 1)." + ) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + alpha=alpha, + T_alpha_beta3=T_alpha_beta3, + shampoo_decay=shampoo_decay, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """ + Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + Optional[float]: The loss if closure is provided, else None. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Iterate over parameter groups + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_slows = [] + preconditioners1 = [] + preconditioners2 = [] + state_steps = [] + + # Collect parameters and their states + for p in group["params"]: + if p.grad is None: + continue + if p.grad.is_sparse: + raise RuntimeError( + "AdEMAMixDistributedShampoo does not support sparse gradients" + ) + if not p.requires_grad: + continue + + params_with_grad.append(p) + grad = p.grad + grads.append(grad) + + state = self.state[p] + if not state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_slow"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Initialize Shampoo preconditioners as identity matrices or scalars + if p.dim() >= 2: + state["preconditioner1"] = torch.eye( + p.size(0), device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.eye( + p.size(1), device=p.device, dtype=p.dtype + ) + else: + state["preconditioner1"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + state["preconditioner2"] = torch.tensor( + 1.0, device=p.device, dtype=p.dtype + ) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + exp_avg_slows.append(state["exp_avg_slow"]) + preconditioners1.append(state["preconditioner1"]) + preconditioners2.append(state["preconditioner2"]) + state_steps.append(state["step"]) + state["step"] += 1 + + if not params_with_grad: + continue # Skip if no parameters to update in this group + + betas = group["betas"] + beta1, beta2, beta3 = betas + alpha = group["alpha"] + T_alpha_beta3 = group["T_alpha_beta3"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + shampoo_decay = group["shampoo_decay"] + + # Update Shampoo preconditioners in a distributed manner + self._update_preconditioners_distributed( + preconditioners1, preconditioners2, grads, group, shampoo_decay, eps + ) + + # Update parameters using AdEMAMix with Shampoo preconditioning + self._update_adamemix_distributed_shampoo( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + exp_avg_slows, + preconditioners1, + preconditioners2, + state_steps, + beta1=beta1, + beta2=beta2, + beta3=beta3, + alpha=alpha, + T_alpha_beta3=T_alpha_beta3, + lr=lr, + weight_decay=weight_decay, + eps=eps, + ) + + return loss + + def _update_preconditioners_distributed( + self, + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + grads: List[torch.Tensor], + group: dict, + shampoo_decay: float, + eps: float, + ): + """ + Updates Shampoo preconditioners and synchronizes them across distributed workers. + + Args: + preconditioners1 (List[torch.Tensor]): List of first preconditioners for each parameter. + preconditioners2 (List[torch.Tensor]): List of second preconditioners for each parameter. + grads (List[torch.Tensor]): List of gradients for each parameter. + group (dict): Parameter group options. + shampoo_decay (float): Decay rate for Shampoo preconditioners. + eps (float): Small epsilon for numerical stability. + """ + for pc1, pc2, grad in zip(preconditioners1, preconditioners2, grads): + if grad.dim() >= 2: + A = grad @ grad.t() # [in_features, in_features] + B = grad.t() @ grad # [out_features, out_features] + else: + A = (grad**2).sum() + B = A.clone() # For 1D gradients, B is same as A + + # Update preconditioners with exponential moving average + pc1.mul_(shampoo_decay).add_(A, alpha=1 - shampoo_decay) + pc2.mul_(shampoo_decay).add_(B, alpha=1 - shampoo_decay) + + # Synchronize preconditioners across workers + if dist.is_initialized(): + dist.all_reduce(pc1, op=dist.ReduceOp.SUM) + dist.all_reduce(pc2, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + pc1.div_(world_size) + pc2.div_(world_size) + + def _update_adamemix_distributed_shampoo( + self, + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + exp_avg_slows: List[torch.Tensor], + preconditioners1: List[torch.Tensor], + preconditioners2: List[torch.Tensor], + steps: List[int], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + T_alpha_beta3: Optional[int], + lr: float, + weight_decay: float, + eps: float, + ): + """ + Performs the AdEMAMix update with Shampoo preconditioning. + + Args: + params (List[torch.Tensor]): List of parameters to update. + grads (List[torch.Tensor]): List of gradients for each parameter. + exp_avgs (List[torch.Tensor]): List of first moment estimates. + exp_avg_sqs (List[torch.Tensor]): List of second moment estimates. + exp_avg_slows (List[torch.Tensor]): List of slow EMA estimates. + preconditioners1 (List[torch.Tensor]): List of first preconditioners. + preconditioners2 (List[torch.Tensor]): List of second preconditioners. + steps (List[int]): List of step counts for each parameter. + beta1 (float): Coefficient for first moment. + beta2 (float): Coefficient for second moment. + beta3 (float): Coefficient for slow EMA. + alpha (float): Alpha parameter for AdEMAMix. + T_alpha_beta3 (Optional[int]): Time constant for scheduling. + lr (float): Learning rate. + weight_decay (float): Weight decay coefficient. + eps (float): Small epsilon for numerical stability. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_slow = exp_avg_slows[i] + pc1 = preconditioners1[i] + pc2 = preconditioners2[i] + step = steps[i] + + # Bias corrections + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + # Schedule alpha_t and beta3_t + if T_alpha_beta3 is not None and T_alpha_beta3 > 0: + alpha_t = min(step * alpha / T_alpha_beta3, alpha) + # Avoid division by zero + if T_alpha_beta3 != step: + log_beta1 = math.log(beta1) + log_beta3 = math.log(beta3) + denominator = (1 - step / T_alpha_beta3) * log_beta3 + ( + step / T_alpha_beta3 + ) * log_beta1 + if denominator != 0: + beta3_t = min( + math.exp((log_beta1 * log_beta3) / denominator), beta3 + ) + else: + beta3_t = beta3 + else: + beta3_t = beta3 + else: + alpha_t = alpha + beta3_t = beta3 + + # Update biased first moment estimate + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # Update biased second raw moment estimate + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + # Update slow EMA + exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1 - beta3_t) + + # Compute bias-corrected second moment + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + # Compute step size + step_size = lr / (bias_correction1 if bias_correction1 > 0 else 0.01) + + # Apply weight decay + if weight_decay != 0: + param.mul_(1 - lr * weight_decay) + + # Compute Shampoo preconditioned gradient + if grad.dim() >= 2: + # Safe inversion with added epsilon to diagonal + inv_pc1 = torch.inverse( + pc1 + + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps + ).sqrt() + inv_pc2 = torch.inverse( + pc2 + + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps + ).sqrt() + + # Precondition the gradient + preconditioned_grad = inv_pc1 @ grad @ inv_pc2 + else: + # For 1D gradients, use scalar preconditioning + preconditioned_grad = grad / (pc1.sqrt() + eps) + + # Combine AdEMAMix update with Shampoo preconditioning + combined_grad = ( + exp_avg + alpha_t * exp_avg_slow + preconditioned_grad + ) / 3 # Weighted average + + # Update parameters + param.addcdiv_(combined_grad, denom, value=-step_size) + + # Optional: Gradient Clipping (Uncomment if needed) + # torch.nn.utils.clip_grad_norm_(param, max_norm=1.0) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(lr={self.defaults['lr']}, " + f"betas={self.defaults['betas']}, eps={self.defaults['eps']}, " + f"weight_decay={self.defaults['weight_decay']}, alpha={self.defaults['alpha']}, " + f"T_alpha_beta3={self.defaults['T_alpha_beta3']})" + ) diff --git a/src/optim/sophia.py b/src/optim/sophia.py index 69ceec3..142ed14 100644 --- a/src/optim/sophia.py +++ b/src/optim/sophia.py @@ -1,3 +1,8 @@ +""" +Here is an original implementation of SophiaG. +Source: https://github.com/Liuhong99/Sophia +""" + from typing import List import torch From 479b56be0d20b3d73d17c355fd80efb3fba6c79f Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 6 Nov 2024 05:33:25 +0300 Subject: [PATCH 36/58] shampoo is here, problems with memory due to torch.inverse, needs improvements --- src/optim/shampoo.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/optim/shampoo.py b/src/optim/shampoo.py index 5119f47..5a7e82a 100644 --- a/src/optim/shampoo.py +++ b/src/optim/shampoo.py @@ -268,11 +268,19 @@ def _update_distributed_shampoo( # Safe inversion with added epsilon to diagonal if not torch.isfinite(pc1).all() or not torch.isfinite(pc2).all(): raise RuntimeError("Matrix contains NaN or Inf values.") - inv_pc1 = torch.inverse( + # inv_pc1 = torch.inverse( + # pc1 + # + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps + # ).sqrt() + # inv_pc2 = torch.inverse( + # pc2 + # + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps + # ).sqrt() + inv_pc1 = torch.linalg.inv( pc1 + torch.eye(pc1.size(0), device=pc1.device, dtype=pc1.dtype) * eps ).sqrt() - inv_pc2 = torch.inverse( + inv_pc2 = torch.linalg.inv( pc2 + torch.eye(pc2.size(1), device=pc2.device, dtype=pc2.dtype) * eps ).sqrt() From c81173982e6d4d9fdc826338fbe749204d730034 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 7 Nov 2024 01:15:49 +0300 Subject: [PATCH 37/58] muon is here, still fix shampoo --- README.md | 1 + src/config/base.py | 1 + src/main.py | 34 +++++-- src/optim/muon.py | 187 ++++++++++++++++++++++++-------------- src/optim/schedulefree.py | 104 +++++++++++---------- 5 files changed, 204 insertions(+), 123 deletions(-) diff --git a/README.md b/README.md index 67fe2d8..cb6c020 100755 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ parser.add_argument('--correct_bias', default=True, type=bool) parser.add_argument('--nesterov', default=False, type=bool) # whether to use Nesterov-style momentum parser.add_argument('--muon_backend', default='newtonschulz5', type=str) # the chosen backend for the orthogonalization step parser.add_argument('--muon_backend_steps', default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative +parser.add_argument('--muon_lr_factor', default=0.1, type=float) # a factor by which to reduce the lr for muon parser.add_argmunet('--adema_beta3', default=0.9, type=float) # beta3 in AdEMAMix parser.add_argument('--adema_alpha', default=2.0, type=float) # alpha in AdEMAMix parser.add_argument('--adema_beta3_warmup', default=None, type=int) # AdEMAMix hyperparameter diff --git a/src/config/base.py b/src/config/base.py index 6e74bfc..0d4324f 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -116,6 +116,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--nesterov", default=False, type=bool) parser.add_argument("--muon_backend", default="newtonschulz5", type=str) parser.add_argument("--muon_backend_steps", default=5, type=int) + parser.add_argument("--muon_lr_factor", default=0.1, type=float) parser.add_argument("--adema_beta3", default=0.9, type=float) parser.add_argument("--adema_alpha", default=2.0, type=float) parser.add_argument("--adema_beta3_warmup", default=None, type=int) diff --git a/src/main.py b/src/main.py index 870abfe..88ede61 100755 --- a/src/main.py +++ b/src/main.py @@ -19,8 +19,10 @@ from optim.ademamix import AdEMAMix from optim.ademamix2 import AdEMAMix2 from optim.base import train +# from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo +# from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion -from optim.muon import Muon, zeropower_backends +from optim.muon import CombinedOptimizer, Muon from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule) @@ -146,13 +148,24 @@ def main(args, parser): correct_bias=args.correct_bias, ) elif args.opt == "muon": - opt = Muon( + opt = CombinedOptimizer( group_specs, - lr=args.lr, - momentum=args.momentum, - nesterov=args.nesterov, # use True for Muon as a default - backend=args.muon_backend, - backend_steps=args.muon_backend_steps, + [Muon, torch.optim.AdamW], + [ + { + "lr": args.muon_lr_factor * args.lr, + "momentum": args.momentum, + "nesterov": args.nesterov, + "backend": args.muon_backend, + "backend_steps": args.muon_backend_steps, + }, + { + "lr": args.lr, + "betas": (args.beta1, args.beta2), + "weight_decay": args.weight_decay, + "fused": True, + }, + ], ) elif args.opt == "ademamix": opt = AdEMAMix( @@ -267,8 +280,13 @@ def main(args, parser): group_specs, lr=args.lr, betas=(args.beta1, args.beta2), - shampoo_decay=args.momentum, # decay rate for Shampoo preconditioners with the momentum constant + precondition_frequency=args.precondition_frequency, weight_decay=args.weight_decay, + use_decoupled_weight_decay=True, + # grafting_config=AdamGraftingConfig( + # beta2=args.beta2, # oroginally, the default value is 0.999 + # epsilon=1e-8, + # ), ) else: opt = torch.optim.SGD( diff --git a/src/optim/muon.py b/src/optim/muon.py index cd4ef6b..e8729ab 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -135,72 +135,121 @@ def step(self): curr_idx += p.numel() -# class Muon(torch.optim.Optimizer): -# """ -# Muon - MomentUm Orthogonalized by Newton-schulz - -# Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- -# processing step, in which each 2D parameter's update is replaced with the nearest orthogonal -# matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has -# the advantage that it can be stably run in bfloat16 on the GPU. - -# Some warnings: -# - This optimizer assumes that all parameters passed in are 2D. -# - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D -# parameters; those should all be optimized by a standard method (e.g., AdamW). -# - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. -# - We believe it is unlikely to work well for training with small batch size. -# - We believe it may not work well for finetuning pretrained models, but we haven't tested this. -# - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). - -# Arguments: -# lr: The learning rate used by the internal SGD. -# momentum: The momentum used by the internal SGD. -# nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) -# backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') -# backend_steps: The number of iteration steps to use in the backend, if it is iterative. -# """ - -# def __init__( -# self, -# params, -# lr=3e-4, -# momentum=0.95, -# nesterov=True, -# backend="newtonschulz5", -# backend_steps=5, -# ): -# defaults = dict( -# lr=lr, -# momentum=momentum, -# nesterov=nesterov, -# backend=backend, -# backend_steps=backend_steps, -# ) -# super().__init__(params, defaults) - -# def step(self): -# loss = None -# for group in self.param_groups: -# lr = group["lr"] -# momentum = group["momentum"] -# zeropower_backend = zeropower_backends[group["backend"]] - -# for p in group["params"]: -# g = p.grad -# if g is None: -# continue -# state = self.state[p] -# if "momentum_buffer" not in state: -# state["momentum_buffer"] = torch.zeros_like(g) -# buf = state["momentum_buffer"] -# buf.mul_(momentum).add_(g) -# if group["nesterov"]: -# g = g.add(buf, alpha=momentum) -# g = zeropower_backend(g, steps=group["backend_steps"]) -# g *= ( -# max(g.size(0), g.size(1)) ** 0.5 -# ) # scale to have update.square().mean() == 1 -# p.data.add_(g, alpha=-lr) - -# return loss +def separate_params(param_groups): + param_groups_2d = [] + param_groups_non2d = [] + total_param_2d_count = 0 + total_param_non2d_count = 0 + + # Check if param_groups is a list of dicts or list of params + if ( + isinstance(param_groups, list) and isinstance(param_groups[0], dict) + ) or isinstance(param_groups, dict): + if isinstance(param_groups, dict): + param_groups = [param_groups] + # param_groups is a list of dicts + for group in param_groups: + params_2d, params_non2d, param_2d_count, param_non2d_count = ( + separate_params(group["params"]) + ) + param_group_2d = {"params": params_2d} + param_group_non2d = {"params": params_non2d} + # Copy the group dict and replace the 'params' key with the separated params + for k in group.keys(): + if k != "params": + param_group_2d[k] = group[k] + param_group_non2d[k] = group[k] + + param_groups_2d.append(param_group_2d) + param_groups_non2d.append(param_group_non2d) + total_param_2d_count += param_2d_count + total_param_non2d_count += param_non2d_count + + return ( + param_groups_2d, + param_groups_non2d, + total_param_2d_count, + total_param_non2d_count, + ) + + elif isinstance(param_groups, list) and isinstance(param_groups[0], torch.Tensor): + params_2d = [] + params_non2d = [] + param_group = param_groups + # param_group is a list of param tensors + for param in param_group: + if param.ndim == 2: + params_2d.append(param) + else: + params_non2d.append(param) + return params_2d, params_non2d, len(params_2d), len(params_non2d) + else: + breakpoint() + + +class CombinedOptimizer(torch.optim.Optimizer): + """ +# Note that CombinedOptimizer is not a torch.optim.Optimizer, but a wrapper around multiple optimizers. +# Original Example: + optimizer = CombinedOptimizer([ + torch.optim.AdamW(self.lm_head.parameters(), lr=learning_rate, betas=betas, weight_decay=0, fused=True), + OrthogonalNesterov(self.transformer.h.parameters(), lr=0.1*learning_rate, momentum=0.95) + ]) +# Refactored Example: + optimizer = CombinedOptimizer(\ + self.parameters(), + [OrthogonalNesterov, torch.optim.AdamW], + [{'lr': 0.1*learning_rate, 'momentum': 0.95}, + {'lr': learning_rate, 'betas': betas, 'weight_decay': 0, 'fused': True} + ]) +""" + + def __init__(self, params, optimizer_types, configs): + # Separate 2D and non-2D parameters. + # param_groups_2d_non2d: (param_groups_2d, param_groups_non2d). + # If params is a list of tensors, then each of param_groups_2d and param_groups_non2d + # will be a list of tensors. + # If params is a list of dicts, then each of param_groups_2d and param_groups_non2d + # will be a list of dicts. + # If params is a dict, then each of param_groups_2d and param_groups_non2d will + # be a list of dicts containing only one dict. + ( + param_groups_2d, + param_groups_non2d, + total_param_2d_count, + total_param_non2d_count, + ) = separate_params(params) + param_groups_2d_non2d = (param_groups_2d, param_groups_non2d) + print( + f"Total 2D params: {total_param_2d_count}, Total non-2D params: {total_param_non2d_count}" + ) + + assert ( + len(optimizer_types) == len(configs) == 2 + ), "You must use only two optimizers" + assert optimizer_types[0] == Muon, "The first optimizer must be Muon" + self.optimizers = [ + optimizer_types[i](param_groups_2d_non2d[i], **configs[i]) for i in range(2) + ] + self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups] + self.base_lrs = [opt.param_groups[0]["lr"] for opt in self.optimizers] + # Combine the state dicts of all opt in self.optimizers into a single dict + self.state = {k: v for opt in self.optimizers for k, v in opt.state.items()} + # Initially all states are empty. So no point to print their counts. + # Only use the defaults of the OrthogonalNesterov optimizer + self.defaults = self.optimizers[0].defaults + + def step(self, *args, **kwargs): + for opt in self.optimizers: + opt.step(*args, **kwargs) + + def zero_grad(self, **kwargs): + for opt in self.optimizers: + opt.zero_grad(**kwargs) + + def scale_lrs(self, lr_scale): + for base_lr, opt in zip(self.base_lrs, self.optimizers): + opt.param_groups[0]["lr"] = base_lr * lr_scale + + def state_dict(self): + return [opt.state_dict() for opt in self.optimizers] diff --git a/src/optim/schedulefree.py b/src/optim/schedulefree.py index 266f56a..bbf18fa 100644 --- a/src/optim/schedulefree.py +++ b/src/optim/schedulefree.py @@ -4,7 +4,7 @@ """ import math -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch from typing_extensions import TypeAlias @@ -70,15 +70,17 @@ def __init__( r=r, k=0, warmup_steps=warmup_steps, - train_mode=True, + train_mode=False, weight_sum=0.0, lr_max=-1.0, + scheduled_lr=0.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay, foreach=foreach, ) super().__init__(params, defaults) + @torch.no_grad() def eval(self): for group in self.param_groups: train_mode = group["train_mode"] @@ -87,12 +89,11 @@ def eval(self): for p in group["params"]: state = self.state[p] if "z" in state: - # Set p.data to x - p.data.lerp_( - end=state["z"].to(p.data.device), weight=1 - 1 / beta1 - ) + # Set p to x + p.lerp_(end=state["z"].to(p.device), weight=1 - 1 / beta1) group["train_mode"] = False + @torch.no_grad() def train(self): for group in self.param_groups: train_mode = group["train_mode"] @@ -101,21 +102,29 @@ def train(self): for p in group["params"]: state = self.state[p] if "z" in state: - # Set p.data to y - p.data.lerp_(end=state["z"].to(p.data.device), weight=1 - beta1) + # Set p to y + p.lerp_(end=state["z"].to(p.device), weight=1 - beta1) group["train_mode"] = True - def step(self, closure=None): + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ + if not self.param_groups[0]["train_mode"]: + raise Exception( + "Optimizer was not in train mode when step is called. " + "Please insert .train() and .eval() calls on the " + "optimizer. See documentation for details." + ) loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: eps = group["eps"] @@ -132,7 +141,8 @@ def step(self, closure=None): sched = 1.0 bias_correction2 = 1 - beta2 ** (k + 1) - lr = group["lr"] * sched * math.sqrt(bias_correction2) + lr = group["lr"] * sched + group["scheduled_lr"] = lr # For logging purposes lr_max = group["lr_max"] = max(lr, group["lr_max"]) @@ -144,25 +154,21 @@ def step(self, closure=None): except ZeroDivisionError: ckp1 = 0 - if not group["train_mode"]: - raise Exception("Not in train mode!") - active_p = [p for p in group["params"] if p.grad is not None] for p in active_p: if "z" not in self.state[p]: - self.state[p]["z"] = torch.clone(p.data) - self.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) + self.state[p]["z"] = torch.clone( + p, memory_format=torch.preserve_format + ) + self.state[p]["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) if group["foreach"] and len(active_p) > 0: y, grad, exp_avg_sq, z = zip( *[ - ( - p.data, - p.grad, - self.state[p]["exp_avg_sq"], - self.state[p]["z"], - ) + (p, p.grad, self.state[p]["exp_avg_sq"], self.state[p]["z"]) for p in active_p ] ) @@ -170,7 +176,8 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient torch._foreach_mul_(exp_avg_sq, beta2) torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2) - denom = torch._foreach_sqrt(exp_avg_sq) + denom = torch._foreach_div(exp_avg_sq, bias_correction2) + torch._foreach_sqrt_(denom) torch._foreach_add_(denom, eps) # Normalize grad in-place for memory efficiency @@ -189,8 +196,8 @@ def step(self, closure=None): torch._foreach_sub_(z, grad, alpha=lr) else: for p in active_p: - y = p.data # Notation to match theory - grad = p.grad.data + y = p # Notation to match theory + grad = p.grad state = self.state[p] @@ -198,7 +205,7 @@ def step(self, closure=None): exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = exp_avg_sq.sqrt().add_(eps) + denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps) # Reuse grad buffer for memory efficiency grad_normalized = grad.div_(denom) @@ -277,15 +284,17 @@ def __init__( r=r, k=0, warmup_steps=warmup_steps, - train_mode=True, + train_mode=False, weight_sum=0.0, lr_max=-1.0, + scheduled_lr=0.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay, foreach=foreach, ) super().__init__(params, defaults) + @torch.no_grad() def eval(self): for group in self.param_groups: train_mode = group["train_mode"] @@ -294,12 +303,11 @@ def eval(self): for p in group["params"]: state = self.state[p] if "z" in state: - # Set p.data to x - p.data.lerp_( - end=state["z"].to(p.data.device), weight=1 - 1 / momentum - ) + # Set p to x + p.lerp_(end=state["z"].to(p.device), weight=1 - 1 / momentum) group["train_mode"] = False + @torch.no_grad() def train(self): for group in self.param_groups: train_mode = group["train_mode"] @@ -308,23 +316,29 @@ def train(self): for p in group["params"]: state = self.state[p] if "z" in state: - # Set p.data to y - p.data.lerp_( - end=state["z"].to(p.data.device), weight=1 - momentum - ) + # Set p to y + p.lerp_(end=state["z"].to(p.device), weight=1 - momentum) group["train_mode"] = True - def step(self, closure=None): + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ + if not self.param_groups[0]["train_mode"]: + raise Exception( + "Optimizer was not in train mode when step is called. " + "Please insert .train() and .eval() calls on the " + "optimizer. See documentation for details." + ) loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: momentum = group["momentum"] @@ -338,6 +352,7 @@ def step(self, closure=None): else: sched = 1.0 lr = group["lr"] * sched + group["scheduled_lr"] = lr # For logging purposes weight_lr_power = group["weight_lr_power"] @@ -352,19 +367,16 @@ def step(self, closure=None): except ZeroDivisionError: ckp1 = 0 - if not group["train_mode"]: - raise Exception("Not in train mode!") - active_p = [p for p in group["params"] if p.grad is not None] for p in active_p: if "z" not in self.state[p]: - self.state[p]["z"] = torch.clone(p.data) + self.state[p]["z"] = torch.clone( + p, memory_format=torch.preserve_format + ) if group["foreach"] and len(active_p) > 0: - y, grad, z = zip( - *[(p.data, p.grad, self.state[p]["z"]) for p in active_p] - ) + y, grad, z = zip(*[(p, p.grad, self.state[p]["z"]) for p in active_p]) # Apply weight decay if weight_decay != 0: @@ -379,8 +391,8 @@ def step(self, closure=None): torch._foreach_sub_(z, grad, alpha=lr) else: for p in active_p: - y = p.data # Notation to match theory - grad = p.grad.data + y = p # Notation to match theory + grad = p.grad z = self.state[p]["z"] # Apply weight decay From 361213ddc67938d16fff13fbd65e34f145560485 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 7 Nov 2024 01:38:35 +0300 Subject: [PATCH 38/58] --fix sophiag --- src/optim/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/optim/base.py b/src/optim/base.py index 2a41fb3..dbac184 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -146,11 +146,9 @@ def train( opt.step() if cfg.scheduler != "none": scheduler.step() - if cfg.opt == "sophia": + if cfg.opt == "sophiag": opt.zero_grad(set_to_none=True) - if curr_iter % 10 != 10 - 1: - continue - else: + if curr_iter % 10 == 10 - 1: samp_dist = torch.distributions.Categorical(logits=outputs["logits"]) y_sample = samp_dist.sample() loss_sampled = torch.nn.functional.cross_entropy( From 2b9d0f3c002b49f698d1474f22b6e10a86d8a3d4 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 7 Nov 2024 02:19:19 +0300 Subject: [PATCH 39/58] sophiag fixed, test two adamw runs using muon and soap branches, if they are the same, then merge into soap --- src/optim/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/optim/base.py b/src/optim/base.py index dbac184..e3c9325 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -143,20 +143,21 @@ def train( if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": opt.train() - opt.step() + opt.step() if cfg.opt != "sophiag" else opt.step(bs=tokens) if cfg.scheduler != "none": scheduler.step() if cfg.opt == "sophiag": opt.zero_grad(set_to_none=True) if curr_iter % 10 == 10 - 1: - samp_dist = torch.distributions.Categorical(logits=outputs["logits"]) + sample_again = model(x, targets=y, get_logits=True) + samp_dist = torch.distributions.Categorical(logits=sample_again["logits"]) y_sample = samp_dist.sample() loss_sampled = torch.nn.functional.cross_entropy( - outputs["logits"].view(-1, outputs["logits"].size(-1)), + sample_again["logits"].view(-1, sample_again["logits"].size(-1)), y_sample.view(-1), ignore_index=-1, ) - loss_sampled.backward() + (loss_sampled / cfg.acc_steps).backward() opt.update_hessian() opt.zero_grad(set_to_none=True) model.zero_grad() From 19ff7b111fef2092f1c5968c238078963a40c14b Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 7 Nov 2024 20:32:39 +0300 Subject: [PATCH 40/58] --adopt todo, --fix sophia --- README.md | 2 +- src/config/base.py | 5 +++-- src/data/fineweb.py | 13 +++++++++---- src/data/fineweb_edu.py | 13 +++++++++---- src/data/utils.py | 4 ++-- src/distributed/ddp.py | 3 +-- src/optim/base.py | 16 ++++++++++++---- 7 files changed, 37 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 6295d0b..4767fbc 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT diff --git a/src/config/base.py b/src/config/base.py index 320d804..e27fc29 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -94,6 +94,7 @@ def parse_args(base_parser, args, namespace): "prodigy", "sophiag", "shampoo", + "adopt", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -149,8 +150,8 @@ def parse_args(base_parser, args, namespace): "slimpajama", "slimpajama_chunk1", "redpajamav2", - "fineweb", - "finewebedu", + "fineweb", + "finewebedu", ], ) parser.add_argument( diff --git a/src/data/fineweb.py b/src/data/fineweb.py index a2d9b96..3588295 100644 --- a/src/data/fineweb.py +++ b/src/data/fineweb.py @@ -9,13 +9,18 @@ def get_fineweb_data(datasets_dir, num_proc=40): - """ To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code. """ + """To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code.""" FWEB_DATA_PATH = os.path.join(datasets_dir, "fineweb-100BT/") if not os.path.exists(os.path.join(FWEB_DATA_PATH, "train.bin")): os.makedirs(FWEB_DATA_PATH, exist_ok=True) - dataset = load_dataset("HuggingFaceFW/fineweb", name="sample-100BT", split="train", - streaming=False, verification_mode="no_checks") + dataset = load_dataset( + "HuggingFaceFW/fineweb", + name="sample-100BT", + split="train", + streaming=False, + verification_mode="no_checks", + ) split_dataset = dataset.train_test_split( test_size=0.0001, seed=2357, shuffle=True @@ -68,4 +73,4 @@ def process(example): if __name__ == "__main__": - get_fineweb_data("./datasets/") \ No newline at end of file + get_fineweb_data("./datasets/") diff --git a/src/data/fineweb_edu.py b/src/data/fineweb_edu.py index c7ed7d9..a0ce49e 100644 --- a/src/data/fineweb_edu.py +++ b/src/data/fineweb_edu.py @@ -9,13 +9,18 @@ def get_fineweb_edu_data(datasets_dir, num_proc=40): - """ To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code. """ + """To change the cache dir, run `export HF_HOME=/path/to/cache/` before running the code.""" FWEB_DATA_PATH = os.path.join(datasets_dir, "fineweb-edu-100BT/") if not os.path.exists(os.path.join(FWEB_DATA_PATH, "train.bin")): os.makedirs(FWEB_DATA_PATH, exist_ok=True) - dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", split="train", - streaming=False, verification_mode="no_checks") + dataset = load_dataset( + "HuggingFaceFW/fineweb-edu", + name="sample-100BT", + split="train", + streaming=False, + verification_mode="no_checks", + ) split_dataset = dataset.train_test_split( test_size=0.0001, seed=2357, shuffle=True @@ -68,4 +73,4 @@ def process(example): if __name__ == "__main__": - get_fineweb_edu_data("./datasets/") \ No newline at end of file + get_fineweb_edu_data("./datasets/") diff --git a/src/data/utils.py b/src/data/utils.py index a592785..5f3ee1d 100755 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -6,13 +6,13 @@ import torch.distributed as dist from .arxiv import get_arxiv_2000, get_arxiv_full +from .fineweb import get_fineweb_data +from .fineweb_edu import get_fineweb_edu_data from .openwebtext2 import get_openwebtext2_data from .redpajama import get_redpajama_data, get_redpajamav2_data from .shakespeare import get_shakespeare_data from .slimpajama import get_slimpajama_data from .wikitext import get_wikitext_data -from .fineweb import get_fineweb_data -from .fineweb_edu import get_fineweb_edu_data def get_dataset(args) -> Dict[str, np.ndarray]: diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index 9226ff1..f805cc3 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -2,8 +2,7 @@ import os from contextlib import contextmanager -from torch.distributed import (destroy_process_group, get_world_size, - init_process_group) +from torch.distributed import destroy_process_group, get_world_size, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP from .backend import DistributedBackend diff --git a/src/optim/base.py b/src/optim/base.py index 04cf7af..69eb063 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -9,8 +9,14 @@ import wandb # from logger.logger import DynamicsLogger -from .utils import (eval, get_batch, load_checkpoint, load_worker_state, - save_checkpoint, save_worker_state) +from .utils import ( + eval, + get_batch, + load_checkpoint, + load_worker_state, + save_checkpoint, + save_worker_state, +) def train( @@ -144,14 +150,16 @@ def train( if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": opt.train() - opt.step() if cfg.opt != "sophiag" else opt.step(bs=tokens) + opt.step() if cfg.opt != "sophiag" else opt.step(bs=480 * cfg.sequence_length) if cfg.scheduler != "none": scheduler.step() if cfg.opt == "sophiag": opt.zero_grad(set_to_none=True) if curr_iter % 10 == 10 - 1: sample_again = model(x, targets=y, get_logits=True) - samp_dist = torch.distributions.Categorical(logits=sample_again["logits"]) + samp_dist = torch.distributions.Categorical( + logits=sample_again["logits"] + ) y_sample = samp_dist.sample() loss_sampled = torch.nn.functional.cross_entropy( sample_again["logits"].view(-1, sample_again["logits"].size(-1)), From 6709eb64d53ee9bb3127255477f666c29f71c134 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Fri, 8 Nov 2024 01:17:05 +0100 Subject: [PATCH 41/58] clipped version are here, muon schedules todo --- README.md | 5 +- src/config/base.py | 8 + src/distributed/backend.py | 1 - src/distributed/ddp.py | 4 +- src/distributed/single.py | 1 - src/main.py | 43 +++- src/models/llama.py | 1 - src/models/test.py | 1 - src/optim/ademamix.py | 2 - src/optim/ademamix2.py | 2 - src/optim/base.py | 13 +- src/optim/clipped.py | 414 +++++++++++++++++++++++++++++++++++++ src/optim/muon.py | 11 +- src/optim/prodigy.py | 5 + src/optim/schedulefree.py | 1 - src/optim/sign.py | 1 - src/optim/sophia.py | 2 - src/optim/utils.py | 1 - 18 files changed, 482 insertions(+), 34 deletions(-) create mode 100644 src/optim/clipped.py diff --git a/README.md b/README.md index 4767fbc..dbf7fc4 100755 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -84,6 +84,8 @@ parser.add_argument('--prodigy_use_bias_correction', default=False, type=bool) parser.add_argument('--prodigy_safeguard_warmup', default=False, type=bool) # Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. parser.add_argument('--prodigy_fsdp_in_use', default=False, type=bool) parser.add_argument('--sophia_rho', default=0.04, type=float) +parser.add_argument('--clipping_type', default='no', choices=['no', 'local', 'elementwise']) # for methods with clipping +parser.add_argument('--clipping_eta', default=1.0, type=float) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) @@ -123,7 +125,6 @@ parser.add_argument('--log_dynamics', action='store_true') # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type (e.g. nccl) -# parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) ``` ## Using WandB diff --git a/src/config/base.py b/src/config/base.py index e27fc29..a0de09d 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -95,6 +95,10 @@ def parse_args(base_parser, args, namespace): "sophiag", "shampoo", "adopt", + "clip-adagrad", + "clip-adagrad-delay-eta", + "clip-adam", + "clip-adam-delay-eta", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -133,6 +137,10 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) parser.add_argument("--sophia_rho", default=0.04, type=float) + parser.add_argument( + "--clipping_type", default="no", choices=["no", "local", "elementwise"] + ) + parser.add_argument("--clip_eta", default=1.0, type=float) # Dataset params parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") diff --git a/src/distributed/backend.py b/src/distributed/backend.py index 9fc0539..5faa6c2 100644 --- a/src/distributed/backend.py +++ b/src/distributed/backend.py @@ -2,7 +2,6 @@ class DistributedBackend(object): - def __init__(self, args): pass diff --git a/src/distributed/ddp.py b/src/distributed/ddp.py index f805cc3..e66b5f6 100644 --- a/src/distributed/ddp.py +++ b/src/distributed/ddp.py @@ -2,14 +2,14 @@ import os from contextlib import contextmanager -from torch.distributed import destroy_process_group, get_world_size, init_process_group +from torch.distributed import (destroy_process_group, get_world_size, + init_process_group) from torch.nn.parallel import DistributedDataParallel as DDP from .backend import DistributedBackend class DataParallelDistributedBackend(DistributedBackend): - def __init__(self, args): self.rank = int(os.environ.get("RANK", -1)) assert self.rank != -1, "DDP backend can not be used without rank" diff --git a/src/distributed/single.py b/src/distributed/single.py index 8ece239..5f8adb2 100644 --- a/src/distributed/single.py +++ b/src/distributed/single.py @@ -4,7 +4,6 @@ class SinlgeNodeBackend(DistributedBackend): - def __init__(self, args): super().__init__(args) self.rank = 0 diff --git a/src/main.py b/src/main.py index 88ede61..a87cd0b 100755 --- a/src/main.py +++ b/src/main.py @@ -9,16 +9,18 @@ import numpy as np import torch +import wandb import config import distributed -import wandb from data.utils import DataReader, get_dataset from models.utils import get_model from optim.adammini import Adam_mini from optim.ademamix import AdEMAMix from optim.ademamix2 import AdEMAMix2 from optim.base import train +from optim.clipped import (AdagradClip, AdaGradClipDelayedEta, AdamClip, + AdamClipDelayedEta) # from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo # from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion @@ -50,7 +52,6 @@ def get_args(): def main(args, parser): - distributed_backend = distributed.make_backend_from_args(args) args = distributed_backend.get_adjusted_args_for_process(args) args.world_size = distributed_backend.get_world_size() @@ -288,6 +289,44 @@ def main(args, parser): # epsilon=1e-8, # ), ) + elif args.opt == "adopt": + raise NotImplementedError("Have not implemented yet") + elif args.opt in [ + "clip-adagrad", + "clip-adagrad-delay-eta", + "clip-adam", + "clip-adam-delay-eta", + ]: + clipped_adagrad_cfg = { + "lr": args.lr, + "eps": 1e-8, + "weight_decay": args.weight_decay, + "clipping": args.clipping_type, + "max_grad_norm": 1.0, + } + if args.opt == "clip-adagrad": + opt = AdagradClip(**clipped_adagrad_cfg) + clipped_adagrad_delay_eta_cfg = { + **clipped_adagrad_cfg, + "exp_avg_sq_value": 0.0001, + "etta": args.clipping_eta, + } + if args.opt == "clip-adagrad-delay-eta": + opt = AdaGradClipDelayedEta(**clipped_adagrad_delay_eta_cfg) + clipped_adam_cfg = { + **clipped_adagrad_cfg, + "betas": (args.beta1, args.beta2), + "correct_bias": args.correct_bias, + } + if args.opt == "clip-adam": + opt = AdamClip(**clipped_adam_cfg) + clipped_adam_delay_eta_cfg = { + **clipped_adam_cfg, + "exp_avg_sq_value": 0.00001, + "etta": args.clipping_eta, + } + if args.opt == "clip-adam-delay-eta": + opt = AdamClipDelayedEta(**clipped_adam_delay_eta_cfg) else: opt = torch.optim.SGD( group_specs, diff --git a/src/models/llama.py b/src/models/llama.py index ebb7430..e6aaec6 100644 --- a/src/models/llama.py +++ b/src/models/llama.py @@ -93,7 +93,6 @@ def forward(self, x): class LlamaAttention(CausalSelfAttention): - def forward(self, x, freqs_cis): # batch size, sequence length, embedding dimensionality (n_embd) ( diff --git a/src/models/test.py b/src/models/test.py index d146880..1deb696 100644 --- a/src/models/test.py +++ b/src/models/test.py @@ -85,7 +85,6 @@ def forward(self, x): class LlamaAttention(CausalSelfAttention): - def forward(self, x, freqs_cis): # batch size, sequence length, embedding dimensionality (n_embd) ( diff --git a/src/optim/ademamix.py b/src/optim/ademamix.py index 9a633e3..035697a 100644 --- a/src/optim/ademamix.py +++ b/src/optim/ademamix.py @@ -16,7 +16,6 @@ def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): - def f(beta, eps=1e-8): return math.log(0.5) / math.log(beta + eps) - 1 @@ -100,7 +99,6 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - lr = group["lr"] lmbda = group["weight_decay"] eps = group["eps"] diff --git a/src/optim/ademamix2.py b/src/optim/ademamix2.py index 1db0aad..9cc5084 100644 --- a/src/optim/ademamix2.py +++ b/src/optim/ademamix2.py @@ -16,7 +16,6 @@ def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): - def f(beta, eps=1e-8): return math.log(0.5) / math.log(beta + eps) - 1 @@ -100,7 +99,6 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - lr = group["lr"] lmbda = group["weight_decay"] eps = group["eps"] diff --git a/src/optim/base.py b/src/optim/base.py index 69eb063..7784104 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -4,19 +4,12 @@ from pathlib import Path import torch -import yaml - import wandb +import yaml # from logger.logger import DynamicsLogger -from .utils import ( - eval, - get_batch, - load_checkpoint, - load_worker_state, - save_checkpoint, - save_worker_state, -) +from .utils import (eval, get_batch, load_checkpoint, load_worker_state, + save_checkpoint, save_worker_state) def train( diff --git a/src/optim/clipped.py b/src/optim/clipped.py new file mode 100644 index 0000000..0f0fc6f --- /dev/null +++ b/src/optim/clipped.py @@ -0,0 +1,414 @@ +""" +Here is an original implementation of Clip-Adam and Clip-Adagrad. +Source: https://github.com/yaroslavkliukin/Clipped-AdaGrad-and-Adam +""" + +import math + +import torch + + +class AdamClip(torch.optim.Optimizer): + """ + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. + clipping (str): "no", "local", "elementwise". Default: "no" + max_grad_norm (float): value to which we clip the gradient. Default: 1.0 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.0, + correct_bias=True, + clipping="no", + max_grad_norm=1.0, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]) + ) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + clipping=clipping, + max_grad_norm=max_grad_norm, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + grad = p.grad.data + + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + p.data.addcdiv_(tensor1=exp_avg, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss + + +class AdamClipDelayedEta(torch.optim.Optimizer): + """ + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default False. + clipping (str): "no", "local", "elementwise". Default: "no" + max_grad_norm (float): value to which we clip the gradient. Default: 1.0 + exp_avg_sq_value (float): value used to initialise the second moment in the first gradient update step. Default: 0.00001 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.0, + correct_bias=False, + clipping="no", + max_grad_norm=1.0, + exp_avg_sq_value=0.00001, + etta=1.0, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]) + ) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + clipping=clipping, + max_grad_norm=max_grad_norm, + exp_avg_sq_value=exp_avg_sq_value, + etta=etta, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + if p.grad.data.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = ( + torch.ones_like(p.data) * group["exp_avg_sq_value"] + ) + # Gradient from previous step + state["prev_grad"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq, prev_grad = ( + state["exp_avg"], + state["exp_avg_sq"], + state["prev_grad"], + ) + beta1, beta2 = group["betas"] + + state["step"] += 1 + + exp_avg.mul_(beta1).add_(p.grad.data, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_( + prev_grad, prev_grad, value=(1.0 - beta2) * group["etta"] + ) + state["prev_grad"] = p.grad.data + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = ( + step_size * math.sqrt(bias_correction2) / bias_correction1 + ) + + p.data.addcdiv_(tensor1=exp_avg, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss + + +class AdagradClip(torch.optim.Optimizer): + """Implements Adagrad algorithm with clipping. + Parameters: + lr (float): learning rate. Default 1e-3. + eps (float): Adagrad epsilon. Default: 1e-10 + weight_decay (float): Weight decay. Default: 0.0 + clipping (str): "no", "global", "local", "elementwise". Default: "local" + max_grad_norm (bool): value to which we clip the gradient. Default: 1.0 + """ + + def __init__( + self, + params, + lr=1e-3, + eps=1e-10, + weight_decay=0.0, + clipping="local", + max_grad_norm=1.0, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + eps=eps, + weight_decay=weight_decay, + clipping=clipping, + max_grad_norm=max_grad_norm, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + grad = p.grad.data + + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + + exp_avg_sq.addcmul_(grad, grad, value=1.0) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + p.data.addcdiv_(tensor1=grad, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss + + +class AdaGradClipDelayedEta(torch.optim.Optimizer): + """Implements Adagrad algorithm with clipping, delay and reweighing. + Parameters: + lr (float): learning rate. Default 1e-3. + eps (float): Adagrad epsilon. Default: 1e-10 + weight_decay (float): Weight decay. Default: 0.0 + clipping (str): "no", "global", "local", "elementwise". Default: "local" + max_grad_norm (bool): value to which we clip the gradient. Default: 1.0 + etta (float): reweighing parameter. Default: 1.0 + exp_avg_sq_value (float): initial value to imitate first gradient + """ + + def __init__( + self, + params, + lr=1e-3, + eps=1e-10, + weight_decay=0.0, + clipping="local", + max_grad_norm=1.0, + etta=1.0, + exp_avg_sq_value=0.0001, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) + defaults = dict( + lr=lr, + eps=eps, + weight_decay=weight_decay, + clipping=clipping, + max_grad_norm=max_grad_norm, + etta=etta, + exp_avg_sq_value=exp_avg_sq_value, + ) + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + if group["clipping"] == "global": + torch.nn.utils.clip_grad_norm_(group["params"], group["max_grad_norm"]) + + for p in group["params"]: + if p.grad is None: + continue + + if group["clipping"] == "local": + torch.nn.utils.clip_grad_norm_(p, group["max_grad_norm"]) + + if group["clipping"] == "elementwise": + torch.nn.utils.clip_grad_value_(p, group["max_grad_norm"]) + + if p.grad.data.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg_sq"] = ( + torch.ones_like(p.data) * group["exp_avg_sq_value"] + ) + state["prev_grad"] = torch.zeros_like(p.data) + + exp_avg_sq, prev_grad = state["exp_avg_sq"], state["prev_grad"] + state["step"] += 1 + + exp_avg_sq.addcmul_(prev_grad, prev_grad, value=group["etta"]) + state["prev_grad"] = p.grad.data + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + + p.data.addcdiv_(tensor1=prev_grad, tensor2=denom, value=-step_size) + + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + + return loss diff --git a/src/optim/muon.py b/src/optim/muon.py index e8729ab..705bc37 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -90,9 +90,7 @@ def __init__( super().__init__(params, defaults) def step(self): - for group in self.param_groups: - lr = group["lr"] momentum = group["momentum"] zeropower_backend = zeropower_backends[group["backend"]] @@ -149,9 +147,12 @@ def separate_params(param_groups): param_groups = [param_groups] # param_groups is a list of dicts for group in param_groups: - params_2d, params_non2d, param_2d_count, param_non2d_count = ( - separate_params(group["params"]) - ) + ( + params_2d, + params_non2d, + param_2d_count, + param_non2d_count, + ) = separate_params(group["params"]) param_group_2d = {"params": params_2d} param_group_non2d = {"params": params_non2d} # Copy the group dict and replace the 'params' key with the separated params diff --git a/src/optim/prodigy.py b/src/optim/prodigy.py index 2f6c049..aa7f5b9 100644 --- a/src/optim/prodigy.py +++ b/src/optim/prodigy.py @@ -1,3 +1,8 @@ +""" +Here is an original implementation of Prodigy. +Source: https://github.com/konstmish/prodigy +""" + import math import torch diff --git a/src/optim/schedulefree.py b/src/optim/schedulefree.py index bbf18fa..3c5f4f2 100644 --- a/src/optim/schedulefree.py +++ b/src/optim/schedulefree.py @@ -62,7 +62,6 @@ def __init__( weight_lr_power: float = 2.0, foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), ): - defaults = dict( lr=lr, betas=betas, diff --git a/src/optim/sign.py b/src/optim/sign.py index c1ac3b8..80f8c1a 100644 --- a/src/optim/sign.py +++ b/src/optim/sign.py @@ -52,7 +52,6 @@ def _init_state(self, example, state=None): def _compute_update( self, grad, state, lr, momentum, nesterov, dampening, sign_update, **kwargs ): - if momentum != 0: # Signum check buf = state["momentum_buffer"] buf.mul_(momentum).add_(grad, alpha=1 - dampening) diff --git a/src/optim/sophia.py b/src/optim/sophia.py index 142ed14..e0cdc4d 100644 --- a/src/optim/sophia.py +++ b/src/optim/sophia.py @@ -169,7 +169,6 @@ def sophiag( weight_decay: float, maximize: bool ): - if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( "API has changed, `state_steps` argument must contain a list of singleton tensors" @@ -210,7 +209,6 @@ def _single_tensor_sophiag( maximize: bool, capturable: bool ): - for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] diff --git a/src/optim/utils.py b/src/optim/utils.py index 1bdd01e..12185ac 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist import torch.nn.functional as F - import wandb From 0f05b199ff7e6a32c3455e9c539bd83eddbdc4bf Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sun, 10 Nov 2024 17:18:07 +0300 Subject: [PATCH 42/58] updates in muon --- README.md | 5 +- src/config/base.py | 5 +- src/main.py | 36 ++++---- src/optim/base.py | 3 +- src/optim/muon.py | 202 ++++++++++++++++++++++----------------------- src/optim/utils.py | 1 + 6 files changed, 123 insertions(+), 129 deletions(-) diff --git a/README.md b/README.md index dbf7fc4..6048bfe 100755 --- a/README.md +++ b/README.md @@ -65,9 +65,8 @@ parser.add_argument('--normalize_grads', default=False, type=bool) parser.add_argument('--soap_data_format', default='channels_first', type=str) parser.add_argument('--correct_bias', default=True, type=bool) parser.add_argument('--nesterov', default=False, type=bool) # whether to use Nesterov-style momentum -parser.add_argument('--muon_backend', default='newtonschulz5', type=str) # the chosen backend for the orthogonalization step -parser.add_argument('--muon_backend_steps', default=5, type=int) # the number of iteration steps to use in the muon_backend, if it is iterative -parser.add_argument('--muon_lr_factor', default=0.1, type=float) # a factor by which to reduce the lr for muon +parser.add_argument('--muon_ns_steps', default=5, type=int) # the number of steps to use in the newton schulz, if it is iterative +parser.add_argument('--muon_lr_factor', default=0.02, type=float) # a factor by which to reduce the lr for muon parser.add_argmunet('--adema_beta3', default=0.9, type=float) # beta3 in AdEMAMix parser.add_argument('--adema_alpha', default=2.0, type=float) # alpha in AdEMAMix parser.add_argument('--adema_beta3_warmup', default=None, type=int) # AdEMAMix hyperparameter diff --git a/src/config/base.py b/src/config/base.py index a0de09d..f6cc936 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -119,9 +119,8 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--soap_data_format", default="channels_first", type=str) parser.add_argument("--correct_bias", default=True, type=bool) parser.add_argument("--nesterov", default=False, type=bool) - parser.add_argument("--muon_backend", default="newtonschulz5", type=str) - parser.add_argument("--muon_backend_steps", default=5, type=int) - parser.add_argument("--muon_lr_factor", default=0.1, type=float) + parser.add_argument("--muon_ns_steps", default=5, type=int) + parser.add_argument("--muon_lr_factor", default=1.0, type=float) parser.add_argument("--adema_beta3", default=0.9, type=float) parser.add_argument("--adema_alpha", default=2.0, type=float) parser.add_argument("--adema_beta3_warmup", default=None, type=int) diff --git a/src/main.py b/src/main.py index a87cd0b..2c60aa5 100755 --- a/src/main.py +++ b/src/main.py @@ -9,10 +9,10 @@ import numpy as np import torch -import wandb import config import distributed +import wandb from data.utils import DataReader, get_dataset from models.utils import get_model from optim.adammini import Adam_mini @@ -24,7 +24,7 @@ # from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo # from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion -from optim.muon import CombinedOptimizer, Muon +from optim.muon import Muon, separate_params from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule) @@ -149,24 +149,20 @@ def main(args, parser): correct_bias=args.correct_bias, ) elif args.opt == "muon": - opt = CombinedOptimizer( - group_specs, - [Muon, torch.optim.AdamW], - [ - { - "lr": args.muon_lr_factor * args.lr, - "momentum": args.momentum, - "nesterov": args.nesterov, - "backend": args.muon_backend, - "backend_steps": args.muon_backend_steps, - }, - { - "lr": args.lr, - "betas": (args.beta1, args.beta2), - "weight_decay": args.weight_decay, - "fused": True, - }, - ], + param_groups_2d, param_groups_non2d, _, _ = separate_params(group_specs) + print(len(param_groups_2d)) + print(len(param_groups_non2d)) + opt = Muon( + muon_params=param_groups_2d[0]["params"], + lr=args.muon_lr_factor, # since adamw_lr_ration = adamw_lr / muon_lr + momentum=args.momentum, + nesterov=args.nesterov, # always use nesterov momentum for Muon + ns_steps=args.muon_ns_steps, + adamw_params=param_groups_non2d[0]["params"], + adamw_lr=args.lr, + adamw_betas=(args.beta1, args.beta2), + adamw_eps=1e-8, + adamw_wd=args.weight_decay, ) elif args.opt == "ademamix": opt = AdEMAMix( diff --git a/src/optim/base.py b/src/optim/base.py index 7784104..5470f3e 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -4,9 +4,10 @@ from pathlib import Path import torch -import wandb import yaml +import wandb + # from logger.logger import DynamicsLogger from .utils import (eval, get_batch, load_checkpoint, load_worker_state, save_checkpoint, save_worker_state) diff --git a/src/optim/muon.py b/src/optim/muon.py index 705bc37..3cf0062 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -9,11 +9,6 @@ import torch.distributed as dist -def zeropower_via_svd(G, steps=None): - U, S, V = G.svd() - return U @ V.T - - @torch.compile def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): """ @@ -33,18 +28,13 @@ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): X = X.T for _ in range(steps): A = X @ X.T - B = A @ X - X = a * X + b * B + c * A @ B + B = b * A + c * A @ A + X = a * X + B @ X if G.size(0) > G.size(1): X = X.T return X -zeropower_backends = dict( - svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5 -) - - class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -55,56 +45,96 @@ class Muon(torch.optim.Optimizer): the advantage that it can be stably run in bfloat16 on the GPU. Some warnings: - - This optimizer assumes that all parameters passed in are 2D. - - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D - parameters; those should all be optimized by a standard method (e.g., AdamW). - - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. - - We believe it is unlikely to work well for training with small batch size. + - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). Arguments: - lr: The learning rate used by the internal SGD. - momentum: The momentum used by the internal SGD. + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') - backend_steps: The number of iteration steps to use in the backend, if it is iterative. + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. """ def __init__( self, - params, + muon_params, lr=0.02, momentum=0.95, nesterov=True, - backend="newtonschulz5", - backend_steps=5, + ns_steps=6, + adamw_params=None, + adamw_lr=3e-4, + adamw_betas=(0.95, 0.95), + adamw_eps=1e-8, + adamw_wd=0, ): + defaults = dict( lr=lr, momentum=momentum, nesterov=nesterov, - backend=backend, - backend_steps=backend_steps, + ns_steps=ns_steps, + adamw_lr_ratio=adamw_lr / lr, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + adamw_wd=adamw_wd, ) + + params = list(muon_params) + adamw_params = list(adamw_params) if adamw_params is not None else [] + params.extend(adamw_params) super().__init__(params, defaults) + # Sort parameters into those for which we will use Muon, and those for which we will not + for p in muon_params: + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + if p.ndim >= 2 and p.size(0) < 10000: + self.state[p]["use_muon"] = True + # self.state[p]["use_muon"] = True + else: + self.state[p]["use_muon"] = False + for p in adamw_params: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + if "WORLD_SIZE" in os.environ: + self.world_size = int(os.environ["WORLD_SIZE"]) + self.rank = int(os.environ["RANK"]) + else: + self.world_size = 1 + self.rank = 0 + def step(self): + for group in self.param_groups: + + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] lr = group["lr"] momentum = group["momentum"] - zeropower_backend = zeropower_backends[group["backend"]] # generate weight updates in distributed fashion - total_params = sum(p.numel() for p in group["params"]) + total_params = sum(p.numel() for p in params) updates_flat = torch.zeros( total_params, device="cuda", dtype=torch.bfloat16 ) curr_idx = 0 - for i, p in enumerate(group["params"]): + for i, p in enumerate(params): # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs - if i % int(os.environ["WORLD_SIZE"]) == int(os.environ["RANK"]): + if i % self.world_size == self.rank: g = p.grad + if g.ndim > 2: + g = g.view(g.size(0), -1) assert g is not None state = self.state[p] if "momentum_buffer" not in state: @@ -113,17 +143,18 @@ def step(self): buf.mul_(momentum).add_(g) if group["nesterov"]: g = g.add(buf, alpha=momentum) - g = zeropower_backend(g, steps=group["backend_steps"]) + g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr_idx : curr_idx + p.numel()] = g.flatten() curr_idx += p.numel() # sync updates across devices. we are not memory-constrained so can do this simple deserialization - dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + if self.world_size > 1: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) # deserialize and apply updates curr_idx = 0 - for p in group["params"]: + for p in params: g = ( updates_flat[curr_idx : curr_idx + p.numel()] .view_as(p.data) @@ -132,6 +163,41 @@ def step(self): p.data.add_(g, alpha=-lr) curr_idx += p.numel() + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = ( + group["adamw_lr_ratio"] * group["lr"] + ) # in order for lr schedule to work + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["adamw_wd"] + + for p in params: + g = p.grad + assert g is not None + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + def separate_params(param_groups): param_groups_2d = [] @@ -186,71 +252,3 @@ def separate_params(param_groups): return params_2d, params_non2d, len(params_2d), len(params_non2d) else: breakpoint() - - -class CombinedOptimizer(torch.optim.Optimizer): - """ -# Note that CombinedOptimizer is not a torch.optim.Optimizer, but a wrapper around multiple optimizers. -# Original Example: - optimizer = CombinedOptimizer([ - torch.optim.AdamW(self.lm_head.parameters(), lr=learning_rate, betas=betas, weight_decay=0, fused=True), - OrthogonalNesterov(self.transformer.h.parameters(), lr=0.1*learning_rate, momentum=0.95) - ]) -# Refactored Example: - optimizer = CombinedOptimizer(\ - self.parameters(), - [OrthogonalNesterov, torch.optim.AdamW], - [{'lr': 0.1*learning_rate, 'momentum': 0.95}, - {'lr': learning_rate, 'betas': betas, 'weight_decay': 0, 'fused': True} - ]) -""" - - def __init__(self, params, optimizer_types, configs): - # Separate 2D and non-2D parameters. - # param_groups_2d_non2d: (param_groups_2d, param_groups_non2d). - # If params is a list of tensors, then each of param_groups_2d and param_groups_non2d - # will be a list of tensors. - # If params is a list of dicts, then each of param_groups_2d and param_groups_non2d - # will be a list of dicts. - # If params is a dict, then each of param_groups_2d and param_groups_non2d will - # be a list of dicts containing only one dict. - ( - param_groups_2d, - param_groups_non2d, - total_param_2d_count, - total_param_non2d_count, - ) = separate_params(params) - param_groups_2d_non2d = (param_groups_2d, param_groups_non2d) - print( - f"Total 2D params: {total_param_2d_count}, Total non-2D params: {total_param_non2d_count}" - ) - - assert ( - len(optimizer_types) == len(configs) == 2 - ), "You must use only two optimizers" - assert optimizer_types[0] == Muon, "The first optimizer must be Muon" - self.optimizers = [ - optimizer_types[i](param_groups_2d_non2d[i], **configs[i]) for i in range(2) - ] - self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups] - self.base_lrs = [opt.param_groups[0]["lr"] for opt in self.optimizers] - # Combine the state dicts of all opt in self.optimizers into a single dict - self.state = {k: v for opt in self.optimizers for k, v in opt.state.items()} - # Initially all states are empty. So no point to print their counts. - # Only use the defaults of the OrthogonalNesterov optimizer - self.defaults = self.optimizers[0].defaults - - def step(self, *args, **kwargs): - for opt in self.optimizers: - opt.step(*args, **kwargs) - - def zero_grad(self, **kwargs): - for opt in self.optimizers: - opt.zero_grad(**kwargs) - - def scale_lrs(self, lr_scale): - for base_lr, opt in zip(self.base_lrs, self.optimizers): - opt.param_groups[0]["lr"] = base_lr * lr_scale - - def state_dict(self): - return [opt.state_dict() for opt in self.optimizers] diff --git a/src/optim/utils.py b/src/optim/utils.py index 12185ac..1bdd01e 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F + import wandb From 729c820b83a9e7e10dfe99d18c9e89b1f41f45c9 Mon Sep 17 00:00:00 2001 From: Andrei Semenov <67924720+Andron00e@users.noreply.github.com> Date: Sun, 10 Nov 2024 15:49:40 +0100 Subject: [PATCH 43/58] -- micro --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6048bfe..191e670 100755 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ parser.add_argument('--sophia_rho', default=0.04, type=float) parser.add_argument('--clipping_type', default='no', choices=['no', 'local', 'elementwise']) # for methods with clipping parser.add_argument('--clipping_eta', default=1.0, type=float) # Dataset params -parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', "arxiv2000", 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) +parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', 'arxiv2000', 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) parser.add_argument('--vocab_size', default=50304, type=int) parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this From d81d0b2ab0c2c11b9a9037d9a66fe752f56ab8f7 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Mon, 11 Nov 2024 04:04:28 +0300 Subject: [PATCH 44/58] implemented a scheduler for muon --- src/main.py | 44 ++++++++++++++--------- src/optim/muon.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 17 deletions(-) diff --git a/src/main.py b/src/main.py index 2c60aa5..c48156a 100755 --- a/src/main.py +++ b/src/main.py @@ -24,7 +24,7 @@ # from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo # from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion -from optim.muon import Muon, separate_params +from optim.muon import CombinedScheduler, Muon, separate_params from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule) @@ -150,8 +150,6 @@ def main(args, parser): ) elif args.opt == "muon": param_groups_2d, param_groups_non2d, _, _ = separate_params(group_specs) - print(len(param_groups_2d)) - print(len(param_groups_non2d)) opt = Muon( muon_params=param_groups_2d[0]["params"], lr=args.muon_lr_factor, # since adamw_lr_ration = adamw_lr / muon_lr @@ -340,18 +338,22 @@ def main(args, parser): if args.scheduler in ["cos", "linear"]: # initial lr is args.lr / div_factor # final lr is initial_lr/final_div_factor = args.lr / div_factor / final_div_factor - scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer=opt, - max_lr=[ - group.get("lr", args.lr) for group in group_specs - ], # it was args.lr - total_steps=args.iterations, - pct_start=args.warmup_steps - / args.iterations, # it was args.warmup_percent - anneal_strategy=args.scheduler, - cycle_momentum=False, - div_factor=1e2, - final_div_factor=1, + scheduler = ( + torch.optim.lr_scheduler.OneCycleLR( + optimizer=opt, + max_lr=[ + group.get("lr", args.lr) for group in group_specs + ], # it was args.lr + total_steps=args.iterations, + pct_start=args.warmup_steps + / args.iterations, # it was args.warmup_percent + anneal_strategy=args.scheduler, + cycle_momentum=False, + div_factor=1e2, + final_div_factor=1, + ) + if args.opt != "muon" + else CombinedScheduler(opt, args) ) elif args.scheduler == "cos_inf": lambda_schedule = cos_inf_schedule( @@ -361,7 +363,11 @@ def main(args, parser): div_factor=1e2, final_div_factor=0.1, ) - scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + scheduler = ( + torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + if args.opt != "muon" + else CombinedScheduler(opt, args) + ) elif args.scheduler == "wsd": lambda_schedule = wsd_schedule( n_iterations=args.iterations, @@ -371,7 +377,11 @@ def main(args, parser): final_lr_factor=args.wsd_final_lr_scale, # should be 0 here decay_type=args.decay_type, ) - scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + scheduler = ( + torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + if args.opt != "muon" + else CombinedScheduler(opt, args) + ) elif args.scheduler == "cos_wsd": lambda_schedule = cosine_wsd_decay_schedule( n_iterations=args.iterations, diff --git a/src/optim/muon.py b/src/optim/muon.py index 3cf0062..f0d6e49 100644 --- a/src/optim/muon.py +++ b/src/optim/muon.py @@ -8,6 +8,8 @@ import torch import torch.distributed as dist +from .schedule import cos_inf_schedule, cosine_wsd_decay_schedule, wsd_schedule + @torch.compile def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): @@ -81,6 +83,7 @@ def __init__( momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, + adamw_lr=adamw_lr, adamw_lr_ratio=adamw_lr / lr, adamw_betas=adamw_betas, adamw_eps=adamw_eps, @@ -252,3 +255,91 @@ def separate_params(param_groups): return params_2d, params_non2d, len(params_2d), len(params_non2d) else: breakpoint() + + +class CombinedScheduler: + """ + CombinedScheduler implements a scheduler for the Muon optimizer: it leverages both Muon and AdamW learning rates, and applies the same sort of scheduler for both of them. + + Arguments: + optimizer: Muon optimizer. + cfg: arguments used for schedulers. + muon_lr_key: defaults["lr"] is responsible for the Muon learning rate. + adamw_lr_key: defaults["adamw_r"] is responsible for the AdamW learning rate. + """ + + def __init__(self, optimizer, cfg, muon_lr_key="lr", adamw_lr_key="adamw_lr"): + self.schedulers = [] + scheduler_map = { + "cos": torch.optim.lr_scheduler.OneCycleLR, + "linear": torch.optim.lr_scheduler.OneCycleLR, + "cos_inf": lambda opt, lr: torch.optim.lr_scheduler.LambdaLR( + opt, + cos_inf_schedule( + n_iterations=cfg.iterations, + n_warmup=cfg.warmup_steps, + n_inf=cfg.cos_inf_steps, + div_factor=1e2, + final_div_factor=0.1, + ), + ), + "wsd": lambda opt, lr: torch.optim.lr_scheduler.LambdaLR( + opt, + wsd_schedule( + n_iterations=cfg.iterations, + n_warmup=cfg.warmup_steps, + fract_decay=cfg.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=cfg.wsd_final_lr_scale, + decay_type=cfg.decay_type, + ), + ), + "cos_wsd": lambda opt, lr: torch.optim.lr_scheduler.LambdaLR( + opt, + cosine_wsd_decay_schedule( + n_iterations=cfg.iterations, + n_warmup=cfg.warmup_steps, + anneal_end_factor=0.15, + fract_decay=cfg.wsd_fract_decay, + init_div_factor=1e2, + final_lr_factor=0.1, + decay_type=cfg.decay_type, + ), + ), + } + + for group in optimizer.param_groups: + lr_key = muon_lr_key if muon_lr_key in group else adamw_lr_key + if lr_key in group: + scheduler_cls = scheduler_map.get(cfg.scheduler, None) + if scheduler_cls: + if cfg.scheduler in ["cos", "linear"]: + scheduler = scheduler_cls( + optimizer, + max_lr=[group.get(lr_key, getattr(cfg, lr_key.lower()))], + total_steps=cfg.iterations, + pct_start=cfg.warmup_steps / cfg.iterations, + anneal_strategy=cfg.scheduler, + cycle_momentum=False, + div_factor=1e2, + final_div_factor=1, + ) + else: + scheduler = scheduler_cls( + optimizer, group.get(lr_key, getattr(cfg, lr_key.lower())) + ) + self.schedulers.append(scheduler) + + def step(self): + for scheduler in self.schedulers: + scheduler.step() + + def state_dict(self): + state_dict = {} + for i, scheduler in enumerate(self.schedulers): + state_dict[f"scheduler_{i}"] = scheduler.state_dict() + return state_dict + + def load_state_dict(self, state_dict): + for i, scheduler in enumerate(self.schedulers): + scheduler.load_state_dict(state_dict[f"scheduler_{i}"]) From 6a8a1595b7933102421c3693e78c5bd022bbe2f2 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Mon, 11 Nov 2024 20:20:30 +0300 Subject: [PATCH 45/58] new scheduler, double decay, test it --- README.md | 4 +- src/config/base.py | 8 ++- src/main.py | 15 ++++- src/optim/schedule.py | 144 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 191e670..588fb2b 100755 --- a/README.md +++ b/README.md @@ -46,10 +46,12 @@ parser.add_argument('--lr', default=1e-3, type=float) parser.add_argument('--wsd_final_lr_scale', default=0.0, type=float) # wsd scheduler parser.add_argument('--wsd_fract_decay', default=0.1, type=float) # wsd scheduler parser.add_argument('--decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) +parser.add_argument('--dd_second_decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) +parser.add_argument("--dd_first_lr_factor", default=1e-2, type=float) parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter -parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none']) +parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none', 'dd']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations diff --git a/src/config/base.py b/src/config/base.py index f6cc936..a5f1b7c 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -56,7 +56,7 @@ def parse_args(base_parser, args, namespace): parser.add_argument( "--scheduler", default="cos", - choices=["linear", "cos", "wsd", "none", "cos_inf", "cos_wsd"], + choices=["linear", "cos", "wsd", "none", "cos_inf", "cos_wsd", "dd"], ) parser.add_argument("--cos_inf_steps", default=0, type=int) # parser.add_argument("--cos-final-lr", default=1e-6, type=float) @@ -72,6 +72,12 @@ def parse_args(base_parser, args, namespace): default="linear", choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], ) + parser.add_argument( + "--dd_second_decay_type", + default="linear", + choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"], + ) + parser.add_argument("--dd_first_lr_factor", default=1e-2, type=float) # Optimization parser.add_argument( diff --git a/src/main.py b/src/main.py index c48156a..de11359 100755 --- a/src/main.py +++ b/src/main.py @@ -27,7 +27,7 @@ from optim.muon import CombinedScheduler, Muon, separate_params from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, - wsd_schedule) + dd_schedule, wsd_schedule) from optim.schedulefree import AdamWScheduleFree, SGDScheduleFree from optim.sgdf import SGDF from optim.shampoo import DistributedShampoo @@ -393,6 +393,19 @@ def main(args, parser): decay_type=args.decay_type, ) scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) + elif args.scheduler == "dd": + lambda_schedule = dd_schedule( + n_iterations=args.iterations, + n_warmup=args.warmup_steps, + fract_fisrt_decay=args.wsd_fract_decay, # this will be responsible for the first decay phase + max_lr=[group.get("lr", args.lr) for group in group_specs], + first_final_lr_factor=args.dd_first_lr_factor, + second_final_lr_factor=0.0, # stop with zero lr + div_factor=1e2, + first_decay_type=args.decay_type, + second_decay_type=args.dd_second_decay_type, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lambda_schedule) else: raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") else: diff --git a/src/optim/schedule.py b/src/optim/schedule.py index 89538d4..8d8ccab 100644 --- a/src/optim/schedule.py +++ b/src/optim/schedule.py @@ -195,3 +195,147 @@ def schedule(step): return final_lr_factor return schedule + + +def dd_schedule( + n_iterations, + n_warmup, + fract_fisrt_decay, + max_lr, + first_final_lr_factor=1e-2, + second_final_lr_factor=0.0, + div_factor=1e2, + first_decay_type="cosine", + second_decay_type="linear", +): + """Warmup, cosine annealing, and linear decay schedule. + Args: + n_iterations: total number of iterations + n_warmup: number of warmup iterations + fract_fisrt_decay: fraction of iterations for the first decay phase + max_lr: the mamimum value of learning rate during the training + first_final_lr_factor: factor by which to reduce max_lr at the end of the first decay phase + second_final_lr_factor: factor by which to reduce first_final_lr_factor at the end of the second decay phase + div_factor: initial division factor for warmup + first_decay_type: which decay approach to use during the fisrt decay phase + second_decay_type: which decay approach to use during the second decay phase + Returns: + schedule: a function that takes the current iteration and + returns the multiplicative factor for the learning rate + """ + if fract_fisrt_decay > 1.0: + raise ValueError( + "Invalid fract_fisrt_decay value: {}".format(fract_fisrt_decay) + ) + n_fisrt_decay = int(fract_fisrt_decay * n_iterations) + + def schedule(step): + if step < n_warmup: + return (step / n_warmup) + (1 - step / n_warmup) / div_factor + elif step < n_warmup + n_fisrt_decay: + if first_decay_type == "cosine": + return first_final_lr_factor + 0.5 * ( + max_lr - first_final_lr_factor + ) * (1 + math.cos(math.pi * (step - n_warmup) / n_fisrt_decay)) + elif first_decay_type == "linear": + return first_final_lr_factor + (max_lr - first_final_lr_factor) * ( + 1 - (step - n_warmup) / n_fisrt_decay + ) + elif first_decay_type == "exp": + return first_final_lr_factor ** ((step - n_warmup) / n_fisrt_decay) + elif first_decay_type == "mirror_cosine": + cosine_value = ( + first_final_lr_factor + + (max_lr - first_final_lr_factor) + * (1 + math.cos(math.pi * (step - n_warmup) / n_fisrt_decay)) + * 0.5 + ) + linear_value = first_final_lr_factor + ( + max_lr - first_final_lr_factor + ) * (1 - (step - n_warmup) / n_fisrt_decay) + return linear_value * 2 - cosine_value + elif first_decay_type == "square": + return first_final_lr_factor + (max_lr - first_final_lr_factor) * ( + 1 - ((step - n_warmup) / n_fisrt_decay) ** 2 + ) + elif first_decay_type == "sqrt": + return first_final_lr_factor + (max_lr - first_final_lr_factor) * ( + 1 - math.sqrt((step - n_warmup) / n_fisrt_decay) + ) + else: + raise ValueError( + f"decay type {first_decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + elif step < n_iterations: + if second_decay_type == "linear": + return second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - (step - n_warmup - n_fisrt_decay) / (n_iterations - n_fisrt_decay) + ) + elif second_decay_type == "cosine": + return second_final_lr_factor + 0.5 * ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + + math.cos( + math.pi + * (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ) + elif second_decay_type == "exp": + return first_final_lr_factor ** ( + (step - n_warmup - n_fisrt_decay) / (n_iterations - n_fisrt_decay) + ) + elif second_decay_type == "mirror_cosine": + cosine_value = ( + second_final_lr_factor + + (first_final_lr_factor - second_final_lr_factor) + * ( + 1 + + math.cos( + math.pi + * (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ) + * 0.5 + ) + linear_value = second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - (step - n_warmup - n_fisrt_decay) / (n_iterations - n_fisrt_decay) + ) + return linear_value * 2 - cosine_value + elif second_decay_type == "square": + return second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - ( + (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ** 2 + ) + elif second_decay_type == "sqrt": + return second_final_lr_factor + ( + first_final_lr_factor - second_final_lr_factor + ) * ( + 1 + - math.sqrt( + (step - n_warmup - n_fisrt_decay) + / (n_iterations - n_fisrt_decay) + ) + ) + else: + raise ValueError( + f"decay type {second_decay_type} is not in ['cosine','miror_cosine','linear','exp']" + ) + else: + return second_final_lr_factor + + return schedule From c30eecaacf0640e75c2a494bcd7d8392e7818430 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 12 Nov 2024 16:55:28 +0100 Subject: [PATCH 46/58] minor --- src/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index de11359..0c1bbcf 100755 --- a/src/main.py +++ b/src/main.py @@ -398,7 +398,7 @@ def main(args, parser): n_iterations=args.iterations, n_warmup=args.warmup_steps, fract_fisrt_decay=args.wsd_fract_decay, # this will be responsible for the first decay phase - max_lr=[group.get("lr", args.lr) for group in group_specs], + max_lr=args.lr,#[group.get("lr", args.lr) for group in group_specs], first_final_lr_factor=args.dd_first_lr_factor, second_final_lr_factor=0.0, # stop with zero lr div_factor=1e2, From 82dd1af15e745fd2ad90e12cbed08aac370a40f8 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 19 Nov 2024 01:45:08 +0300 Subject: [PATCH 47/58] adopt is here --- src/main.py | 10 +++- src/optim/adopt.py | 103 ++++++++++++++++++++++++++++++++++++++++++ src/optim/schedule.py | 3 +- 3 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 src/optim/adopt.py diff --git a/src/main.py b/src/main.py index 0c1bbcf..8981b42 100755 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ from optim.adammini import Adam_mini from optim.ademamix import AdEMAMix from optim.ademamix2 import AdEMAMix2 +from optim.adopt import ADOPT from optim.base import train from optim.clipped import (AdagradClip, AdaGradClipDelayedEta, AdamClip, AdamClipDelayedEta) @@ -284,7 +285,12 @@ def main(args, parser): # ), ) elif args.opt == "adopt": - raise NotImplementedError("Have not implemented yet") + opt = ADOPT( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + ) elif args.opt in [ "clip-adagrad", "clip-adagrad-delay-eta", @@ -398,7 +404,7 @@ def main(args, parser): n_iterations=args.iterations, n_warmup=args.warmup_steps, fract_fisrt_decay=args.wsd_fract_decay, # this will be responsible for the first decay phase - max_lr=args.lr,#[group.get("lr", args.lr) for group in group_specs], + max_lr=args.lr, # [group.get("lr", args.lr) for group in group_specs], first_final_lr_factor=args.dd_first_lr_factor, second_final_lr_factor=0.0, # stop with zero lr div_factor=1e2, diff --git a/src/optim/adopt.py b/src/optim/adopt.py new file mode 100644 index 0000000..80a9cd1 --- /dev/null +++ b/src/optim/adopt.py @@ -0,0 +1,103 @@ +""" +Here is an original implementation of ADOPT. +Source: https://github.com/iShohei220/adopt +""" + +import math + +import torch + + +class ADOPT(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.9999), + eps=1e-6, + weight_decay=0.0, + decoupled=False, + ): + """ + Args: + params: iterable of parameters to optimize or dictionaries defining parameter groups. + lr: learning rate. + betas: coefficients used for computing running averages of gradient and gradient squared. + eps: term added to the denominator to improve numerical stability. + weight_decay: weight decay. + decoupled: whether to use decoupled weight decay. + """ + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta values: {betas}") + if eps <= 0.0: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + decoupled=decoupled, + ) + super(ADOPT, self).__init__(params, defaults) + + @torch.no_grad + def step(self, closure=None): + """Performs a single optimization step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("ADOPT does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + if group["weight_decay"] != 0: + if group["decoupled"]: + p.add_(p, alpha=-group["lr"] * group["decoupled"]) + else: + grad = grad.add(p, alpha=group["weight_decay"]) + + if state["step"] == 1: + exp_avg_sq.addcmul_(grad, grad) + continue + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + + if state["step"] == 2: + exp_avg.addcdiv_(grad, denom) + else: + exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + + p.add_(exp_avg.div(bias_correction1), alpha=-group["lr"]) + + return loss diff --git a/src/optim/schedule.py b/src/optim/schedule.py index 8d8ccab..a6a6730 100644 --- a/src/optim/schedule.py +++ b/src/optim/schedule.py @@ -192,7 +192,8 @@ def schedule(step): 1 - math.sqrt(progress) ) - return final_lr_factor + else: + return final_lr_factor return schedule From 755cf59c1c1c7825a05f7206c0438376cdedaf14 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 19 Nov 2024 02:01:06 +0300 Subject: [PATCH 48/58] adopt fix --- src/optim/adopt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optim/adopt.py b/src/optim/adopt.py index 80a9cd1..9458c0d 100644 --- a/src/optim/adopt.py +++ b/src/optim/adopt.py @@ -45,7 +45,7 @@ def __init__( ) super(ADOPT, self).__init__(params, defaults) - @torch.no_grad + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step.""" loss = None From 0c53df3147a5deff78dd964b255fef95fcc21ee9 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 19 Nov 2024 02:33:15 +0300 Subject: [PATCH 49/58] adopt fix again --- src/optim/adopt.py | 44 ++++++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/optim/adopt.py b/src/optim/adopt.py index 9458c0d..b91f32b 100644 --- a/src/optim/adopt.py +++ b/src/optim/adopt.py @@ -3,8 +3,6 @@ Source: https://github.com/iShohei220/adopt """ -import math - import torch @@ -13,10 +11,9 @@ def __init__( self, params, lr=1e-3, - betas=(0.9, 0.9999), + betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, - decoupled=False, ): """ Args: @@ -29,11 +26,11 @@ def __init__( """ if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] < 1.0: + if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] <= 1.0: raise ValueError(f"Invalid beta values: {betas}") if eps <= 0.0: raise ValueError(f"Invalid epsilon value: {eps}") - if not 0.0 <= weight_decay: + if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( @@ -41,7 +38,6 @@ def __init__( betas=betas, eps=eps, weight_decay=weight_decay, - decoupled=decoupled, ) super(ADOPT, self).__init__(params, defaults) @@ -50,16 +46,21 @@ def step(self, closure=None): """Performs a single optimization step.""" loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue + grad = p.grad if grad.is_sparse: raise RuntimeError("ADOPT does not support sparse gradients") + if group["weight_decay"] != 0: + grad = grad.add(p, alpha=group["weight_decay"]) + state = self.state[p] # State initialization @@ -72,32 +73,23 @@ def step(self, closure=None): p, memory_format=torch.preserve_format ) + state["step"] += 1 + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] + eps = group["eps"] - state["step"] += 1 bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - if group["weight_decay"] != 0: - if group["decoupled"]: - p.add_(p, alpha=-group["lr"] * group["decoupled"]) - else: - grad = grad.add(p, alpha=group["weight_decay"]) - - if state["step"] == 1: - exp_avg_sq.addcmul_(grad, grad) - continue + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( - group["eps"] - ) + exp_avg_sq = exp_avg_sq.div(bias_correction2).sqrt() + denom = torch.maximum(exp_avg_sq, torch.tensor(eps, device=grad.device)) - if state["step"] == 2: - exp_avg.addcdiv_(grad, denom) - else: - exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + exp_avg.mul_(beta1).add_(grad.div(denom), alpha=1 - beta1) - p.add_(exp_avg.div(bias_correction1), alpha=-group["lr"]) + step_size = group["lr"] * (exp_avg.div(bias_correction1)) + p.add_(-step_size) return loss From 6cfd01b9e1153032d0a975a3a61446d493cfeadf Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 19 Nov 2024 18:00:44 +0300 Subject: [PATCH 50/58] adopt again --- README.md | 2 +- src/main.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 588fb2b..3a15b49 100755 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ parser.add_argument('--wsd_final_lr_scale', default=0.0, type=float) # wsd sched parser.add_argument('--wsd_fract_decay', default=0.1, type=float) # wsd scheduler parser.add_argument('--decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) parser.add_argument('--dd_second_decay_type', default='linear', choices=['linear', 'cosine', 'exp', 'miror_cosine', 'square', 'sqrt']) -parser.add_argument("--dd_first_lr_factor", default=1e-2, type=float) +parser.add_argument('--dd_first_lr_factor', default=1e-2, type=float) parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter diff --git a/src/main.py b/src/main.py index 8981b42..6b44cdd 100755 --- a/src/main.py +++ b/src/main.py @@ -289,6 +289,7 @@ def main(args, parser): group_specs, lr=args.lr, betas=(args.beta1, args.beta2), + eps=1e-6, weight_decay=args.weight_decay, ) elif args.opt in [ From d4585ae812cc3c21458a0353d9ecb37333dbcb18 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 19 Nov 2024 21:28:52 +0300 Subject: [PATCH 51/58] mars is here, ready to try --- README.md | 4 +- src/config/base.py | 12 ++ src/main.py | 16 +++ src/optim/base.py | 6 +- src/optim/mars.py | 295 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 src/optim/mars.py diff --git a/README.md b/README.md index 3a15b49..247f351 100755 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none', 'dd']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta', 'mars']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -87,6 +87,8 @@ parser.add_argument('--prodigy_fsdp_in_use', default=False, type=bool) parser.add_argument('--sophia_rho', default=0.04, type=float) parser.add_argument('--clipping_type', default='no', choices=['no', 'local', 'elementwise']) # for methods with clipping parser.add_argument('--clipping_eta', default=1.0, type=float) +parser.add_argument('--mars_type', default='mars-adamw', choices=['mars-adamw', 'mars-lion', 'mars-shampoo'],) +parser.add_argument('--mars_vr_gamma', default=0.025, type=float) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', 'arxiv2000', 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) diff --git a/src/config/base.py b/src/config/base.py index a5f1b7c..be2f901 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -105,6 +105,7 @@ def parse_args(base_parser, args, namespace): "clip-adagrad-delay-eta", "clip-adam", "clip-adam-delay-eta", + "mars", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -142,10 +143,21 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool) parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool) parser.add_argument("--sophia_rho", default=0.04, type=float) + parser.add_argument("--sophia_bs", default=480, type=int) parser.add_argument( "--clipping_type", default="no", choices=["no", "local", "elementwise"] ) parser.add_argument("--clip_eta", default=1.0, type=float) + parser.add_argument( + "--mars_type", + default="mars-adamw", + choices=["mars-adamw", "mars-lion", "mars-shampoo"], + ) + parser.add_argument("--mars_vr_gamma", default=0.025, type=float) + parser.add_argument("--mars_is_approx", default=True, type=float) + parser.add_argument("--mars_lr", default=3e-3, type=float) + parser.add_argument("--mars_beta1", default=0.95, type=float) + parser.add_argument("--mars_beta2", default=0.99, type=float) # Dataset params parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") diff --git a/src/main.py b/src/main.py index 6b44cdd..0a9db23 100755 --- a/src/main.py +++ b/src/main.py @@ -25,6 +25,7 @@ # from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo # from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion +from optim.mars import MARS from optim.muon import CombinedScheduler, Muon, separate_params from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, @@ -328,6 +329,21 @@ def main(args, parser): } if args.opt == "clip-adam-delay-eta": opt = AdamClipDelayedEta(**clipped_adam_delay_eta_cfg) + elif args.opt == "mars": + opt = MARS( + group_specs, + lr=args.mars_lr, + betas=(args.mars_beta1, args.mars_beta2), + weight_decay=args.weight_decay, + amsgrad=False, + gamma=args.mars_vr_gamma, + is_approx=args.mars_is_approx, + mars_type=args.mars_type, + optimize_1d=False, # we set in order to optimize 1D parameters with AdamW + lr_1d=args.lr, # AdamW's lr when optimize_1d=False + betas_1d=(args.beta1, args.beta2), # AdamW's betas when optimize_1d=False + weight_decay_1d=0.1, # AdamW's weight decay + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/base.py b/src/optim/base.py index 5470f3e..d37473a 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -144,7 +144,11 @@ def train( if cfg.opt == "sf-sgd" or cfg.opt == "sf-adamw": opt.train() - opt.step() if cfg.opt != "sophiag" else opt.step(bs=480 * cfg.sequence_length) + ( + opt.step() + if cfg.opt != "sophiag" + else opt.step(bs=cfg.sophia_bs * cfg.sequence_length) + ) if cfg.scheduler != "none": scheduler.step() if cfg.opt == "sophiag": diff --git a/src/optim/mars.py b/src/optim/mars.py new file mode 100644 index 0000000..ace3869 --- /dev/null +++ b/src/optim/mars.py @@ -0,0 +1,295 @@ +""" +Here is an original implementation of MARS. +Source: https://github.com/AGI-Arena/MARS +""" + +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +import math + +import torch + +from .muon import zeropower_via_newtonschulz5 + + +def exists(val): + return val is not None + + +def update_fn( + p, + grad, + exp_avg, + exp_avg_sq, + lr, + wd, + beta1, + beta2, + last_grad, + eps, + amsgrad, + max_exp_avg_sq, + step, + gamma, + mars_type, + is_grad_2d, + optimize_1d, + lr_1d_factor, + betas_1d, + weight_decay_1d, +): + # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para + if optimize_1d or is_grad_2d: + c_t = (grad - last_grad).mul(gamma * (beta1 / (1.0 - beta1))).add(grad) + c_t_norm = torch.norm(c_t) + if c_t_norm > 1.0: + c_t = c_t / c_t_norm + exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1) + if (mars_type == "mars-adamw") or ( + mars_type == "mars-shampoo" and not is_grad_2d + ): + exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2) + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + if amsgrad: + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + denom = ( + max_exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + else: + denom = ( + exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.div(denom)) + elif mars_type == "mars-lion": + real_update_tmp = -lr * torch.mul(p.data, wd).add(exp_avg.sign()) + elif mars_type == "mars-shampoo" and is_grad_2d: + factor = max(1, grad.size(0) / grad.size(1)) ** 0.5 + real_update_tmp = ( + zeropower_via_newtonschulz5(exp_avg.mul(1.0 / (1.0 - beta1)), eps=eps) + .mul(factor) + .add(wd, p.data) + .mul(-lr) + ) + p.data.add_(real_update_tmp) + else: + beta1_1d, beta2_1d = betas_1d + exp_avg.mul_(beta1_1d).add_(grad, alpha=1 - beta1_1d) + exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1 - beta2_1d) + bias_correction1 = 1.0 - beta1_1d**step + bias_correction2 = 1.0 - beta2_1d**step + if amsgrad: + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + denom = ( + max_exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + else: + denom = ( + exp_avg_sq.sqrt() + .mul(1 / math.sqrt(bias_correction2)) + .add(eps) + .mul(bias_correction1) + ) + real_update_tmp = ( + -lr + * lr_1d_factor + * torch.mul(p.data, weight_decay_1d).add(exp_avg.div(denom)) + ) + p.data.add_(real_update_tmp) + return exp_avg, exp_avg_sq + + +class MARS(torch.optim.Optimizer): + def __init__( + self, + params, + lr=3e-3, + betas=(0.95, 0.99), + eps=1e-8, + weight_decay=0.0, + amsgrad=False, + gamma=0.025, + is_approx=True, + mars_type="mars-adamw", + optimize_1d=False, + lr_1d=3e-3, + betas_1d=(0.9, 0.95), + weight_decay_1d=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + assert mars_type in [ + "mars-adamw", + "mars-lion", + "mars-shampoo", + ], "MARS type not supported" + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + mars_type=mars_type, + gamma=gamma, + optimize_1d=optimize_1d, + weight_decay_1d=weight_decay_1d, + ) + super(MARS, self).__init__(params, defaults) + self.eps = eps + self.update_fn = update_fn + self.lr = lr + self.weight_decay = weight_decay + self.amsgrad = amsgrad + self.step_num = 0 + self.is_approx = is_approx + self.gamma = gamma + self.mars_type = mars_type + self.optimize_1d = optimize_1d + self.lr_1d_factor = lr_1d / lr + self.weight_decay_1d = weight_decay_1d + self.betas_1d = betas_1d + + @torch.no_grad() + def update_last_grad(self): + if not self.is_approx: + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + if "last_grad" not in state: + state["last_grad"] = torch.zeros_like(p) + state["last_grad"].zero_().add_(state["previous_grad"], alpha=1.0) + + @torch.no_grad() + def update_previous_grad(self): + if not self.is_approx: + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + print(p, "grad is none") + continue + state = self.state[p] + if "previous_grad" not in state: + state["previous_grad"] = torch.zeros_like(p) + state["previous_grad"].zero_().add_(p.grad, alpha=1.0) + + def __setstate__(self, state): + super(MARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @torch.no_grad() + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + grad_scaler=None, + ): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError( + "FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." + ) + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + gamma = self.gamma + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group["params"]): + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group["amsgrad"] + + state = self.state[p] + # ('----- starting a parameter state', state.keys(), 'Length of state', len(state)) + # State initialization + if len(state) <= 1: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Last Gradient + state["last_grad"] = torch.zeros_like(p) + # state['previous_grad'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p.data) + # import pdb + # pdb.set_trace() + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + last_grad = state["last_grad"] + lr, wd, beta1, beta2 = ( + group["lr"], + group["weight_decay"], + *group["betas"], + ) + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + else: + max_exp_avg_sq = 0 + + if "step" in state: + state["step"] += 1 + else: + state["step"] = 1 + step = state["step"] + is_grad_2d = len(grad.shape) == 2 + exp_avg, exp_avg_sq = self.update_fn( + p, + grad, + exp_avg, + exp_avg_sq, + lr, + wd, + beta1, + beta2, + last_grad, + self.eps, + amsgrad, + max_exp_avg_sq, + step, + gamma, + mars_type=self.mars_type, + is_grad_2d=is_grad_2d, + optimize_1d=self.optimize_1d, + lr_1d_factor=self.lr_1d_factor, + betas_1d=self.betas_1d, + weight_decay_1d=( + self.weight_decay if self.optimize_1d else self.weight_decay_1d + ), + ) + if self.is_approx: + state["last_grad"] = grad + self.step_num = step + + return loss From 9700a5020c08202b26075a197753d4d4b9133daf Mon Sep 17 00:00:00 2001 From: Andron00e Date: Tue, 19 Nov 2024 21:33:28 +0300 Subject: [PATCH 52/58] --fix info --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 247f351..2dc7172 100755 --- a/README.md +++ b/README.md @@ -89,6 +89,10 @@ parser.add_argument('--clipping_type', default='no', choices=['no', 'local', 'el parser.add_argument('--clipping_eta', default=1.0, type=float) parser.add_argument('--mars_type', default='mars-adamw', choices=['mars-adamw', 'mars-lion', 'mars-shampoo'],) parser.add_argument('--mars_vr_gamma', default=0.025, type=float) +parser.add_argument('--mars_is_approx', default=True, type=float) +parser.add_argument('--mars_lr', default=3e-3, type=float) +parser.add_argument('--mars_beta1', default=0.95, type=float) +parser.add_argument('--mars_beta2', default=0.99, type=float) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', 'arxiv2000', 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) From 29ce41b3b6519967c417616134a12398d3063d86 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Wed, 20 Nov 2024 01:52:19 +0300 Subject: [PATCH 53/58] --small changes in mars train --- src/optim/base.py | 3 +++ src/optim/mars.py | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/optim/base.py b/src/optim/base.py index d37473a..021690b 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -168,6 +168,9 @@ def train( opt.update_hessian() opt.zero_grad(set_to_none=True) model.zero_grad() + elif cfg.opt == "mars": + opt.zero_grad(set_to_none=True) + opt.update_last_grad() else: opt.zero_grad(set_to_none=True) # opt.zero_grad(set_to_none=True) diff --git a/src/optim/mars.py b/src/optim/mars.py index ace3869..ad34576 100644 --- a/src/optim/mars.py +++ b/src/optim/mars.py @@ -244,8 +244,6 @@ def step( if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p.data) - # import pdb - # pdb.set_trace() exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] last_grad = state["last_grad"] lr, wd, beta1, beta2 = ( From aaff7fcf21ad31f9f4e81a4cca80bbf1f940a0d4 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Thu, 21 Nov 2024 19:27:10 +0300 Subject: [PATCH 54/58] adafactor and lamb --- README.md | 4 +- src/config/base.py | 4 + src/main.py | 20 ++++ src/optim/adafactor.py | 205 +++++++++++++++++++++++++++++++++++++++++ src/optim/lamb.py | 126 +++++++++++++++++++++++++ 5 files changed, 358 insertions(+), 1 deletion(-) create mode 100644 src/optim/adafactor.py create mode 100644 src/optim/lamb.py diff --git a/README.md b/README.md index 2dc7172..73f65f8 100755 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none', 'dd']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta', 'mars']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta', 'mars', 'adafactor', 'lamb']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT @@ -93,6 +93,8 @@ parser.add_argument('--mars_is_approx', default=True, type=float) parser.add_argument('--mars_lr', default=3e-3, type=float) parser.add_argument('--mars_beta1', default=0.95, type=float) parser.add_argument('--mars_beta2', default=0.99, type=float) +parser.add_argument('--adafactor_decay_rate', default=-0.8, type=float) +parser.add_argument('--lamb_use_bias_correction', default=False, type=bool) # Dataset params parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', 'shakespeare-char', 'arxiv', 'arxiv2000', 'arxiv+wiki', 'openwebtext2', 'redpajama', 'redpajamav2', 'slimpajama_chunk1', 'fineweb', 'finewebedu']) parser.add_argument('--tokenizer', default='gpt2', type=str, choices=['gpt2', 'mistral']) diff --git a/src/config/base.py b/src/config/base.py index be2f901..bdd8112 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -106,6 +106,8 @@ def parse_args(base_parser, args, namespace): "clip-adam", "clip-adam-delay-eta", "mars", + "adafactor", + "lamb", ], ) parser.add_argument("--batch_size", default=50, type=int) @@ -158,6 +160,8 @@ def parse_args(base_parser, args, namespace): parser.add_argument("--mars_lr", default=3e-3, type=float) parser.add_argument("--mars_beta1", default=0.95, type=float) parser.add_argument("--mars_beta2", default=0.99, type=float) + parser.add_argument("--adafactor_decay_rate", default=-0.8, type=float) + parser.add_argument("--lamb_use_bias_correction", default=False, type=bool) # Dataset params parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/") diff --git a/src/main.py b/src/main.py index 0a9db23..5df2c62 100755 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ import wandb from data.utils import DataReader, get_dataset from models.utils import get_model +from optim.adafactor import Adafactor from optim.adammini import Adam_mini from optim.ademamix import AdEMAMix from optim.ademamix2 import AdEMAMix2 @@ -22,6 +23,7 @@ from optim.base import train from optim.clipped import (AdagradClip, AdaGradClipDelayedEta, AdamClip, AdamClipDelayedEta) +from optim.lamb import Lamb # from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo # from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion @@ -344,6 +346,24 @@ def main(args, parser): betas_1d=(args.beta1, args.beta2), # AdamW's betas when optimize_1d=False weight_decay_1d=0.1, # AdamW's weight decay ) + elif args.opt == "adafactor": + opt = Adafactor( + group_specs, + lr=args.lr, + decay_rate=args.adafactor_decay_rate, + beta1=args.beta1, + clip_threshold=1.0, + weight_decay=args.weight_decay, + ) + elif args.opt == "lamb": + opt = Lamb( + group_specs, + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + adam=False, + bias_correction=args.lamb_use_bias_correction, + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/adafactor.py b/src/optim/adafactor.py new file mode 100644 index 0000000..2381f92 --- /dev/null +++ b/src/optim/adafactor.py @@ -0,0 +1,205 @@ +""" +Here is an implementation of Adafactor. +Source: https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adafactor.py +""" + +import math +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +LossClosure = Callable[[], float] +OptLossClosure = Optional[LossClosure] +OptFloat = Optional[float] +ParamGroup = Dict[str, Any] +State = Dict[str, Any] + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + + It has been proposed in: `Adafactor: Adaptive Learning Rates with + Sublinear Memory Cost`__. + + Arguments: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: external learning rate (default: None) + eps2: regularization constans for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold: threshold of root mean square of + final gradient update (default: 1.0) + decay_rate: coefficient used to compute running averages of square + gradient (default: -0.8) + beta1: coefficient used for computing running averages of gradient + (default: None) + weight_decay: weight decay (L2 penalty) (default: 0) + scale_parameter: if true, learning rate is scaled by root mean square + of parameter (default: True) + relative_step: if true, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init: time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + + Example: + >>> import torch_optimizer as optim + >>> optimizer = optim.Adafactor(model.parameters()) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + + __ https://arxiv.org/abs/1804.04235 + + Note: + Reference code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py # noqa + """ + + def __init__( + self, + params, + lr=1e-3, + eps2=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and lr <= 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict( + lr=lr, + eps2=eps2, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super(Adafactor, self).__init__(params, defaults) + + def _get_lr(self, param_group: ParamGroup, param_state: State) -> float: + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = ( + 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + ) + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps2"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + def _get_options( + self, param_group: ParamGroup, param_shape: Tuple[int, ...] + ) -> Tuple[bool, bool]: + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + def _rms(self, tensor: torch.Tensor) -> float: + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad( + self, + exp_avg_sq_row: torch.Tensor, + exp_avg_sq_col: torch.Tensor, + output: torch.Tensor, + ) -> None: + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + torch.mul(r_factor, c_factor, out=output) + + def step(self, closure: OptLossClosure = None) -> OptFloat: + r"""Performs a single optimization step. + + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + grad, memory_format=torch.preserve_format + ) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).type_as( + grad + ) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).type_as(grad) + else: + state["exp_avg_sq"] = torch.zeros_like( + grad, memory_format=torch.preserve_format + ) + + state["RMS"] = 0 + + state["step"] += 1 + state["RMS"] = self._rms(p.data) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps2"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + + # Approximation of exponential moving average of square + # of gradient + self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + torch.rsqrt(exp_avg_sq, out=update).mul_(grad) + + update.div_(max(1.0, self._rms(update) / group["clip_threshold"])) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) + update = exp_avg + + if group["weight_decay"] != 0: + p.data.add_(p.data, alpha=-group["weight_decay"] * lr) + + p.data.add_(-update) + + return loss diff --git a/src/optim/lamb.py b/src/optim/lamb.py new file mode 100644 index 0000000..baa63bf --- /dev/null +++ b/src/optim/lamb.py @@ -0,0 +1,126 @@ +""" +Here is an official implementation of LAMB. +Source: https://github.com/cybertronai/pytorch-lamb +""" + +import math + +import torch + + +class Lamb(torch.optim.Optimizer): + r"""Implements Lamb algorithm. + + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0, + adam=False, + bias_correction=False, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.adam = adam + self.bias_correction = bias_correction + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + if self.bias_correction: + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) + else: + step_size = group["lr"] + + weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state["weight_norm"] = weight_norm + state["adam_norm"] = adam_norm + state["trust_ratio"] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss From 7ece053646581a0431a78a0b63a60aeae1fed5a4 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sat, 23 Nov 2024 20:43:06 +0300 Subject: [PATCH 55/58] --fix adopt --- src/main.py | 2 +- src/optim/adopt.py | 111 ++++++++++++++++++--------------------------- src/optim/base.py | 3 +- src/optim/utils.py | 3 -- 4 files changed, 46 insertions(+), 73 deletions(-) diff --git a/src/main.py b/src/main.py index 5df2c62..689dca5 100755 --- a/src/main.py +++ b/src/main.py @@ -9,10 +9,10 @@ import numpy as np import torch +import wandb import config import distributed -import wandb from data.utils import DataReader, get_dataset from models.utils import get_model from optim.adafactor import Adafactor diff --git a/src/optim/adopt.py b/src/optim/adopt.py index b91f32b..aa6dc16 100644 --- a/src/optim/adopt.py +++ b/src/optim/adopt.py @@ -6,90 +6,67 @@ import torch -class ADOPT(torch.optim.Optimizer): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=0.0, - ): - """ - Args: - params: iterable of parameters to optimize or dictionaries defining parameter groups. - lr: learning rate. - betas: coefficients used for computing running averages of gradient and gradient squared. - eps: term added to the denominator to improve numerical stability. - weight_decay: weight decay. - decoupled: whether to use decoupled weight decay. - """ - if lr < 0.0: - raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= betas[0] < 1.0 or not 0.0 <= betas[1] <= 1.0: - raise ValueError(f"Invalid beta values: {betas}") - if eps <= 0.0: - raise ValueError(f"Invalid epsilon value: {eps}") - if weight_decay < 0.0: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") +def exists(val): + return val is not None - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - ) + +class ADOPT(torch.optim.Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(ADOPT, self).__init__(params, defaults) + self.eps = eps + + def __setstate__(self, state): + super(ADOPT, self).__setstate__(state) @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step.""" + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + grad_scaler=None, + ): loss = None - if closure is not None: + if exists(closure): with torch.enable_grad(): loss = closure() for group in self.param_groups: - for p in group["params"]: + for p in filter(lambda p: exists(p.grad), group["params"]): if p.grad is None: continue - - grad = p.grad - if grad.is_sparse: - raise RuntimeError("ADOPT does not support sparse gradients") - - if group["weight_decay"] != 0: - grad = grad.add(p, alpha=group["weight_decay"]) - + grad = p.grad.data + grad.add_(p.data, alpha=group["weight_decay"]) state = self.state[p] - - # State initialization if len(state) == 0: state["step"] = 0 - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - state["step"] += 1 - + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = grad.mul(grad) + continue exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if "step" in state: + state["step"] += 1 + else: + state["step"] = 1 beta1, beta2 = group["betas"] - eps = group["eps"] - - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] + denom = torch.maximum(exp_avg_sq.sqrt(), torch.tensor(self.eps)) + if state["step"] == 1: + exp_avg = grad.div(denom) + else: + exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) + p.data.add_(state["exp_avg"], alpha=-group["lr"]) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - exp_avg_sq = exp_avg_sq.div(bias_correction2).sqrt() - denom = torch.maximum(exp_avg_sq, torch.tensor(eps, device=grad.device)) - - exp_avg.mul_(beta1).add_(grad.div(denom), alpha=1 - beta1) - - step_size = group["lr"] * (exp_avg.div(bias_correction1)) - p.add_(-step_size) - return loss diff --git a/src/optim/base.py b/src/optim/base.py index 021690b..c992b56 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -4,9 +4,8 @@ from pathlib import Path import torch -import yaml - import wandb +import yaml # from logger.logger import DynamicsLogger from .utils import (eval, get_batch, load_checkpoint, load_worker_state, diff --git a/src/optim/utils.py b/src/optim/utils.py index 1bdd01e..6db4c39 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -6,9 +6,6 @@ import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F - -import wandb def get_batch(datareader, device="cpu"): From 10a2a3a57720219f99777596eca6b49815e8a75f Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sun, 24 Nov 2024 19:06:36 +0300 Subject: [PATCH 56/58] muon-debug --- src/main.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/main.py b/src/main.py index 689dca5..74ba249 100755 --- a/src/main.py +++ b/src/main.py @@ -153,14 +153,30 @@ def main(args, parser): correct_bias=args.correct_bias, ) elif args.opt == "muon": - param_groups_2d, param_groups_non2d, _, _ = separate_params(group_specs) + param_list = list(model.parameters()) if args.distributed_backend is None else list(model.module.parameters()) + assert sum(p.numel() for p in param_list) == params_cnt, "number of parameters must be the same" + # param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count = separate_params(param_list) #group_specs + # if args.distributed_backend is not None: + # param_groups_non2d.extend(list(model.module.lm_head.parameters())) + # else: + # param_groups_non2d.extend(list(model.lm_head.parameters())) + # print("param_groups_non2d", param_groups_non2d) + # print("param_groups_2d", param_groups_2d) + # print("2D params", total_param_2d_count) + # print("non-2D params", total_param_non2d_count) + # muon_params = [p for p in model.parameters() if p.ndim >=2] + # adamw_params = [p for p in model.parameters() if p.ndim < 2] + # adamw_params.extend(model.lm_head.parameters()) + # adamw_params.extend(model.embed.parameters()) + # print(model) + # print(len(muon_params), len(adamw_params)) opt = Muon( - muon_params=param_groups_2d[0]["params"], + muon_params=param_list,#param_groups_2d[0]["params"], lr=args.muon_lr_factor, # since adamw_lr_ration = adamw_lr / muon_lr momentum=args.momentum, nesterov=args.nesterov, # always use nesterov momentum for Muon ns_steps=args.muon_ns_steps, - adamw_params=param_groups_non2d[0]["params"], + adamw_params=None,#param_groups_non2d[0]["params"], adamw_lr=args.lr, adamw_betas=(args.beta1, args.beta2), adamw_eps=1e-8, From 4f5c061e11a59055c6c7488246c494520bab26f3 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Mon, 25 Nov 2024 22:43:13 +0300 Subject: [PATCH 57/58] --signum fix --- src/main.py | 39 ++++++++++++++------------------------- src/optim/base.py | 3 ++- src/optim/sign.py | 2 +- 3 files changed, 17 insertions(+), 27 deletions(-) diff --git a/src/main.py b/src/main.py index 74ba249..4908696 100755 --- a/src/main.py +++ b/src/main.py @@ -9,10 +9,10 @@ import numpy as np import torch -import wandb import config import distributed +import wandb from data.utils import DataReader, get_dataset from models.utils import get_model from optim.adafactor import Adafactor @@ -24,11 +24,9 @@ from optim.clipped import (AdagradClip, AdaGradClipDelayedEta, AdamClip, AdamClipDelayedEta) from optim.lamb import Lamb -# from optim.distributed_shampoo.distributed_shampoo import DistributedShampoo -# from optim.distributed_shampoo.shampoo_types import AdamGraftingConfig from optim.lion import Lion from optim.mars import MARS -from optim.muon import CombinedScheduler, Muon, separate_params +from optim.muon import CombinedScheduler, Muon from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, dd_schedule, wsd_schedule) @@ -153,30 +151,21 @@ def main(args, parser): correct_bias=args.correct_bias, ) elif args.opt == "muon": - param_list = list(model.parameters()) if args.distributed_backend is None else list(model.module.parameters()) - assert sum(p.numel() for p in param_list) == params_cnt, "number of parameters must be the same" - # param_groups_2d, param_groups_non2d, total_param_2d_count, total_param_non2d_count = separate_params(param_list) #group_specs - # if args.distributed_backend is not None: - # param_groups_non2d.extend(list(model.module.lm_head.parameters())) - # else: - # param_groups_non2d.extend(list(model.lm_head.parameters())) - # print("param_groups_non2d", param_groups_non2d) - # print("param_groups_2d", param_groups_2d) - # print("2D params", total_param_2d_count) - # print("non-2D params", total_param_non2d_count) - # muon_params = [p for p in model.parameters() if p.ndim >=2] - # adamw_params = [p for p in model.parameters() if p.ndim < 2] - # adamw_params.extend(model.lm_head.parameters()) - # adamw_params.extend(model.embed.parameters()) - # print(model) - # print(len(muon_params), len(adamw_params)) + param_list = ( + list(model.parameters()) + if args.distributed_backend is None + else list(model.module.parameters()) + ) + assert ( + sum(p.numel() for p in param_list) == params_cnt + ), "number of parameters must be the same" opt = Muon( - muon_params=param_list,#param_groups_2d[0]["params"], - lr=args.muon_lr_factor, # since adamw_lr_ration = adamw_lr / muon_lr + muon_params=param_list, + lr=args.muon_lr_factor, momentum=args.momentum, - nesterov=args.nesterov, # always use nesterov momentum for Muon + nesterov=args.nesterov, ns_steps=args.muon_ns_steps, - adamw_params=None,#param_groups_non2d[0]["params"], + adamw_params=None, adamw_lr=args.lr, adamw_betas=(args.beta1, args.beta2), adamw_eps=1e-8, diff --git a/src/optim/base.py b/src/optim/base.py index c992b56..021690b 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -4,9 +4,10 @@ from pathlib import Path import torch -import wandb import yaml +import wandb + # from logger.logger import DynamicsLogger from .utils import (eval, get_batch, load_checkpoint, load_worker_state, save_checkpoint, save_worker_state) diff --git a/src/optim/sign.py b/src/optim/sign.py index 80f8c1a..d8e9e37 100644 --- a/src/optim/sign.py +++ b/src/optim/sign.py @@ -88,7 +88,7 @@ def step(self, closure=None): state = self.state[p] if group["weight_decay"] != 0: - grad = grad.add(p, alpha=group["weight_decay"]) + p.mul_(1 - group["lr"] * group["weight_decay"]) if len(state) == 0: self._init_state(example=p, state=state) From c642f63cae093588703b178fdd6909b99dc975c7 Mon Sep 17 00:00:00 2001 From: Andron00e Date: Sat, 7 Dec 2024 21:53:18 +0300 Subject: [PATCH 58/58] normalized sgd + sophia removed hardcoded precondition frequency --- README.md | 8 +-- src/config/base.py | 1 + src/main.py | 11 ++++ src/optim/base.py | 2 +- src/optim/normalized.py | 116 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 src/optim/normalized.py diff --git a/README.md b/README.md index 73f65f8..c5f651e 100755 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # LLM-baselines -A modular codebase to experiment with transformers, inspired by NanoGPT. +A modular codebase to experiment with transformers, inspired by nanoGPT. ## Quickstart @@ -53,13 +53,13 @@ parser.add_argument('--beta1', default=0.9, type=float) # adam parameter parser.add_argument('--beta2', default=0.95, type=float) # adam parameter parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'wsd', 'cos_inf', 'none', 'dd']) parser.add_argument('--cos_inf_steps', default=0, type=int) # cos_inf scheduler -parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta', 'mars', 'adafactor', 'lamb']) +parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd', 'muon', 'soap', 'ademamix', 'ademamix2', 'lion', 'sf-adamw', 'sf-sgd', 'signsgd', 'signum', 'sgdf', 'prodigy', 'sophiag', 'shampoo', 'adopt', 'clip-adagrad', 'clip-adagrad-delay-eta', 'clip-adam', 'clip-adam-delay-eta', 'mars', 'adafactor', 'lamb', 'normalized-sgd']) parser.add_argument('--eval_freq', default=200, type=int) # in iterations parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved -parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT +parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in nanoGPT parser.add_argument('--momentum', default=0.9, type=float) parser.add_argument('--shampoo_beta', default=-1.0, type=float) -parser.add_argument('--precondition_frequency', default=10, type=int) +parser.add_argument('--precondition_frequency', default=10, type=int) #for SOAP and Sophia parser.add_argument('--max_precond_dim', default=10000, type=int) parser.add_argument('--merge_dims', default=False, type=bool) # merge dimensions till the product of the dimensions is less than or equal to max_precond_dim parser.add_argument('--precondition_1d', default=False, type=bool) diff --git a/src/config/base.py b/src/config/base.py index bdd8112..4c80165 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -108,6 +108,7 @@ def parse_args(base_parser, args, namespace): "mars", "adafactor", "lamb", + "normalized-sgd", ], ) parser.add_argument("--batch_size", default=50, type=int) diff --git a/src/main.py b/src/main.py index 4908696..b963956 100755 --- a/src/main.py +++ b/src/main.py @@ -27,6 +27,7 @@ from optim.lion import Lion from optim.mars import MARS from optim.muon import CombinedScheduler, Muon +from optim.normalized import NormalizedSGD from optim.prodigy import Prodigy from optim.schedule import (cos_inf_schedule, cosine_wsd_decay_schedule, dd_schedule, wsd_schedule) @@ -369,6 +370,16 @@ def main(args, parser): adam=False, bias_correction=args.lamb_use_bias_correction, ) + elif args.opt == "normalized-sgd": + opt = NormalizedSGD( + group_specs, + lr=args.lr, + momentum=args.momentum, + dampening=args.dampening, + weight_decay=args.weight_decay, + nesterov=args.nesterov, + sign_update=False, + ) else: opt = torch.optim.SGD( group_specs, diff --git a/src/optim/base.py b/src/optim/base.py index 021690b..19875d0 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -153,7 +153,7 @@ def train( scheduler.step() if cfg.opt == "sophiag": opt.zero_grad(set_to_none=True) - if curr_iter % 10 == 10 - 1: + if curr_iter % cfg.precondition_frequency == cfg.precondition_frequency - 1: sample_again = model(x, targets=y, get_logits=True) samp_dist = torch.distributions.Categorical( logits=sample_again["logits"] diff --git a/src/optim/normalized.py b/src/optim/normalized.py new file mode 100644 index 0000000..7a93c80 --- /dev/null +++ b/src/optim/normalized.py @@ -0,0 +1,116 @@ +""" +Version of different optimizers with normalized update +""" + +from typing import Dict + +import torch + + +class NormalizedSGD(torch.optim.Optimizer): + def __init__( + self, + params, + lr=1e-3, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + sign_update=False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + sign_update=sign_update, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def _init_state(self, example, state=None): + assert isinstance(example, torch.Tensor) + assert isinstance(state, Dict) or state is None + if state is None: + state = {} + state["step"] = 0 + state["momentum_buffer"] = torch.clone(example).detach() + return state + + @torch.no_grad() + def _compute_update( + self, grad, state, lr, momentum, nesterov, dampening, sign_update, **kwargs + ): + if momentum != 0: + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + if sign_update: + grad = grad.sign() + + return grad / (grad.norm() + 1e-8) * (-lr) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if group["weight_decay"] != 0: + p.mul_(1 - group["lr"] * group["weight_decay"]) + + if len(state) == 0: + self._init_state(example=p, state=state) + if not group["momentum"]: + state.pop("momentum_buffer", None) + + state["step"] += 1 + + update = self._compute_update( + grad, + state, + group["lr"], + group["momentum"], + group["nesterov"], + group["dampening"], + group["sign_update"], + ) + + p.add_(update) + + return loss