From 266b738f851e8adef525356be8a1f77a3ecd18c9 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 11 Mar 2024 12:40:53 +0000 Subject: [PATCH 01/39] add a run_generate for Mamba --- examples/mamba/run_generate.py | 244 +++++++++++++++++++++++++++++++++ examples/mamba/run_generate.sh | 5 + 2 files changed, 249 insertions(+) create mode 100644 examples/mamba/run_generate.py create mode 100755 examples/mamba/run_generate.sh diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py new file mode 100644 index 00000000..47f45bdd --- /dev/null +++ b/examples/mamba/run_generate.py @@ -0,0 +1,244 @@ +""" +Nanotron Inference Script + +Usage: +``` +export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations +torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4 +``` +""" + +import argparse +import os +from pathlib import Path + +import torch +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ( + GenerationArgs, + LoggingArgs, + ParallelismArgs, + get_config_from_file, +) +from nanotron.generation.decode import ( + GenerationInput, + TokenizerConfig, + decode_text, + decode_tokenized, +) +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import ( + OneForwardOneBackwardPipelineEngine, +) +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.random import ( + RandomStates, + get_current_random_state, + get_synced_random_state, + set_random_seed, +) +from nanotron.serialize import load_weights +from nanotron.trainer import CONFIG_TO_MODEL_CLASS, mark_tied_parameters +from config import MambaConfig, MambaModelConfig +from mamba import MambaForTraining + +try: + from transformers import AutoTokenizer +except ImportError: + AutoTokenizer = None + +logger = logging.get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") + parser.add_argument("--dp", type=int, default=0) + parser.add_argument("--pp", type=int, default=0) + parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") + return parser.parse_args() + + +def main(): + args = get_args() + + assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" + + config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), config_class=MambaConfig, model_config_class=MambaModelConfig) + model_config = config.model.model_config + tokenizer_path = config.tokenizer.tokenizer_name_or_path + + parallel_config = ParallelismArgs( + dp=args.dp or config.parallelism.dp, + pp=args.pp or config.parallelism.pp, + tp=args.tp or config.parallelism.tp, + pp_engine=OneForwardOneBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + # Initialise all process groups + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + # Set log levels + logging_config = LoggingArgs( + log_level="info", + log_level_replica="info", + ) + + # Set log levels + set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config) + + log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"tokenizer_path: {tokenizer_path}", logger=logger, level=logging.INFO, rank=0) + + dtype = torch.bfloat16 + + # Set random states + set_random_seed(42) + + # Get synchronized random states + if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE: + random_states = RandomStates( + {"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg)} + ) + else: + # We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER) + random_states = RandomStates({}) + + model = build_model( + model_builder=lambda: MambaForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ), + dtype=dtype, + parallel_context=parallel_context, + ) + + # Mark some parameters as tied + # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + + # Sanity check model + sanity_check(root_module=model) + + # Load checkpoint + checkpoint_path = args.ckpt_path + log_rank( + f"Loading checkpoint from {checkpoint_path}:", + logger=logger, + level=logging.INFO, + rank=0, + ) + load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) + + model.eval() + if AutoTokenizer is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + # tokenizer.pad_token_id = tokenizer.eos_token_id + if tokenizer.pad_token_id is None: + if tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + elif getattr(model.config, "pad_token_id", None) is not None: + tokenizer.pad_token_id = int(model.config.pad_token_id) + elif getattr(model.config, "eos_token_id", None) is not None: + tokenizer.pad_token_id = int(model.config.eos_token_id) + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? + dummy_inputs = [ + # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + "def fib(n)", + # "This film was probably inspired by Godzilla", + ] + + outputs = decode_text( + input_iter=(GenerationInput(text=text) for text in dummy_inputs), + tokenizer=tokenizer, + # TODO @thomasw21: From ModelWithLoss extract the model. + model=model.model, + parallel_context=parallel_context, + max_new_tokens=args.max_new_tokens, + max_micro_batch_size=2, + generation_config=GenerationArgs(sampler="greedy", use_cache=True), + tokenizer_config=TokenizerConfig(max_input_length=None), + is_bench=os.environ.get("USE_BENCH", "0") == "1", + ) + for output in outputs: + input_ids = output.input_ids + generated_ids = output.generation_ids + if isinstance(input_ids, TensorPointer): + assert isinstance(generated_ids, TensorPointer) + continue + assert isinstance(generated_ids, torch.Tensor) + + log_rank( + f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + log_rank( + f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + log_rank( + "--------------------------------------------------", + logger=logger, + level=logging.INFO, + rank=0, + ) + else: + outputs = decode_tokenized( + input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), + input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"), + model=model.model, + parallel_context=parallel_context, + generation_config=GenerationArgs(sampler="greedy", use_cache=True), + max_micro_batch_size=1, + max_new_tokens=12, + returns_logits=False, + ) + for output in outputs: + input_ids = output.input_ids + generated_ids = output.generation_ids + if isinstance(input_ids, TensorPointer): + assert isinstance(generated_ids, TensorPointer) + continue + assert isinstance(generated_ids, torch.Tensor) + log_rank( + f"generation: {generated_ids[len(input_ids) :]}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + log_rank( + "--------------------------------------------------", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/examples/mamba/run_generate.sh b/examples/mamba/run_generate.sh new file mode 100755 index 00000000..85ff1331 --- /dev/null +++ b/examples/mamba/run_generate.sh @@ -0,0 +1,5 @@ +if [ -n "$DEBUG" ]; then + debugpy-run -m torch.distributed.run -- --nproc_per_node=1 run_generate.py --ckpt-path /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints/100 +else + torchrun --nproc_per_node=1 run_generate.py --ckpt-path /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints/100 +fi \ No newline at end of file From 44e563d8033a58d3e7949372540f8f6eb94f767c Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 11 Mar 2024 13:53:33 +0000 Subject: [PATCH 02/39] add converter hf to nanotron --- examples/mamba/config.py | 4 +- examples/mamba/convert_hf_to_nanotron.py | 261 +++++++++++++++++++++++ examples/mamba/run_generate.py | 5 +- 3 files changed, 267 insertions(+), 3 deletions(-) create mode 100644 examples/mamba/convert_hf_to_nanotron.py diff --git a/examples/mamba/config.py b/examples/mamba/config.py index c7bdc7b9..a599620a 100644 --- a/examples/mamba/config.py +++ b/examples/mamba/config.py @@ -11,8 +11,8 @@ class MambaInit: # mamba_ssm.models.mixer_seq_simple._init_weights initializer_range: float = 0.02 - rescale_prenorm_residual: bool = (True,) - n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP + rescale_prenorm_residual: bool = True + n_residuals_per_layer: int = 1 # Change to 2 if we have MLP @dataclass diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py new file mode 100644 index 00000000..3e1a1335 --- /dev/null +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -0,0 +1,261 @@ +# ruff: noqa: E402 +""" +Converts a HF model from (https://huggingface.co/state-spaces/) to a Brrr model + +Command: + torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --model 130M --save_path nanotron-weights +""" +import argparse +import torch +import yaml +from pathlib import Path +from tqdm import tqdm +from typing import Dict +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +import lovely_tensors as lt; lt.monkey_patch() + +from nanotron.config import ( + AllForwardAllBackwardPipelineEngine, + ParallelismArgs, + TensorParallelLinearMode, +) +from config import MambaModelConfig, MambaConfig, MambaInit +from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs + +from nanotron.distributed import dist +from nanotron.helpers import _vocab_size_with_padding +from nanotron.models import build_model +from mamba import MambaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter, sanity_check +from nanotron.serialize import save_meta, save_weights +from nanotron.trainer import mark_tied_parameters + + +def get_weight_from_hf( + name: str, + ref_module_state_dict: Dict[str, torch.Tensor], + ref_module: MambaLMHeadModel, + nanotron_to_hf: Dict[str, str], + get_grad: bool = False +) -> torch.Tensor: + """From our brrr implementation, we get the equivalent tensor in transformers implementation""" + + hf_name = nanotron_to_hf[name] + + if get_grad is False: + + def get_tensor(path: str): + return ref_module_state_dict[path] + else: + + def get_tensor(path: str): + weight = ref_module.get_parameter(path) + return weight.grad + + return get_tensor(hf_name) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert HF weights from states-space repo to brrr weights") + parser.add_argument("--model", type=str, default="130M", help="130M | 370M | 790M | 1.4B | 2.8B") + parser.add_argument("--save_path", type=str, default="mamba-nanotron") + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--tp", type=int, default=1) + args = parser.parse_args() + + if args.model not in ["130M", "370M", "790M", "1.4B", "2.8B"]: + raise ValueError("Model should be one of 130M, 370M, 790M, 1.4B, 2.8B") + + if args.tp > 1: + raise ValueError("Tensor parallelism not supported yet (as A_log is nn.Parameter which is not a NanotronParameter and thus cannot be marked as is_sharded)") + + save_path = Path(args.save_path) + + #TODO: Do it this way so that we can choose the dp pp and tp we want + # https://github.com/huggingface/brrr/blob/main/legacy_examples/starcoder2/convert_brrr_to_trfrs.py + parallel_config = ParallelismArgs( + dp=args.dp, + pp=args.pp, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + d_model = None + num_hidden_layers = None + pretrained_model_name = None + + if args.model == "130M": + d_model = 768 + num_hidden_layers = 24 + pretrained_model_name = "state-spaces/mamba-130m" + elif args.model == "370M": + d_model = 1024 + num_hidden_layers = 48 + pretrained_model_name = "state-spaces/mamba-370m" + elif args.model == "790M": + d_model = 1536 + num_hidden_layers = 24 + pretrained_model_name = "state-spaces/mamba-790m" + elif args.model == "1.4B": + d_model = 2048 + num_hidden_layers = 48 + pretrained_model_name = "state-spaces/mamba-1.4b" + elif args.model == "2.8B": + d_model = 2560 + num_hidden_layers = 64 + pretrained_model_name = "state-spaces/mamba-2.8b" + + yaml_content = f""" + is_mamba_config: true + d_model: {d_model} + dtype: bfloat16 + fused_add_norm: true + is_mamba_config: true + num_hidden_layers: {num_hidden_layers} + pad_token_id: null + pad_vocab_size_multiple: 8 + residual_in_fp32: true + rms_norm: true + rms_norm_eps: 1.0e-05 + ssm_cfg: null + vocab_size: 50277 + """ + + str_to_dtype = { + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "uint8": torch.uint8, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, + } + + attrs = yaml.safe_load(yaml_content) + model_config = MambaModelConfig(**attrs) + + # # Initiliaze Brrr model + # model_config.vocab_size = _vocab_size_with_padding( + # model_config.vocab_size, + # pg_size=parallel_context.tp_pg.size(), + # make_vocab_size_divisible_by=1, + # ) + + nanotron_model = build_model( + model_builder=lambda: MambaForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=str_to_dtype[model_config.dtype], + device=torch.device("cpu") + ) + + device_map = {} + current_pp_rank = dist.get_rank(parallel_context.pp_pg) + + tied_embs_ranks = [nanotron_model.model.token_position_embeddings.rank, nanotron_model.model.lm_head.rank] + + device_map["backbone.embedding"] = ( + nanotron_model.model.token_position_embeddings.rank if current_pp_rank in tied_embs_ranks else "meta" + ) + + for i in range(model_config.num_hidden_layers): + device_map[f"backbone.layers[{i}]"] = ( + nanotron_model.model.decoder[i].rank if current_pp_rank == nanotron_model.model.decoder[i].rank else "meta" + ) + + device_map["lm_head"] = nanotron_model.model.lm_head.rank if current_pp_rank in tied_embs_ranks else "meta" + + model_ref = MambaLMHeadModel.from_pretrained(pretrained_model_name, device="cpu", dtype=str_to_dtype[model_config.dtype]) + + # Create a mapping from nanotron to hf + nanotron_to_hf = {} + + for i in range(model_config.num_hidden_layers): + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.A_log'] = f'backbone.layers.{i}.mixer.A_log' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.D'] = f'backbone.layers.{i}.mixer.D' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.in_proj.weight'] = f'backbone.layers.{i}.mixer.in_proj.weight' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.conv1d.weight'] = f'backbone.layers.{i}.mixer.conv1d.weight' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.conv1d.bias'] = f'backbone.layers.{i}.mixer.conv1d.bias' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.x_proj.weight'] = f'backbone.layers.{i}.mixer.x_proj.weight' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.x_proj.bias'] = f'backbone.layers.{i}.mixer.x_proj.bias' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.dt_proj.weight'] = f'backbone.layers.{i}.mixer.dt_proj.weight' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.dt_proj.bias'] = f'backbone.layers.{i}.mixer.dt_proj.bias' + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.out_proj.weight'] = f'backbone.layers.{i}.mixer.out_proj.weight' + #TODO: Maybe check if bias exists? + nanotron_to_hf[f'decoder.{i}.pp_block.mixer.out_proj.bias'] = f'backbone.layers.{i}.mixer.out_proj.bias' + nanotron_to_hf[f'decoder.{i}.pp_block.norm.weight'] = f'backbone.layers.{i}.norm.weight' + + nanotron_to_hf['token_position_embeddings.pp_block.token_embedding.weight'] = 'backbone.embedding.weight' + nanotron_to_hf['final_layer_norm.pp_block.weight'] = 'backbone.norm_f.weight' + nanotron_to_hf['lm_head.pp_block.weight'] = 'lm_head.weight' + + # Sync weights + ref_state_dict = model_ref.state_dict() + for name, param in tqdm(nanotron_model.model.named_parameters(), total=len(list(nanotron_model.model.named_parameters())), desc="Converting"): + ref_param = get_weight_from_hf(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref, nanotron_to_hf=nanotron_to_hf) + + param_is_tp_sharded = ( + isinstance(param, NanotronParameter) + and param.is_sharded + and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg + ) + + if param_is_tp_sharded: + sharded_info = param.get_sharded_info() + # copy param data (not just the reference) + with torch.no_grad(): + for local_global_slices_pair in sharded_info.local_global_slices_pairs: + local_slices = local_global_slices_pair.local_slices + global_slices = local_global_slices_pair.global_slices + param[local_slices].copy_(ref_param[global_slices]) + else: + assert ( + ref_param.shape == param.shape + ), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}" + # copy param data (not just the reference) + with torch.no_grad(): + param.copy_(ref_param) + ref_param = None + torch.cuda.empty_cache() + + # Marks parameters as NanotronParameters + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + + sanity_check(root_module=nanotron_model) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) + checkpoint_metadata = { + "last_train_step": 0, + "consumed_train_samples": 0, + } + save_meta(root_folder=save_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) + + with open(save_path / "config.yaml", "w") as f: + config = MambaConfig( + general=GeneralArgs(project="test", run="mamba"), + parallelism=parallel_config, + model=ModelArgs( + init_method=MambaInit(), + model_config=model_config, + ), + tokenizer=TokenizerArgs("EleutherAI/gpt-neox-20b"), + ) + print("Saving config ...") + yaml.dump(config.as_dict(), f) \ No newline at end of file diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 47f45bdd..a2e61e45 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -73,6 +73,8 @@ def main(): config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), config_class=MambaConfig, model_config_class=MambaModelConfig) model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path + + assert "EleutherAI/gpt-neox-20b" == tokenizer_path; f"Should be EleutherAI/gpt-neox-20b tokenizer and not '{tokenizer_path}'" parallel_config = ParallelismArgs( dp=args.dp or config.parallelism.dp, @@ -161,8 +163,9 @@ def main(): tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? dummy_inputs = [ # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", + # "def fib(n)", # "This film was probably inspired by Godzilla", + "Hello" ] outputs = decode_text( From cdb73a9076d3d9869806be5420da69c97364861e Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 11 Mar 2024 14:51:11 +0000 Subject: [PATCH 03/39] add sanity check weights for converter --- examples/mamba/convert_hf_to_nanotron.py | 50 +++++++++++++++++++----- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 3e1a1335..885eeb53 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -32,6 +32,33 @@ from nanotron.trainer import mark_tied_parameters +def sanity_check_weights(model, model_ref): + + def sort_key(name_param_pair): + name, _ = name_param_pair + # Split the name and take the last part as the key for sorting + return name.split('.')[-1] + + total, fail = 0, 0 + + for (name_ref, param_ref), (name, param) in zip( + sorted(model_ref.named_parameters(), key=sort_key), + sorted(model.model.named_parameters(), key=sort_key) + ): + + total += 1 + try: + torch.testing.assert_allclose(param_ref, param, rtol=1e-10, atol=1e-10) + except AssertionError as e: + print(f"{name_ref} and {name} are not equal") + fail += 1 + + print(f"{fail}/{total} parameters are not equal") + + if fail > 0: + raise AssertionError("Some parameters are not equal") + + def get_weight_from_hf( name: str, ref_module_state_dict: Dict[str, torch.Tensor], @@ -117,7 +144,7 @@ def get_tensor(path: str): yaml_content = f""" is_mamba_config: true d_model: {d_model} - dtype: bfloat16 + dtype: float32 fused_add_norm: true is_mamba_config: true num_hidden_layers: {num_hidden_layers} @@ -147,13 +174,15 @@ def get_tensor(path: str): attrs = yaml.safe_load(yaml_content) model_config = MambaModelConfig(**attrs) - - # # Initiliaze Brrr model - # model_config.vocab_size = _vocab_size_with_padding( - # model_config.vocab_size, - # pg_size=parallel_context.tp_pg.size(), - # make_vocab_size_divisible_by=1, - # ) + + assert model_config.dtype == "float32", "Convert weights only in float32" + + # Initiliaze Brrr model + model_config.vocab_size = _vocab_size_with_padding( + model_config.vocab_size, + pg_size=parallel_context.tp_pg.size(), + make_vocab_size_divisible_by=4, # 50277 -> 50280 + ) nanotron_model = build_model( model_builder=lambda: MambaForTraining( @@ -240,6 +269,9 @@ def get_tensor(path: str): mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) sanity_check(root_module=nanotron_model) + + sanity_check_weights(model=nanotron_model, model_ref=model_ref) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) checkpoint_metadata = { "last_train_step": 0, @@ -255,7 +287,7 @@ def get_tensor(path: str): init_method=MambaInit(), model_config=model_config, ), - tokenizer=TokenizerArgs("EleutherAI/gpt-neox-20b"), + tokenizer=TokenizerArgs(pretrained_model_name + "-hf"), ) print("Saving config ...") yaml.dump(config.as_dict(), f) \ No newline at end of file From 8bb8e02cdd4608de08013c0607419aced2cd8804 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 11 Mar 2024 15:29:46 +0000 Subject: [PATCH 04/39] set reference in run generate --- examples/mamba/run_generate.py | 81 ++++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index a2e61e45..01ff43fb 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -54,6 +54,37 @@ logger = logging.get_logger(__name__) +# from transformers import MambaForCausalLM +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + +import lovely_tensors as lt; lt.monkey_patch() + +def sanity_check_weights(model, model_ref): + + def sort_key(name_param_pair): + name, _ = name_param_pair + # Split the name and take the last part as the key for sorting + return name.split('.')[-1] + + total, fail = 0, 0 + + for (name_ref, param_ref), (name, param) in zip( + sorted(model_ref.named_parameters(), key=sort_key), + sorted(model.model.named_parameters(), key=sort_key) + ): + + total += 1 + try: + torch.testing.assert_close(param_ref, param, rtol=1e-10, atol=1e-10) + except AssertionError as e: + print(f"{name_ref} and {name} are not equal. Error: {e}") + fail += 1 + + print(f"{fail}/{total} parameters are not equal.") + + if fail > 0: + raise AssertionError("Some parameters are not equal") def get_args(): parser = argparse.ArgumentParser() @@ -74,7 +105,7 @@ def main(): model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path - assert "EleutherAI/gpt-neox-20b" == tokenizer_path; f"Should be EleutherAI/gpt-neox-20b tokenizer and not '{tokenizer_path}'" + assert "state-spaces/mamba-130m-hf" == tokenizer_path; f"Should be 'state-spaces/mamba-130m-hf' tokenizer and not '{tokenizer_path}'" parallel_config = ParallelismArgs( dp=args.dp or config.parallelism.dp, @@ -104,8 +135,6 @@ def main(): log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0) log_rank(f"tokenizer_path: {tokenizer_path}", logger=logger, level=logging.INFO, rank=0) - dtype = torch.bfloat16 - # Set random states set_random_seed(42) @@ -117,6 +146,21 @@ def main(): else: # We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER) random_states = RandomStates({}) + + str_to_dtype = { + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "uint8": torch.uint8, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, + } model = build_model( model_builder=lambda: MambaForTraining( @@ -125,9 +169,13 @@ def main(): parallel_config=parallel_config, random_states=random_states, ), - dtype=dtype, + dtype=str_to_dtype[model_config.dtype], parallel_context=parallel_context, ) + + assert str_to_dtype[model_config.dtype] == torch.float32, f"Model dtype {str_to_dtype[model_config.dtype]} should be torch.float32" + + model_ref = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m").to("cuda") # Mark some parameters as tied # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? @@ -146,7 +194,11 @@ def main(): ) load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) + sanity_check_weights(model=model, model_ref=model_ref) + model.eval() + model_ref.eval() + if AutoTokenizer is not None: tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # tokenizer.pad_token_id = tokenizer.eos_token_id @@ -208,6 +260,27 @@ def main(): level=logging.INFO, rank=0, ) + + # Model ref + tokens = tokenizer(dummy_inputs, return_tensors="pt") + input_ids = tokens.input_ids.to(device="cuda") + + output_ref = model_ref.generate( + input_ids=input_ids, + max_length=args.max_new_tokens, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=False, + temperature=1.0, + top_k=1, + top_p=1.0, + min_p=0.0, + repetition_penalty=1.0 + ) + + log_rank(f"input REF: {tokenizer.decode(input_ids[0], clean_up_tokenization_spaces=False)}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"generation REF: {tokenizer.batch_decode(output_ref.sequences.tolist())}", logger=logger, level=logging.INFO, rank=0) else: outputs = decode_tokenized( input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), From d39217e68de75a1e29d98654260418bbb191afb5 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 12 Mar 2024 16:51:39 +0000 Subject: [PATCH 05/39] can now load checkpoint from HF model with different TP --- examples/mamba/convert_hf_to_nanotron.py | 61 +++++++++++++++--------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 885eeb53..eaf2cf56 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -30,34 +30,48 @@ from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.serialize import save_meta, save_weights from nanotron.trainer import mark_tied_parameters +from nanotron import logging +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.config import LoggingArgs +logger = logging.get_logger(__name__) -def sanity_check_weights(model, model_ref): - - def sort_key(name_param_pair): + +def sanity_check_weights(model, model_ref, tp_size): + def _sort_key(name_param_pair): name, _ = name_param_pair # Split the name and take the last part as the key for sorting return name.split('.')[-1] + def _split_weight(data: torch.Tensor, dim: int) -> torch.Tensor: + rank = dist.get_rank() + world_size = dist.get_world_size() + chunks = torch.chunk(data, world_size, dim=dim) + return chunks[rank].contiguous() + total, fail = 0, 0 for (name_ref, param_ref), (name, param) in zip( - sorted(model_ref.named_parameters(), key=sort_key), - sorted(model.model.named_parameters(), key=sort_key) + sorted(model_ref.named_parameters(), key=_sort_key), + sorted(model.model.named_parameters(), key=_sort_key) ): total += 1 try: - torch.testing.assert_allclose(param_ref, param, rtol=1e-10, atol=1e-10) + param_shard_ref = param_ref + if isinstance(param, NanotronParameter) and param.is_sharded and tp_size > 1: + dim = next(index for index, (dim1, dim2) in enumerate(zip(param.shape, param_ref.shape)) if dim1 != dim2) + param_shard_ref = _split_weight(param_ref, dim) + + torch.testing.assert_close(param_shard_ref, param, rtol=1e-10, atol=1e-10) except AssertionError as e: - print(f"{name_ref} and {name} are not equal") + log_rank(f"{name_ref} and {name} are not equal. {e}", logger=logger, level=logging.INFO, rank=0) fail += 1 - print(f"{fail}/{total} parameters are not equal") + log_rank(f"{fail}/{total} parameters are not equal", logger=logger, level=logging.INFO, rank=0) if fail > 0: raise AssertionError("Some parameters are not equal") - def get_weight_from_hf( name: str, @@ -66,12 +80,10 @@ def get_weight_from_hf( nanotron_to_hf: Dict[str, str], get_grad: bool = False ) -> torch.Tensor: - """From our brrr implementation, we get the equivalent tensor in transformers implementation""" - + """From our brrr implementation, we get the equivalent tensor in transformers implementation""" hf_name = nanotron_to_hf[name] if get_grad is False: - def get_tensor(path: str): return ref_module_state_dict[path] else: @@ -79,7 +91,7 @@ def get_tensor(path: str): def get_tensor(path: str): weight = ref_module.get_parameter(path) return weight.grad - + return get_tensor(hf_name) if __name__ == "__main__": @@ -94,13 +106,8 @@ def get_tensor(path: str): if args.model not in ["130M", "370M", "790M", "1.4B", "2.8B"]: raise ValueError("Model should be one of 130M, 370M, 790M, 1.4B, 2.8B") - if args.tp > 1: - raise ValueError("Tensor parallelism not supported yet (as A_log is nn.Parameter which is not a NanotronParameter and thus cannot be marked as is_sharded)") - save_path = Path(args.save_path) - #TODO: Do it this way so that we can choose the dp pp and tp we want - # https://github.com/huggingface/brrr/blob/main/legacy_examples/starcoder2/convert_brrr_to_trfrs.py parallel_config = ParallelismArgs( dp=args.dp, pp=args.pp, @@ -116,6 +123,15 @@ def get_tensor(path: str): tensor_parallel_size=parallel_config.tp, ) + # Set log log levels + logging_config = LoggingArgs( + log_level="info", + log_level_replica="info", + ) + + # Set log levels + set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config) + d_model = None num_hidden_layers = None pretrained_model_name = None @@ -171,6 +187,7 @@ def get_tensor(path: str): "int64": torch.int64, "bool": torch.bool, } + device = torch.device("cuda") attrs = yaml.safe_load(yaml_content) model_config = MambaModelConfig(**attrs) @@ -193,7 +210,7 @@ def get_tensor(path: str): ), parallel_context=parallel_context, dtype=str_to_dtype[model_config.dtype], - device=torch.device("cpu") + device=device ) device_map = {} @@ -212,7 +229,7 @@ def get_tensor(path: str): device_map["lm_head"] = nanotron_model.model.lm_head.rank if current_pp_rank in tied_embs_ranks else "meta" - model_ref = MambaLMHeadModel.from_pretrained(pretrained_model_name, device="cpu", dtype=str_to_dtype[model_config.dtype]) + model_ref = MambaLMHeadModel.from_pretrained(pretrained_model_name, device=device, dtype=str_to_dtype[model_config.dtype]) # Create a mapping from nanotron to hf nanotron_to_hf = {} @@ -270,7 +287,7 @@ def get_tensor(path: str): sanity_check(root_module=nanotron_model) - sanity_check_weights(model=nanotron_model, model_ref=model_ref) + sanity_check_weights(model=nanotron_model, model_ref=model_ref, tp_size=parallel_context.tp_pg.size()) save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) checkpoint_metadata = { @@ -289,5 +306,5 @@ def get_tensor(path: str): ), tokenizer=TokenizerArgs(pretrained_model_name + "-hf"), ) - print("Saving config ...") + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) yaml.dump(config.as_dict(), f) \ No newline at end of file From 4bb1a7164d0cb8a1b56a7efdd5ad56360c409db0 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 12 Mar 2024 17:04:42 +0000 Subject: [PATCH 06/39] fix bug residual which was added twice --- examples/mamba/mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 5065ed53..f244f699 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -500,7 +500,7 @@ def forward( hidden_states, self.norm.weight, self.norm.bias, - residual=residual, + residual=None if (self.layer_idx == 0) else residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, From 3360de27e0c7c3491a9cdb6ce868c1e4c99d7fe2 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 13 Mar 2024 17:38:22 +0000 Subject: [PATCH 07/39] interleaved some weights to make HF checkpoints compatible with nanotron --- examples/mamba/convert_hf_to_nanotron.py | 45 +++++++++++++++++++----- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index eaf2cf56..2c6fac09 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) -def sanity_check_weights(model, model_ref, tp_size): +def sanity_check_weights(model, model_ref, tp_size): def _sort_key(name_param_pair): name, _ = name_param_pair # Split the name and take the last part as the key for sorting @@ -49,7 +49,7 @@ def _split_weight(data: torch.Tensor, dim: int) -> torch.Tensor: chunks = torch.chunk(data, world_size, dim=dim) return chunks[rank].contiguous() - total, fail = 0, 0 + total, fail, excluded = 0, 0, 0 for (name_ref, param_ref), (name, param) in zip( sorted(model_ref.named_parameters(), key=_sort_key), @@ -63,11 +63,17 @@ def _split_weight(data: torch.Tensor, dim: int) -> torch.Tensor: dim = next(index for index, (dim1, dim2) in enumerate(zip(param.shape, param_ref.shape)) if dim1 != dim2) param_shard_ref = _split_weight(param_ref, dim) + if "in_proj" in name_ref: + # Don't check this weight as we changed it manually (interleaved) + excluded += 1 + continue + torch.testing.assert_close(param_shard_ref, param, rtol=1e-10, atol=1e-10) except AssertionError as e: log_rank(f"{name_ref} and {name} are not equal. {e}", logger=logger, level=logging.INFO, rank=0) fail += 1 + log_rank(f"{excluded}/{total} parameters were not sanity check (interleaved)", logger=logger, level=logging.INFO, rank=0) log_rank(f"{fail}/{total} parameters are not equal", logger=logger, level=logging.INFO, rank=0) if fail > 0: @@ -81,18 +87,39 @@ def get_weight_from_hf( get_grad: bool = False ) -> torch.Tensor: """From our brrr implementation, we get the equivalent tensor in transformers implementation""" + + def _interleave_pattern(N): + """ + interleave_pattern(4) -> [0, 2, 1, 3] + interleave_pattern(8) -> [0, 4, 1, 5, 2, 6, 3, 7] + """ + assert N % 2 == 0, "N must be even" + pattern = [] + for i in range(N // 2): + pattern.append(i) + pattern.append(i + N // 2) + return pattern + hf_name = nanotron_to_hf[name] - + if get_grad is False: - def get_tensor(path: str): + def _get_tensor(path: str): return ref_module_state_dict[path] else: - - def get_tensor(path: str): - weight = ref_module.get_parameter(path) - return weight.grad + def _get_tensor(path: str): + param = ref_module.get_parameter(path) + return param.grad + + param = _get_tensor(hf_name) - return get_tensor(hf_name) + if "in_proj" in hf_name: + # In Nanotron, we do tensor parallel column so weight need to be split in the column dimension (i.e: xz.view(...)) + # However, the HF weights was trained such that it expected xz.chunk(...) to split the tensor in the row dimension + # Thus, we need to interleaved the HF weights to make it compatible with Nanotron + log_rank(f"Interleaving {hf_name} to make it compatible with Nanotron", logger=logger, level=logging.INFO, rank=0) + param = param[_interleave_pattern(param.shape[0]), :] + + return param if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert HF weights from states-space repo to brrr weights") From 2c475eba81684276e35a99299088f2f6a2cd4d18 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 13 Mar 2024 18:07:28 +0000 Subject: [PATCH 08/39] make sure vocab is 50280 for every value of TP --- examples/mamba/convert_hf_to_nanotron.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 2c6fac09..d9d90d52 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -225,9 +225,11 @@ def _get_tensor(path: str): model_config.vocab_size = _vocab_size_with_padding( model_config.vocab_size, pg_size=parallel_context.tp_pg.size(), - make_vocab_size_divisible_by=4, # 50277 -> 50280 + make_vocab_size_divisible_by=5, # So that every value of TP from 1 to 8 yield a vocab_size of 50280 ) + model_ref = MambaLMHeadModel.from_pretrained(pretrained_model_name, device=device, dtype=str_to_dtype[model_config.dtype]) + nanotron_model = build_model( model_builder=lambda: MambaForTraining( config=model_config, @@ -256,8 +258,6 @@ def _get_tensor(path: str): device_map["lm_head"] = nanotron_model.model.lm_head.rank if current_pp_rank in tied_embs_ranks else "meta" - model_ref = MambaLMHeadModel.from_pretrained(pretrained_model_name, device=device, dtype=str_to_dtype[model_config.dtype]) - # Create a mapping from nanotron to hf nanotron_to_hf = {} From 737ba3ac401082e9650e361fc764106a8d57fbc4 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 13 Mar 2024 18:23:42 +0000 Subject: [PATCH 09/39] remove environement variable from Mamba --- examples/mamba/mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index f244f699..66955c60 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -220,7 +220,7 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p # In the backward pass we write dx and dz next to each other to avoid torch.cat if ( - self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1" + self.use_fast_path and inference_params is None ): # Doesn't support outputting the states y = mamba_inner_fn( d_inner=self.d_inner, From b62392f02089102321d6de1c3aa2a87225a0444d Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 13 Mar 2024 18:27:32 +0000 Subject: [PATCH 10/39] cleaner sharding of weights (they are now NanotronParameter and thus marked as sharded) --- examples/mamba/mamba.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 66955c60..b890ff30 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -133,6 +133,10 @@ def __init__( **factory_kwargs, ) + self.conv1d.weight = create_sharded_parameter_from_config(parameter=self.conv1d.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + if conv_bias: + self.conv1d.bias = create_sharded_parameter_from_config(parameter=self.conv1d.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.activation = "silu" self.act = nn.SiLU() @@ -148,16 +152,6 @@ def __init__( self.dt_proj = nn.Linear(self.dt_rank, self.d_inner // self.tp_pg.size(), bias=True, **factory_kwargs) - # Initialize special dt projection to preserve variance at initialization - # Perform in `def init_model_randomly` - # dt_init_std = self.dt_rank**-0.5 * dt_scale - # if dt_init == "constant": - # nn.init.constant_(self.dt_proj.weight, dt_init_std) - # elif dt_init == "random": - # nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - # else: - # raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(self.d_inner // self.tp_pg.size(), **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) @@ -170,6 +164,9 @@ def __init__( # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True + self.dt_proj.weight = create_sharded_parameter_from_config(parameter=self.dt_proj.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.dt_proj.bias = create_sharded_parameter_from_config(parameter=self.dt_proj.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + # S4D real initialization A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), @@ -177,11 +174,11 @@ def __init__( d=self.d_inner // self.tp_pg.size(), ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) + self.A_log = create_sharded_parameter_from_config(parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) self.A_log._no_weight_decay = True # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner // self.tp_pg.size(), device=device)) # Keep in fp32 + self.D = create_sharded_parameter_from_config(parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device), pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) self.D._no_weight_decay = True # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) From 21ee315bfba2b91b716734724ebe7510f657ff21 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 13 Mar 2024 18:28:20 +0000 Subject: [PATCH 11/39] add inference part to Mamba --- examples/mamba/mamba.py | 58 +++++++++++++++++++++++++++------- examples/mamba/run_generate.py | 40 ++++++----------------- examples/mamba/run_generate.sh | 5 --- 3 files changed, 57 insertions(+), 46 deletions(-) delete mode 100755 examples/mamba/run_generate.sh diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index b890ff30..677282de 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -65,6 +65,28 @@ logger = logging.get_logger(__name__) +from dataclasses import dataclass, field +from torch import Tensor +from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() class Mamba(nn.Module): def __init__( @@ -197,10 +219,6 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p hidden_states: (B, L, D) Returns: same shape as hidden_states """ - - if inference_params is not None: - raise NotImplementedError("Inference params not tested yet.") - batch, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None @@ -291,11 +309,14 @@ def step( conv_state: torch.Tensor, ssm_state: torch.Tensor, ): + batch, seqlen, dim = hidden_states.shape dtype = hidden_states.dtype - assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + assert seqlen == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.chunk(2, dim=-1) # (B D) - + x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2).chunk(2, dim=2) + x = x.squeeze(2) # (B D) + z = z.squeeze(2) # (B D) + # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) @@ -372,14 +393,14 @@ def _get_states_from_cache(self, inference_params, batch_size: int, initialize_s if self.layer_idx not in inference_params.key_value_memory_dict: conv_state = torch.zeros( batch_size, - self.d_model * self.expand, + self.d_model * self.expand // self.tp_pg.size(), self.d_conv, device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, ) ssm_state = torch.zeros( batch_size, - self.d_model * self.expand, + self.d_model * self.expand // self.tp_pg.size(), self.d_state, device=self.dt_proj.weight.device, dtype=self.dt_proj.weight.dtype, @@ -559,7 +580,7 @@ def __init__( "device": self.p2p.device, "dtype": cast_str_to_torch_dtype(config.dtype), }, - module_input_keys={"hidden_states", "sequence_mask", "residual"}, + module_input_keys={"hidden_states", "sequence_mask", "residual", "inference_params"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, ) for layer_idx in range(config.num_hidden_layers) @@ -598,6 +619,19 @@ def __init__( module_input_keys={"x"}, module_output_keys={"output"}, ) + + self.inference_params = None + + def setup_inference_params(self, max_length, max_batch_size): + self.inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=max_batch_size) + + def update_inference_params(self, logits: torch.Tensor): + if self.inference_params is not None: + self.inference_params.seqlen_offset += 1 # We are processing only one token at a time + # We need to transpose to make it compatible with decode.py (which will tranpose again, thus cancelling this transpose) + logits = logits.transpose(0, 1) # [batch_size, seq_length, vocab_size] + return logits + def forward( self, @@ -622,7 +656,7 @@ def forward_with_hidden_states( } for block in self.decoder: - hidden_encoder_states = block(**hidden_encoder_states) + hidden_encoder_states = block(**hidden_encoder_states, inference_params=self.inference_params) hidden_states = self.final_layer_norm( x=hidden_encoder_states["hidden_states"], @@ -632,6 +666,8 @@ def forward_with_hidden_states( sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + fp32_sharded_logits = self.update_inference_params(fp32_sharded_logits) + return fp32_sharded_logits, hidden_states def get_block_compute_costs(self): diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 01ff43fb..182b9eab 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -57,35 +57,10 @@ # from transformers import MambaForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from nanotron.parallel.parameters import NanotronParameter import lovely_tensors as lt; lt.monkey_patch() -def sanity_check_weights(model, model_ref): - - def sort_key(name_param_pair): - name, _ = name_param_pair - # Split the name and take the last part as the key for sorting - return name.split('.')[-1] - - total, fail = 0, 0 - - for (name_ref, param_ref), (name, param) in zip( - sorted(model_ref.named_parameters(), key=sort_key), - sorted(model.model.named_parameters(), key=sort_key) - ): - - total += 1 - try: - torch.testing.assert_close(param_ref, param, rtol=1e-10, atol=1e-10) - except AssertionError as e: - print(f"{name_ref} and {name} are not equal. Error: {e}") - fail += 1 - - print(f"{fail}/{total} parameters are not equal.") - - if fail > 0: - raise AssertionError("Some parameters are not equal") - def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") @@ -116,6 +91,8 @@ def main(): tp_linear_async_communication=False, ) + print(parallel_config) + # Initialise all process groups parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, @@ -193,9 +170,6 @@ def main(): rank=0, ) load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) - - sanity_check_weights(model=model, model_ref=model_ref) - model.eval() model_ref.eval() @@ -220,6 +194,12 @@ def main(): "Hello" ] + log_rank("Setup Inference mode for mamba model", logger=logger, level=logging.INFO, rank=0) + #TODO: make it work with batch of inputs + assert len(dummy_inputs) == 1, f"Only one input is supported for now. Got {len(dummy_inputs)} inputs" + batch_size = tokenizer(dummy_inputs, return_tensors="pt").input_ids.size(0) + model.model.setup_inference_params(max_length=args.max_new_tokens, max_batch_size=batch_size) + outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), tokenizer=tokenizer, @@ -268,7 +248,7 @@ def main(): output_ref = model_ref.generate( input_ids=input_ids, max_length=args.max_new_tokens, - cg=True, + cg=False, return_dict_in_generate=True, output_scores=True, enable_timing=False, diff --git a/examples/mamba/run_generate.sh b/examples/mamba/run_generate.sh deleted file mode 100755 index 85ff1331..00000000 --- a/examples/mamba/run_generate.sh +++ /dev/null @@ -1,5 +0,0 @@ -if [ -n "$DEBUG" ]; then - debugpy-run -m torch.distributed.run -- --nproc_per_node=1 run_generate.py --ckpt-path /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints/100 -else - torchrun --nproc_per_node=1 run_generate.py --ckpt-path /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints/100 -fi \ No newline at end of file From 249a8a0de4398c89ea5f820a17d58fbf6bbed18f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 14 Mar 2024 11:54:54 +0000 Subject: [PATCH 12/39] save mode_config as json as well --- examples/mamba/convert_hf_to_nanotron.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index d9d90d52..7de5132f 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -8,6 +8,7 @@ import argparse import torch import yaml +import json from pathlib import Path from tqdm import tqdm from typing import Dict @@ -334,4 +335,8 @@ def _get_tensor(path: str): tokenizer=TokenizerArgs(pretrained_model_name + "-hf"), ) log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) - yaml.dump(config.as_dict(), f) \ No newline at end of file + yaml.dump(config.as_dict(), f) + + with open(save_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(attrs, f) \ No newline at end of file From 28a35f60c1a6148f25d8ee3c188c91869a5b4768 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 14 Mar 2024 12:03:11 +0000 Subject: [PATCH 13/39] better convert dataclass to json instead (to get vocab size updated by padding) --- examples/mamba/convert_hf_to_nanotron.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 7de5132f..0aff2da4 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -13,7 +13,7 @@ from tqdm import tqdm from typing import Dict from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel -import lovely_tensors as lt; lt.monkey_patch() +from dataclasses import asdict from nanotron.config import ( AllForwardAllBackwardPipelineEngine, @@ -339,4 +339,4 @@ def _get_tensor(path: str): with open(save_path / "model_config.json", "w") as f: log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) - json.dump(attrs, f) \ No newline at end of file + json.dump(asdict(model_config), f) \ No newline at end of file From 07ee37c14d28dc208b1088eed8a13a8d0ef83ec3 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 14 Mar 2024 14:25:19 +0000 Subject: [PATCH 14/39] add nanotron to HF now converter (works with every TP values) --- examples/mamba/convert_nanotron_to_hf.py | 223 +++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 examples/mamba/convert_nanotron_to_hf.py diff --git a/examples/mamba/convert_nanotron_to_hf.py b/examples/mamba/convert_nanotron_to_hf.py new file mode 100644 index 00000000..421a2736 --- /dev/null +++ b/examples/mamba/convert_nanotron_to_hf.py @@ -0,0 +1,223 @@ +# ruff: noqa: E402 +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=weights-tp1 --save_path=HF_130M +""" +import argparse +import torch +import yaml +import json +from pathlib import Path + +from nanotron.config import ( + AllForwardAllBackwardPipelineEngine, + ParallelismArgs, + TensorParallelLinearMode, +) +from config import MambaModelConfig + +from nanotron.distributed import dist +from nanotron.models import build_model +from mamba import MambaForTraining +from nanotron.parallel import ParallelContext +from nanotron.trainer import mark_tied_parameters +from nanotron import logging +from nanotron.serialize import load_weights +from nanotron.models import init_on_device_and_dtype + +from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer + +logger = logging.get_logger(__name__) + +import lovely_tensors as lt; lt.monkey_patch() + +HARDCODED_HF_MODEL_NAME = "state-spaces/mamba-130m-hf" +HARCODED_PROMPT = "Hello" + +str_to_dtype = { + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "uint8": torch.uint8, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "bool": torch.bool, +} + +def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): + device = torch.device("cuda") + + with open(checkpoint_path / "model_config.json", "r") as f: + attrs = json.load(f) + model_config = MambaModelConfig(**attrs) + + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=1, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + parallel_context = ParallelContext( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=1, + ) + + model_nanotron = build_model( + model_builder=lambda: MambaForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=str_to_dtype[model_config.dtype], + device=device + ) + + mark_tied_parameters(model=model_nanotron, parallel_context=parallel_context) + + # Load checkpoint directly in memory and then only keep the state dictionary + load_weights(model=model_nanotron, parallel_context=parallel_context, root_folder=checkpoint_path) + model_nanotron_state_dict = model_nanotron.state_dict() + del model_nanotron + + # Init the HF mode + if model_config.ssm_cfg is None: + model_config_hf = MambaConfig( + vocab_size=model_config.vocab_size, + num_hidden_layers=model_config.num_hidden_layers, + residual_in_fp32=model_config.residual_in_fp32, + layer_norm_epsilon=model_config.rms_norm_eps, + hidden_size=model_config.d_model, + ) + else: + model_config_hf = MambaConfig( + vocab_size=model_config.vocab_size, + num_hidden_layers=model_config.num_hidden_layers, + residual_in_fp32=model_config.residual_in_fp32, + layer_norm_epsilon=model_config.rms_norm_eps, + hidden_size=model_config.d_model, + state_size=model_config.ssm_cfg["d_state"], + expand=model_config.ssm_cfg["expand"], + conv_kernel=model_config.ssm_cfg["d_conv"], + use_bias=model_config.ssm_cfg["bias"], + use_conv_bias=model_config.ssm_cfg["conv_bias"], + time_step_rank=model_config.ssm_cfg["dt_rank"], + time_step_scale=model_config.ssm_cfg["dt_scale"], + time_step_min=model_config.ssm_cfg["dt_min"], + time_step_max=model_config.ssm_cfg["dt_max"], + time_step_init_scheme=model_config.ssm_cfg["dt_init"], + time_step_floor=model_config.ssm_cfg["dt_init_floor"], + ) + + # Initialised HF model + with init_on_device_and_dtype(device, str_to_dtype[model_config.dtype]): + model_hf = MambaForCausalLM._from_config(model_config_hf) + + # Get mapping of Nanotron layer and HF layer + hf_to_nanotron = {} + + # Static mappings + hf_to_nanotron['backbone.embeddings.weight'] = 'token_position_embeddings.pp_block.token_embedding.weight' + hf_to_nanotron['backbone.norm_f.weight'] = 'final_layer_norm.pp_block.weight' + hf_to_nanotron['lm_head.weight'] = 'lm_head.pp_block.weight' + + # Dynamic mappings within a loop + for i in range(model_config.num_hidden_layers): + hf_to_nanotron[f'backbone.layers.{i}.mixer.A_log'] = f'decoder.{i}.pp_block.mixer.A_log' + hf_to_nanotron[f'backbone.layers.{i}.mixer.D'] = f'decoder.{i}.pp_block.mixer.D' + hf_to_nanotron[f'backbone.layers.{i}.mixer.in_proj.weight'] = f'decoder.{i}.pp_block.mixer.in_proj.weight' + hf_to_nanotron[f'backbone.layers.{i}.mixer.conv1d.weight'] = f'decoder.{i}.pp_block.mixer.conv1d.weight' + hf_to_nanotron[f'backbone.layers.{i}.mixer.conv1d.bias'] = f'decoder.{i}.pp_block.mixer.conv1d.bias' + hf_to_nanotron[f'backbone.layers.{i}.mixer.x_proj.weight'] = f'decoder.{i}.pp_block.mixer.x_proj.weight' + hf_to_nanotron[f'backbone.layers.{i}.mixer.x_proj.bias'] = f'decoder.{i}.pp_block.mixer.x_proj.bias' + hf_to_nanotron[f'backbone.layers.{i}.mixer.dt_proj.weight'] = f'decoder.{i}.pp_block.mixer.dt_proj.weight' + hf_to_nanotron[f'backbone.layers.{i}.mixer.dt_proj.bias'] = f'decoder.{i}.pp_block.mixer.dt_proj.bias' + hf_to_nanotron[f'backbone.layers.{i}.mixer.out_proj.weight'] = f'decoder.{i}.pp_block.mixer.out_proj.weight' + hf_to_nanotron[f'backbone.layers.{i}.mixer.out_proj.bias'] = f'decoder.{i}.pp_block.mixer.out_proj.bias' + hf_to_nanotron[f'backbone.layers.{i}.norm.weight'] = f'decoder.{i}.pp_block.norm.weight' + + + def _reverse_interleave_pattern(N): + """ + Compute the reverse of the interleave pattern given by _interleave_pattern. + Example: + reverse_interleave_pattern(4) -> [0, 2, 1, 3] + reverse_interleave_pattern(8) -> [0, 2, 4, 6, 1, 3, 5, 7] + """ + assert N % 2 == 0, "N must be even" + def __interleave_pattern(N): + """ + interleave_pattern(4) -> [0, 2, 1, 3] + interleave_pattern(8) -> [0, 4, 1, 5, 2, 6, 3, 7] + """ + assert N % 2 == 0, "N must be even" + pattern = [] + for i in range(N // 2): + pattern.append(i) + pattern.append(i + N // 2) + return pattern + + interleaved_pattern = __interleave_pattern(N) + reverse_pattern = [0] * N + for original_index, interleaved_index in enumerate(interleaved_pattern): + reverse_pattern[interleaved_index] = original_index + return reverse_pattern + + # Loop over the state dict and convert the keys to HF format + for module_name_hf, module_hf in model_hf.named_modules(): + for param_name_hf, param_hf in module_hf.named_parameters(recurse=False): + # Get the Nanotron parameter + nanotron_key = "model." + hf_to_nanotron[f"{module_name_hf}.{param_name_hf}"] + param = model_nanotron_state_dict[nanotron_key] + + if "in_proj" in nanotron_key: + # Undo the interleaving weights in Nanotron to make it HF compatible + param = param[_reverse_interleave_pattern(param.shape[0]), :] + + with torch.no_grad(): + param_hf.copy_(param) + + # Save the model + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + +def check_converted_model_generation(save_path: Path, hf_reference_model_name: str): + tokenizer = AutoTokenizer.from_pretrained(hf_reference_model_name) + input_ids = tokenizer(HARCODED_PROMPT, return_tensors="pt")["input_ids"] + print("Inputs:", tokenizer.batch_decode(input_ids)) + + # Ref + model = MambaForCausalLM.from_pretrained(hf_reference_model_name) + out = model.generate(input_ids, max_new_tokens=20) + print("Generation (ref): ", tokenizer.batch_decode(out)) + + # Converted + model = MambaForCausalLM.from_pretrained(save_path) + out = model.generate(input_ids, max_new_tokens=20) + print("Generation (converted): ", tokenizer.batch_decode(out)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Nanotron weights to HF format") + parser.add_argument("--checkpoint_path", type=str, default="mamba-130m") + parser.add_argument("--save_path", type=str, default="mamba-hf") + args = parser.parse_args() + + save_path = Path(args.save_path) + checkpoint_path = Path(args.checkpoint_path) + + # Convert Nanotron model to HF format + convert_checkpoint_and_save(checkpoint_path=checkpoint_path, save_path=save_path) + + # check if the conversion was successful by generating some text + check_converted_model_generation(save_path=save_path, hf_reference_model_name=HARDCODED_HF_MODEL_NAME) \ No newline at end of file From d0583865a00d7c84f0b7da31a28a436eef8bdc74 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 15 Mar 2024 16:25:02 +0000 Subject: [PATCH 15/39] rename config to model_config --- examples/mamba/convert_hf_to_nanotron.py | 2 +- examples/mamba/convert_nanotron_to_hf.py | 2 +- examples/mamba/mamba.py | 59 ++++++++++++------------ examples/mamba/run_generate.py | 2 +- 4 files changed, 33 insertions(+), 32 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 0aff2da4..69786a93 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -233,7 +233,7 @@ def _get_tensor(path: str): nanotron_model = build_model( model_builder=lambda: MambaForTraining( - config=model_config, + model_config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, diff --git a/examples/mamba/convert_nanotron_to_hf.py b/examples/mamba/convert_nanotron_to_hf.py index 421a2736..af5f30c6 100644 --- a/examples/mamba/convert_nanotron_to_hf.py +++ b/examples/mamba/convert_nanotron_to_hf.py @@ -74,7 +74,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): model_nanotron = build_model( model_builder=lambda: MambaForTraining( - config=model_config, + model_config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 677282de..e4bed659 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -423,14 +423,14 @@ class Embedding(nn.Module, AttachableStore): def __init__( self, tp_pg: dist.ProcessGroup, - config: MambaModelConfig, + model_config: MambaModelConfig, parallel_config: Optional[ParallelismArgs], ): super().__init__() self.token_embedding = TensorParallelEmbedding( - num_embeddings=config.vocab_size, - embedding_dim=config.d_model, - padding_idx=config.pad_token_id, + num_embeddings=model_config.vocab_size, + embedding_dim=model_config.d_model, + padding_idx=model_config.pad_token_id, pg=tp_pg, mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, ) @@ -457,7 +457,7 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ class MambaDecoderLayer(nn.Module): def __init__( self, - config: MambaModelConfig, + model_config: MambaModelConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, @@ -468,17 +468,17 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} - if config.ssm_cfg is None: + if model_config.ssm_cfg is None: ssm_cfg = {} else: - ssm_cfg = config.ssm_cfg + ssm_cfg = model_config.ssm_cfg self.layer_idx = layer_idx - self.residual_in_fp32 = config.residual_in_fp32 - self.fused_add_norm = config.fused_add_norm + self.residual_in_fp32 = model_config.residual_in_fp32 + self.fused_add_norm = model_config.fused_add_norm self.mixer = Mamba( - d_model=config.d_model, + d_model=model_config.d_model, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, @@ -487,10 +487,10 @@ def __init__( ) self.norm = partial( - nn.LayerNorm if not config.rms_norm else RMSNorm, - eps=config.rms_norm_eps, + nn.LayerNorm if not model_config.rms_norm else RMSNorm, + eps=model_config.rms_norm_eps, **factory_kwargs, - )(config.d_model) + )(model_config.d_model) if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" @@ -538,7 +538,7 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) class MambaModel(nn.Module): def __init__( self, - config: MambaModelConfig, + model_config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, @@ -547,7 +547,7 @@ def __init__( # Declare all the nodes self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) - self.config = config + self.model_config = model_config self.parallel_config = parallel_config self.parallel_context = parallel_context self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE @@ -560,7 +560,7 @@ def __init__( module_builder=Embedding, module_kwargs={ "tp_pg": parallel_context.tp_pg, - "config": config, + "model_config": model_config, "parallel_config": parallel_config, }, module_input_keys={"input_ids", "input_mask"}, @@ -573,24 +573,24 @@ def __init__( p2p=self.p2p, module_builder=MambaDecoderLayer, module_kwargs={ - "config": config, + "model_config": model_config, "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, "device": self.p2p.device, - "dtype": cast_str_to_torch_dtype(config.dtype), + "dtype": cast_str_to_torch_dtype(model_config.dtype), }, module_input_keys={"hidden_states", "sequence_mask", "residual", "inference_params"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, ) - for layer_idx in range(config.num_hidden_layers) + for layer_idx in range(model_config.num_hidden_layers) ] ) self.final_layer_norm = PipelineBlock( p2p=self.p2p, module_builder=RMSNorm, - module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, + module_kwargs={"hidden_size": model_config.d_model, "eps": model_config.rms_norm_eps}, module_input_keys={"x", "residual"}, module_output_keys={"hidden_states"}, ) @@ -600,8 +600,8 @@ def __init__( # Understand that this means that we return sharded logits that are going to need to be gathered module_builder=TensorParallelColumnLinear, module_kwargs={ - "in_features": config.d_model, - "out_features": config.vocab_size, + "in_features": model_config.d_model, + "out_features": model_config.vocab_size, "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. @@ -768,17 +768,21 @@ def forward( class MambaForTraining(NanotronModel): def __init__( self, - config: MambaModelConfig, + model_config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, ): super().__init__() + self.parallel_context = parallel_context + self.model_config = model_config + self.parallel_config = parallel_config + self.model = MambaModel( - config=config, - parallel_context=parallel_context, - parallel_config=parallel_config, + model_config=self.model_config, + parallel_context=self.parallel_context, + parallel_config=self.parallel_config, random_states=random_states, ) @@ -793,9 +797,6 @@ def __init__( }, module_output_keys={"loss"}, ) - self.parallel_context = parallel_context - self.config = config - self.parallel_config = parallel_config def forward( self, diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 182b9eab..13f94454 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -141,7 +141,7 @@ def main(): model = build_model( model_builder=lambda: MambaForTraining( - config=model_config, + model_config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=random_states, From 4625fc00876e76f3bfd6f6e6a96dc0fd9e23eb71 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 15 Mar 2024 16:26:54 +0000 Subject: [PATCH 16/39] refacto inference by using Attachable store --- examples/mamba/config.py | 5 +- examples/mamba/convert_hf_to_nanotron.py | 68 +++++++------- examples/mamba/mamba.py | 112 ++++++++++++----------- examples/mamba/run_generate.py | 15 ++- src/nanotron/generation/decode.py | 14 ++- 5 files changed, 114 insertions(+), 100 deletions(-) diff --git a/examples/mamba/config.py b/examples/mamba/config.py index a599620a..2c408033 100644 --- a/examples/mamba/config.py +++ b/examples/mamba/config.py @@ -6,7 +6,10 @@ from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs from nanotron.config.utils_config import cast_str_to_torch_dtype - +@dataclass +class MambaInferenceConfig: + max_new_tokens: int = 42 + @dataclass class MambaInit: # mamba_ssm.models.mixer_seq_simple._init_weights diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 69786a93..753a9228 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -38,47 +38,47 @@ logger = logging.get_logger(__name__) -def sanity_check_weights(model, model_ref, tp_size): - def _sort_key(name_param_pair): - name, _ = name_param_pair - # Split the name and take the last part as the key for sorting - return name.split('.')[-1] +# def sanity_check_weights(model, model_ref, tp_size): +# def _sort_key(name_param_pair): +# name, _ = name_param_pair +# # Split the name and take the last part as the key for sorting +# return name.split('.')[-1] - def _split_weight(data: torch.Tensor, dim: int) -> torch.Tensor: - rank = dist.get_rank() - world_size = dist.get_world_size() - chunks = torch.chunk(data, world_size, dim=dim) - return chunks[rank].contiguous() +# def _split_weight(data: torch.Tensor, dim: int) -> torch.Tensor: +# rank = dist.get_rank() +# world_size = dist.get_world_size() +# chunks = torch.chunk(data, world_size, dim=dim) +# return chunks[rank].contiguous() - total, fail, excluded = 0, 0, 0 +# total, fail, excluded = 0, 0, 0 - for (name_ref, param_ref), (name, param) in zip( - sorted(model_ref.named_parameters(), key=_sort_key), - sorted(model.model.named_parameters(), key=_sort_key) - ): +# for (name_ref, param_ref), (name, param) in zip( +# sorted(model_ref.named_parameters(), key=_sort_key), +# sorted(model.model.named_parameters(), key=_sort_key) +# ): - total += 1 - try: - param_shard_ref = param_ref - if isinstance(param, NanotronParameter) and param.is_sharded and tp_size > 1: - dim = next(index for index, (dim1, dim2) in enumerate(zip(param.shape, param_ref.shape)) if dim1 != dim2) - param_shard_ref = _split_weight(param_ref, dim) +# total += 1 +# try: +# param_shard_ref = param_ref +# if isinstance(param, NanotronParameter) and param.is_sharded and tp_size > 1: +# dim = next(index for index, (dim1, dim2) in enumerate(zip(param.shape, param_ref.shape)) if dim1 != dim2) +# param_shard_ref = _split_weight(param_ref, dim) - if "in_proj" in name_ref: - # Don't check this weight as we changed it manually (interleaved) - excluded += 1 - continue +# if "in_proj" in name_ref: +# # Don't check this weight as we changed it manually (interleaved) +# excluded += 1 +# continue - torch.testing.assert_close(param_shard_ref, param, rtol=1e-10, atol=1e-10) - except AssertionError as e: - log_rank(f"{name_ref} and {name} are not equal. {e}", logger=logger, level=logging.INFO, rank=0) - fail += 1 +# torch.testing.assert_close(param_shard_ref, param, rtol=1e-10, atol=1e-10) +# except AssertionError as e: +# log_rank(f"{name_ref} and {name} are not equal. {e}", logger=logger, level=logging.INFO, rank=0) +# fail += 1 - log_rank(f"{excluded}/{total} parameters were not sanity check (interleaved)", logger=logger, level=logging.INFO, rank=0) - log_rank(f"{fail}/{total} parameters are not equal", logger=logger, level=logging.INFO, rank=0) +# log_rank(f"{excluded}/{total} parameters were not sanity check (interleaved)", logger=logger, level=logging.INFO, rank=0) +# log_rank(f"{fail}/{total} parameters are not equal", logger=logger, level=logging.INFO, rank=0) - if fail > 0: - raise AssertionError("Some parameters are not equal") +# if fail > 0: +# raise AssertionError("Some parameters are not equal") def get_weight_from_hf( name: str, @@ -315,7 +315,7 @@ def _get_tensor(path: str): sanity_check(root_module=nanotron_model) - sanity_check_weights(model=nanotron_model, model_ref=model_ref, tp_size=parallel_context.tp_pg.size()) + # sanity_check_weights(model=nanotron_model, model_ref=model_ref, tp_size=parallel_context.tp_pg.size()) save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) checkpoint_metadata = { diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index e4bed659..a69d1723 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -22,13 +22,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from config import MambaModelConfig +from config import MambaModelConfig, MambaInferenceConfig from einops import rearrange, repeat from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs from nanotron.config.utils_config import cast_str_to_torch_dtype -from nanotron.generation.generate_store import AttachableStore +from nanotron.generation.generate_store import AttachableStore, Store from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.parallel import ParallelContext @@ -69,24 +69,6 @@ from torch import Tensor from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config -@dataclass -class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference.""" - - max_seqlen: int - max_batch_size: int - seqlen_offset: int = 0 - batch_size_offset: int = 0 - key_value_memory_dict: dict = field(default_factory=dict) - lengths_per_sample: Optional[Tensor] = None - - def reset(self, max_seqlen, max_batch_size): - self.max_seqlen = max_seqlen - self.max_batch_size = max_batch_size - self.seqlen_offset = 0 - if self.lengths_per_sample is not None: - self.lengths_per_sample.zero_() class Mamba(nn.Module): def __init__( @@ -109,6 +91,8 @@ def __init__( layer_idx: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + is_inference: bool = False, + store: Optional[Store] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -120,6 +104,7 @@ def __init__( self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx + self.is_inference = is_inference tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication is False @@ -129,6 +114,7 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) + self.store = store # Get current tensor parallel rank self.tp_pg = tp_pg self.tp_rank = dist.get_rank(self.tp_pg) @@ -214,7 +200,7 @@ def __init__( contiguous_chunks=None, ) - def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_params=None): + def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): """ hidden_states: (B, L, D) Returns: same shape as hidden_states @@ -222,9 +208,9 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p batch, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: + if self.is_inference: + conv_state, ssm_state = self._get_states_from_cache(batch) + if self.store["seqlen_offset"] > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out @@ -235,7 +221,7 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p # In the backward pass we write dx and dz next to each other to avoid torch.cat if ( - self.use_fast_path and inference_params is None + self.use_fast_path and not self.is_inference ): # Doesn't support outputting the states y = mamba_inner_fn( d_inner=self.d_inner, @@ -367,7 +353,7 @@ def step( out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state - def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = None, **kwargs): + def allocate_inference_cache(self, batch_size: int, max_new_tokens: int, dtype: torch.dtype = None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( @@ -388,9 +374,10 @@ def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torc ) return conv_state, ssm_state - def _get_states_from_cache(self, inference_params, batch_size: int, initialize_states: bool = False): + def _get_states_from_cache(self, batch_size: int, initialize_states: bool = False): assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: + + if self.layer_idx not in self.store["key_value_memory_dict"]: conv_state = torch.zeros( batch_size, self.d_model * self.expand // self.tp_pg.size(), @@ -406,12 +393,13 @@ def _get_states_from_cache(self, inference_params, batch_size: int, initialize_s dtype=self.dt_proj.weight.dtype, # dtype=torch.float32, ) - inference_params.key_value_memory_dict[self.layer_idx] = ( + + self.store["key_value_memory_dict"][self.layer_idx] = ( conv_state, - ssm_state, + ssm_state ) else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + conv_state, ssm_state = self.store["key_value_memory_dict"][self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: conv_state.zero_() @@ -463,6 +451,8 @@ def __init__( layer_idx: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + is_inference: bool = False, + store: Optional[Store] = None, ): super().__init__() @@ -482,6 +472,8 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, + is_inference=is_inference, + store=store, **ssm_cfg, **factory_kwargs, ) @@ -503,7 +495,6 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], residual: Optional[Union[torch.Tensor, TensorPointer]], - inference_params=None, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: if not self.fused_add_norm: # self.layer_idx was assigned when calling create_block @@ -523,7 +514,7 @@ def forward( residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) + hidden_states = self.mixer(hidden_states) return { "hidden_states": hidden_states, @@ -531,8 +522,8 @@ def forward( "residual": residual, } - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + def allocate_inference_cache(self, batch_size, max_new_tokens, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_new_tokens, dtype=dtype, **kwargs) class MambaModel(nn.Module): @@ -542,6 +533,8 @@ def __init__( parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, + is_inference: bool = False, + store: Optional[Store] = None, ): super().__init__() @@ -554,7 +547,10 @@ def __init__( tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - + + self.is_inference = is_inference + self.store = store + self.token_position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=Embedding, @@ -579,8 +575,10 @@ def __init__( "layer_idx": layer_idx, "device": self.p2p.device, "dtype": cast_str_to_torch_dtype(model_config.dtype), + "is_inference": is_inference, + "store": self.store }, - module_input_keys={"hidden_states", "sequence_mask", "residual", "inference_params"}, + module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, ) for layer_idx in range(model_config.num_hidden_layers) @@ -619,25 +617,14 @@ def __init__( module_input_keys={"x"}, module_output_keys={"output"}, ) - - self.inference_params = None - - def setup_inference_params(self, max_length, max_batch_size): - self.inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=max_batch_size) - - def update_inference_params(self, logits: torch.Tensor): - if self.inference_params is not None: - self.inference_params.seqlen_offset += 1 # We are processing only one token at a time - # We need to transpose to make it compatible with decode.py (which will tranpose again, thus cancelling this transpose) - logits = logits.transpose(0, 1) # [batch_size, seq_length, vocab_size] - return logits - + def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] ): + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] def forward_with_hidden_states( @@ -656,7 +643,7 @@ def forward_with_hidden_states( } for block in self.decoder: - hidden_encoder_states = block(**hidden_encoder_states, inference_params=self.inference_params) + hidden_encoder_states = block(**hidden_encoder_states) hidden_states = self.final_layer_norm( x=hidden_encoder_states["hidden_states"], @@ -666,7 +653,8 @@ def forward_with_hidden_states( sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - fp32_sharded_logits = self.update_inference_params(fp32_sharded_logits) + if self.is_inference: + self.store["seqlen_offset"] += 1 # We are processing only one token at a time return fp32_sharded_logits, hidden_states @@ -765,13 +753,15 @@ def forward( return {"loss": loss} -class MambaForTraining(NanotronModel): +class MambaForTraining(NanotronModel, AttachableStore): def __init__( self, model_config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, + is_inference: bool = False, + inference_config: Optional[MambaInferenceConfig] = None, ): super().__init__() @@ -779,11 +769,27 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config + store = None + + if is_inference and inference_config is not None: + self._attach_store(Store()) + self._store.update( + { + "max_new_tokens": inference_config.max_new_tokens, + "max_batch_size": 1, # We are processing only one token at a time + "seqlen_offset": 0, + "key_value_memory_dict": {}, + } + ) + store = self._store + self.model = MambaModel( model_config=self.model_config, parallel_context=self.parallel_context, parallel_config=self.parallel_config, random_states=random_states, + is_inference=is_inference, + store=store, ) self.loss = PipelineBlock( diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 13f94454..9fe129b6 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -43,8 +43,8 @@ set_random_seed, ) from nanotron.serialize import load_weights -from nanotron.trainer import CONFIG_TO_MODEL_CLASS, mark_tied_parameters -from config import MambaConfig, MambaModelConfig +from nanotron.trainer import mark_tied_parameters +from config import MambaConfig, MambaModelConfig, MambaInferenceConfig from mamba import MambaForTraining try: @@ -54,11 +54,8 @@ logger = logging.get_logger(__name__) -# from transformers import MambaForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel -from nanotron.parallel.parameters import NanotronParameter - import lovely_tensors as lt; lt.monkey_patch() def get_args(): @@ -145,6 +142,8 @@ def main(): parallel_context=parallel_context, parallel_config=parallel_config, random_states=random_states, + is_inference=True, + inference_config=MambaInferenceConfig(max_new_tokens=args.max_new_tokens), ), dtype=str_to_dtype[model_config.dtype], parallel_context=parallel_context, @@ -195,10 +194,7 @@ def main(): ] log_rank("Setup Inference mode for mamba model", logger=logger, level=logging.INFO, rank=0) - #TODO: make it work with batch of inputs - assert len(dummy_inputs) == 1, f"Only one input is supported for now. Got {len(dummy_inputs)} inputs" - batch_size = tokenizer(dummy_inputs, return_tensors="pt").input_ids.size(0) - model.model.setup_inference_params(max_length=args.max_new_tokens, max_batch_size=batch_size) + # assert config.inference_params.max_batch_size == 1, "Only batch size 1 is supported for inference for now" outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), @@ -211,6 +207,7 @@ def main(): generation_config=GenerationArgs(sampler="greedy", use_cache=True), tokenizer_config=TokenizerConfig(max_input_length=None), is_bench=os.environ.get("USE_BENCH", "0") == "1", + is_logits_transpose=False, ) for output in outputs: input_ids = output.input_ids diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 48d801cc..fc1eddaa 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -166,6 +166,7 @@ def decode_text( max_micro_batch_size: int, max_new_tokens: int, is_bench: bool = False, + is_logits_transpose: bool = True, ) -> Generator[GenerationOutput, None, None]: """We assume the following: - Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input. @@ -243,12 +244,19 @@ def decode_text( new_decoder_states.append(state) # Get the new logits if generation_config.use_cache: - with attach_store(model=model, store=state.store): - # transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size] + if hasattr(model, "_store"): + # Some model like mamba already has store build up for generation sharded_logits = model( input_ids=state.new_input_ids, input_mask=state.new_input_mask, ) + else: + with attach_store(model=model, store=state.store): + # transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size] + sharded_logits = model( + input_ids=state.new_input_ids, + input_mask=state.new_input_mask, + ) else: if isinstance(state.new_input_ids, torch.Tensor): batch_generated_ids = torch.cat(state.generation_ids, dim=-1) @@ -261,7 +269,7 @@ def decode_text( input_mask=batch_generated_mask, ) - if isinstance(sharded_logits, torch.Tensor): + if isinstance(sharded_logits, torch.Tensor) and is_logits_transpose: sharded_logits = sharded_logits.transpose(0, 1) # Communicate # TODO @thomasw21: Make a diagram to show how this works From fc79250bfd1d475fa33569c9fb1537863c9b53a6 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 18 Mar 2024 10:12:25 +0000 Subject: [PATCH 17/39] only rank 0 dump config and model_config --- examples/mamba/convert_hf_to_nanotron.py | 33 ++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 753a9228..21a7a1f7 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -324,19 +324,20 @@ def _get_tensor(path: str): } save_meta(root_folder=save_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) - with open(save_path / "config.yaml", "w") as f: - config = MambaConfig( - general=GeneralArgs(project="test", run="mamba"), - parallelism=parallel_config, - model=ModelArgs( - init_method=MambaInit(), - model_config=model_config, - ), - tokenizer=TokenizerArgs(pretrained_model_name + "-hf"), - ) - log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) - yaml.dump(config.as_dict(), f) - - with open(save_path / "model_config.json", "w") as f: - log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) - json.dump(asdict(model_config), f) \ No newline at end of file + if dist.get_rank() == 0: + with open(save_path / "config.yaml", "w") as f: + config = MambaConfig( + general=GeneralArgs(project="test", run="mamba"), + parallelism=parallel_config, + model=ModelArgs( + init_method=MambaInit(), + model_config=model_config, + ), + tokenizer=TokenizerArgs(pretrained_model_name + "-hf"), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(save_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(model_config), f) \ No newline at end of file From a43e437c16f059d9c23a9350fd399b50056315c7 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 18 Mar 2024 10:12:48 +0000 Subject: [PATCH 18/39] revert back to config instead of model_config --- examples/mamba/convert_hf_to_nanotron.py | 2 +- examples/mamba/convert_nanotron_to_hf.py | 2 +- examples/mamba/mamba.py | 50 ++++++++++++------------ examples/mamba/run_generate.py | 2 +- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 21a7a1f7..c2a33ed9 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -233,7 +233,7 @@ def _get_tensor(path: str): nanotron_model = build_model( model_builder=lambda: MambaForTraining( - model_config=model_config, + config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, diff --git a/examples/mamba/convert_nanotron_to_hf.py b/examples/mamba/convert_nanotron_to_hf.py index af5f30c6..421a2736 100644 --- a/examples/mamba/convert_nanotron_to_hf.py +++ b/examples/mamba/convert_nanotron_to_hf.py @@ -74,7 +74,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): model_nanotron = build_model( model_builder=lambda: MambaForTraining( - model_config=model_config, + config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=None, diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index a69d1723..df48436e 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -411,14 +411,14 @@ class Embedding(nn.Module, AttachableStore): def __init__( self, tp_pg: dist.ProcessGroup, - model_config: MambaModelConfig, + config: MambaModelConfig, parallel_config: Optional[ParallelismArgs], ): super().__init__() self.token_embedding = TensorParallelEmbedding( - num_embeddings=model_config.vocab_size, - embedding_dim=model_config.d_model, - padding_idx=model_config.pad_token_id, + num_embeddings=config.vocab_size, + embedding_dim=config.d_model, + padding_idx=config.pad_token_id, pg=tp_pg, mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, ) @@ -445,7 +445,7 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ class MambaDecoderLayer(nn.Module): def __init__( self, - model_config: MambaModelConfig, + config: MambaModelConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, @@ -458,17 +458,17 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} - if model_config.ssm_cfg is None: + if config.ssm_cfg is None: ssm_cfg = {} else: - ssm_cfg = model_config.ssm_cfg + ssm_cfg = config.ssm_cfg self.layer_idx = layer_idx - self.residual_in_fp32 = model_config.residual_in_fp32 - self.fused_add_norm = model_config.fused_add_norm + self.residual_in_fp32 = config.residual_in_fp32 + self.fused_add_norm = config.fused_add_norm self.mixer = Mamba( - d_model=model_config.d_model, + d_model=config.d_model, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, @@ -479,10 +479,10 @@ def __init__( ) self.norm = partial( - nn.LayerNorm if not model_config.rms_norm else RMSNorm, - eps=model_config.rms_norm_eps, + nn.LayerNorm if not config.rms_norm else RMSNorm, + eps=config.rms_norm_eps, **factory_kwargs, - )(model_config.d_model) + )(config.d_model) if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" @@ -529,7 +529,7 @@ def allocate_inference_cache(self, batch_size, max_new_tokens, dtype=None, **kwa class MambaModel(nn.Module): def __init__( self, - model_config: MambaModelConfig, + config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, @@ -540,7 +540,7 @@ def __init__( # Declare all the nodes self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) - self.model_config = model_config + self.config = config self.parallel_config = parallel_config self.parallel_context = parallel_context self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE @@ -556,7 +556,7 @@ def __init__( module_builder=Embedding, module_kwargs={ "tp_pg": parallel_context.tp_pg, - "model_config": model_config, + "config": config, "parallel_config": parallel_config, }, module_input_keys={"input_ids", "input_mask"}, @@ -569,26 +569,26 @@ def __init__( p2p=self.p2p, module_builder=MambaDecoderLayer, module_kwargs={ - "model_config": model_config, + "config": config, "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, "device": self.p2p.device, - "dtype": cast_str_to_torch_dtype(model_config.dtype), + "dtype": cast_str_to_torch_dtype(config.dtype), "is_inference": is_inference, "store": self.store }, module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, ) - for layer_idx in range(model_config.num_hidden_layers) + for layer_idx in range(config.num_hidden_layers) ] ) self.final_layer_norm = PipelineBlock( p2p=self.p2p, module_builder=RMSNorm, - module_kwargs={"hidden_size": model_config.d_model, "eps": model_config.rms_norm_eps}, + module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, module_input_keys={"x", "residual"}, module_output_keys={"hidden_states"}, ) @@ -598,8 +598,8 @@ def __init__( # Understand that this means that we return sharded logits that are going to need to be gathered module_builder=TensorParallelColumnLinear, module_kwargs={ - "in_features": model_config.d_model, - "out_features": model_config.vocab_size, + "in_features": config.d_model, + "out_features": config.vocab_size, "pg": parallel_context.tp_pg, "bias": False, # TODO @thomasw21: refactor so that we store that default in a single place. @@ -756,7 +756,7 @@ def forward( class MambaForTraining(NanotronModel, AttachableStore): def __init__( self, - model_config: MambaModelConfig, + config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, @@ -766,7 +766,7 @@ def __init__( super().__init__() self.parallel_context = parallel_context - self.model_config = model_config + self.config = config self.parallel_config = parallel_config store = None @@ -784,7 +784,7 @@ def __init__( store = self._store self.model = MambaModel( - model_config=self.model_config, + config=self.config, parallel_context=self.parallel_context, parallel_config=self.parallel_config, random_states=random_states, diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 9fe129b6..1fefd437 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -138,7 +138,7 @@ def main(): model = build_model( model_builder=lambda: MambaForTraining( - model_config=model_config, + config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=random_states, From 9c04f7fede2dfbb913eed73f754be1372706a622 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 18 Mar 2024 10:27:01 +0000 Subject: [PATCH 19/39] remove commented code --- examples/mamba/convert_hf_to_nanotron.py | 43 ------------------------ 1 file changed, 43 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index c2a33ed9..0d6ab02e 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -37,49 +37,6 @@ logger = logging.get_logger(__name__) - -# def sanity_check_weights(model, model_ref, tp_size): -# def _sort_key(name_param_pair): -# name, _ = name_param_pair -# # Split the name and take the last part as the key for sorting -# return name.split('.')[-1] - -# def _split_weight(data: torch.Tensor, dim: int) -> torch.Tensor: -# rank = dist.get_rank() -# world_size = dist.get_world_size() -# chunks = torch.chunk(data, world_size, dim=dim) -# return chunks[rank].contiguous() - -# total, fail, excluded = 0, 0, 0 - -# for (name_ref, param_ref), (name, param) in zip( -# sorted(model_ref.named_parameters(), key=_sort_key), -# sorted(model.model.named_parameters(), key=_sort_key) -# ): - -# total += 1 -# try: -# param_shard_ref = param_ref -# if isinstance(param, NanotronParameter) and param.is_sharded and tp_size > 1: -# dim = next(index for index, (dim1, dim2) in enumerate(zip(param.shape, param_ref.shape)) if dim1 != dim2) -# param_shard_ref = _split_weight(param_ref, dim) - -# if "in_proj" in name_ref: -# # Don't check this weight as we changed it manually (interleaved) -# excluded += 1 -# continue - -# torch.testing.assert_close(param_shard_ref, param, rtol=1e-10, atol=1e-10) -# except AssertionError as e: -# log_rank(f"{name_ref} and {name} are not equal. {e}", logger=logger, level=logging.INFO, rank=0) -# fail += 1 - -# log_rank(f"{excluded}/{total} parameters were not sanity check (interleaved)", logger=logger, level=logging.INFO, rank=0) -# log_rank(f"{fail}/{total} parameters are not equal", logger=logger, level=logging.INFO, rank=0) - -# if fail > 0: -# raise AssertionError("Some parameters are not equal") - def get_weight_from_hf( name: str, ref_module_state_dict: Dict[str, torch.Tensor], From 4b703a7e5df9e1813947d88717db4d378d09af4a Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 18 Mar 2024 10:51:39 +0000 Subject: [PATCH 20/39] fix dtype + load any hf checkpoints --- examples/mamba/convert_hf_to_nanotron.py | 75 +++++++----------------- examples/mamba/convert_nanotron_to_hf.py | 21 ++----- examples/mamba/run_generate.py | 21 +------ 3 files changed, 28 insertions(+), 89 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 0d6ab02e..a32cc681 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -35,8 +35,15 @@ from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.config import LoggingArgs +from transformers.utils import WEIGHTS_NAME, CONFIG_NAME +from transformers.utils.hub import cached_file +from mamba_ssm.models.config_mamba import MambaConfig as HFMambaConfig logger = logging.get_logger(__name__) +def load_config_hf(model_name): + resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) + return json.load(open(resolved_archive_file)) + def get_weight_from_hf( name: str, ref_module_state_dict: Dict[str, torch.Tensor], @@ -81,16 +88,15 @@ def _get_tensor(path: str): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert HF weights from states-space repo to brrr weights") - parser.add_argument("--model", type=str, default="130M", help="130M | 370M | 790M | 1.4B | 2.8B") parser.add_argument("--save_path", type=str, default="mamba-nanotron") + parser.add_argument("--model", type=str, default="state-spaces/mamba-130m") parser.add_argument("--dp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--tp", type=int, default=1) args = parser.parse_args() - if args.model not in ["130M", "370M", "790M", "1.4B", "2.8B"]: - raise ValueError("Model should be one of 130M, 370M, 790M, 1.4B, 2.8B") - + assert "state-spaces" in args.model, "Only models from state-spaces repo are supported" + save_path = Path(args.save_path) parallel_config = ParallelismArgs( @@ -101,6 +107,7 @@ def _get_tensor(path: str): tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, ) + assert parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE and parallel_config.tp_linear_async_communication is False parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, @@ -117,38 +124,18 @@ def _get_tensor(path: str): # Set log levels set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config) - d_model = None - num_hidden_layers = None - pretrained_model_name = None - - if args.model == "130M": - d_model = 768 - num_hidden_layers = 24 - pretrained_model_name = "state-spaces/mamba-130m" - elif args.model == "370M": - d_model = 1024 - num_hidden_layers = 48 - pretrained_model_name = "state-spaces/mamba-370m" - elif args.model == "790M": - d_model = 1536 - num_hidden_layers = 24 - pretrained_model_name = "state-spaces/mamba-790m" - elif args.model == "1.4B": - d_model = 2048 - num_hidden_layers = 48 - pretrained_model_name = "state-spaces/mamba-1.4b" - elif args.model == "2.8B": - d_model = 2560 - num_hidden_layers = 64 - pretrained_model_name = "state-spaces/mamba-2.8b" + hf_config_data = load_config_hf(args.model) + hf_config = HFMambaConfig(**hf_config_data) + dtype_str = "float32" + yaml_content = f""" is_mamba_config: true - d_model: {d_model} - dtype: float32 + d_model: {hf_config.d_model} + dtype: {dtype_str} fused_add_norm: true is_mamba_config: true - num_hidden_layers: {num_hidden_layers} + num_hidden_layers: {hf_config.n_layer} pad_token_id: null pad_vocab_size_multiple: 8 residual_in_fp32: true @@ -158,27 +145,11 @@ def _get_tensor(path: str): vocab_size: 50277 """ - str_to_dtype = { - "float32": torch.float32, - "float64": torch.float64, - "complex64": torch.complex64, - "complex128": torch.complex128, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "uint8": torch.uint8, - "int8": torch.int8, - "int16": torch.int16, - "int32": torch.int32, - "int64": torch.int64, - "bool": torch.bool, - } + dtype = getattr(torch, dtype_str) device = torch.device("cuda") attrs = yaml.safe_load(yaml_content) model_config = MambaModelConfig(**attrs) - - assert model_config.dtype == "float32", "Convert weights only in float32" - # Initiliaze Brrr model model_config.vocab_size = _vocab_size_with_padding( model_config.vocab_size, @@ -186,7 +157,7 @@ def _get_tensor(path: str): make_vocab_size_divisible_by=5, # So that every value of TP from 1 to 8 yield a vocab_size of 50280 ) - model_ref = MambaLMHeadModel.from_pretrained(pretrained_model_name, device=device, dtype=str_to_dtype[model_config.dtype]) + model_ref = MambaLMHeadModel.from_pretrained(args.model, device=device, dtype=dtype) nanotron_model = build_model( model_builder=lambda: MambaForTraining( @@ -196,7 +167,7 @@ def _get_tensor(path: str): random_states=None, ), parallel_context=parallel_context, - dtype=str_to_dtype[model_config.dtype], + dtype=dtype, device=device ) @@ -272,8 +243,6 @@ def _get_tensor(path: str): sanity_check(root_module=nanotron_model) - # sanity_check_weights(model=nanotron_model, model_ref=model_ref, tp_size=parallel_context.tp_pg.size()) - save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) checkpoint_metadata = { "last_train_step": 0, @@ -290,7 +259,7 @@ def _get_tensor(path: str): init_method=MambaInit(), model_config=model_config, ), - tokenizer=TokenizerArgs(pretrained_model_name + "-hf"), + tokenizer=TokenizerArgs(args.model + "-hf"), ) log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) yaml.dump(config.as_dict(), f) diff --git a/examples/mamba/convert_nanotron_to_hf.py b/examples/mamba/convert_nanotron_to_hf.py index 421a2736..0bcccb1d 100644 --- a/examples/mamba/convert_nanotron_to_hf.py +++ b/examples/mamba/convert_nanotron_to_hf.py @@ -35,21 +35,6 @@ HARDCODED_HF_MODEL_NAME = "state-spaces/mamba-130m-hf" HARCODED_PROMPT = "Hello" -str_to_dtype = { - "float32": torch.float32, - "float64": torch.float64, - "complex64": torch.complex64, - "complex128": torch.complex128, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "uint8": torch.uint8, - "int8": torch.int8, - "int16": torch.int16, - "int32": torch.int32, - "int64": torch.int64, - "bool": torch.bool, -} - def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): device = torch.device("cuda") @@ -57,6 +42,8 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): attrs = json.load(f) model_config = MambaModelConfig(**attrs) + dtype = getattr(torch, model_config.dtype) + parallel_config = ParallelismArgs( dp=1, pp=1, @@ -80,7 +67,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): random_states=None, ), parallel_context=parallel_context, - dtype=str_to_dtype[model_config.dtype], + dtype=dtype, device=device ) @@ -121,7 +108,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): ) # Initialised HF model - with init_on_device_and_dtype(device, str_to_dtype[model_config.dtype]): + with init_on_device_and_dtype(device, dtype): model_hf = MambaForCausalLM._from_config(model_config_hf) # Get mapping of Nanotron layer and HF layer diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 1fefd437..e650854d 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -120,21 +120,7 @@ def main(): else: # We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER) random_states = RandomStates({}) - - str_to_dtype = { - "float32": torch.float32, - "float64": torch.float64, - "complex64": torch.complex64, - "complex128": torch.complex128, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "uint8": torch.uint8, - "int8": torch.int8, - "int16": torch.int16, - "int32": torch.int32, - "int64": torch.int64, - "bool": torch.bool, - } + model = build_model( model_builder=lambda: MambaForTraining( @@ -145,12 +131,9 @@ def main(): is_inference=True, inference_config=MambaInferenceConfig(max_new_tokens=args.max_new_tokens), ), - dtype=str_to_dtype[model_config.dtype], + dtype=getattr(torch, model_config.dtype), parallel_context=parallel_context, ) - - assert str_to_dtype[model_config.dtype] == torch.float32, f"Model dtype {str_to_dtype[model_config.dtype]} should be torch.float32" - model_ref = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m").to("cuda") # Mark some parameters as tied From b56d51c998411e8d00f3985f4a649515722967d1 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 21 Mar 2024 13:48:45 +0000 Subject: [PATCH 21/39] force vocab size to be 50280 in config --- examples/mamba/convert_hf_to_nanotron.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index a32cc681..f3bb75ce 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -142,7 +142,7 @@ def _get_tensor(path: str): rms_norm: true rms_norm_eps: 1.0e-05 ssm_cfg: null - vocab_size: 50277 + vocab_size: 50280 """ dtype = getattr(torch, dtype_str) @@ -150,12 +150,7 @@ def _get_tensor(path: str): attrs = yaml.safe_load(yaml_content) model_config = MambaModelConfig(**attrs) - # Initiliaze Brrr model - model_config.vocab_size = _vocab_size_with_padding( - model_config.vocab_size, - pg_size=parallel_context.tp_pg.size(), - make_vocab_size_divisible_by=5, # So that every value of TP from 1 to 8 yield a vocab_size of 50280 - ) + assert model_config.vocab_size == 50280 model_ref = MambaLMHeadModel.from_pretrained(args.model, device=device, dtype=dtype) From 26c1f3e3e35ed1b5a2c37e013fd91d6bf989c2c3 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 21 Mar 2024 15:27:36 +0000 Subject: [PATCH 22/39] cleaner way to pass store --- examples/mamba/config.py | 5 -- examples/mamba/mamba.py | 65 +++++++---------------- examples/mamba/run_generate.py | 5 +- src/nanotron/generation/decode.py | 26 +++++---- src/nanotron/generation/generate_store.py | 2 +- 5 files changed, 36 insertions(+), 67 deletions(-) diff --git a/examples/mamba/config.py b/examples/mamba/config.py index 2c408033..6634971d 100644 --- a/examples/mamba/config.py +++ b/examples/mamba/config.py @@ -5,11 +5,6 @@ from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs from nanotron.config.utils_config import cast_str_to_torch_dtype - -@dataclass -class MambaInferenceConfig: - max_new_tokens: int = 42 - @dataclass class MambaInit: # mamba_ssm.models.mixer_seq_simple._init_weights diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index df48436e..f4e29ff7 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from config import MambaModelConfig, MambaInferenceConfig +from config import MambaModelConfig from einops import rearrange, repeat from nanotron import distributed as dist from nanotron import logging @@ -61,8 +61,6 @@ except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None -# import lovely_tensors as lt; lt.monkey_patch() - logger = logging.get_logger(__name__) from dataclasses import dataclass, field @@ -70,7 +68,7 @@ from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config -class Mamba(nn.Module): +class Mamba(nn.Module, AttachableStore): def __init__( self, d_model: int, @@ -91,8 +89,6 @@ def __init__( layer_idx: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - is_inference: bool = False, - store: Optional[Store] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -104,7 +100,6 @@ def __init__( self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx - self.is_inference = is_inference tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication is False @@ -114,7 +109,6 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - self.store = store # Get current tensor parallel rank self.tp_pg = tp_pg self.tp_rank = dist.get_rank(self.tp_pg) @@ -208,9 +202,11 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): batch, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None - if self.is_inference: + + store = self.get_local_store() + if store is not None: conv_state, ssm_state = self._get_states_from_cache(batch) - if self.store["seqlen_offset"] > 0: + if store["seqlen_offset"] > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out @@ -221,7 +217,7 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): # In the backward pass we write dx and dz next to each other to avoid torch.cat if ( - self.use_fast_path and not self.is_inference + self.use_fast_path and store is None ): # Doesn't support outputting the states y = mamba_inner_fn( d_inner=self.d_inner, @@ -377,7 +373,9 @@ def allocate_inference_cache(self, batch_size: int, max_new_tokens: int, dtype: def _get_states_from_cache(self, batch_size: int, initialize_states: bool = False): assert self.layer_idx is not None - if self.layer_idx not in self.store["key_value_memory_dict"]: + store = self.get_local_store() + + if self.layer_idx not in store["key_value_memory_dict"]: conv_state = torch.zeros( batch_size, self.d_model * self.expand // self.tp_pg.size(), @@ -394,12 +392,12 @@ def _get_states_from_cache(self, batch_size: int, initialize_states: bool = Fals # dtype=torch.float32, ) - self.store["key_value_memory_dict"][self.layer_idx] = ( + store["key_value_memory_dict"][self.layer_idx] = ( conv_state, ssm_state ) else: - conv_state, ssm_state = self.store["key_value_memory_dict"][self.layer_idx] + conv_state, ssm_state = store["key_value_memory_dict"][self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: conv_state.zero_() @@ -451,8 +449,6 @@ def __init__( layer_idx: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - is_inference: bool = False, - store: Optional[Store] = None, ): super().__init__() @@ -472,8 +468,6 @@ def __init__( parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, - is_inference=is_inference, - store=store, **ssm_cfg, **factory_kwargs, ) @@ -526,15 +520,13 @@ def allocate_inference_cache(self, batch_size, max_new_tokens, dtype=None, **kwa return self.mixer.allocate_inference_cache(batch_size, max_new_tokens, dtype=dtype, **kwargs) -class MambaModel(nn.Module): +class MambaModel(nn.Module, AttachableStore): def __init__( self, config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, - is_inference: bool = False, - store: Optional[Store] = None, ): super().__init__() @@ -548,9 +540,6 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - self.is_inference = is_inference - self.store = store - self.token_position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=Embedding, @@ -575,8 +564,6 @@ def __init__( "layer_idx": layer_idx, "device": self.p2p.device, "dtype": cast_str_to_torch_dtype(config.dtype), - "is_inference": is_inference, - "store": self.store }, module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, @@ -653,8 +640,10 @@ def forward_with_hidden_states( sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - if self.is_inference: - self.store["seqlen_offset"] += 1 # We are processing only one token at a time + store = self.get_local_store() + + if store is not None: + store["seqlen_offset"] += 1 # We are processing only one token at a time return fp32_sharded_logits, hidden_states @@ -753,43 +742,25 @@ def forward( return {"loss": loss} -class MambaForTraining(NanotronModel, AttachableStore): +class MambaForTraining(NanotronModel): def __init__( self, config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, - is_inference: bool = False, - inference_config: Optional[MambaInferenceConfig] = None, ): super().__init__() self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config - - store = None - - if is_inference and inference_config is not None: - self._attach_store(Store()) - self._store.update( - { - "max_new_tokens": inference_config.max_new_tokens, - "max_batch_size": 1, # We are processing only one token at a time - "seqlen_offset": 0, - "key_value_memory_dict": {}, - } - ) - store = self._store self.model = MambaModel( config=self.config, parallel_context=self.parallel_context, parallel_config=self.parallel_config, random_states=random_states, - is_inference=is_inference, - store=store, ) self.loss = PipelineBlock( diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index e650854d..80d92131 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -44,7 +44,7 @@ ) from nanotron.serialize import load_weights from nanotron.trainer import mark_tied_parameters -from config import MambaConfig, MambaModelConfig, MambaInferenceConfig +from config import MambaConfig, MambaModelConfig from mamba import MambaForTraining try: @@ -121,15 +121,12 @@ def main(): # We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER) random_states = RandomStates({}) - model = build_model( model_builder=lambda: MambaForTraining( config=model_config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=random_states, - is_inference=True, - inference_config=MambaInferenceConfig(max_new_tokens=args.max_new_tokens), ), dtype=getattr(torch, model_config.dtype), parallel_context=parallel_context, diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index fc1eddaa..576a18f5 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -216,11 +216,24 @@ def decode_text( is_max_nb_microbatches = number_states_in_buffer == max_nb_microbatches # Initialize decoder states + store = Store() + + if model.__class__.__name__ == "MambaModel": + + store.update( + { + "max_new_tokens": max_new_tokens, + "max_batch_size": 1, # We are processing only one token at a time + "seqlen_offset": 0, + "key_value_memory_dict": {}, + } + ) + decoder_states: Iterable[GenerationStates] = ( GenerationStates( new_input_ids=batch.input_ids, new_input_mask=batch.input_masks, - store=Store(), + store=store, generation_ids=[batch.input_ids], generation_mask=[batch.input_masks], ) @@ -244,19 +257,12 @@ def decode_text( new_decoder_states.append(state) # Get the new logits if generation_config.use_cache: - if hasattr(model, "_store"): - # Some model like mamba already has store build up for generation + with attach_store(model=model, store=state.store): + # transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size] sharded_logits = model( input_ids=state.new_input_ids, input_mask=state.new_input_mask, ) - else: - with attach_store(model=model, store=state.store): - # transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size] - sharded_logits = model( - input_ids=state.new_input_ids, - input_mask=state.new_input_mask, - ) else: if isinstance(state.new_input_ids, torch.Tensor): batch_generated_ids = torch.cat(state.generation_ids, dim=-1) diff --git a/src/nanotron/generation/generate_store.py b/src/nanotron/generation/generate_store.py index 99e70f49..1a565e71 100644 --- a/src/nanotron/generation/generate_store.py +++ b/src/nanotron/generation/generate_store.py @@ -31,7 +31,7 @@ def get_local_store(self): if hasattr(self, "_store"): if isinstance(self, nn.Module): assert self.training is False, "Store is used only in evaluation mode" - return self._store[id(self)] + return self._store else: return None From f9f29e1faa7a893ab995db380e4279a55a4be86f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 21 Mar 2024 18:26:46 +0000 Subject: [PATCH 23/39] move store logic inside mamba instead --- examples/mamba/mamba.py | 48 +++++++---------------- src/nanotron/generation/decode.py | 15 +------ src/nanotron/generation/generate_store.py | 2 +- 3 files changed, 16 insertions(+), 49 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index f4e29ff7..927a3e69 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -205,8 +205,16 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): store = self.get_local_store() if store is not None: + + if "key_value_memory_list" not in store: + store["key_value_memory_list"] = [] + + if "seqlen_offset" not in self._store: + self._store["seqlen_offset"] = 0 + conv_state, ssm_state = self._get_states_from_cache(batch) - if store["seqlen_offset"] > 0: + + if self._store["seqlen_offset"] > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out @@ -349,33 +357,12 @@ def step( out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state - def allocate_inference_cache(self, batch_size: int, max_new_tokens: int, dtype: torch.dtype = None, **kwargs): - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_conv, - device=device, - dtype=conv_dtype, - ) - ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, - self.d_model * self.expand, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return conv_state, ssm_state - def _get_states_from_cache(self, batch_size: int, initialize_states: bool = False): assert self.layer_idx is not None store = self.get_local_store() - if self.layer_idx not in store["key_value_memory_dict"]: + if len(store["key_value_memory_list"]) == 0: conv_state = torch.zeros( batch_size, self.d_model * self.expand // self.tp_pg.size(), @@ -391,13 +378,12 @@ def _get_states_from_cache(self, batch_size: int, initialize_states: bool = Fals dtype=self.dt_proj.weight.dtype, # dtype=torch.float32, ) - - store["key_value_memory_dict"][self.layer_idx] = ( + store["key_value_memory_list"] = ( conv_state, ssm_state ) else: - conv_state, ssm_state = store["key_value_memory_dict"][self.layer_idx] + conv_state, ssm_state = store["key_value_memory_list"] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: conv_state.zero_() @@ -516,10 +502,6 @@ def forward( "residual": residual, } - def allocate_inference_cache(self, batch_size, max_new_tokens, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_new_tokens, dtype=dtype, **kwargs) - - class MambaModel(nn.Module, AttachableStore): def __init__( self, @@ -640,10 +622,8 @@ def forward_with_hidden_states( sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - store = self.get_local_store() - - if store is not None: - store["seqlen_offset"] += 1 # We are processing only one token at a time + if self._store is not None: + self._store["seqlen_offset"] += 1 return fp32_sharded_logits, hidden_states diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 576a18f5..f570dec9 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -216,24 +216,11 @@ def decode_text( is_max_nb_microbatches = number_states_in_buffer == max_nb_microbatches # Initialize decoder states - store = Store() - - if model.__class__.__name__ == "MambaModel": - - store.update( - { - "max_new_tokens": max_new_tokens, - "max_batch_size": 1, # We are processing only one token at a time - "seqlen_offset": 0, - "key_value_memory_dict": {}, - } - ) - decoder_states: Iterable[GenerationStates] = ( GenerationStates( new_input_ids=batch.input_ids, new_input_mask=batch.input_masks, - store=store, + store=Store(), generation_ids=[batch.input_ids], generation_mask=[batch.input_masks], ) diff --git a/src/nanotron/generation/generate_store.py b/src/nanotron/generation/generate_store.py index 1a565e71..99e70f49 100644 --- a/src/nanotron/generation/generate_store.py +++ b/src/nanotron/generation/generate_store.py @@ -31,7 +31,7 @@ def get_local_store(self): if hasattr(self, "_store"): if isinstance(self, nn.Module): assert self.training is False, "Store is used only in evaluation mode" - return self._store + return self._store[id(self)] else: return None From a87bceb92368687e72e2b37d6a928fe9e01c5336 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 25 Mar 2024 13:16:22 +0000 Subject: [PATCH 24/39] dont use global store --- examples/mamba/mamba.py | 88 +++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 927a3e69..66f1aac4 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Mamba model. -""" +"""PyTorch Mamba model.""" + import math -import os from functools import partial from typing import Dict, Optional, Union @@ -28,13 +27,14 @@ from nanotron import logging from nanotron.config import ParallelismArgs from nanotron.config.utils_config import cast_str_to_torch_dtype -from nanotron.generation.generate_store import AttachableStore, Store +from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -63,10 +63,6 @@ logger = logging.get_logger(__name__) -from dataclasses import dataclass, field -from torch import Tensor -from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config - class Mamba(nn.Module, AttachableStore): def __init__( @@ -135,9 +131,13 @@ def __init__( **factory_kwargs, ) - self.conv1d.weight = create_sharded_parameter_from_config(parameter=self.conv1d.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.conv1d.weight = create_sharded_parameter_from_config( + parameter=self.conv1d.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) + ) if conv_bias: - self.conv1d.bias = create_sharded_parameter_from_config(parameter=self.conv1d.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.conv1d.bias = create_sharded_parameter_from_config( + parameter=self.conv1d.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) + ) self.activation = "silu" self.act = nn.SiLU() @@ -166,8 +166,12 @@ def __init__( # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True - self.dt_proj.weight = create_sharded_parameter_from_config(parameter=self.dt_proj.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) - self.dt_proj.bias = create_sharded_parameter_from_config(parameter=self.dt_proj.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.dt_proj.weight = create_sharded_parameter_from_config( + parameter=self.dt_proj.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) + ) + self.dt_proj.bias = create_sharded_parameter_from_config( + parameter=self.dt_proj.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) + ) # S4D real initialization A = repeat( @@ -176,11 +180,17 @@ def __init__( d=self.d_inner // self.tp_pg.size(), ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = create_sharded_parameter_from_config(parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.A_log = create_sharded_parameter_from_config( + parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) + ) self.A_log._no_weight_decay = True # D "skip" parameter - self.D = create_sharded_parameter_from_config(parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device), pg=self.tp_pg, split_config=SplitConfig(split_dim=0)) + self.D = create_sharded_parameter_from_config( + parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device), + pg=self.tp_pg, + split_config=SplitConfig(split_dim=0), + ) self.D._no_weight_decay = True # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) @@ -199,34 +209,33 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): hidden_states: (B, L, D) Returns: same shape as hidden_states """ - batch, seqlen, dim = hidden_states.shape + batch_size, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None - + store = self.get_local_store() if store is not None: - if "key_value_memory_list" not in store: store["key_value_memory_list"] = [] - - if "seqlen_offset" not in self._store: - self._store["seqlen_offset"] = 0 - - conv_state, ssm_state = self._get_states_from_cache(batch) - if self._store["seqlen_offset"] > 0: + if "seqlen_offset" not in store: + store["seqlen_offset"] = 0 + + conv_state, ssm_state = self._get_states_from_cache(batch_size) + + if store["seqlen_offset"] > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out + store["seqlen_offset"] += 1 + # We do matmul and transpose BLH -> HBL at the same time xz = self.in_proj(hidden_states).transpose(1, 2) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat - if ( - self.use_fast_path and store is None - ): # Doesn't support outputting the states + if self.use_fast_path and store is None: # Doesn't support outputting the states y = mamba_inner_fn( d_inner=self.d_inner, tp_pg=self.tp_pg, @@ -244,7 +253,7 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): ) else: assert self.d_inner % self.tp_pg.size() == 0 - x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) + x, z = xz.view(batch_size, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) x = x.squeeze(2) z = z.squeeze(2) # Compute short convolution @@ -299,14 +308,14 @@ def step( conv_state: torch.Tensor, ssm_state: torch.Tensor, ): - batch, seqlen, dim = hidden_states.shape + batch_size, seqlen, dim = hidden_states.shape dtype = hidden_states.dtype assert seqlen == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2).chunk(2, dim=2) + x, z = xz.view(batch_size, self.d_inner // self.tp_pg.size(), 2).chunk(2, dim=2) x = x.squeeze(2) # (B D) z = z.squeeze(2) # (B D) - + # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) @@ -359,9 +368,9 @@ def step( def _get_states_from_cache(self, batch_size: int, initialize_states: bool = False): assert self.layer_idx is not None - + store = self.get_local_store() - + if len(store["key_value_memory_list"]) == 0: conv_state = torch.zeros( batch_size, @@ -378,10 +387,7 @@ def _get_states_from_cache(self, batch_size: int, initialize_states: bool = Fals dtype=self.dt_proj.weight.dtype, # dtype=torch.float32, ) - store["key_value_memory_list"] = ( - conv_state, - ssm_state - ) + store["key_value_memory_list"] = (conv_state, ssm_state) else: conv_state, ssm_state = store["key_value_memory_list"] # TODO: What if batch size changes between generation, and we reuse the same states? @@ -502,7 +508,8 @@ def forward( "residual": residual, } -class MambaModel(nn.Module, AttachableStore): + +class MambaModel(nn.Module): def __init__( self, config: MambaModelConfig, @@ -521,7 +528,7 @@ def __init__( tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - + self.token_position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=Embedding, @@ -587,13 +594,11 @@ def __init__( module_output_keys={"output"}, ) - def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] def forward_with_hidden_states( @@ -622,9 +627,6 @@ def forward_with_hidden_states( sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - if self._store is not None: - self._store["seqlen_offset"] += 1 - return fp32_sharded_logits, hidden_states def get_block_compute_costs(self): From 4a20e50a0efc35584f77ad5201b8f059e7d2950b Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 25 Mar 2024 13:17:44 +0000 Subject: [PATCH 25/39] fix dtype --- examples/mamba/mamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 66f1aac4..433f5b9c 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -385,7 +385,6 @@ def _get_states_from_cache(self, batch_size: int, initialize_states: bool = Fals self.d_state, device=self.dt_proj.weight.device, dtype=self.dt_proj.weight.dtype, - # dtype=torch.float32, ) store["key_value_memory_list"] = (conv_state, ssm_state) else: From 5b8f72aa44cabfe6013731bc2681c29a1e9ba285 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 25 Mar 2024 13:26:04 +0000 Subject: [PATCH 26/39] remove reference model --- examples/mamba/convert_nanotron_to_hf.py | 94 +++++++++++------------- examples/mamba/run_generate.py | 45 +++--------- 2 files changed, 54 insertions(+), 85 deletions(-) diff --git a/examples/mamba/convert_nanotron_to_hf.py b/examples/mamba/convert_nanotron_to_hf.py index 0bcccb1d..e6c45a25 100644 --- a/examples/mamba/convert_nanotron_to_hf.py +++ b/examples/mamba/convert_nanotron_to_hf.py @@ -5,45 +5,39 @@ torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=weights-tp1 --save_path=HF_130M """ import argparse -import torch -import yaml import json from pathlib import Path +import torch +from config import MambaModelConfig +from mamba import MambaForTraining +from nanotron import logging from nanotron.config import ( AllForwardAllBackwardPipelineEngine, ParallelismArgs, TensorParallelLinearMode, ) -from config import MambaModelConfig - -from nanotron.distributed import dist -from nanotron.models import build_model -from mamba import MambaForTraining +from nanotron.models import build_model, init_on_device_and_dtype from nanotron.parallel import ParallelContext -from nanotron.trainer import mark_tied_parameters -from nanotron import logging from nanotron.serialize import load_weights -from nanotron.models import init_on_device_and_dtype - -from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer +from nanotron.trainer import mark_tied_parameters +from transformers import AutoTokenizer, MambaConfig, MambaForCausalLM logger = logging.get_logger(__name__) -import lovely_tensors as lt; lt.monkey_patch() - -HARDCODED_HF_MODEL_NAME = "state-spaces/mamba-130m-hf" +TOKENIZER_NAME = "state-spaces/mamba-130m-hf" HARCODED_PROMPT = "Hello" + def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): device = torch.device("cuda") - + with open(checkpoint_path / "model_config.json", "r") as f: attrs = json.load(f) model_config = MambaModelConfig(**attrs) - + dtype = getattr(torch, model_config.dtype) - + parallel_config = ParallelismArgs( dp=1, pp=1, @@ -58,7 +52,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): pipeline_parallel_size=1, tensor_parallel_size=1, ) - + model_nanotron = build_model( model_builder=lambda: MambaForTraining( config=model_config, @@ -68,9 +62,9 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): ), parallel_context=parallel_context, dtype=dtype, - device=device + device=device, ) - + mark_tied_parameters(model=model_nanotron, parallel_context=parallel_context) # Load checkpoint directly in memory and then only keep the state dictionary @@ -80,7 +74,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): # Init the HF mode if model_config.ssm_cfg is None: - model_config_hf = MambaConfig( + model_config_hf = MambaConfig( vocab_size=model_config.vocab_size, num_hidden_layers=model_config.num_hidden_layers, residual_in_fp32=model_config.residual_in_fp32, @@ -106,7 +100,7 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): time_step_init_scheme=model_config.ssm_cfg["dt_init"], time_step_floor=model_config.ssm_cfg["dt_init_floor"], ) - + # Initialised HF model with init_on_device_and_dtype(device, dtype): model_hf = MambaForCausalLM._from_config(model_config_hf) @@ -115,26 +109,25 @@ def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): hf_to_nanotron = {} # Static mappings - hf_to_nanotron['backbone.embeddings.weight'] = 'token_position_embeddings.pp_block.token_embedding.weight' - hf_to_nanotron['backbone.norm_f.weight'] = 'final_layer_norm.pp_block.weight' - hf_to_nanotron['lm_head.weight'] = 'lm_head.pp_block.weight' + hf_to_nanotron["backbone.embeddings.weight"] = "token_position_embeddings.pp_block.token_embedding.weight" + hf_to_nanotron["backbone.norm_f.weight"] = "final_layer_norm.pp_block.weight" + hf_to_nanotron["lm_head.weight"] = "lm_head.pp_block.weight" # Dynamic mappings within a loop for i in range(model_config.num_hidden_layers): - hf_to_nanotron[f'backbone.layers.{i}.mixer.A_log'] = f'decoder.{i}.pp_block.mixer.A_log' - hf_to_nanotron[f'backbone.layers.{i}.mixer.D'] = f'decoder.{i}.pp_block.mixer.D' - hf_to_nanotron[f'backbone.layers.{i}.mixer.in_proj.weight'] = f'decoder.{i}.pp_block.mixer.in_proj.weight' - hf_to_nanotron[f'backbone.layers.{i}.mixer.conv1d.weight'] = f'decoder.{i}.pp_block.mixer.conv1d.weight' - hf_to_nanotron[f'backbone.layers.{i}.mixer.conv1d.bias'] = f'decoder.{i}.pp_block.mixer.conv1d.bias' - hf_to_nanotron[f'backbone.layers.{i}.mixer.x_proj.weight'] = f'decoder.{i}.pp_block.mixer.x_proj.weight' - hf_to_nanotron[f'backbone.layers.{i}.mixer.x_proj.bias'] = f'decoder.{i}.pp_block.mixer.x_proj.bias' - hf_to_nanotron[f'backbone.layers.{i}.mixer.dt_proj.weight'] = f'decoder.{i}.pp_block.mixer.dt_proj.weight' - hf_to_nanotron[f'backbone.layers.{i}.mixer.dt_proj.bias'] = f'decoder.{i}.pp_block.mixer.dt_proj.bias' - hf_to_nanotron[f'backbone.layers.{i}.mixer.out_proj.weight'] = f'decoder.{i}.pp_block.mixer.out_proj.weight' - hf_to_nanotron[f'backbone.layers.{i}.mixer.out_proj.bias'] = f'decoder.{i}.pp_block.mixer.out_proj.bias' - hf_to_nanotron[f'backbone.layers.{i}.norm.weight'] = f'decoder.{i}.pp_block.norm.weight' - - + hf_to_nanotron[f"backbone.layers.{i}.mixer.A_log"] = f"decoder.{i}.pp_block.mixer.A_log" + hf_to_nanotron[f"backbone.layers.{i}.mixer.D"] = f"decoder.{i}.pp_block.mixer.D" + hf_to_nanotron[f"backbone.layers.{i}.mixer.in_proj.weight"] = f"decoder.{i}.pp_block.mixer.in_proj.weight" + hf_to_nanotron[f"backbone.layers.{i}.mixer.conv1d.weight"] = f"decoder.{i}.pp_block.mixer.conv1d.weight" + hf_to_nanotron[f"backbone.layers.{i}.mixer.conv1d.bias"] = f"decoder.{i}.pp_block.mixer.conv1d.bias" + hf_to_nanotron[f"backbone.layers.{i}.mixer.x_proj.weight"] = f"decoder.{i}.pp_block.mixer.x_proj.weight" + hf_to_nanotron[f"backbone.layers.{i}.mixer.x_proj.bias"] = f"decoder.{i}.pp_block.mixer.x_proj.bias" + hf_to_nanotron[f"backbone.layers.{i}.mixer.dt_proj.weight"] = f"decoder.{i}.pp_block.mixer.dt_proj.weight" + hf_to_nanotron[f"backbone.layers.{i}.mixer.dt_proj.bias"] = f"decoder.{i}.pp_block.mixer.dt_proj.bias" + hf_to_nanotron[f"backbone.layers.{i}.mixer.out_proj.weight"] = f"decoder.{i}.pp_block.mixer.out_proj.weight" + hf_to_nanotron[f"backbone.layers.{i}.mixer.out_proj.bias"] = f"decoder.{i}.pp_block.mixer.out_proj.bias" + hf_to_nanotron[f"backbone.layers.{i}.norm.weight"] = f"decoder.{i}.pp_block.norm.weight" + def _reverse_interleave_pattern(N): """ Compute the reverse of the interleave pattern given by _interleave_pattern. @@ -143,6 +136,7 @@ def _reverse_interleave_pattern(N): reverse_interleave_pattern(8) -> [0, 2, 4, 6, 1, 3, 5, 7] """ assert N % 2 == 0, "N must be even" + def __interleave_pattern(N): """ interleave_pattern(4) -> [0, 2, 1, 3] @@ -154,7 +148,7 @@ def __interleave_pattern(N): pattern.append(i) pattern.append(i + N // 2) return pattern - + interleaved_pattern = __interleave_pattern(N) reverse_pattern = [0] * N for original_index, interleaved_index in enumerate(interleaved_pattern): @@ -171,7 +165,7 @@ def __interleave_pattern(N): if "in_proj" in nanotron_key: # Undo the interleaving weights in Nanotron to make it HF compatible param = param[_reverse_interleave_pattern(param.shape[0]), :] - + with torch.no_grad(): param_hf.copy_(param) @@ -179,21 +173,17 @@ def __interleave_pattern(N): model_hf.save_pretrained(save_path) print(f"Model saved to {save_path}") -def check_converted_model_generation(save_path: Path, hf_reference_model_name: str): - tokenizer = AutoTokenizer.from_pretrained(hf_reference_model_name) + +def check_converted_model_generation(save_path: Path, tokenizer_name: str): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) input_ids = tokenizer(HARCODED_PROMPT, return_tensors="pt")["input_ids"] print("Inputs:", tokenizer.batch_decode(input_ids)) - - # Ref - model = MambaForCausalLM.from_pretrained(hf_reference_model_name) - out = model.generate(input_ids, max_new_tokens=20) - print("Generation (ref): ", tokenizer.batch_decode(out)) - # Converted model = MambaForCausalLM.from_pretrained(save_path) - out = model.generate(input_ids, max_new_tokens=20) + out = model.generate(input_ids, max_new_tokens=100) print("Generation (converted): ", tokenizer.batch_decode(out)) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Nanotron weights to HF format") parser.add_argument("--checkpoint_path", type=str, default="mamba-130m") @@ -207,4 +197,4 @@ def check_converted_model_generation(save_path: Path, hf_reference_model_name: s convert_checkpoint_and_save(checkpoint_path=checkpoint_path, save_path=save_path) # check if the conversion was successful by generating some text - check_converted_model_generation(save_path=save_path, hf_reference_model_name=HARDCODED_HF_MODEL_NAME) \ No newline at end of file + check_converted_model_generation(save_path=save_path, tokenizer_name=TOKENIZER_NAME) diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index 80d92131..cb12e511 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -13,6 +13,8 @@ from pathlib import Path import torch +from config import MambaConfig, MambaModelConfig +from mamba import MambaForTraining from nanotron import distributed as dist from nanotron import logging from nanotron.config import ( @@ -44,8 +46,6 @@ ) from nanotron.serialize import load_weights from nanotron.trainer import mark_tied_parameters -from config import MambaConfig, MambaModelConfig -from mamba import MambaForTraining try: from transformers import AutoTokenizer @@ -54,9 +54,6 @@ logger = logging.get_logger(__name__) -from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel - -import lovely_tensors as lt; lt.monkey_patch() def get_args(): parser = argparse.ArgumentParser() @@ -73,11 +70,14 @@ def main(): assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" - config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), config_class=MambaConfig, model_config_class=MambaModelConfig) + config = get_config_from_file( + (args.ckpt_path / "config.yaml").as_posix(), config_class=MambaConfig, model_config_class=MambaModelConfig + ) model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path - - assert "state-spaces/mamba-130m-hf" == tokenizer_path; f"Should be 'state-spaces/mamba-130m-hf' tokenizer and not '{tokenizer_path}'" + + assert "state-spaces/mamba-130m-hf" == tokenizer_path + f"Should be 'state-spaces/mamba-130m-hf' tokenizer and not '{tokenizer_path}'" parallel_config = ParallelismArgs( dp=args.dp or config.parallelism.dp, @@ -111,7 +111,7 @@ def main(): # Set random states set_random_seed(42) - + # Get synchronized random states if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE: random_states = RandomStates( @@ -131,8 +131,6 @@ def main(): dtype=getattr(torch, model_config.dtype), parallel_context=parallel_context, ) - model_ref = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m").to("cuda") - # Mark some parameters as tied # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) @@ -150,8 +148,7 @@ def main(): ) load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) model.eval() - model_ref.eval() - + if AutoTokenizer is not None: tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # tokenizer.pad_token_id = tokenizer.eos_token_id @@ -168,14 +165,13 @@ def main(): tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? dummy_inputs = [ # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - # "def fib(n)", # "This film was probably inspired by Godzilla", "Hello" ] log_rank("Setup Inference mode for mamba model", logger=logger, level=logging.INFO, rank=0) # assert config.inference_params.max_batch_size == 1, "Only batch size 1 is supported for inference for now" - + outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), tokenizer=tokenizer, @@ -217,27 +213,10 @@ def main(): level=logging.INFO, rank=0, ) - + # Model ref tokens = tokenizer(dummy_inputs, return_tensors="pt") input_ids = tokens.input_ids.to(device="cuda") - - output_ref = model_ref.generate( - input_ids=input_ids, - max_length=args.max_new_tokens, - cg=False, - return_dict_in_generate=True, - output_scores=True, - enable_timing=False, - temperature=1.0, - top_k=1, - top_p=1.0, - min_p=0.0, - repetition_penalty=1.0 - ) - - log_rank(f"input REF: {tokenizer.decode(input_ids[0], clean_up_tokenization_spaces=False)}", logger=logger, level=logging.INFO, rank=0) - log_rank(f"generation REF: {tokenizer.batch_decode(output_ref.sequences.tolist())}", logger=logger, level=logging.INFO, rank=0) else: outputs = decode_tokenized( input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), From 9982ff2cbf2e310925cd294112b8f8ba002dc33c Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 27 Mar 2024 10:58:52 +0000 Subject: [PATCH 27/39] fix fan_in computation for bias in mamba --- examples/mamba/mamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 433f5b9c..181874ce 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -837,24 +837,22 @@ def init_model_randomly(self, config): module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) elif isinstance(module, nn.Conv1d): - fan_in = None + fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) if "weight" == param_name: - fan_in, _ = init._calculate_fan_in_and_fan_out(param) init.kaiming_uniform_(module.weight, a=math.sqrt(5)) elif "bias" == param_name: - bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + bound = 1 / math.sqrt(fan_in) if (fan_in > 0) else 0 init.uniform_(module.bias, -bound, bound) else: raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, nn.Linear): - fan_in = None + fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) if "weight" == param_name: - fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) init.kaiming_uniform_(module.weight, a=math.sqrt(5)) elif "bias" == param_name: - bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + bound = 1 / math.sqrt(fan_in) if (fan_in > 0) else 0 init.uniform_(module.bias, -bound, bound) else: raise ValueError(f"Who the fuck is {param_name}?") From 91458707fb2c96b93074a05a80c2319a6040f097 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 27 Mar 2024 18:15:47 +0000 Subject: [PATCH 28/39] fix store increment --- examples/mamba/mamba.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 181874ce..149834b7 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -226,9 +226,10 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): if store["seqlen_offset"] > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) + store["seqlen_offset"] += 1 return out - - store["seqlen_offset"] += 1 + else: + store["seqlen_offset"] += 1 # We do matmul and transpose BLH -> HBL at the same time xz = self.in_proj(hidden_states).transpose(1, 2) From f6f65922d969c306b27c13b5b6207b94d82d8dd9 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 4 Apr 2024 13:25:27 +0000 Subject: [PATCH 29/39] clean code --- examples/mamba/config.py | 4 +- examples/mamba/convert_hf_to_nanotron.py | 129 +++++++++++++---------- examples/mamba/convert_nanotron_to_hf.py | 23 ++-- examples/mamba/mamba.py | 2 - examples/mamba/run_generate.py | 4 +- src/nanotron/generation/decode.py | 4 +- 6 files changed, 95 insertions(+), 71 deletions(-) diff --git a/examples/mamba/config.py b/examples/mamba/config.py index 6634971d..d2b01666 100644 --- a/examples/mamba/config.py +++ b/examples/mamba/config.py @@ -2,12 +2,12 @@ from typing import Optional, Union import torch - from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs from nanotron.config.utils_config import cast_str_to_torch_dtype + + @dataclass class MambaInit: - # mamba_ssm.models.mixer_seq_simple._init_weights initializer_range: float = 0.02 rescale_prenorm_residual: bool = True n_residuals_per_layer: int = 1 # Change to 2 if we have MLP diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index f3bb75ce..bd0fa9e6 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -3,56 +3,58 @@ Converts a HF model from (https://huggingface.co/state-spaces/) to a Brrr model Command: - torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --model 130M --save_path nanotron-weights + torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --model state-spaces/mamba-130m --tokenizer state-spaces/mamba-130m-hf --save_path nanotron_weights """ import argparse -import torch -import yaml import json +from dataclasses import asdict from pathlib import Path -from tqdm import tqdm from typing import Dict -from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel -from dataclasses import asdict +import torch +import yaml +from config import MambaConfig, MambaInit, MambaModelConfig +from mamba import MambaForTraining +from mamba_ssm.models.config_mamba import MambaConfig as HFMambaConfig +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from nanotron import logging from nanotron.config import ( AllForwardAllBackwardPipelineEngine, + GeneralArgs, + LoggingArgs, + ModelArgs, ParallelismArgs, TensorParallelLinearMode, + TokenizerArgs, ) -from config import MambaModelConfig, MambaConfig, MambaInit -from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs - from nanotron.distributed import dist -from nanotron.helpers import _vocab_size_with_padding +from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model -from mamba import MambaForTraining from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter, sanity_check from nanotron.serialize import save_meta, save_weights from nanotron.trainer import mark_tied_parameters -from nanotron import logging -from nanotron.logging import log_rank, set_ranks_logging_level -from nanotron.config import LoggingArgs - -from transformers.utils import WEIGHTS_NAME, CONFIG_NAME +from tqdm import tqdm +from transformers.utils import CONFIG_NAME from transformers.utils.hub import cached_file -from mamba_ssm.models.config_mamba import MambaConfig as HFMambaConfig + logger = logging.get_logger(__name__) + def load_config_hf(model_name): resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) return json.load(open(resolved_archive_file)) + def get_weight_from_hf( name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MambaLMHeadModel, nanotron_to_hf: Dict[str, str], - get_grad: bool = False + get_grad: bool = False, ) -> torch.Tensor: - """From our brrr implementation, we get the equivalent tensor in transformers implementation""" - + """From our brrr implementation, we get the equivalent tensor in transformers implementation""" + def _interleave_pattern(N): """ interleave_pattern(4) -> [0, 2, 1, 3] @@ -64,39 +66,46 @@ def _interleave_pattern(N): pattern.append(i) pattern.append(i + N // 2) return pattern - + hf_name = nanotron_to_hf[name] - + if get_grad is False: + def _get_tensor(path: str): return ref_module_state_dict[path] + else: + def _get_tensor(path: str): param = ref_module.get_parameter(path) return param.grad - + param = _get_tensor(hf_name) - + if "in_proj" in hf_name: # In Nanotron, we do tensor parallel column so weight need to be split in the column dimension (i.e: xz.view(...)) # However, the HF weights was trained such that it expected xz.chunk(...) to split the tensor in the row dimension # Thus, we need to interleaved the HF weights to make it compatible with Nanotron - log_rank(f"Interleaving {hf_name} to make it compatible with Nanotron", logger=logger, level=logging.INFO, rank=0) + log_rank( + f"Interleaving {hf_name} to make it compatible with Nanotron", logger=logger, level=logging.INFO, rank=0 + ) param = param[_interleave_pattern(param.shape[0]), :] - + return param + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert HF weights from states-space repo to brrr weights") parser.add_argument("--save_path", type=str, default="mamba-nanotron") parser.add_argument("--model", type=str, default="state-spaces/mamba-130m") + parser.add_argument("--tokenizer", type=str, default="state-spaces/mamba-130m-hf") parser.add_argument("--dp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--tp", type=int, default=1) args = parser.parse_args() assert "state-spaces" in args.model, "Only models from state-spaces repo are supported" - + save_path = Path(args.save_path) parallel_config = ParallelismArgs( @@ -107,7 +116,10 @@ def _get_tensor(path: str): tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, ) - assert parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE and parallel_config.tp_linear_async_communication is False + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) parallel_context = ParallelContext( data_parallel_size=parallel_config.dp, @@ -128,7 +140,7 @@ def _get_tensor(path: str): hf_config = HFMambaConfig(**hf_config_data) dtype_str = "float32" - + yaml_content = f""" is_mamba_config: true d_model: {hf_config.d_model} @@ -163,7 +175,7 @@ def _get_tensor(path: str): ), parallel_context=parallel_context, dtype=dtype, - device=device + device=device, ) device_map = {} @@ -186,28 +198,33 @@ def _get_tensor(path: str): nanotron_to_hf = {} for i in range(model_config.num_hidden_layers): - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.A_log'] = f'backbone.layers.{i}.mixer.A_log' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.D'] = f'backbone.layers.{i}.mixer.D' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.in_proj.weight'] = f'backbone.layers.{i}.mixer.in_proj.weight' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.conv1d.weight'] = f'backbone.layers.{i}.mixer.conv1d.weight' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.conv1d.bias'] = f'backbone.layers.{i}.mixer.conv1d.bias' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.x_proj.weight'] = f'backbone.layers.{i}.mixer.x_proj.weight' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.x_proj.bias'] = f'backbone.layers.{i}.mixer.x_proj.bias' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.dt_proj.weight'] = f'backbone.layers.{i}.mixer.dt_proj.weight' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.dt_proj.bias'] = f'backbone.layers.{i}.mixer.dt_proj.bias' - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.out_proj.weight'] = f'backbone.layers.{i}.mixer.out_proj.weight' - #TODO: Maybe check if bias exists? - nanotron_to_hf[f'decoder.{i}.pp_block.mixer.out_proj.bias'] = f'backbone.layers.{i}.mixer.out_proj.bias' - nanotron_to_hf[f'decoder.{i}.pp_block.norm.weight'] = f'backbone.layers.{i}.norm.weight' - - nanotron_to_hf['token_position_embeddings.pp_block.token_embedding.weight'] = 'backbone.embedding.weight' - nanotron_to_hf['final_layer_norm.pp_block.weight'] = 'backbone.norm_f.weight' - nanotron_to_hf['lm_head.pp_block.weight'] = 'lm_head.weight' + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.A_log"] = f"backbone.layers.{i}.mixer.A_log" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.D"] = f"backbone.layers.{i}.mixer.D" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.in_proj.weight"] = f"backbone.layers.{i}.mixer.in_proj.weight" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.conv1d.weight"] = f"backbone.layers.{i}.mixer.conv1d.weight" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.conv1d.bias"] = f"backbone.layers.{i}.mixer.conv1d.bias" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.x_proj.weight"] = f"backbone.layers.{i}.mixer.x_proj.weight" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.x_proj.bias"] = f"backbone.layers.{i}.mixer.x_proj.bias" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.dt_proj.weight"] = f"backbone.layers.{i}.mixer.dt_proj.weight" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.dt_proj.bias"] = f"backbone.layers.{i}.mixer.dt_proj.bias" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.out_proj.weight"] = f"backbone.layers.{i}.mixer.out_proj.weight" + nanotron_to_hf[f"decoder.{i}.pp_block.mixer.out_proj.bias"] = f"backbone.layers.{i}.mixer.out_proj.bias" + nanotron_to_hf[f"decoder.{i}.pp_block.norm.weight"] = f"backbone.layers.{i}.norm.weight" + + nanotron_to_hf["token_position_embeddings.pp_block.token_embedding.weight"] = "backbone.embedding.weight" + nanotron_to_hf["final_layer_norm.pp_block.weight"] = "backbone.norm_f.weight" + nanotron_to_hf["lm_head.pp_block.weight"] = "lm_head.weight" # Sync weights ref_state_dict = model_ref.state_dict() - for name, param in tqdm(nanotron_model.model.named_parameters(), total=len(list(nanotron_model.model.named_parameters())), desc="Converting"): - ref_param = get_weight_from_hf(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref, nanotron_to_hf=nanotron_to_hf) + for name, param in tqdm( + nanotron_model.model.named_parameters(), + total=len(list(nanotron_model.model.named_parameters())), + desc="Converting", + ): + ref_param = get_weight_from_hf( + name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref, nanotron_to_hf=nanotron_to_hf + ) param_is_tp_sharded = ( isinstance(param, NanotronParameter) @@ -237,7 +254,7 @@ def _get_tensor(path: str): mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) sanity_check(root_module=nanotron_model) - + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) checkpoint_metadata = { "last_train_step": 0, @@ -251,14 +268,14 @@ def _get_tensor(path: str): general=GeneralArgs(project="test", run="mamba"), parallelism=parallel_config, model=ModelArgs( - init_method=MambaInit(), - model_config=model_config, - ), - tokenizer=TokenizerArgs(args.model + "-hf"), + init_method=MambaInit(), + model_config=model_config, + ), + tokenizer=TokenizerArgs(args.tokenizer), ) log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) yaml.dump(config.as_dict(), f) - + with open(save_path / "model_config.json", "w") as f: log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) - json.dump(asdict(model_config), f) \ No newline at end of file + json.dump(asdict(model_config), f) diff --git a/examples/mamba/convert_nanotron_to_hf.py b/examples/mamba/convert_nanotron_to_hf.py index e6c45a25..235d4644 100644 --- a/examples/mamba/convert_nanotron_to_hf.py +++ b/examples/mamba/convert_nanotron_to_hf.py @@ -2,13 +2,14 @@ """ Converts a nanotron model to HF format Command: - torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=weights-tp1 --save_path=HF_130M + torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=nanotron_weights --save_path=HF_weights """ import argparse import json from pathlib import Path import torch +import yaml from config import MambaModelConfig from mamba import MambaForTraining from nanotron import logging @@ -25,13 +26,14 @@ logger = logging.get_logger(__name__) -TOKENIZER_NAME = "state-spaces/mamba-130m-hf" -HARCODED_PROMPT = "Hello" - def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path): device = torch.device("cuda") + with open(checkpoint_path / "config.yaml", "r") as f: + attrs = yaml.safe_load(f) + tokenizer_name = attrs["tokenizer"]["tokenizer_name_or_path"] + with open(checkpoint_path / "model_config.json", "r") as f: attrs = json.load(f) model_config = MambaModelConfig(**attrs) @@ -173,9 +175,16 @@ def __interleave_pattern(N): model_hf.save_pretrained(save_path) print(f"Model saved to {save_path}") - -def check_converted_model_generation(save_path: Path, tokenizer_name: str): + # Save the tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + print(f"Tokenizer saved to {save_path}") + + +def check_converted_model_generation(save_path: Path): + HARCODED_PROMPT = "What is your " + + tokenizer = AutoTokenizer.from_pretrained(save_path) input_ids = tokenizer(HARCODED_PROMPT, return_tensors="pt")["input_ids"] print("Inputs:", tokenizer.batch_decode(input_ids)) @@ -197,4 +206,4 @@ def check_converted_model_generation(save_path: Path, tokenizer_name: str): convert_checkpoint_and_save(checkpoint_path=checkpoint_path, save_path=save_path) # check if the conversion was successful by generating some text - check_converted_model_generation(save_path=save_path, tokenizer_name=TOKENIZER_NAME) + check_converted_model_generation(save_path=save_path) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 149834b7..5ff60dae 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -163,8 +163,6 @@ def __init__( inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) - # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - self.dt_proj.bias._no_reinit = True self.dt_proj.weight = create_sharded_parameter_from_config( parameter=self.dt_proj.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0) diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index cb12e511..b03984d0 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -166,7 +166,7 @@ def main(): dummy_inputs = [ # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", # "This film was probably inspired by Godzilla", - "Hello" + "What is your " ] log_rank("Setup Inference mode for mamba model", logger=logger, level=logging.INFO, rank=0) @@ -183,7 +183,7 @@ def main(): generation_config=GenerationArgs(sampler="greedy", use_cache=True), tokenizer_config=TokenizerConfig(max_input_length=None), is_bench=os.environ.get("USE_BENCH", "0") == "1", - is_logits_transpose=False, + logits_are_batch_first=False, ) for output in outputs: input_ids = output.input_ids diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index f570dec9..f6c1f1a8 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -166,7 +166,7 @@ def decode_text( max_micro_batch_size: int, max_new_tokens: int, is_bench: bool = False, - is_logits_transpose: bool = True, + logits_are_batch_first: bool = True, ) -> Generator[GenerationOutput, None, None]: """We assume the following: - Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input. @@ -262,7 +262,7 @@ def decode_text( input_mask=batch_generated_mask, ) - if isinstance(sharded_logits, torch.Tensor) and is_logits_transpose: + if isinstance(sharded_logits, torch.Tensor) and logits_are_batch_first: sharded_logits = sharded_logits.transpose(0, 1) # Communicate # TODO @thomasw21: Make a diagram to show how this works From af75b417e64e470a0e4c80b5b93d6f2c77e37d97 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 4 Apr 2024 15:11:10 +0000 Subject: [PATCH 30/39] use HF instead of states-spaces version for hf_to_nanotron --- examples/mamba/convert_hf_to_nanotron.py | 56 ++++++++++++------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index bd0fa9e6..58a2cda5 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -1,9 +1,9 @@ # ruff: noqa: E402 """ -Converts a HF model from (https://huggingface.co/state-spaces/) to a Brrr model +Converts a HF model to a Nanotron model Command: - torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --model state-spaces/mamba-130m --tokenizer state-spaces/mamba-130m-hf --save_path nanotron_weights + torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --inp_path state-spaces/mamba-130m-hf --out_path nanotron_weights """ import argparse import json @@ -15,8 +15,6 @@ import yaml from config import MambaConfig, MambaInit, MambaModelConfig from mamba import MambaForTraining -from mamba_ssm.models.config_mamba import MambaConfig as HFMambaConfig -from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from nanotron import logging from nanotron.config import ( AllForwardAllBackwardPipelineEngine, @@ -35,6 +33,8 @@ from nanotron.serialize import save_meta, save_weights from nanotron.trainer import mark_tied_parameters from tqdm import tqdm +from transformers import MambaConfig as HFMambaConfig +from transformers import MambaForCausalLM from transformers.utils import CONFIG_NAME from transformers.utils.hub import cached_file @@ -49,7 +49,7 @@ def load_config_hf(model_name): def get_weight_from_hf( name: str, ref_module_state_dict: Dict[str, torch.Tensor], - ref_module: MambaLMHeadModel, + ref_module: MambaForCausalLM, nanotron_to_hf: Dict[str, str], get_grad: bool = False, ) -> torch.Tensor: @@ -96,17 +96,14 @@ def _get_tensor(path: str): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert HF weights from states-space repo to brrr weights") - parser.add_argument("--save_path", type=str, default="mamba-nanotron") - parser.add_argument("--model", type=str, default="state-spaces/mamba-130m") - parser.add_argument("--tokenizer", type=str, default="state-spaces/mamba-130m-hf") + parser.add_argument("--inp_path", type=str, default="state-spaces/mamba-130m-hf") + parser.add_argument("--out_path", type=str, default="nanotron_weight") parser.add_argument("--dp", type=int, default=1) parser.add_argument("--pp", type=int, default=1) parser.add_argument("--tp", type=int, default=1) args = parser.parse_args() - assert "state-spaces" in args.model, "Only models from state-spaces repo are supported" - - save_path = Path(args.save_path) + out_path = Path(args.out_path) parallel_config = ParallelismArgs( dp=args.dp, @@ -136,25 +133,25 @@ def _get_tensor(path: str): # Set log levels set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config) - hf_config_data = load_config_hf(args.model) - hf_config = HFMambaConfig(**hf_config_data) + hf_config = HFMambaConfig.from_pretrained(args.inp_path) dtype_str = "float32" + # TODO(fmom): Add support for ssm_cfg yaml_content = f""" is_mamba_config: true - d_model: {hf_config.d_model} + d_model: {hf_config.hidden_size} dtype: {dtype_str} fused_add_norm: true is_mamba_config: true - num_hidden_layers: {hf_config.n_layer} + num_hidden_layers: {hf_config.num_hidden_layers} pad_token_id: null pad_vocab_size_multiple: 8 residual_in_fp32: true rms_norm: true rms_norm_eps: 1.0e-05 ssm_cfg: null - vocab_size: 50280 + vocab_size: {hf_config.vocab_size} """ dtype = getattr(torch, dtype_str) @@ -162,9 +159,10 @@ def _get_tensor(path: str): attrs = yaml.safe_load(yaml_content) model_config = MambaModelConfig(**attrs) - assert model_config.vocab_size == 50280 - model_ref = MambaLMHeadModel.from_pretrained(args.model, device=device, dtype=dtype) + model_ref = MambaForCausalLM.from_pretrained(args.inp_path) + model_ref.to(device, dtype=dtype) + model_ref.eval() nanotron_model = build_model( model_builder=lambda: MambaForTraining( @@ -194,9 +192,15 @@ def _get_tensor(path: str): device_map["lm_head"] = nanotron_model.model.lm_head.rank if current_pp_rank in tied_embs_ranks else "meta" - # Create a mapping from nanotron to hf + # Get mapping of Nanotron layer to HF layer nanotron_to_hf = {} + # Static mappings + nanotron_to_hf["token_position_embeddings.pp_block.token_embedding.weight"] = "backbone.embeddings.weight" + nanotron_to_hf["final_layer_norm.pp_block.weight"] = "backbone.norm_f.weight" + nanotron_to_hf["lm_head.pp_block.weight"] = "lm_head.weight" + + # Dynamic mappings within a loop for i in range(model_config.num_hidden_layers): nanotron_to_hf[f"decoder.{i}.pp_block.mixer.A_log"] = f"backbone.layers.{i}.mixer.A_log" nanotron_to_hf[f"decoder.{i}.pp_block.mixer.D"] = f"backbone.layers.{i}.mixer.D" @@ -211,10 +215,6 @@ def _get_tensor(path: str): nanotron_to_hf[f"decoder.{i}.pp_block.mixer.out_proj.bias"] = f"backbone.layers.{i}.mixer.out_proj.bias" nanotron_to_hf[f"decoder.{i}.pp_block.norm.weight"] = f"backbone.layers.{i}.norm.weight" - nanotron_to_hf["token_position_embeddings.pp_block.token_embedding.weight"] = "backbone.embedding.weight" - nanotron_to_hf["final_layer_norm.pp_block.weight"] = "backbone.norm_f.weight" - nanotron_to_hf["lm_head.pp_block.weight"] = "lm_head.weight" - # Sync weights ref_state_dict = model_ref.state_dict() for name, param in tqdm( @@ -255,15 +255,15 @@ def _get_tensor(path: str): sanity_check(root_module=nanotron_model) - save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=save_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=out_path) checkpoint_metadata = { "last_train_step": 0, "consumed_train_samples": 0, } - save_meta(root_folder=save_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) + save_meta(root_folder=out_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata) if dist.get_rank() == 0: - with open(save_path / "config.yaml", "w") as f: + with open(out_path / "config.yaml", "w") as f: config = MambaConfig( general=GeneralArgs(project="test", run="mamba"), parallelism=parallel_config, @@ -271,11 +271,11 @@ def _get_tensor(path: str): init_method=MambaInit(), model_config=model_config, ), - tokenizer=TokenizerArgs(args.tokenizer), + tokenizer=TokenizerArgs(args.inp_path), ) log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) yaml.dump(config.as_dict(), f) - with open(save_path / "model_config.json", "w") as f: + with open(out_path / "model_config.json", "w") as f: log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) json.dump(asdict(model_config), f) From d407d44a85453efabe1d178e637bca865b75a6c9 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 4 Apr 2024 15:13:42 +0000 Subject: [PATCH 31/39] add contiguous_chunk + fix split bug when TP=1 --- examples/mamba/mamba.py | 18 +++++++++++++----- examples/mamba/selective_scan_interface.py | 8 +++++--- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 5ff60dae..664e8183 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -116,7 +116,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=False, - contiguous_chunks=None, + contiguous_chunks=(self.d_inner, self.d_inner), ) assert self.d_inner % self.tp_pg.size() == 0 @@ -251,10 +251,14 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): delta_softplus=True, ) else: - assert self.d_inner % self.tp_pg.size() == 0 - x, z = xz.view(batch_size, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) + if self.tp_pg.size() > 1: + x, z = xz.view(batch_size, self.d_inner // 2, 2, seqlen).chunk(2, dim=2) + else: + x, z = xz.view(batch_size, self.d_inner, 2, seqlen).chunk(2, dim=2) + x = x.squeeze(2) z = z.squeeze(2) + # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv @@ -311,10 +315,14 @@ def step( dtype = hidden_states.dtype assert seqlen == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - x, z = xz.view(batch_size, self.d_inner // self.tp_pg.size(), 2).chunk(2, dim=2) + + if self.tp_pg.size() > 1: + x, z = xz.view(batch_size, self.d_inner // 2, 2).chunk(2, dim=2) + else: + x, z = xz.view(batch_size, self.d_inner, 2).chunk(2, dim=2) + x = x.squeeze(2) # (B D) z = z.squeeze(2) # (B D) - # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) diff --git a/examples/mamba/selective_scan_interface.py b/examples/mamba/selective_scan_interface.py index 123641c8..45d2aae1 100644 --- a/examples/mamba/selective_scan_interface.py +++ b/examples/mamba/selective_scan_interface.py @@ -235,9 +235,11 @@ def forward( xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - # x, z = xz.chunk(2, dim=1) - assert d_inner % tp_pg.size() == 0 - x, z = xz.view(batch, d_inner // tp_pg.size(), 2, L).chunk(2, dim=2) + if tp_pg.size() > 1: + x, z = xz.view(batch, d_inner // 2, 2, L).chunk(2, dim=2) + else: + x, z = xz.view(batch, d_inner, 2, L).chunk(2, dim=2) + x = x.squeeze(2) z = z.squeeze(2) From 10e1a75d7cbc27739d9d6cf0ffac08518babc2e4 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 4 Apr 2024 15:17:38 +0000 Subject: [PATCH 32/39] better tiny mamba config --- examples/mamba/create_config_mamba.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/mamba/create_config_mamba.py b/examples/mamba/create_config_mamba.py index 40c211f9..9915ec67 100644 --- a/examples/mamba/create_config_mamba.py +++ b/examples/mamba/create_config_mamba.py @@ -3,7 +3,6 @@ import os from config import MambaConfig, MambaInit, MambaModelConfig - from nanotron.config import ( CheckpointsArgs, DataArgs, @@ -36,9 +35,9 @@ } # https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json model_config = MambaModelConfig( - d_model=1536, + d_model=1024, num_hidden_layers=48, - vocab_size=50277, + vocab_size=50278, ssm_cfg=ssm_cfg, rms_norm=True, fused_add_norm=True, @@ -98,7 +97,11 @@ adam_beta2=0.95, torch_adam_is_fused=True, learning_rate_scheduler=LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=0.0015, + lr_warmup_steps=30, + lr_warmup_style="linear", + lr_decay_style="cosine", + min_decay_lr=0.00015, ), ) @@ -111,7 +114,7 @@ tp_linear_async_communication=False, ) -tokens = TokensArgs(sequence_length=2048, train_steps=100, micro_batch_size=2, batch_accumulation_per_replica=1) +tokens = TokensArgs(sequence_length=2048, train_steps=300, micro_batch_size=8, batch_accumulation_per_replica=1) dataset = PretrainDatasetsArgs( hf_dataset_or_datasets={"roneneldan/TinyStories": 1.0}, @@ -127,7 +130,7 @@ config = MambaConfig( general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=100), parallelism=parallelism, model=ModelArgs( init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), From 5fa8401e45dc956bdc6947b51ab629be81071d6b Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 4 Apr 2024 15:18:00 +0000 Subject: [PATCH 33/39] remove assert in run generate --- examples/mamba/run_generate.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index b03984d0..f7194668 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -76,9 +76,6 @@ def main(): model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path - assert "state-spaces/mamba-130m-hf" == tokenizer_path - f"Should be 'state-spaces/mamba-130m-hf' tokenizer and not '{tokenizer_path}'" - parallel_config = ParallelismArgs( dp=args.dp or config.parallelism.dp, pp=args.pp or config.parallelism.pp, From c50a9fd56aab6bdc4b1dae03192a4caa305b6d2c Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 10 Apr 2024 11:42:47 +0000 Subject: [PATCH 34/39] breaking: apply interleaved only for TP=1 --- examples/mamba/convert_hf_to_nanotron.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 58a2cda5..059ff8af 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -52,6 +52,7 @@ def get_weight_from_hf( ref_module: MambaForCausalLM, nanotron_to_hf: Dict[str, str], get_grad: bool = False, + param_is_tp_sharded: bool = False, ) -> torch.Tensor: """From our brrr implementation, we get the equivalent tensor in transformers implementation""" @@ -82,14 +83,15 @@ def _get_tensor(path: str): param = _get_tensor(hf_name) - if "in_proj" in hf_name: + # only do this when the weight is not sharded in tensor parallel + if "in_proj" in hf_name and not param_is_tp_sharded: # In Nanotron, we do tensor parallel column so weight need to be split in the column dimension (i.e: xz.view(...)) # However, the HF weights was trained such that it expected xz.chunk(...) to split the tensor in the row dimension # Thus, we need to interleaved the HF weights to make it compatible with Nanotron log_rank( f"Interleaving {hf_name} to make it compatible with Nanotron", logger=logger, level=logging.INFO, rank=0 ) - param = param[_interleave_pattern(param.shape[0]), :] + return param[_interleave_pattern(param.shape[0]), :] return param @@ -222,16 +224,20 @@ def _get_tensor(path: str): total=len(list(nanotron_model.model.named_parameters())), desc="Converting", ): - ref_param = get_weight_from_hf( - name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref, nanotron_to_hf=nanotron_to_hf - ) - param_is_tp_sharded = ( isinstance(param, NanotronParameter) and param.is_sharded and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg ) + ref_param = get_weight_from_hf( + name=name, + ref_module_state_dict=ref_state_dict, + ref_module=model_ref, + nanotron_to_hf=nanotron_to_hf, + param_is_tp_sharded=param_is_tp_sharded, + ) + if param_is_tp_sharded: sharded_info = param.get_sharded_info() # copy param data (not just the reference) @@ -249,7 +255,6 @@ def _get_tensor(path: str): param.copy_(ref_param) ref_param = None torch.cuda.empty_cache() - # Marks parameters as NanotronParameters mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) From 8320e6254624878018b7d4d1ef809514236f7d80 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Apr 2024 09:48:38 +0000 Subject: [PATCH 35/39] backup --- examples/mamba/convert_hf_to_nanotron.py | 4 +- examples/mamba/create_config_mamba.py | 106 +++++++++++++++++++++-- 2 files changed, 103 insertions(+), 7 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 059ff8af..413eb66a 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -84,7 +84,8 @@ def _get_tensor(path: str): param = _get_tensor(hf_name) # only do this when the weight is not sharded in tensor parallel - if "in_proj" in hf_name and not param_is_tp_sharded: + # if "in_proj" in hf_name and not param_is_tp_sharded: + if "in_proj" in hf_name: # In Nanotron, we do tensor parallel column so weight need to be split in the column dimension (i.e: xz.view(...)) # However, the HF weights was trained such that it expected xz.chunk(...) to split the tensor in the row dimension # Thus, we need to interleaved the HF weights to make it compatible with Nanotron @@ -255,6 +256,7 @@ def _get_tensor(path: str): param.copy_(ref_param) ref_param = None torch.cuda.empty_cache() + # Marks parameters as NanotronParameters mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) diff --git a/examples/mamba/create_config_mamba.py b/examples/mamba/create_config_mamba.py index 47f214ad..60389423 100644 --- a/examples/mamba/create_config_mamba.py +++ b/examples/mamba/create_config_mamba.py @@ -1,6 +1,8 @@ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" import math import os +import pprint +import uuid from config import MambaConfig, MambaInit, MambaModelConfig from nanotron.config import ( @@ -19,6 +21,11 @@ ) from nanotron.logging import human_format +import wandb + +new_job_id = uuid.uuid4() +job_id = str(new_job_id)[:8] + ssm_cfg_dtype = "bfloat16" ssm_cfg = { "d_state": 16, @@ -107,8 +114,8 @@ ) parallelism = ParallelismArgs( - dp=2, - pp=2, + dp=1, + pp=1, tp=2, pp_engine="1f1b", tp_mode="ALL_REDUCE", @@ -148,7 +155,94 @@ ) if __name__ == "__main__": - dir = os.path.dirname(__file__) - - # Save config as YAML file - config.save_as_yaml(f"{dir}/config_mamba.yaml") + import argparse + from dataclasses import fields, is_dataclass + + from nanotron.config import get_config_from_file + + def print_differences(target, updates): + if not is_dataclass(target) or not is_dataclass(updates): + raise ValueError("Both target and updates should be dataclass instances") + + for field in fields(target): + update_value = getattr(updates, field.name) + + if update_value is not None: + if is_dataclass(update_value): + print_differences(getattr(target, field.name), update_value) + else: + target_value = getattr(target, field.name) + if update_value != target_value: + if update_value.__class__.__module__ != "builtins": + continue + print(f"{field.name}: {target_value} -> {update_value}") + + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", required=True, help="Output directory for yaml", type=str) + parser.add_argument("--wandb-username", required=True, help="Specific wandb username", type=str) + parser.add_argument("--wandb-project", required=True, help="Specific wandb project name", type=str) + parser.add_argument("--wandb-run", required=True, help="Specific name for this run", type=str) + + args = parser.parse_args() + + config.general.project = args.wandb_project + config.general.run = f"{args.wandb_run}_{job_id}" + + api = wandb.Api() + projects = api.projects(entity=args.wandb_username) + project_exists = any(project.name == args.wandb_project for project in projects) + + if not project_exists: + raise ValueError( + f"Project '{args.wandb_project}' does not exist. You should create the project first at entity {config.experiment_logger.wandb_logger.wandb_entity}" + ) + + directories = [] + + experiment_path = f"{args.out_dir}/{config.general.project}/{config.general.run}" + directories.append(experiment_path) + + config.checkpoints.checkpoints_path = f"{experiment_path}/checkpoints" + config.checkpoints.resume_checkpoint_path = f"{experiment_path}/checkpoints" + directories.append(config.checkpoints.checkpoints_path) + directories.append(config.checkpoints.resume_checkpoint_path) + + # if config.lighteval is not None: + # config.lighteval.slurm_script_dir = f"{experiment_path}/lighteval/slurm_scripts" + # config.lighteval.slurm_template = f"{experiment_path}/run_eval.slurm.jinja" + # config.lighteval.logging.local_output_path = f"{experiment_path}/logs" + + # directories.append(config.lighteval.slurm_script_dir) + # directories.append(config.lighteval.logging.local_output_path) + + # if config.s3_upload is not None: + # config.s3_upload.upload_s3_path = f"s3://huggingface-brrr-us-east-1/fmom/checkpoints/{args.wandb_run}_{job_id}" + # directories.append(config.s3_upload.upload_s3_path) + + # if config.profiler is not None: + # config.profiler.profiler_export_path = f"{experiment_path}/logs" + + directories.append(f"{experiment_path}/logs") + + for dir_path in directories: + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + pprint.pprint(f"Dataset name: {config.data_stages}") + print("Parallelism") + print("\tdp", config.parallelism.dp) + print("\tpp", config.parallelism.pp) + print("\ttp", config.parallelism.tp) + if config.lighteval is not None: + print("Parallelism LightEval") + print("\tdp", config.lighteval.parallelism.dp) + print("\tpp", config.lighteval.parallelism.pp) + print("\ttp", config.lighteval.parallelism.tp) + + yaml_path = f"{experiment_path}/{config.general.run}.yaml" + # Sanity check that we can load, save to YAML and reload the config + config.save_as_yaml(yaml_path) + config2 = get_config_from_file(yaml_path, config_class=MambaConfig) + print_differences(config, config2) + + print("Save at", yaml_path) From d92be5e0a8dd81339025081d357988d217f6d1d7 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Apr 2024 20:11:42 +0000 Subject: [PATCH 36/39] fix all bug --- examples/mamba/mamba.py | 5 ++--- src/nanotron/serialize/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 17c981d9..20f86804 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -116,7 +116,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=False, - contiguous_chunks=(self.d_inner, self.d_inner), + contiguous_chunks=None, ) assert self.d_inner % self.tp_pg.size() == 0 @@ -228,7 +228,6 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]): return out else: store["seqlen_offset"] += 1 - # We do matmul and transpose BLH -> HBL at the same time xz = self.in_proj(hidden_states).transpose(1, 2) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) @@ -664,7 +663,7 @@ def get_block_compute_costs(self): def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """ - Get flops per second for a Mamba model. + Get flops per second for a Mamba model. Terms such as nonlinearities, biases, and layer normalization are omitted (https://arxiv.org/pdf/2001.08361.pdf) """ # world_size = self.parallel_context.world_pg.size() diff --git a/src/nanotron/serialize/utils.py b/src/nanotron/serialize/utils.py index 661dcab6..f46c6028 100644 --- a/src/nanotron/serialize/utils.py +++ b/src/nanotron/serialize/utils.py @@ -37,8 +37,6 @@ def get_path( suffix = tensor_name.split(".") suffix_path, suffix_name = suffix[:-1], suffix[-1] - suffix_name = f"{type.value}_{suffix_name}.safetensors" - if exp_tp_pp_rank_and_size: # We always show pp_rank and tp_rank if `exp_tp_pp_rank_and_size` is provided # We only show exp_rank if tensor is exp_sharded and exp_size > 1 @@ -49,6 +47,8 @@ def get_path( ) else: suffix_name = f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{exp_rank}-of-{exp_size}.safetensors" + else: + suffix_name = f"{type.value}_{suffix_name}.safetensors" suffix_path.append(suffix_name) if prefix is None: From e245eb8f3ab5e4f11453516bac6a152c8240598f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Apr 2024 20:19:50 +0000 Subject: [PATCH 37/39] cleaning convert hf to nanotron --- examples/mamba/convert_hf_to_nanotron.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/mamba/convert_hf_to_nanotron.py b/examples/mamba/convert_hf_to_nanotron.py index 413eb66a..5109e970 100644 --- a/examples/mamba/convert_hf_to_nanotron.py +++ b/examples/mamba/convert_hf_to_nanotron.py @@ -83,8 +83,6 @@ def _get_tensor(path: str): param = _get_tensor(hf_name) - # only do this when the weight is not sharded in tensor parallel - # if "in_proj" in hf_name and not param_is_tp_sharded: if "in_proj" in hf_name: # In Nanotron, we do tensor parallel column so weight need to be split in the column dimension (i.e: xz.view(...)) # However, the HF weights was trained such that it expected xz.chunk(...) to split the tensor in the row dimension From f6befa64a86d36d5dd054ac7b20ae14ec0114b86 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Apr 2024 20:41:21 +0000 Subject: [PATCH 38/39] revert create config mamba --- examples/mamba/create_config_mamba.py | 128 +++++--------------------- 1 file changed, 23 insertions(+), 105 deletions(-) diff --git a/examples/mamba/create_config_mamba.py b/examples/mamba/create_config_mamba.py index 60389423..eee8d161 100644 --- a/examples/mamba/create_config_mamba.py +++ b/examples/mamba/create_config_mamba.py @@ -1,11 +1,11 @@ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" import math import os -import pprint import uuid from config import MambaConfig, MambaInit, MambaModelConfig from nanotron.config import ( + AdamWOptimizerArgs, CheckpointsArgs, DataArgs, DatasetStageArgs, @@ -21,10 +21,9 @@ ) from nanotron.logging import human_format -import wandb - new_job_id = uuid.uuid4() job_id = str(new_job_id)[:8] +seed = 42 ssm_cfg_dtype = "bfloat16" ssm_cfg = { @@ -44,7 +43,7 @@ # https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json model_config = MambaModelConfig( d_model=1024, - num_hidden_layers=48, + num_hidden_layers=2, vocab_size=50278, ssm_cfg=ssm_cfg, rms_norm=True, @@ -95,15 +94,12 @@ seed = 42 + optimizer = OptimizerArgs( zero_stage=0, weight_decay=0.01, clip_grad=1.0, accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1 - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, learning_rate_scheduler=LRSchedulerArgs( learning_rate=0.0015, lr_warmup_steps=30, @@ -111,11 +107,18 @@ lr_decay_style="cosine", min_decay_lr=0.00015, ), + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), ) + parallelism = ParallelismArgs( - dp=1, - pp=1, + dp=2, + pp=2, tp=2, pp_engine="1f1b", tp_mode="ALL_REDUCE", @@ -135,6 +138,11 @@ ) ] +model = ModelArgs( + init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), + model_config=model_config, +) + checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" os.makedirs(checkpoints_path, exist_ok=True) @@ -142,10 +150,7 @@ general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True), checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=100), parallelism=parallelism, - model=ModelArgs( - init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), - model_config=model_config, - ), + model=model, tokenizer=TokenizerArgs("gpt2"), optimizer=optimizer, logging=LoggingArgs(), @@ -155,94 +160,7 @@ ) if __name__ == "__main__": - import argparse - from dataclasses import fields, is_dataclass - - from nanotron.config import get_config_from_file - - def print_differences(target, updates): - if not is_dataclass(target) or not is_dataclass(updates): - raise ValueError("Both target and updates should be dataclass instances") - - for field in fields(target): - update_value = getattr(updates, field.name) - - if update_value is not None: - if is_dataclass(update_value): - print_differences(getattr(target, field.name), update_value) - else: - target_value = getattr(target, field.name) - if update_value != target_value: - if update_value.__class__.__module__ != "builtins": - continue - print(f"{field.name}: {target_value} -> {update_value}") - - parser = argparse.ArgumentParser() - parser.add_argument("--out-dir", required=True, help="Output directory for yaml", type=str) - parser.add_argument("--wandb-username", required=True, help="Specific wandb username", type=str) - parser.add_argument("--wandb-project", required=True, help="Specific wandb project name", type=str) - parser.add_argument("--wandb-run", required=True, help="Specific name for this run", type=str) - - args = parser.parse_args() - - config.general.project = args.wandb_project - config.general.run = f"{args.wandb_run}_{job_id}" - - api = wandb.Api() - projects = api.projects(entity=args.wandb_username) - project_exists = any(project.name == args.wandb_project for project in projects) - - if not project_exists: - raise ValueError( - f"Project '{args.wandb_project}' does not exist. You should create the project first at entity {config.experiment_logger.wandb_logger.wandb_entity}" - ) - - directories = [] - - experiment_path = f"{args.out_dir}/{config.general.project}/{config.general.run}" - directories.append(experiment_path) - - config.checkpoints.checkpoints_path = f"{experiment_path}/checkpoints" - config.checkpoints.resume_checkpoint_path = f"{experiment_path}/checkpoints" - directories.append(config.checkpoints.checkpoints_path) - directories.append(config.checkpoints.resume_checkpoint_path) - - # if config.lighteval is not None: - # config.lighteval.slurm_script_dir = f"{experiment_path}/lighteval/slurm_scripts" - # config.lighteval.slurm_template = f"{experiment_path}/run_eval.slurm.jinja" - # config.lighteval.logging.local_output_path = f"{experiment_path}/logs" - - # directories.append(config.lighteval.slurm_script_dir) - # directories.append(config.lighteval.logging.local_output_path) - - # if config.s3_upload is not None: - # config.s3_upload.upload_s3_path = f"s3://huggingface-brrr-us-east-1/fmom/checkpoints/{args.wandb_run}_{job_id}" - # directories.append(config.s3_upload.upload_s3_path) - - # if config.profiler is not None: - # config.profiler.profiler_export_path = f"{experiment_path}/logs" - - directories.append(f"{experiment_path}/logs") - - for dir_path in directories: - if not os.path.exists(dir_path): - os.makedirs(dir_path) - - pprint.pprint(f"Dataset name: {config.data_stages}") - print("Parallelism") - print("\tdp", config.parallelism.dp) - print("\tpp", config.parallelism.pp) - print("\ttp", config.parallelism.tp) - if config.lighteval is not None: - print("Parallelism LightEval") - print("\tdp", config.lighteval.parallelism.dp) - print("\tpp", config.lighteval.parallelism.pp) - print("\ttp", config.lighteval.parallelism.tp) - - yaml_path = f"{experiment_path}/{config.general.run}.yaml" - # Sanity check that we can load, save to YAML and reload the config - config.save_as_yaml(yaml_path) - config2 = get_config_from_file(yaml_path, config_class=MambaConfig) - print_differences(config, config2) - - print("Save at", yaml_path) + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_mamba.yaml") From f93373640ccfc53f355a60e71f2cfffb77c6613c Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Apr 2024 20:44:24 +0000 Subject: [PATCH 39/39] fix run_generate --- examples/mamba/run_generate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/mamba/run_generate.py b/examples/mamba/run_generate.py index f7194668..75271fa9 100644 --- a/examples/mamba/run_generate.py +++ b/examples/mamba/run_generate.py @@ -58,9 +58,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") - parser.add_argument("--dp", type=int, default=0) - parser.add_argument("--pp", type=int, default=0) - parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--tp", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") return parser.parse_args() @@ -77,9 +77,9 @@ def main(): tokenizer_path = config.tokenizer.tokenizer_name_or_path parallel_config = ParallelismArgs( - dp=args.dp or config.parallelism.dp, - pp=args.pp or config.parallelism.pp, - tp=args.tp or config.parallelism.tp, + dp=args.dp, + pp=args.pp, + tp=args.tp, pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False,