diff --git a/aq_engine.py b/aq_engine.py index f051a43..ea14eec 100644 --- a/aq_engine.py +++ b/aq_engine.py @@ -47,7 +47,6 @@ def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight: assert isinstance(args.devices, (list, tuple)) and len(args.devices) >= 1, f"Found devices = {args.devices}" assert args.devices[0] == self.device, (args.devices[0], self.XTX.device) self.quantized_weight = QuantizedWeight( - XTX=self.XTX.to(device=self.device, dtype=torch.float32), reference_weight=self.layer.weight.detach().to(device=self.device, dtype=torch.float32), out_group_size=args.out_group_size, in_group_size=args.in_group_size, @@ -165,7 +164,7 @@ def _replace_and_beam_search(self, params_to_replace: nn.ParameterDict, selectio ) reference_weight = self.layer.weight.detach()[out_channel_selection].to(dtype) return self.quantized_weight.beam_search_update_codes_( - self.XTX.to(dtype), reference_weight, selection=selection, **kwargs + XTX=self.XTX.to(dtype), reference_weight=reference_weight, selection=selection, **kwargs ).clone() @torch.no_grad() @@ -177,12 +176,15 @@ def beam_search_update_codes_( seed: Optional[int] = None, **kwargs, ): - """Update self.quantized_weight.codes in-place via beam search""" + """Update quantized_weight codes in-place via beam search""" if len(devices) == 1: # single device assert replicas is None dtype = self.quantized_weight.codebooks.dtype self.quantized_weight.beam_search_update_codes_( - self.XTX.to(dtype), self.layer.weight.detach().to(dtype), dim_rng=random.Random(seed), **kwargs + XTX=self.XTX.to(dtype), + reference_weight=self.layer.weight.detach().to(dtype), + dim_rng=random.Random(seed), + **kwargs, ) else: assert replicas[0] is self @@ -203,7 +205,7 @@ def beam_search_update_codes_( ) # gather all code parts and assign them to each replica for device, replica in zip(devices, replicas): - replica.quantized_weight.codes[...] = Gather.apply(device, 0, *new_code_parts_by_replica) + replica.quantized_weight.set_codes(Gather.apply(device, 0, *new_code_parts_by_replica)) def replace_parameter_(module: nn.Module, name: str, new_value: torch.Tensor): diff --git a/convert_legacy_model_format.py b/convert_legacy_model_format.py new file mode 100644 index 0000000..e7e7a16 --- /dev/null +++ b/convert_legacy_model_format.py @@ -0,0 +1,214 @@ +""" +This abomination converts between one of several quantized model formats to the same format as returned by main.py . +This code exists because we failed to produce a single data format for quantized model. +We should eventually switch to saving all models in the same data format. Once we do, this file should be deleted. +""" +import argparse +import os +import warnings +from copy import deepcopy + +import torch +import transformers.models +from torch import nn + +from src.aq import QuantizedLinear, QuantizedWeight +from src.modelutils import get_model, save_quantized_model +from src.utils import is_signed + + +def load_quantized_model_with_old_pickle(base_model_name: str, quantized_model_name: str, **kwargs): + """Hacky way to allow compatibility between old *pickled* layers and new transformers""" + # because patching it for the fourth time is better than writing a proper saver once >.< + import transformers.activations + + if not hasattr(transformers.activations, "SiLUActivation"): + transformers.activations.SiLUActivation = deepcopy(torch.nn.SiLU) + transformers.activations.SiLUActivation.inplace = False + # https://github.com/huggingface/transformers/issues/28496 + if not hasattr(transformers.models.llama.modeling_llama.LlamaAttention, "attention_dropout"): + transformers.models.llama.modeling_llama.LlamaAttention.attention_dropout = 0 + quantized_model = get_model(base_model_name, None, **kwargs) + quantized_model_src = get_model(base_model_name, quantized_model_name, **kwargs) + for module in quantized_model_src.modules(): + if isinstance(module, QuantizedWeight) and not hasattr(module, "codes_storage"): + module.codes_storage = None # backwards compatibility with older pickled snapshots + + lut = {} + for name, module in quantized_model_src.named_modules(): + for child_name, child_module in module.named_children(): + if isinstance(child_module, QuantizedWeight): + lut[name + "." + child_name] = child_module + print(f"found {len(lut)} quantized weight matrices") + for name, module in quantized_model.named_modules(): + for child_name, child_module in module.named_children(): + if name + "." + child_name + ".quantized_weight" in lut: + quantized_weight = lut.pop(name + "." + child_name + ".quantized_weight") + assert isinstance(child_module, nn.Linear) + setattr(module, child_name, QuantizedLinear(quantized_weight, bias=child_module.bias)) + assert not lut, list(lut.keys()) + quantized_model.load_state_dict(quantized_model_src.state_dict()) + warnings.warn("You should be ashamed of yourself.") + return quantized_model + + +import functools + + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def load_quantized_model_from_fdsp_checkpoint(base_model_name: str, fsdp_checkpoint_path: str, **kwargs): + original_model = get_model(base_model_name, None, **kwargs) + + state_filenames = os.listdir(fsdp_checkpoint_path) + + non_quant_fname = "non_quantized_state_dict.pth" + non_quant_path = os.path.join(fsdp_checkpoint_path, non_quant_fname) + non_quant_states = torch.load(non_quant_path) + + incomp_keys = original_model.load_state_dict(non_quant_states, strict=False) + assert not incomp_keys.unexpected_keys + + missing_keys = list() + for module_name, module in original_model.named_modules(): + if not isinstance(module, nn.Linear): + continue + + assert not module.bias + state_fname = f"{module_name}.weight.pth" + + if state_fname not in state_filenames: + missing_keys.append(module_name) + continue + + state_path = os.path.join(fsdp_checkpoint_path, state_fname) + quantized_weight = torch.load(state_path, map_location="cpu") + quantized_linear = QuantizedLinear(quantized_weight, bias=None) + rsetattr(original_model, module_name, quantized_linear) + + return original_model + + +def main(): + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--base_model", + type=str, + required=True, + help="path or name of the teacher model", + ) + parser.add_argument( + "--quantized_model", + type=str, + required=True, + help="path to quantized model", + ) + parser.add_argument( + "--load_dtype", + type=str, + default="auto", + choices=["auto", "float16", "float32", "bfloat16"], + help="dtype to load the model in", + ) + parser.add_argument( + "--code_dtype", + type=str, + default=None, + help="if specified, cast quantized layers' codes to this dtype; default = keep loaded dtype", + ) + parser.add_argument( + "--p_finetuned_state_dict", + type=str, + default=None, + help="path to quantized model state dict saved by the old FSDP finetuning code", + ) + parser.add_argument( + "--pv_fsdp_dir", + type=str, + default=None, + help="path to quantized model state dict saved by the old FSDP finetuning code", + ) + parser.add_argument( + "--monkeypatch_old_pickle", + action="store_true", + help="If set, load quantized_model in a hacky way that allows pickled models with older transformers/torch.", + ) + parser.add_argument( + "--attn_implementation", + type=str, + default=None, + help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether to trust remote code when loading base model.", + ) + parser.add_argument("--save", type=str, required=True, help="Save the converted quantized model here") + + args = parser.parse_args() + assert args.p_finetuned_state_dict or args.pv_fsdp_dir, "either one of those must be specified" + print(f"{args.p_finetuned_state_dict=}, {args.pv_fsdp_dir=}") + assert (args.p_finetuned_state_dict is not None) != (args.pv_fsdp_dir is not None) + + args.load_dtype = getattr(torch, args.load_dtype) if args.load_dtype != "auto" else "auto" + args.code_dtype = getattr(torch, args.code_dtype) if args.code_dtype is not None else None + + if not args.monkeypatch_old_pickle: + quantized_model = get_model( + args.base_model, + args.quantized_model, + dtype=args.load_dtype, + trust_remote_code=args.trust_remote_code, + attn_implementation=args.attn_implementation, + ) + elif args.p_finetuned_state_dict: + quantized_model = load_quantized_model_with_old_pickle( + args.base_model, + args.quantized_model, + dtype=args.load_dtype, + trust_remote_code=args.trust_remote_code, + attn_implementation=args.attn_implementation, + ) + elif args.pv_fsdp_dir: + quantized_model = load_quantized_model_from_fdsp_checkpoint( + args.base_model, + args.pv_fsdp_dir, + dtype=args.load_dtype, + trust_remote_code=args.trust_remote_code, + ) + + for module in quantized_model.modules(): + if isinstance(module, QuantizedWeight): + if not hasattr(module, "codes_storage"): + module.codes_storage = None + if module.codes is None: + module.unwrap_codes_() + assert module.codes is not None + if args.code_dtype is not None: + assert module.nbits_per_codebook <= torch.iinfo(args.code_dtype).bits - is_signed(args.code_dtype) + module.codes = nn.Parameter(module.codes.to(args.code_dtype), requires_grad=module.codes.requires_grad) + + if args.p_finetuned_state_dict is not None: + state_dict = torch.load(args.p_finetuned_state_dict, map_location="cpu") + state_dict = {k: v for k, v in state_dict.items() if not k.endswith(".codes_storage.data")} + status = quantized_model.load_state_dict(state_dict, strict=False) + assert all(key.endswith("codes") for key in status.missing_keys) + assert not status.unexpected_keys + del state_dict, status # note: in this case, it is okay not to load codes since P step does not change them + + save_quantized_model(quantized_model, args.save) + + +if __name__ == "__main__": + main() diff --git a/finetune.py b/finetune.py index 7c68488..bfb405f 100644 --- a/finetune.py +++ b/finetune.py @@ -1,12 +1,45 @@ +""" +Fine-tune an LLM that was previously quantized with AQLM; +based on https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py +""" import argparse import os -import shutil -from copy import deepcopy +from contextlib import nullcontext +from functools import partial +from typing import Dict, Optional, Tuple +import datasets import torch +import torch.distributed import torch.nn.functional as F -from accelerate.hooks import remove_hook_from_submodules -from tqdm import tqdm, trange +import torch.optim +import torch.utils.data +import transformers +from torch import nn as nn +from torch.distributed.fsdp import ( + CPUOffload, + FullStateDictConfig, + FullyShardedDataParallel, + MixedPrecision, + StateDictType, +) +from tqdm.auto import tqdm + +from convert_legacy_model_format import load_quantized_model_with_old_pickle +from src.aq import QuantizedWeight +from src.configurable_adam import ConfigurableAdamW +from src.datautils import evaluate_perplexity, get_loaders, group_texts, split_long_texts +from src.memory_efficient_loss import compute_kl_divergence_loss_values +from src.modelutils import get_model, is_model_for_causal_lm +from src.pv_optimizer import StraightThroughAdamW +from src.pv_utils import ( + YourQuantizedWeightIsInAnotherRank, + create_dequantized_model, + get_original_named_parameters_from_fsdp_module, + infer_module_classes, + split_quantized_weights_between_ranks, +) +from src.utils import IntCodes, is_signed, master_rank_first, one_rank_at_a_time try: import wandb @@ -15,159 +48,8 @@ except ModuleNotFoundError: has_wandb = False -from main import perplexity_eval -from src.datautils import get_loaders -from src.modelutils import get_layers, get_model, save_not_quantized_weights -from src.utils import _extract_into_tensor, maybe_get_0th_element - - -@torch.inference_mode() -def cache_hiddens(model, dataloader, args): - device = next(model.parameters()).device - cached_hiddens = [] - for i in trange(len(dataloader), total=len(dataloader), desc="Caching hiddens", leave=False): - with torch.autocast(device_type="cuda", enabled=args.amp): - batch = maybe_get_0th_element(dataloader[i]).to(device) - cached_hiddens.append(model.model(batch).last_hidden_state.cpu()) - return cached_hiddens - - -def kl_div(student_hiddens, teacher_hiddens): - C = student_hiddens.shape[-1] # num classes - return F.kl_div( - input=F.log_softmax(student_hiddens.view(-1, C), dim=-1), - target=F.log_softmax(teacher_hiddens.view(-1, C), dim=-1), - log_target=True, - reduction="batchmean", - ) - - -@torch.no_grad() -def evaluate(model, lm_head, loader, hiddens, batch_size, dtype): - model.eval() - loss_numerator, loss_denominator = 0, 0 - device = next(model.parameters()).device - # convert tensor to list - for i in range(0, len(loader), batch_size): - batch_ids = range(i, i + batch_size) - inputs = _extract_into_tensor(loader, batch_ids, device=device) - targets = lm_head(_extract_into_tensor(hiddens, batch_ids, device=device, dtype=dtype)) - outputs = model(inputs).logits - loss = kl_div(outputs, targets.to(outputs.device)) - loss_numerator += loss.item() - loss_denominator += 1 - return loss_numerator / loss_denominator - - -def finetune(model, train_loader, train_hiddens, args, device, val_loader=None, val_hiddens=None): - # cast model to finetune dtype - model.to(args.finetune_dtype) - lm_head = deepcopy(model.lm_head) - for param in lm_head.parameters(): - param.requires_grad = False - - diff_params = {name: param for name, param in model.named_parameters() if param.requires_grad} - print(f"Fine-tuning {sum(param.numel() for _, param in diff_params.items())} parameters") - opt = torch.optim.Adam(diff_params.values(), lr=args.lr, betas=(args.adam_beta1, args.adam_beta2)) - scaler = torch.cuda.amp.GradScaler(enabled=args.amp) - - num_accumulation_steps = args.batch_size // args.microbatch_size - num_samples = len(train_loader) - epoch_samples = num_samples - num_samples % args.microbatch_size - microbatches_per_epoch = epoch_samples // args.microbatch_size - - if args.gradient_checkpointing: - model.gradient_checkpointing_enable() - - run_validation = val_loader is not None and val_hiddens is not None - # validate before training - if run_validation: - valid_loss_epoch = evaluate(model, lm_head, val_loader, val_hiddens, args.microbatch_size, args.finetune_dtype) - print(f"Evaluation before training.") - print(f"valid loss={valid_loss_epoch:.3e}\t") - best_loss = valid_loss_epoch - best_params = deepcopy(diff_params) - worse_count = 0 - - for epoch in range(args.epochs): - # train loop - model.train() - loss_numerator, loss_denominator = 0, 0 - steps_accumulated = 0 - # prepare batch indices - batch_indices_epoch = torch.randperm(num_samples)[:epoch_samples].chunk(microbatches_per_epoch) - - for batch_indices in tqdm(batch_indices_epoch, desc=f"Train epoch {epoch}", leave=False): - # convert tensor to list - batch_indices = batch_indices.tolist() - inputs = _extract_into_tensor(train_loader, batch_indices, device=device) - with torch.no_grad(): - targets = lm_head( - _extract_into_tensor(train_hiddens, batch_indices, device=device, dtype=args.finetune_dtype) - ) - - with torch.autocast(device_type="cuda", enabled=args.amp): - outputs = model(inputs).logits - loss = kl_div(outputs, targets.to(device=outputs.device, dtype=args.finetune_dtype)) - - if not torch.isfinite(loss).item(): - raise ValueError(f"Fine-tuning loss is {loss}") - - scaler.scale(loss / num_accumulation_steps).backward() - - steps_accumulated += 1 - if steps_accumulated == num_accumulation_steps: - scaler.step(opt) - scaler.update() - opt.zero_grad() - # reset accumulated step and loss - steps_accumulated = 0 - - loss_numerator += loss.item() - loss_denominator += 1 - train_loss_epoch = loss_numerator / loss_denominator - - if run_validation: - valid_loss_epoch = evaluate( - model, lm_head, val_loader, val_hiddens, args.microbatch_size, args.finetune_dtype - ) - - # log losses in the end of the epoch - print("-" * 10) - print(f"epoch={epoch}") - print(f"train loss={train_loss_epoch:.3e}\t") - if run_validation: - print(f"valid loss={valid_loss_epoch:.3e}\t") - - if args.wandb: - wandb.log({"train_loss": train_loss_epoch}, step=epoch) - if run_validation: - wandb.log({"valid_loss": valid_loss_epoch}, step=epoch) - - if run_validation: - if valid_loss_epoch < best_loss: - print(f"new best loss {valid_loss_epoch:.3e} on epoch {epoch}") - best_loss = valid_loss_epoch - best_params = deepcopy(diff_params) - worse_count = 0 - else: - worse_count += 1 - if worse_count >= args.early_stop: - break - - if run_validation: - model.load_state_dict(best_params, strict=False) - - -def print_memory_stats(): - print(f"GPU max memory allocated: {torch.cuda.max_memory_allocated() / 2 ** 30:.2f} GB.") - print(f"GPU max memory reserved: {torch.cuda.max_memory_reserved() / 2 ** 30:.2f} GB.") - - -def main(): - parser = argparse.ArgumentParser(add_help=True) - # Model params +def add_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--base_model", type=str, @@ -175,22 +57,15 @@ def main(): help="path or name of the teacher model", ) parser.add_argument( - "--quant_model", + "--quantized_model", type=str, required=True, help="path to quantized model", ) - # Data params - parser.add_argument( - "--dataset", - type=str, - help="Dataset name [c4, pajama] or path to data where to extract calibration data from.", - ) parser.add_argument( - "--nsamples", - type=int, - default=1024, - help="number of samples", + "--monkeypatch_old_pickle", + action="store_true", + help="If set, load quantized_model in a hacky way that allows pickled models with older transformers/torch.", ) parser.add_argument( "--model_seqlen", @@ -199,114 +74,379 @@ def main(): help="Model seqlen and calibration data context length.", ) parser.add_argument( - "--eval_model_seqlen", - type=int, + "--master_dtype", + type=str, + default="float32", + help="data type for storing master parameters and computing optimizer updates", + ) + parser.add_argument( + "--embed_dtype", + type=str, default=None, - help="Model seqlen on validation. By default is equal to model_seqlen.", + help="data type for storing master input and output embeddings; defaults to master_dtype", ) parser.add_argument( - "--val_size", - type=int, - default=0, - help="size of validation split", + "--load_dtype", + type=str, + default="auto", + choices=["auto", "float16", "float32", "bfloat16"], + help="dtype to load the model in", ) parser.add_argument( - "--eval_datasets", - nargs="+", + "--amp_dtype", type=str, - default=["wikitext2", "c4"], - help="Datasets to run evaluation on", + default=None, + help="if specified, runs automated mixed precision with this dtype", + ) + parser.add_argument( + "--straight_through_buffer_dtype", + type=str, + default=None, + help="data type for storing optimized straight through buffers, defaults to master_dtype", + ) + parser.add_argument( + "--code_dtype", + type=str, + default=None, + help="if specified, cast quantized layers' codes to this dtype; default = keep loaded dtype", + ) + parser.add_argument( + "--block_type", + type=str, + required=True, + help="string name of a transformer layer to wrap, e.g. LlamaDecoderLayer", + ) + parser.add_argument( + "--wrap_separately", + type=str, + nargs="*", + default=[], + help="module classes (by name, similar to block_type) that will be wrapped in a separate fsdp instance and do " + "not participate in FSDP AMP (if used). Applies to the student (de)quantized model, not the teacher model.", + ) + parser.add_argument( + "--attn_implementation", + type=str, + default=None, + help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2", + ) + parser.add_argument( + "--limit_parallel_inits", + type=int, + default=1, + help="this many ranks (per host) initialize their model in parallel. This parameter is meant to save host RAM.", + ) + + +def add_finetuning_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--update_codes", + action="store_true", + help="If set, train discrete codes; if not, freeze them", + ) + parser.add_argument( + "--update_codebooks_and_scales", + action="store_true", + help="If set, train continuous parameters of quantized representations; if not, freeze them", + ) + parser.add_argument( + "--update_non_quantized_parameters", + action="store_true", + help="If set, train the non-quantized model parameters (layernorm scales, biases, logits); if not, freeze them", + ) + parser.add_argument( + "--force_dequantize", + action="store_true", + help="If set, the algorithm will create a de-quantized model instead of dequantizing weights just in time even" + "when doing p-only tuning. This version only has effect if --update_codes is not set. Setting this will" + " make the training run faster, but it will also use substantially more memory.", ) - # Training params parser.add_argument( "--lr", type=float, default=1e-5, - help="finetuning learning rate", + help="finetuning learning rate for continuous params", ) parser.add_argument( "--adam_beta1", type=float, default=0.90, - help="Adam beta1", + help="Adam beta1 for continuous params", ) parser.add_argument( "--adam_beta2", type=float, default=0.95, - help="Adam beta2", + help="Adam beta2 for continuous params", + ) + + parser.add_argument( + "--code_lr", + type=float, + default=1e-2, + help="finetuning learning rate for discrete codes", ) parser.add_argument( - "--epochs", + "--code_beta1", + type=float, + default=0.0, + help="Adam beta1 for discrete params", + ) + parser.add_argument( + "--code_beta2", + type=float, + default=0.95, + help="Adam beta2 for discrete params", + ) + parser.add_argument( + "--delta_decay", + type=float, + default=0.0, + help="Determines whether to use direct training, straight-through estimation or a mixture thereof. " + "If delta_decay is 0, use straight-through estimation. If delta_decay is 1, do not use it at all. " + "If between 0 and 1, every straight-through buffer will decay to the quantized weight with moving average." + " straight_through_buffer = (1 - delta_decay) * straight_through_buffer + delta_decay * quantized_weight." + " Please refer to the docstring of StraightThroughAdam for details.", + ) + parser.add_argument( + "--max_code_change_per_step", + type=float, + default=1e-2, + help="Maximum number of code groups that can be changed during one update to codes. " + "This constraint is enforced on a per-tensor level. If the weight is represented with multiple codes, " + "changing any of the codes will count towards the limit. If more than this many code groups have changed, " + "the algorithm will rollback the changes with least update norm until the constraint is satisfied.", + ) + parser.add_argument( + "--code_trust_ratio", + type=float, + default=None, + help="By default, the optimizer can make arbitrary changes to quantized weights. If this parameter is set," + "the optimizer ensures that the change to quantized weights is not too large by undoing some of the change" + "until ||new_quantized_weights - prev_quantized_weights|| / ||prev_quantized_weight|| <= code_trust_ratio." + " See StraightThroughAdam docstring for details.", + ) + parser.add_argument( + "--force_code_update", + action="store_true", + help="If set, force discrete codes to change in the direction of optimizer update, even if previous codes" + "were optimal in terms of MSE. See StraightThroughAdam docstring for details. Use when delta_decay==1.", + ) + parser.add_argument( + "--code_selection_temperature", + type=float, + default=0, + help="If max_code_change_per_step or code_trust_ratio is set and code_selection_temperature=0, beam search will" + " prioritize updating codes that have the largest continuosu update norm. If code_selection_temperature is" + " not 0, sample a subset of codes for update stochastically. See StraightThroughAdam for details.", + ) + parser.add_argument( + "--beam_size", type=int, - default=10, - help="Maximum number of epochs", + default=1, + help="Beam size when updating codes; higher is slower but more accurate. For single codebook, use beam_size=1", + ) + parser.add_argument( + "--code_adam_16bit", + action="store_true", + help="If set, adam statistics for codes will be stored as float16 (exp_avg and v_hat) or bfloat16(exp_avg_sq)", + ) + parser.add_argument( + "--offload_optimizer", + action="store_true", + help="If set, adam statistics will be offloaded to RAM", + ) + parser.add_argument( + "--offload_teacher_params", + action="store_true", + help="If set, the teacher model will be offloaded to RAM and paged using FSDP's CPUOffload", + ) + parser.add_argument( + "--offload_student_params", + action="store_true", + help="If set, the student model will be offloaded to RAM and paged using FSDP's CPUOffload", + ) + parser.add_argument( + "--limit_all_gathers", + action="store_true", + help="sets limit_all_gathers in both FSDP instances", + ) + parser.add_argument( + "--forward_prefetch", + action="store_true", + help="sets forward_prefetech in both FSDP instances", + ) + parser.add_argument( + "--lamb", + action="store_true", + help="If set, use Lamb (aka Adam with trust ratio)", + ) + parser.add_argument( + "--amsgrad", + action="store_true", + help="if True, use the AMSGrad variant of adam/lamb", + ) + parser.add_argument( + "--debias", + action="store_true", + default=None, + help="Whether or not to debias optimizer statistics; defaults to True for adam and False for Lamb", + ) + parser.add_argument( + "--no_debias", + action="store_false", + dest="debias", + help="Disable optimizer debiasing (see above)", + ) + parser.add_argument( + "--verbose_optimizer", + action="store_true", + help="If set, the optimizer will print beam search results, tensors norms, etc", ) parser.add_argument( "--batch_size", type=int, default=1, - help="training batch size", + help="training batch size - how many samples are processed per optimizer step, between all GPUs in total", ) parser.add_argument( "--microbatch_size", type=int, default=None, - help="training microbatch size", + help="training microbatch size - how many samples are processed per GPU per forward pass", ) parser.add_argument( "--gradient_checkpointing", action="store_true", - help="Whether to apply gradient checkpointing", + help="Whether to apply gradient checkpointing for transformer blocks", ) parser.add_argument( - "--amp", - action="store_true", - help="Whether to use amp", + "--loss_tokens_per_chunk", + type=int, + default=None, + help="If specified, compute LM logits and loss using gradient checkpointing in chunks of this size." + "This option slows down loss computation, but reduces memory usage. Recommended for large vocabularies", ) parser.add_argument( - "--early_stop", - type=int, - default=3, - help="Terminate finetuning if loss doesn't improve after this number of epochs.", + "--use_fsdp_amp", + action="store_true", + help="Whether to use FSDP native mixed precision (excluding registered layernorms and --wrap_separately).", ) parser.add_argument( - "--finetune_dtype", - type=str, - default="float32", - choices=["float16", "float32", "bfloat16"], - help="dtype to finetune the model", + "--minimize_sync", + action="store_true", + help="if True, accumulate microbatch gradients locally and synchronize once per optimizer step. If False, " + "synchronize after every step. This reduces communication overhead but increases memory usage. See " + "https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync", ) - # Logging params - parser.add_argument("--wandb", action="store_true", help="Whether to use wandb or store locally.") - # Save params - parser.add_argument("--save", type=str, default=None, help="Path to save quantized statistics.") - # Misc params parser.add_argument( "--seed", type=int, - default=0, + default=42, help="Seed for calibration data and initialization. " "Note that the main training is not strictly deterministic.", ) + parser.add_argument("--wandb", action="store_true", help="Whether to use wandb or store locally.") + parser.add_argument("--save", type=str, default=None, help="Path to save training snapshot.") parser.add_argument( - "--offload_activations", - action="store_true", - help="Offload activations to RAM to save GPU memory.", + "--max_epochs", + type=int, + default=1000, + help="Total number of training epochs (passes over calibration data) after which the training will conclude", + ) + parser.add_argument( + "--print_every_steps", + type=int, + default=None, + help="print training metrics once in this many optimizer steps (this many updates to model parameters)", ) parser.add_argument( - "--dtype", + "--eval_every_steps", + type=int, + default=None, + help="evaluate once in this many optimizer steps (this many updates to model parameters)", + ) + parser.add_argument( + "--save_every_steps", + type=int, + default=None, + help="save state once in this many optimizer steps (this many updates to model parameters)", + ) + parser.add_argument("--keep_best_model", action="store_true", help="Save best model state separately") + parser.add_argument( + "--on_save", type=str, - default="auto", - choices=["auto", "float16", "float32", "bfloat16"], - help="dtype to load the model in", + default=None, + help="Optional callback (python code string) to call after each saved layer. Example: when " + "training on preemptible compute, upload partially finetuned model and resume later", + ) + + +def add_data_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--dataset_name", + type=str, + required=True, + help="Training dataset name (from HF datasets) or path to data where to extract calibration data from", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--split", + type=str, + default="train", + help="Training dataset split name, e.g. 'train'", ) parser.add_argument( - "--device_map", + "--cache_dir", type=str, default=None, - choices=[None, "auto"], - help="accelerate device map", + help="Cache dir for huggingface datasets", + ) + parser.add_argument( + "--overwrite_cache", + action="store_true", + help="If set, re-run data preprocessing even if it is cached", + ) + parser.add_argument( + "--num_workers", + type=int, + default=8, + help="Number of CPU workers for preprocessing and data loading", + ) + parser.add_argument( + "--download_num_workers", + type=int, + default=None, + help="Number of CPU workers for downloading the training dataset; overrides num_workers", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="Number of CPU workers for preprocessing; overrides num_workers", + ) + parser.add_argument( + "--preprocessing_chunk_length", + type=int, + default=100_000, + help="Texts exceeding this length will be split approximately in the middle", + ) + parser.add_argument( + "--preprocessing_keep_in_memory", + action="store_true", + help="If set, do not save intermediate preprocessing steps in memory", + ) + parser.add_argument( + "--eval_datasets", + nargs="+", + type=str, + default=["wikitext2", "c4"], + help="Datasets to run evaluation on", ) parser.add_argument( "--use_fast_tokenizer", @@ -316,105 +456,729 @@ def main(): parser.add_argument( "--trust_remote_code", action="store_true", - help="Whether to trust remote code.", + help="Whether to trust remote code when loading base model.", ) - args = parser.parse_args() - args.microbatch_size = args.microbatch_size or args.batch_size - args.finetune_dtype = getattr(torch, args.finetune_dtype) - if args.amp: - assert args.finetune_dtype == torch.float32, "AMP works only with original model in fp32." - # get device - assert torch.cuda.is_available() - device = "cuda" - args.devices = [device] # needed for perplexity eval - if args.wandb: - assert has_wandb, "`wandb` not installed, try pip install `wandb`" - wandb.init(config=args) - # get data - dataloader = get_loaders( - args.dataset, - nsamples=args.nsamples, - seed=args.seed, - model_path=args.base_model, - seqlen=args.model_seqlen, - use_fast_tokenizer=args.use_fast_tokenizer, + parser.add_argument( + "--save_dataset_and_exit", + type=str, + default=None, + help="If not None, save tokenized dataset to this path and exit training immediately", + ) + + +def prepare_training_dataset(args: argparse.Namespace, tokenizer: transformers.PreTrainedTokenizer) -> datasets.Dataset: + if os.path.exists(args.dataset_name): + dataset = datasets.load_from_disk(args.dataset_name) + else: + dataset = datasets.load_dataset( + args.dataset_name, + args.dataset_config_name, + split=args.split, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + num_proc=args.download_num_workers if args.download_num_workers is not None else args.num_workers, + streaming=False, + ) + + def is_tokenized(dataset): + return "input_ids" in dataset.column_names + + if is_tokenized(dataset): + if torch.distributed.get_rank() == 0: + print("Dataset already tokenized") + assert len(dataset[0]["input_ids"]) == args.model_seqlen + return dataset + + text_column_name = "text" if "text" in dataset.column_names else next(iter(dataset.column_names)) + + if args.preprocessing_chunk_length is not None: + dataset = dataset.map( + lambda examples: { + text_column_name: split_long_texts(examples[text_column_name], args.preprocessing_chunk_length) + }, + batched=True, + num_proc=args.preprocessing_num_workers if args.preprocessing_num_workers is not None else args.num_workers, + remove_columns=list(dataset.column_names), + keep_in_memory=args.preprocessing_keep_in_memory, + load_from_cache_file=not args.overwrite_cache, + desc=f"Splitting dataset over newline into chunks of ~{args.preprocessing_chunk_length} characters", + ) + + tokenized_dataset = dataset.map( + lambda example: tokenizer(example[text_column_name]), + num_proc=args.preprocessing_num_workers if args.preprocessing_num_workers is not None else args.num_workers, + remove_columns=list(dataset.column_names), + keep_in_memory=args.preprocessing_keep_in_memory, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + lm_dataset = tokenized_dataset.map( + partial(group_texts, block_size=args.model_seqlen, add_labels=False), + batched=True, + num_proc=args.preprocessing_num_workers if args.preprocessing_num_workers is not None else args.num_workers, + keep_in_memory=args.preprocessing_keep_in_memory, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {args.model_seqlen}", + ) + assert is_tokenized(lm_dataset) + return lm_dataset + + +def load_teacher_model(args: argparse.Namespace, device: torch.device) -> FullyShardedDataParallel: + """Load unquantized model with frozen parameters""" + model = get_model( + args.base_model, + load_quantized=None, + dtype=args.load_dtype, trust_remote_code=args.trust_remote_code, + attn_implementation=args.attn_implementation, + ).to(dtype=args.load_dtype if args.load_dtype != "auto" else None) + model.train(False) + for param in model.parameters(): + param.requires_grad = False + + model.config.use_cache = False + transformer_block_types = infer_module_classes(model, args.block_type) + + return wrap_model_with_fsdp_( + model, + auto_wrap_policy=lambda module, recurse, **_etc: recurse or isinstance(module, transformer_block_types), + cpu_offload=CPUOffload(offload_params=args.offload_teacher_params) if args.offload_teacher_params else None, + limit_all_gathers=args.limit_all_gathers, + forward_prefetch=args.forward_prefetch, + device_id=device, + ) + + +def load_student_model( + args: argparse.Namespace, device: torch.device, dequantize: bool +) -> Tuple[FullyShardedDataParallel, Optional[Dict[str, QuantizedWeight]]]: + """ + load student model for fine-tuning. If dequantize is set, dequantize all quantized weights to accumulate full grads + """ + if not args.monkeypatch_old_pickle: + student_model = get_model( + args.base_model, + args.quantized_model, + dtype=args.load_dtype, + trust_remote_code=args.trust_remote_code, + attn_implementation=args.attn_implementation, + ).to( + args.master_dtype + ) # master parameters + else: + student_model = load_quantized_model_with_old_pickle( + args.base_model, + args.quantized_model, + dtype=args.load_dtype, + trust_remote_code=args.trust_remote_code, + attn_implementation=args.attn_implementation, + ).to(args.master_dtype) + + if args.embed_dtype != args.master_dtype: + student_model.set_output_embeddings(student_model.get_output_embeddings().to(args.embed_dtype)) + student_model.set_input_embeddings(student_model.get_input_embeddings().to(args.embed_dtype)) + + student_model.config.use_cache = False + student_model.train(True) # note: HF gradient checkpoints do not work for some models without train(True); see + # https://github.com/huggingface/transformers/blob/2d92db8/src/transformers/models/llama/modeling_llama.py#L1006 + if args.gradient_checkpointing: + student_model.gradient_checkpointing_enable() + student_model.enable_input_require_grads() + + # convert QuantizedModel state dict to make it compatible with FSDP + for name, module in student_model.named_modules(): + if isinstance(module, QuantizedWeight): + assert module.codes is not None + if args.code_dtype is not None: + assert module.nbits_per_codebook <= torch.iinfo(args.code_dtype).bits - is_signed(args.code_dtype) + module.codes = nn.Parameter(module.codes.to(args.code_dtype), requires_grad=module.codes.requires_grad) + module.wrap_codes_for_fsdp_() + assert module.codes is None and isinstance(module.codes_storage, IntCodes) + assert any(isinstance(module, IntCodes) for module in student_model.modules()) + + if dequantize: + student_model, named_quantized_params = create_dequantized_model( + student_model, dequantized_dtype=args.amp_dtype, reuse_non_quantized=True + ) + else: + named_quantized_params = None + + transformer_block_types = list(infer_module_classes(student_model, args.block_type)) + layernorm_types = list(transformers.pytorch_utils.ALL_LAYERNORM_LAYERS) + extra_block_types = list() + for extra_module_name in args.wrap_separately: + extra_block_types.extend(infer_module_classes(student_model, extra_module_name)) + block_types_to_wrap = tuple( + set( + transformer_block_types + + layernorm_types + + extra_block_types + + [ + IntCodes, + ] + ) + ) + if torch.distributed.get_rank() == 0: + print(f"Blocks to be wrapped separately: {block_types_to_wrap}\n") + + mixed_precision = None + if args.use_fsdp_amp: + assert args.amp_dtype is not None, "requested to use_fsdp_amp, but amp_dtype is not None" + block_types_for_amp_to_ignore = tuple(set(layernorm_types + extra_block_types)) + if torch.distributed.get_rank() == 0: + print(f"Blocks excluded from AMP: {block_types_for_amp_to_ignore}\n") + mixed_precision = MixedPrecision( + param_dtype=args.amp_dtype, + reduce_dtype=args.amp_dtype, + _module_classes_to_ignore=block_types_for_amp_to_ignore, + ) + else: + if torch.distributed.get_rank() == 0: + print(f"Not using FSDP native MixedPrecision; Local amp_dtype={args.amp_dtype}.") + + student_model = wrap_model_with_fsdp_( + student_model, + use_orig_params=True, + auto_wrap_policy=lambda module, recurse, **_etc: recurse or isinstance(module, block_types_to_wrap), + cpu_offload=CPUOffload(offload_params=args.offload_student_params) if args.offload_student_params else None, + limit_all_gathers=args.limit_all_gathers, + forward_prefetch=args.forward_prefetch, + mixed_precision=mixed_precision, + device_id=device, + ) + + if named_quantized_params is not None: + if torch.distributed.get_world_size() > 1: + # distributed pv: each rank holds a subset of all quantized weights; the rest are replaced with pointers + named_quantized_params = split_quantized_weights_between_ranks( + named_quantized_params, verify_checksums=False + ) + for quantized_weight in named_quantized_params.values(): + if isinstance(quantized_weight, QuantizedWeight): + quantized_weight.to(device) + else: + assert isinstance(quantized_weight, YourQuantizedWeightIsInAnotherRank) + + return student_model, named_quantized_params + + +def wrap_model_with_fsdp_( + model: transformers.PreTrainedModel, auto_wrap_policy: callable, **kwargs +) -> FullyShardedDataParallel: + """Wrap a model *ForCausalLM components: transformer and lm_head are wrapped as FSDP instances""" + assert isinstance(model, transformers.PreTrainedModel) and is_model_for_causal_lm(model) + base_model, lm_head = model.base_model, model.get_output_embeddings() + + def _modified_auto_wrap_policy(module, recurse, **kwargs): + return auto_wrap_policy(module, recurse, **kwargs) or (module in (base_model, lm_head)) + + model = FullyShardedDataParallel(model, auto_wrap_policy=_modified_auto_wrap_policy, **kwargs) + + assert isinstance(model.module, transformers.PreTrainedModel) + assert isinstance(model.base_model, FullyShardedDataParallel) + assert isinstance(model.get_output_embeddings(), FullyShardedDataParallel) + return model + + +def trigger_fsdp_lazy_init_( + tokenizer: transformers.PreTrainedTokenizer, + teacher_model: FullyShardedDataParallel, + student_model: FullyShardedDataParallel, + device: torch.device, + amp_dtype: Optional[torch.dtype], +): + """Trigger FullyShardedDataParallel lazy init in the correct order to allow both training and eval""" + print("Initializing FSDP root") + dummy_batch = tokenizer("I am the monument to all your sins", return_tensors="pt") + dummy_batch = {k: v.to(device) for k, v in dummy_batch.items()} + with torch.cuda.amp.autocast(enabled=amp_dtype is not None, dtype=amp_dtype): + with torch.no_grad(): + teacher_model(**dummy_batch) + (student_model(**dummy_batch).logits * 0).sum().backward() + + +def create_pv_optimizer( + args: argparse.Namespace, + student_model: FullyShardedDataParallel, + named_quantized_params: Dict[str, QuantizedWeight], +) -> torch.optim.Optimizer: + """Create optimizer for PV-Tuning using a de-quantized student model and a dictionary of quantized weights""" + named_dequantized_params = get_original_named_parameters_from_fsdp_module(student_model) + opt_device = torch.device("cpu") if args.offload_optimizer else next(student_model.parameters()).device + assert all(name in named_dequantized_params for name in named_quantized_params) + return StraightThroughAdamW( + named_dequantized_params=named_dequantized_params, + named_quantized_params=named_quantized_params, + update_codes=dict( + lr=args.code_lr, + betas=(args.code_beta1, args.code_beta2), + lamb=args.lamb, + debias=args.debias, + amsgrad=args.amsgrad, + compute_dtype=args.master_dtype, + exp_avg_dtype=torch.float16 if args.code_adam_16bit else args.master_dtype, + exp_avg_sq_dtype=torch.bfloat16 if args.code_adam_16bit else args.master_dtype, + v_hat_max_dtype=torch.float16 if args.code_adam_16bit else args.master_dtype, + exp_avg_device=opt_device, + exp_avg_sq_device=opt_device, + v_hat_max_device=opt_device, + ) + if args.update_codes + else None, + update_codebooks_and_scales=dict( + lr=args.lr, + betas=(args.adam_beta1, args.adam_beta2), + lamb=args.lamb, + debias=args.debias, + amsgrad=args.amsgrad, + compute_dtype=args.master_dtype, + exp_avg_dtype=args.master_dtype, + exp_avg_sq_dtype=args.master_dtype, + v_hat_max_dtype=args.master_dtype, + exp_avg_device=opt_device, + exp_avg_sq_device=opt_device, + v_hat_max_device=opt_device, + ) + if args.update_codebooks_and_scales + else None, + update_non_quantized_parameters=dict( + lr=args.lr, + betas=(args.adam_beta1, args.adam_beta2), + lamb=args.lamb, + debias=args.debias, + amsgrad=args.amsgrad, + compute_dtype=args.master_dtype, + exp_avg_dtype=args.master_dtype, + exp_avg_sq_dtype=args.master_dtype, + v_hat_max_dtype=args.master_dtype, + exp_avg_device=opt_device, + exp_avg_sq_device=opt_device, + v_hat_max_device=opt_device, + ) + if args.update_non_quantized_parameters + else None, + delta_decay=args.delta_decay, + max_code_change_per_step=args.max_code_change_per_step, + force_code_update=args.force_code_update, + code_trust_ratio=args.code_trust_ratio, + beam_size=args.beam_size, + straight_through_buffer_dtype=args.straight_through_buffer_dtype, + verbose=args.verbose_optimizer, + ) + + +def create_p_optimizer(args: argparse.Namespace, student_model: FullyShardedDataParallel) -> torch.optim.Optimizer: + """Create optimizer for training only continuous parameters of a quantized model""" + quantized_weight_continuous_parameters = set() + for module in student_model.modules(): + if isinstance(module, QuantizedWeight): + for param in module.parameters(): + if torch.is_floating_point(param) and param.requires_grad: + quantized_weight_continuous_parameters.add(param) + all_trainable_params = [] + if args.update_codebooks_and_scales: + all_trainable_params.extend( + param for param in student_model.parameters() if param in quantized_weight_continuous_parameters + ) # use iteration instead of simply adding list(set) to ensure deterministic order of parameters + if args.update_non_quantized_parameters: + all_trainable_params.extend( + param + for param in student_model.parameters() + if torch.is_floating_point(param) + and param.requires_grad + and param not in quantized_weight_continuous_parameters + ) + if args.update_codes: + raise RuntimeError("When asked to update_codes, one should create_pv_optimizer, but this is create_p_optimizer") + assert len(all_trainable_params) > 0, ( + "found no trainable parameters. Did you specify update_codes, " + "update_codebooks_and_scales or update_non_quantized_parameters?" ) - if args.val_size > 0: - all_ids = torch.randperm(len(dataloader)) - train_ids, val_ids = all_ids[args.val_size :], all_ids[: args.val_size] - train_dataloader = [dataloader[i] for i in train_ids] - val_dataloader = [dataloader[i] for i in val_ids] + opt_device = torch.device("cpu") if args.offload_optimizer else next(student_model.parameters()).device + return ConfigurableAdamW( + params=list(all_trainable_params), + lr=args.lr, + betas=(args.adam_beta1, args.adam_beta2), + lamb=args.lamb, + debias=args.debias, + amsgrad=args.amsgrad, + compute_dtype=args.master_dtype, + exp_avg_dtype=args.master_dtype, + exp_avg_sq_dtype=args.master_dtype, + v_hat_max_dtype=args.master_dtype, + exp_avg_device=opt_device, + exp_avg_sq_device=opt_device, + v_hat_max_device=opt_device, + ) + + +def save_training_state( + args: argparse.Namespace, metadata: dict, quantized_model: nn.Module, optimizer: torch.optim.Optimizer +): + """Save model, optimizer state dict and training metadata to be loaded via load_training_state""" + if args.save is None: + return + rank = torch.distributed.get_rank() + os.makedirs(args.save, exist_ok=True) + if rank == 0: + print(f"Saving snapshot to {args.save}") + torch.save(metadata, os.path.join(args.save, "metadata.pt")) + with FullyShardedDataParallel.state_dict_type(quantized_model, StateDictType.LOCAL_STATE_DICT): + torch.save(quantized_model.state_dict(), os.path.join(args.save, f"quantized_model_state_dict_rank{rank}.pt")) + # model saves non-quantized weights and dequantized versions of QuantizedWeight; the latter is not necessary + torch.save(optimizer.state_dict(), os.path.join(args.save, f"optimizer_state_dict_rank{rank}.pt")) + # optimizer state dict saves statistics QuantizedWeight instances and straight-through buffers + if args.on_save: + exec(args.on_save) + + +def load_training_state( + args: argparse.Namespace, metadata: dict, quantized_model: nn.Module, optimizer: torch.optim.Optimizer +): + """Load model, optimizer state dict and metadata saved via save_training_state; update parameters in-place""" + rank = torch.distributed.get_rank() + if args.save is None or not os.path.exists(args.save): + if args.save is not None and rank == 0: + print(f"No checkpoint found at {args.save}") else: - train_dataloader = dataloader - val_dataloader = None - # create original model - orig_model = get_model(args.base_model, None, args.dtype, args.device_map, trust_remote_code=args.trust_remote_code) - if not args.device_map: - orig_model = orig_model.to(device) - # cache logits - orig_train_hiddens = cache_hiddens(orig_model, train_dataloader, args) - if val_dataloader: - orig_val_hiddens = cache_hiddens(orig_model, val_dataloader, args) + with FullyShardedDataParallel.state_dict_type(quantized_model, StateDictType.LOCAL_STATE_DICT): + # this loads non-quantized weights and de-quantized versions of QuantizedWeight instances + state_dict_ptr = quantized_model.state_dict() + loaded_state_dict = torch.load(os.path.join(args.save, f"quantized_model_state_dict_rank{rank}.pt")) + with torch.no_grad(): + for key in state_dict_ptr: + state_dict_ptr[key].copy_(loaded_state_dict.pop(key)) + assert len(loaded_state_dict) == 0, f"Unused keys:, {tuple(loaded_state_dict.keys())}" + del state_dict_ptr, loaded_state_dict + + # v-- loading optimizer state dict also loads all QuantizedWeights and straight-through buffers + optimizer.load_state_dict( + torch.load(os.path.join(args.save, f"optimizer_state_dict_rank{rank}.pt"), map_location="cpu") + ) + metadata.update(torch.load(os.path.join(args.save, "metadata.pt"))) + if args.eval_datasets is not None and metadata["early_stop_on"] not in args.eval_datasets: + if rank == 0: + print(f"Stopping criterion {metadata['early_stop_on']} is not in eval_datasets; resetting best loss.") + metadata["early_stop_on"] = next(iter(args.eval_datasets)) + metadata["best_eval_perplexity"] = float("inf") + metadata["best_step"] = 0 + if rank == 0: + print(f"Loaded training state from {args.save}: {metadata}") + + +def save_model(args: argparse.Namespace, student_model: FullyShardedDataParallel, optimizer: torch.optim.Optimizer): + """Save model for either P- or PV-Tuning using the appropriate saver""" + if isinstance(optimizer, StraightThroughAdamW): + save_pv_model(args, student_model, optimizer) else: - orig_val_hiddens = None - del orig_model - torch.cuda.empty_cache() - quant_model = get_model( - args.base_model, args.quant_model, args.dtype, args.device_map, trust_remote_code=args.trust_remote_code - ) - if not args.device_map: - quant_model = quant_model.to(device) - - # finetune - finetune( - quant_model, - train_loader=train_dataloader, - train_hiddens=orig_train_hiddens, - args=args, - device=device, - val_loader=val_dataloader, - val_hiddens=orig_val_hiddens, - ) - - print_memory_stats() - - # offload model to cpu - quant_model = quant_model.cpu() - if args.device_map: - remove_hook_from_submodules(quant_model) - torch.cuda.empty_cache() - - # save model - if args.save: - os.makedirs(args.save, exist_ok=True) - for layer_index, layer in enumerate(get_layers(quant_model)): - layer_save_path = os.path.join(args.save, f"{layer_index}.pth") - torch.save(layer, layer_save_path) - save_not_quantized_weights(quant_model, args.save) - # copy args - shutil.copy(os.path.join(args.quant_model, "args.pt"), os.path.join(args.save, "args.pt")) - - print("\n============ Evaluating perplexity... ============") - torch.cuda.reset_peak_memory_stats() - for dataset in args.eval_datasets: - testloader = get_loaders( - dataset, + assert any(isinstance(module, QuantizedWeight) for module in student_model.modules()) + save_p_model(args, student_model) + + +def save_pv_model( + args: argparse.Namespace, dequantized_model: FullyShardedDataParallel, optimizer: StraightThroughAdamW +): + """Save consolidated model from PV tuning, can be exported later via convert_legacy_model_format.py""" + output_path = os.path.join(args.save, "best_model") + os.makedirs(output_path, exist_ok=True) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + local_quantized_weight_names = set() + for name, quantized_weight in optimizer.iterate_local_quantized_weights(): + torch.save(quantized_weight, os.path.join(output_path, f"{name}.pth")) + local_quantized_weight_names.add(name) + + quantized_weight_names_by_rank = [None for _ in range(world_size)] if rank == 0 else None + torch.distributed.gather_object(local_quantized_weight_names, quantized_weight_names_by_rank, dst=0) + + with FullyShardedDataParallel.state_dict_type( + dequantized_model, + StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state_dict = dequantized_model.state_dict() + if rank == 0: + all_quantized_weight_names = set() + for local_quantized_weight_names in quantized_weight_names_by_rank: + all_quantized_weight_names |= set(local_quantized_weight_names) + + non_quantized_state_dict = dict() + for name, tensor in model_state_dict.items(): + if name in all_quantized_weight_names: + all_quantized_weight_names.remove(name) # do not save de-quantized versions of quantized weights + else: + non_quantized_state_dict[name] = tensor + assert len(all_quantized_weight_names) == 0, f"mismatched names: {all_quantized_weight_names}" + torch.save(non_quantized_state_dict, os.path.join(output_path, "non_quantized_state_dict.pth")) + torch.distributed.barrier() + if rank == 0: + print(f"Saved best model shards to {output_path}") + + +def save_p_model(args: argparse.Namespace, quantized_model: FullyShardedDataParallel): + """Save consolidated model state dict from P-only tuning, can be exported via convert_legacy_model_format.py""" + os.makedirs(args.save, exist_ok=True) + rank = torch.distributed.get_rank() + with FullyShardedDataParallel.state_dict_type( + quantized_model, + StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + model_state_dict = quantized_model.state_dict() + if rank == 0: + torch.save(model_state_dict, os.path.join(args.save, f"best_model_state_dict.pt")) + torch.distributed.barrier() + if rank == 0: + print(f"Saved best model state dict to {os.path.join(args.save, f'best_model_state_dict.pt')}") + + +def compute_loss_on_batch( + batch: dict, + teacher_model: FullyShardedDataParallel, + student_model: FullyShardedDataParallel, + *, + amp_dtype: Optional[torch.dtype], + max_tokens_per_chunk: Optional[int], +) -> torch.Tensor: + if max_tokens_per_chunk is not None: # chunked inference, transformer and lm head must be separate FSDP instances + with torch.no_grad(): + teacher_hidden_states = teacher_model.base_model(**batch).last_hidden_state + with torch.cuda.amp.autocast(enabled=amp_dtype is not None, dtype=amp_dtype): + student_hidden_states = student_model.base_model(**batch).last_hidden_state + return compute_kl_divergence_loss_values( + student_hidden_states=student_hidden_states, + student_lm_head=student_model.get_output_embeddings(), + teacher_hidden_states=teacher_hidden_states, + teacher_lm_head=teacher_model.get_output_embeddings(), + max_tokens_per_chunk=max_tokens_per_chunk, + checkpoint_last_chunk=False, + use_reentrant=False, + determinism_check="none", + ).mean() + + else: # combined inference without gradient checkpointing + with torch.no_grad(): + teacher_logprobs = F.log_softmax(teacher_model(**batch).logits, dim=-1) + with torch.cuda.amp.autocast(enabled=amp_dtype is not None, dtype=amp_dtype): + student_logprobs = F.log_softmax(student_model(**batch).logits, dim=-1) + loss = F.kl_div( + input=student_logprobs.flatten(0, -2), + target=teacher_logprobs.flatten(0, -2), + log_target=True, + reduction="batchmean", + ).mean() + return loss + + +def compute_validation_perplexities(args: argparse.Namespace, model: nn.Module, eval_datasets: dict): + rank = torch.distributed.get_rank() + perplexities = {} + for dataset_name, eval_dataset in eval_datasets.items(): + if rank == 0: + print(f"Evaluating perplexity on {dataset_name} ...") + device = next(model.parameters()).device + original_dtype = args.load_dtype if args.load_dtype != "auto" else None + amp_dtype = args.amp_dtype if args.amp_dtype is not None else original_dtype + ppl = evaluate_perplexity(model, eval_dataset, args.model_seqlen, device=device, amp_dtype=amp_dtype) + if rank == 0: + print(f"{dataset_name} perplexity: {ppl:.9f}") + perplexities[dataset_name] = ppl + return perplexities + + +def main(): + assert torch.cuda.is_available() and torch.distributed.is_available() + torch.distributed.init_process_group() + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + parser = argparse.ArgumentParser(add_help=True) + add_model_args(parser) + add_data_args(parser) + add_finetuning_args(parser) + args = parser.parse_args() + + assert torch.distributed.is_initialized() + assert args.batch_size is not None, "please specify batch size" + assert args.batch_size % world_size == 0 + if args.microbatch_size is None: + args.microbatch_size = args.batch_size // world_size + assert args.batch_size % (world_size * args.microbatch_size) == 0 + grad_accumulation_steps = args.batch_size // (world_size * args.microbatch_size) + + args.master_dtype = getattr(torch, args.master_dtype) + args.embed_dtype = getattr(torch, args.embed_dtype) if args.embed_dtype is not None else args.master_dtype + args.load_dtype = getattr(torch, args.load_dtype) if args.load_dtype != "auto" else "auto" + args.code_dtype = getattr(torch, args.code_dtype) if args.code_dtype is not None else None + args.amp_dtype = getattr(torch, args.amp_dtype) if args.amp_dtype is not None else None + + if args.straight_through_buffer_dtype is not None: + args.straight_through_buffer_dtype = getattr(torch, args.straight_through_buffer_dtype) + else: + args.straight_through_buffer_dtype = args.master_dtype + + if args.save_every_steps is not None: + assert args.save is not None, f"save_every_steps={args.save_every_steps}, but --save path not specified" + if args.keep_best_model: + assert args.save is not None, f"--keep_best_model requires --save path" + assert args.eval_every_steps is not None, f"--keep_best_model requires --eval_every_steps" + assert args.eval_datasets is not None, f"--keep_best_model requires --eval_datasets" + + if args.wandb and rank == 0: + assert has_wandb, "`wandb` not installed, try pip install `wandb`" + wandb.init(config={a: getattr(args, a) for a in dir(args) if not a.startswith("_")}) + + if rank == 0: + print(args) + + tokenizer = transformers.AutoTokenizer.from_pretrained(args.base_model) + assert tokenizer.eos_token_id is not None + tokenizer.pad_token = tokenizer.eos_token + + with master_rank_first(local=True): + dataset = prepare_training_dataset(args, tokenizer) + if args.save_dataset_and_exit is not None: + if rank == 0: + dataset.save_to_disk(args.save_dataset_and_exit) + + if args.save_dataset_and_exit is not None: + torch.distributed.barrier() + return + + sampler = torch.utils.data.DistributedSampler( + dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=args.seed + ) + + train_dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.microbatch_size, + num_workers=args.num_workers, + sampler=sampler, + collate_fn=transformers.default_data_collator, + ) + eval_datasets = { + dataset_name: get_loaders( + dataset_name, seed=args.seed, model_path=args.base_model, - seqlen=args.eval_model_seqlen or args.model_seqlen, + seqlen=args.model_seqlen, eval_mode=True, - use_fast_tokenizer=args.use_fast_tokenizer, - trust_remote_code=args.trust_remote_code, ) - args.dataset_name = dataset - perplexity_eval(quant_model, testloader, args) - # make sure that the cache is released - torch.cuda.empty_cache() - - print(f"eval: {torch.cuda.max_memory_allocated()=:,}") - if args.wandb: - wandb.log({"max_cuda_mem_eval": round(torch.cuda.max_memory_allocated() / 1e9, 2)}) + for dataset_name in args.eval_datasets + } + + use_pv_tuning = args.update_codes and not args.force_dequantize + if rank == 0: + print(f"Training {['without', 'with'][use_pv_tuning]} PV-Tuning") + + with one_rank_at_a_time(local=True, group_size=args.limit_parallel_inits): + teacher_model = load_teacher_model(args, device) + student_model, named_quantized_params = load_student_model(args, device, dequantize=use_pv_tuning) + if rank == 0: + print("Wrapped model:") + print(student_model) + for name, param in student_model.named_parameters(): + print(name, param.shape, param.dtype) + + if use_pv_tuning: + optimizer = create_pv_optimizer(args, student_model, named_quantized_params) + else: + optimizer = create_p_optimizer(args, student_model) + del named_quantized_params + + metadata = dict( + current_epoch=0, + microbatches_since_epoch_start=0, + total_microbatches=0, + total_optimizer_steps=0, + loss_numerator=0, + loss_denominator=0, + aggregated_loss=float("nan"), + grad_steps_accumulated=0, + early_stop_on=next(iter(args.eval_datasets)) if args.eval_datasets else None, + best_eval_perplexity=float("inf"), + best_step=0, + ) + + load_training_state(args, metadata, student_model, optimizer) + torch.distributed.barrier() + trigger_fsdp_lazy_init_(tokenizer, teacher_model, student_model, device, amp_dtype=args.amp_dtype) + + for current_epoch in range(args.max_epochs): + if current_epoch < metadata["current_epoch"]: + continue # skip finished epochs + sampler.set_epoch(current_epoch) + + batch_iter = tqdm(train_dataloader, desc=f"Training epoch #{current_epoch}") if rank == 0 else train_dataloader + for batch_index, batch in enumerate(batch_iter): + if batch_index <= metadata["microbatches_since_epoch_start"]: + continue # skip batches processed before checkpoint + metadata["microbatches_since_epoch_start"] += 1 + metadata["total_microbatches"] += 1 + + batch = {k: v.to(device) for k, v in batch.items()} + loss = compute_loss_on_batch( + batch, + teacher_model, + student_model, + amp_dtype=args.amp_dtype, + max_tokens_per_chunk=args.loss_tokens_per_chunk, + ) + + metadata["loss_numerator"] += loss.item() + metadata["loss_denominator"] += 1 + metadata["grad_steps_accumulated"] += 1 + if metadata["grad_steps_accumulated"] < grad_accumulation_steps: + with student_model.no_sync() if args.minimize_sync else nullcontext(): + (loss / grad_accumulation_steps).backward() + else: + (loss / grad_accumulation_steps).backward() + optimizer.step() + optimizer.zero_grad() + metadata["grad_steps_accumulated"] = 0 + metadata["total_optimizer_steps"] += 1 + + if args.print_every_steps and metadata["total_optimizer_steps"] % args.print_every_steps == 0: + loss_numerator_and_denominator = torch.tensor( + [metadata["loss_numerator"], metadata["loss_denominator"]], dtype=torch.float64, device=device + ) + + torch.distributed.all_reduce(loss_numerator_and_denominator, op=torch.distributed.ReduceOp.SUM) + loss_numerator, loss_denominator = loss_numerator_and_denominator.tolist() + metadata["aggregated_loss"] = loss_numerator / loss_denominator + metadata["loss_numerator"] = metadata["loss_denominator"] = 0 + if rank == 0: + print( + f"epoch {metadata['current_epoch']}\tbatch {batch_index}", + f"\t| total updates = {metadata['total_optimizer_steps']}", + f"\tloss = {metadata['aggregated_loss']:.9f}", + ) + + if args.eval_every_steps and metadata["total_optimizer_steps"] % args.eval_every_steps == 0: + perplexity_scores = compute_validation_perplexities(args, student_model, eval_datasets) + for dataset_name, perplexity in perplexity_scores.items(): + metadata[f"perplexity_{dataset_name}"] = perplexity + metric_name = metadata["early_stop_on"] + if perplexity_scores[metric_name] < metadata["best_eval_perplexity"]: + if rank == 0: + print(f"New best perplexity ({metric_name}) = {perplexity_scores[metric_name]:.9f}") + metadata["best_eval_perplexity"] = perplexity_scores[args.eval_datasets[0]] + metadata["best_step"] = metadata["total_optimizer_steps"] + if args.keep_best_model: + save_model(args, student_model, optimizer) + if args.wandb and rank == 0: + wandb.log(metadata, step=metadata["total_microbatches"]) + if args.save_every_steps and metadata["total_optimizer_steps"] % args.save_every_steps == 0: + save_training_state(args, metadata, student_model, optimizer) + + metadata["microbatches_since_epoch_start"] = 0 + metadata["current_epoch"] += 1 + + save_training_state(args, metadata, student_model, optimizer) if __name__ == "__main__": diff --git a/main.py b/main.py index c570c26..9020978 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,7 @@ get_layers, get_lm_logits, get_model, - get_model_head, + get_model_head_with_norm, get_sequential_groups, save_not_quantized_weights, ) @@ -397,7 +397,7 @@ def perplexity_eval(model: PreTrainedModel, testenc: torch.LongTensor, args: Nam torch.cuda.empty_cache() inps, outs = outs, inps - get_model_head(model).to(device) + get_model_head_with_norm(model).to(device) testenc = testenc.to(device) nsamples_per_device = len(inps[0]) assert len(set(map(len, inps[:-1]))) <= 1 and len(inps[-1]) <= len(inps[0]) @@ -415,7 +415,7 @@ def perplexity_eval(model: PreTrainedModel, testenc: torch.LongTensor, args: Nam ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * args.model_seqlen)).item() print(f"\n{args.dataset_name} perplexity = {ppl:.4f}\n") - get_model_head(model).to(torch.device("cpu")) + get_model_head_with_norm(model).to(torch.device("cpu")) if args.wandb: wandb.log({args.dataset_name: ppl}) diff --git a/requirements.txt b/requirements.txt index f25869e..c7da84a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -safetensors==0.3.2 -datasets==2.15.0 -sentencepiece==0.1.99 -torch>=2.1.1 -numpy>=1.21.5 -transformers==4.37.0 -accelerate==0.24.1 +safetensors==0.4.3 +datasets==2.19.0 +sentencepiece==0.2.0 +torch>=2.3.0 +numpy>=1.26.4 +transformers==4.40.1 +accelerate==0.29.3 diff --git a/src/aq.py b/src/aq.py index c7100f9..df2f681 100644 --- a/src/aq.py +++ b/src/aq.py @@ -1,5 +1,6 @@ """ Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search""" -import random +from __future__ import annotations + from typing import List, Optional, Tuple, Union import torch @@ -8,12 +9,14 @@ from torch.utils.checkpoint import checkpoint from tqdm.auto import trange +from src.beam_search_l2 import beam_search_optimal_codes as beam_search_minimize_weight_mse +from src.beam_search_xtx import beam_search_optimal_codes as beam_search_minimize_activation_mse from src.kmeans import find_nearest_cluster, fit_faiss_kmeans, fit_kmeans, fit_kmeans_1d -from src.utils import ellipsis, maybe_script +from src.utils import IntCodes, _dequantize_weight, ellipsis, is_signed class QuantizedLinear(nn.Module): - def __init__(self, quantized_weight, bias: Optional[nn.Parameter]): + def __init__(self, quantized_weight: QuantizedWeight, bias: Optional[nn.Parameter]): super().__init__() self.out_features, self.in_features = quantized_weight.out_features, quantized_weight.in_features self.quantized_weight = quantized_weight @@ -37,7 +40,6 @@ class QuantizedWeight(nn.Module): def __init__( self, *, - XTX: torch.Tensor, reference_weight: torch.Tensor, in_group_size: int, out_group_size: int, @@ -47,12 +49,15 @@ def __init__( codebook_value_num_groups: int = 1, scale_nbits: int = 0, straight_through_gradient: Optional[bool] = None, + code_dtype: torch.dtype = torch.int32, **init_kwargs, ): super().__init__() self.out_features, self.in_features = reference_weight.shape assert self.in_features % in_group_size == 0 assert self.out_features % out_group_size == 0 + if nbits_per_codebook > torch.iinfo(code_dtype).bits - is_signed(code_dtype): + raise ValueError(f"Code dtype cannot store {nbits_per_codebook} bits; please specify code_dtype manually") self.out_group_size, self.in_group_size = out_group_size, in_group_size self.num_codebooks = num_codebooks @@ -105,10 +110,34 @@ def __init__( self.codebooks = nn.Parameter( codebooks, requires_grad=True ) # [num_codebooks, codebook_size, out_group_size, in_group_size] - self.codes = nn.Parameter( - codes.to(torch.int32), - requires_grad=False, - ) # [num_out_groups, num_in_groups, num_codebooks] + self.codes: Optional[nn.Parameter] = nn.Parameter( + codes.to(code_dtype), requires_grad=False + ) # [num_out_groups, num_in_groups, num_codebooks] + self.codes_storage: Optional[IntCodes] = None # storage for FSDP compatibility + + def get_codes(self) -> torch.IntTensor: + """Get a non view to codes, regardless of how codes are stored""" + assert (self.codes is None) != (self.codes_storage is None), "must have either .codes or storage, but not both" + codes = self.codes if self.codes is not None else self.codes_storage() + if torch.iinfo(codes.dtype).bits < 32: + codes = codes.to(torch.int32) # cast to int32 to allow indexing if codes are int16 or uint8 + return codes + + def set_codes(self, new_codes: torch.Tensor, selection: Union[slice, ellipsis, torch.Tensor] = ..., **kwargs): + """Update codes[selection] to new_codes, regardless of their dtype and whether they are wrapped as storage""" + assert (self.codes is None) != (self.codes_storage is None), "must have either .codes or storage, but not both" + codes_ptr = self.codes if self.codes is not None else self.codes_storage() + codes_ptr[selection].copy_(new_codes, **kwargs) + + def wrap_codes_for_fsdp_(self, **kwargs): + """Make this module compatible with FullyShardedDataParallel; modifies state dict in-place""" + assert self.codes is not None and self.codes_storage is None + self.codes_storage, self.codes = IntCodes(self.codes, **kwargs), None + + def unwrap_codes_(self): + """Undo the effect of wrap_codes_for_fsdp_; modifies state dict in-place""" + assert self.codes is None and self.codes_storage is not None + self.codes, self.codes_storage = nn.Parameter(self.codes_storage(), requires_grad=False), None def get_codebooks(self) -> torch.Tensor: """Get quantization codebooks or reconstruct them from second level quantization (see codebook_values_nbits)""" @@ -164,6 +193,10 @@ def get_scales(self) -> torch.Tensor: else: # train scale codebook only return self.scales_clusters.gather(1, self.scales_indices)[:, :, None, None] + @property + def shape(self) -> Tuple[int, int]: + return self.out_features, self.in_features + def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ...): """ Differentably reconstruct the weight (or parts thereof) from compressed components @@ -173,23 +206,25 @@ def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ...): Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size ) """ - weight = _dequantize_weight(self.codes[selection], self.get_codebooks(), self.get_scales()[selection]) + weight = _dequantize_weight(self.get_codes()[selection], self.get_codebooks(), self.get_scales()[selection]) return weight @torch.no_grad() def beam_search_update_codes_( self, - XTX: torch.Tensor, - reference_weight: torch.Tensor, *, + XTX: Optional[torch.Tensor] = None, + reference_weight: torch.Tensor, selection: Union[slice, ellipsis, torch.LongTensor] = ..., **kwargs, ) -> torch: """ - Update self.codes in-place via beam search so as to minimize squared errors. Return the updated codes. - :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] - :note: if XTX is divided by dataset size, this function will return *mean* squared error + Update own codes in-place via beam search so as to minimize squared errors. Return the updated codes. :param reference_weight: original weight matrix that is being quantized, shape: [out_features, in_features] + :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] + - if XTX is divided by dataset size, this function will return *mean* squared error + - if XTX is not specified, this function minimizes squared error between weights, as if XTX was identity + :note: if selection is specified, reference_weight must instead be [num_selected_out_features, in_features] :param selection: By default, this function updates all codes, If selection specified, it will instead update only the codes for a portion of output dimensions (used for parallelism). @@ -197,17 +232,26 @@ def beam_search_update_codes_( Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size ) :param beam_size: consider up to this many best encoding combinations (this param is passed through via kwargs) :param kwargs: any additional keyword arguments are forwarded to beam_search_optimal_codes function - :returns: the updated codes + :returns: the updated codes, in the same shape as self.get_codes()[selection] """ - self.codes[selection] = beam_search_optimal_codes( - XTX=XTX, - reference_weight=reference_weight, - codebooks=self.get_codebooks(), - prev_codes=self.codes[selection], - scales=self.get_scales()[selection], - **kwargs, - ) - return self.codes[selection] + codebooks = self.get_codebooks() + prev_codes = self.get_codes()[selection] + scales = self.get_scales()[selection] + if XTX is not None: + new_codes = beam_search_minimize_activation_mse( + XTX=XTX, + reference_weight=reference_weight, + codebooks=codebooks, + prev_codes=prev_codes, + scales=scales, + **kwargs, + ) + else: + new_codes = beam_search_minimize_weight_mse( + reference_weight=reference_weight, codebooks=codebooks, prev_codes=prev_codes, scales=scales, **kwargs + ) + self.set_codes(new_codes, selection) + return new_codes def estimate_nbits_per_parameter(self) -> float: """Calculate the effective number of bits per original matrix parameters""" @@ -240,388 +284,6 @@ def extra_repr(self) -> str: return f"{self.out_features=}, {self.in_features=}, bits_per_parameter={self.estimate_nbits_per_parameter()}" -@torch.inference_mode() -def beam_search_optimal_codes( - *, - XTX: torch.Tensor, - reference_weight: torch.Tensor, - codebooks: torch.Tensor, - prev_codes: torch.IntTensor, - scales: Optional[torch.Tensor], - beam_size: int, - dim_rng: Optional[random.Random] = None, - sparsity_regularizer: float = 0, - verbose: bool, -): - """ - :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] - :note: if XTX is divided by dataset size, this function will return *mean* squared error - :param reference_weight: original weight matrix that is being quantized, shape: [out_features, in_features] - :param codebooks: look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_siz, in_group_size] - :param prev_codes: previous-best integer weight codes, shape: [num_out_groups, num_in_groups, num_codebooks] - :param scales: weight will be multiplied by this factor, shape = [num_out_groups, num_in_groups or 1, 1, 1] - :param dim_rng: a source of randomness to (optionally) shuffle the order in which the beam search runs - None = update dimensions and codebooks in their natural order (0, 1, ..., n) - random.Random(optional_seed) = shuffle dimensions at random, optionally using the specified seed - - :param beam_size: consider up to this many best encoding combinations - :param sparsity_regularizer: subtract this value from beam search objective each time you have a zero code somewhere - :param verbose: if True, draw a progressbar and periodically print best loss - :return: best quantization codes found, same shape as prev_codes - - :intuition: the beam search needs to produce weight codes that minimize MSE error - - the codes are of shape [out_features / out_group_size, in_features / in_group_size, num_codebooks] - - Out of those three dimensions, out_features is "independent", i.e. changing code in - one output feature does not increase the MSE error for another feature. Therefore, - beam search for different output features can run in independently in parallel. - - Neither (in_features / in_group_size) nor (num_codebooks) dimension are independent: - - changing the encoding for one feature can compensate the error from encoding another, OBC-style - - for a single weight group, changing code in one codebook can affect the optimal choice in another codebook - Therefore, beam search must go in a double loop over (in_features/in_group_size) and (num_codebooks) dimensions - - This leaves one choice: which dimension used for outer loop, and which one goes is in the inner loop? - Due to the nature of beam search, interactions between dimensions of inner loop will be explored better. - We chose to use (in_features/in_group_size) in the outer loop and (num_codebooks) for the inner loop. - This is based on an intuition from GPTQ: you can get decent performance by quantizing each input unit ... - ... greedily --- GPTQ does not change quantizations for previously quantized features and works fine. - Therefore, we believe that we can also use a greedy approach to compensate error between input features. - In turn, we believe that the codes used to encode the same weights (additively) are more inter-dependent. - This should be treated as an educated guess with no proof and no ablation (as of the time of writing). - - """ - num_out_groups, num_in_groups, num_codebooks = prev_codes.shape - num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape - in_features = num_in_groups * in_group_size - out_features = num_out_groups * out_group_size - assert reference_weight.shape == (out_features, in_features) - prev_weight = _dequantize_weight(prev_codes, codebooks, scales) - - # initialize all beam codes as previous codes - so they can be updated during beam search - beam_codes = prev_codes.unsqueeze(0) - # beam_codes shape: [current beam_size, num_out_groups, num_in_groups, num_codebooks], initial beam_size = 1 - beam_weights = prev_weight.unsqueeze(0) - # beam_weights shape: [current beam_size, out_features, in_features], initial beam size = 1 - - beam_losses = ( - _channelwise_squared_error(XTX, prev_weight, reference_weight) - .reshape(1, num_out_groups, out_group_size) - .sum(-1) - ) - # beam_losses shape: [current beam_size, num_out_groups], initial beam_size = 1 - if sparsity_regularizer != 0: - beam_losses = beam_losses - sparsity_regularizer * (prev_codes == 0).sum(dim=(-1, -2))[None, :] - - if verbose: - progressbar = trange(num_in_groups * num_codebooks) - - def _make_range(n: int) -> list: - seq = list(range(n)) - if dim_rng is not None: - dim_rng.shuffle(seq) - return seq - - for input_group_index in _make_range(num_in_groups): - for codebook_index in _make_range(num_codebooks): - ### part 1: compute losses for every possible candidate for one given codebook and input group. - # Currently, we compute errors for all output features in parallel in a vectorized fashion. - best_losses, best_indices = _beam_search_squared_errors( - XTX=XTX, - reference_weight=reference_weight, - codebooks=codebooks, - scales=scales, - beam_losses=beam_losses, - beam_codes=beam_codes, - beam_weights=beam_weights, - input_group_index=input_group_index, - codebook_index=codebook_index, - k_best=beam_size, - sparsity_regularizer=sparsity_regularizer, - ) # [current beam_size, codebook_size, num_out_groups] - - # part 2: select beam_size new best codes and re-arrange beam to account for the fact that ... - # ... sometimes two or more top candidates originate from the same source in previous beam - beam_codes, beam_weights, beam_losses = _beam_search_select_best( - beam_codes=beam_codes, - beam_weights=beam_weights, - codebooks=codebooks, - scales=scales, - input_group_index=input_group_index, - codebook_index=codebook_index, - best_losses=best_losses, - best_indices=best_indices, - beam_size=beam_size, - ) - - if verbose: - progressbar.update() - if (input_group_index * num_codebooks + codebook_index) % verbose != 0: - continue # if update is an integer, compute metrics every (this many) beam search steps - best_loss = beam_losses.min(0).values.sum().item() / out_features - info = f"in_group {input_group_index} / {num_in_groups} " - info += f"| codebook {codebook_index} / {num_codebooks} " - if sparsity_regularizer == 0: - info += f"| loss {best_loss:.10f}" - else: # un-regularize to restore MSE loss, report sparsity rate - num_zero_codes = (beam_codes[0] == 0).sum().item() - best_loss = best_loss + sparsity_regularizer / out_features * num_zero_codes - sparsity = num_zero_codes / prev_codes.numel() - info += f"| loss {best_loss:.5f} | sparse {sparsity * 100:.1f}% |" - - progressbar.desc = info - return beam_codes[0] - - -@maybe_script -def _dequantize_weight( - codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None -) -> torch.Tensor: - """ - Decode float weights from quantization codes. Differentiable. - :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] - :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] - :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] - :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] - """ - num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] - num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape - out_features = num_out_groups * out_group_size - in_features = num_in_groups * in_group_size - codebook_offsets = torch.arange( - 0, num_codebooks * codebook_size, codebook_size, device=codes.device - ) # shape: [num_codebooks] - reconstructed_weight_flat = F.embedding_bag( - codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" - ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] - - reconstructed_weight_groupwise = reconstructed_weight_flat.view( - list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] - ) - if scales is not None: - reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) - return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) - - -@maybe_script -def _beam_search_squared_errors( - XTX: torch.Tensor, - reference_weight: torch.Tensor, - codebooks: torch.Tensor, - scales: Optional[torch.Tensor], - beam_losses: torch.Tensor, - beam_codes: torch.Tensor, - beam_weights: torch.Tensor, - input_group_index: int, - codebook_index: int, - k_best: int, - sparsity_regularizer: float, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute MSE or sum-of-squared-error losses for all possible ways to replace quantization codes for one input group - and one codebook. Works in parallel for all output-dimension groups. - - :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] - :note: if both XTX *and* beam_loses are divided by dataset size, this function will return mean squared error - :param reference_weight: original weight matrix that is being quantized, shape: [out_features, in_features] - :param codebooks: look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_size, in_group_size] - :param scales: weight will be multiplied by this factor, [num_out_groups, num_in_groups, 1, 1] - - :param beam_losses: sum-of-squared-error for each hypothesis in beam and for each output channel; - shape: [beam_size, num_out_groups] - :param beam_codes: a tensor with best weight codes, shape: [beam_size, num_out_groups, num_in_groups, num_codebooks] - :param beam_weights: a tensor with de-quantized beam_codes, shape: [beam_size, out_features, in_features] - :param input_group_index: an index of one group of in_features that is being re-encoded - :param codebook_index: an index of one codebook for that group of features that is being re-encoded - :return: tuple(Tensor, Tensor) of 3d tensor of shape = [beam_size, k_best, num_out_groups]. - First one is float tensor of losses of k_best lowest square errors for each beam and out_group - Second one is int64 tensor of indices of k_best lowest square errors for each beam and out_group - - :note: The code computes MSE using the square-of-difference expansion - ||X@W.T - sum_i X@(Bi@Ci).T||^2 = ||X@W.T||^2 - 2 + ||sum_i X@Bi@Ci||^2 - where X[nsamples,in_features] is calibration data, W[out_features, in_features] is the reference weight, - C[num_codebooks, codebook_size, in_features] are learned codebooks (Ci has shape [codebook_size, out_features]) - B[num_codebooks, out_features, codebook_size] are one-hot encoded indices (quantization codes) - The formula above uses a single group per output "neuron" and a single group. - The algorithm below generalizes the formula for multiple groups and codebooks. - - Furthermore, the algorithm does not compute the entire formula. Instead, it begins from some baseline loss - and computes the change in loss from changing a single code to every possible altearnative code. - When computing the changed loss, the algorithm only computes the few affected parts of the loss formula above. - """ - num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape - beam_size, num_out_groups, num_in_groups, num_codebooks = beam_codes.shape - out_features = num_out_groups * out_group_size - - input_group_slice = slice(input_group_index * in_group_size, (input_group_index + 1) * in_group_size) - - prev_codes_part = beam_codes[:, :, input_group_index, codebook_index] # [beam_size, num_out_groups] - - if scales is not None: - scales_part = scales[:, input_group_index % scales.shape[1], :, :] # [num_out_groups, 1, 1] - else: - scales_part = torch.empty(0, device=XTX.device) - prev_part_dequantized = F.embedding(prev_codes_part, codebooks[codebook_index].flatten(-2, -1)).view( - beam_size, out_features, in_group_size - ) # previous codes de-quantized - - prev_weight_part = prev_part_dequantized - if scales is not None: - prev_weight_part = ( - prev_weight_part.view(beam_size, num_out_groups, out_group_size, in_group_size) - .mul(scales_part) - .view(beam_size, out_features, in_group_size) - ) - - cand_weights = codebooks[codebook_index] # [codebook_size, out_group_size, in_group_size], all replacement codes - - delta_weight_without_part = reference_weight - beam_weights - delta_weight_without_part[:, :, input_group_slice] += prev_weight_part - - # dWTXTX is equivalent to < X @ (W - \sum BiCi except current codebook), X @ SOMETHING > - dWTXTXg = delta_weight_without_part @ XTX[..., input_group_slice] # [beam_size, out_features, in_group_size] - # below: use torch.matmul to compute broadcasted batch matrix multiplication; see matmul docs - - XnewBkC_norms_sq = torch.bmm( - (cand_weights.flatten(0, 1) @ XTX[input_group_slice, input_group_slice]).view( - codebook_size, 1, out_group_size * in_group_size - ), - cand_weights.view(codebook_size, out_group_size * in_group_size, 1), - ).reshape( - codebook_size, 1 - ) # [codebook_size, num_out_groups] - if scales is not None: - XnewBkC_norms_sq = XnewBkC_norms_sq.mul(scales_part.square().reshape(1, num_out_groups)) - - best_losses = torch.empty( - (beam_size, k_best, num_out_groups), dtype=XTX.dtype, device=XTX.device - ) # shape: [beam_size, k_best, num_out_groups] - best_indices = torch.empty( - (beam_size, k_best, num_out_groups), - dtype=torch.int64, - device=XTX.device, - ) - for beam_id in range(beam_size): - dot_products = ( - torch.einsum( - "mg,og->mo", - cand_weights.reshape(codebook_size, out_group_size * in_group_size), - dWTXTXg[beam_id].view(num_out_groups, out_group_size * in_group_size), - ) - .sub_( - torch.einsum( - "og,og->o", - prev_part_dequantized[beam_id].reshape(num_out_groups, out_group_size * in_group_size), - dWTXTXg[beam_id].view(num_out_groups, out_group_size * in_group_size), - ).view(1, num_out_groups) - ) - .view(codebook_size, num_out_groups) - ) - if scales is not None: - dot_products = dot_products.mul_(scales_part.reshape(1, num_out_groups)) - - XoldBkC_norms_sq = torch.bmm( - (prev_weight_part[beam_id] @ XTX[input_group_slice, input_group_slice]).view( - num_out_groups, 1, out_group_size * in_group_size - ), - prev_weight_part[beam_id].view(num_out_groups, out_group_size * in_group_size, 1), - ).reshape(1, num_out_groups) - - # finally, combine them to get MSE - candidate_squared_errors = ( - beam_losses[beam_id, None, :] - 2 * dot_products + XnewBkC_norms_sq - XoldBkC_norms_sq - ) # shape: [codebook_size, num_out_groups] - - if sparsity_regularizer != 0: - candidate_squared_errors += sparsity_regularizer * (prev_codes_part[beam_id] == 0).to(XTX.dtype)[None, :] - candidate_squared_errors[0, :] -= sparsity_regularizer - - best_beam_squared_errors, best_beam_indices = torch.topk( - candidate_squared_errors, k_best, dim=0, largest=False, sorted=False - ) - best_losses[beam_id] = best_beam_squared_errors - best_indices[beam_id] = best_beam_indices - - return best_losses, best_indices - - -@maybe_script -def _beam_search_select_best( - beam_codes: torch.Tensor, - beam_weights: torch.Tensor, - codebooks: torch.Tensor, - scales: Optional[torch.Tensor], - input_group_index: int, - codebook_index: int, - best_losses: torch.Tensor, - best_indices: torch.Tensor, - beam_size: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Select top-:beam_size: and reorder beam accordingly, return new beam - :param beam_codes: a tensor with best weight codes, shape: [beam_size, num_out_groups, num_in_groups, num_codebooks] - :param beam_weights: a tensor with de-quantized beam_codes, shape: [beam_size, out_features, in_features] - :param codebooks: a tensor with look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_size, in_group_size] - :param scales: weight will be multiplied by this factor, [num_out_groups, num_in_groups, 1, 1] - - :param input_group_index: an index of one group of in_features that is being re-encoded - :param codebook_index: an index of one codebook for that group of features that is being re-encoded - :param best_losses: a 3d tensor of losses of k_best lowest square errors for each beam and out group, - shape = [beam_size, k_best, num_out_groups] - :param best_indices: a 3d tensor of indices of k_best lowest square errors for each beam and out group, - shape = [beam_size, k_best, num_out_groups] - :param beam_size: how many top hypotheses should be selected - - :returns: new (beam_codes, beam_weights, beam_losses) - """ - dtype = best_losses.dtype - device = best_losses.device - _prev_beam_size, k_best, num_out_groups = best_losses.shape - _prev_beam_size, out_features, in_features = beam_weights.shape - _prev_beam_size, num_out_groups, num_in_groups, num_codebooks = beam_codes.shape - flat_best = best_losses.flatten(0, 1).topk(dim=0, k=beam_size, largest=False) - best_hypo_source_ids = flat_best.indices // k_best - arange_out_groups = torch.arange(num_out_groups, device=device) - best_hypo_codes = best_indices.flatten(0, 1)[flat_best.indices, arange_out_groups].reshape( - beam_size, num_out_groups - ) - # ^-- shape: [beam_size, num_out_groups] - - # reorder beam codes and weights - new_beam_codes = torch.full( - size=(len(best_hypo_codes), num_out_groups, num_in_groups, num_codebooks), - fill_value=-1, - dtype=beam_codes.dtype, - device=device, - ) # [beam_size, num_out_groups, num_in_groups, num_codebooks] - new_beam_weights = torch.empty(len(best_hypo_codes), out_features, in_features, dtype=dtype, device=device) - - for beam_index in range(len(best_hypo_codes)): - new_beam_codes[beam_index, :, ...] = beam_codes[best_hypo_source_ids[beam_index, :], arange_out_groups, ...] - new_beam_codes[beam_index, :, input_group_index, codebook_index] = best_hypo_codes[beam_index, :] - new_beam_weights[beam_index, :, :] = _dequantize_weight(new_beam_codes[beam_index, ...], codebooks, scales) - - # Note: the code above can be further accelerated by 1) vectorzing loop and ... - # ... 2) updating new_beam_weights only for the chosen input group - return new_beam_codes, new_beam_weights, flat_best.values - - -@maybe_script -def _channelwise_squared_error(XTX: torch.Tensor, weight: torch.Tensor, reference_weight: torch.Tensor): - """ - Compute per-channel squared error between X @ weight_or_weights and X @ reference_weight - :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] - :note: if XTX is divided by dataset size, this function will return *mean* squared error - :param weight: predicted/reconstructed weights of shape [*dims, out_features, in_features] - :param reference_weight: reference weight of shape [out_features, in_features] - :return: per-channel squared errors of shape [*dims, out_features] - """ - XW_norm_square = torch.matmul(weight[..., :, None, :], (weight @ XTX)[..., :, :, None]).flatten(-3) - XWreference_norm_square = torch.bmm(reference_weight[:, None, :], (reference_weight @ XTX)[:, :, None]).flatten(-3) - dot_product = torch.matmul((reference_weight @ XTX)[:, None, :], weight[..., :, :, None]).flatten(-3) - return XW_norm_square - 2 * dot_product + XWreference_norm_square - - @torch.no_grad() def init_aq_kmeans( reference_weight: torch.Tensor, diff --git a/src/beam_search_l2.py b/src/beam_search_l2.py new file mode 100644 index 0000000..9c49cb4 --- /dev/null +++ b/src/beam_search_l2.py @@ -0,0 +1,325 @@ +"""Beam search that minimizes ||Wref - Wq||^2 w.r.t. Wq""" +import math +import random +import time +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from src.utils import _dequantize_weight, maybe_script + + +@torch.inference_mode +def beam_search_optimal_codes( + reference_weight: torch.Tensor, + codebooks: torch.Tensor, + prev_codes: torch.Tensor, + scales: Optional[torch.Tensor], + beam_size: int, + stochastic_rounding_tau: float = 0.0, + chunk_size_bytes: int = 2**32, + dim_rng: Optional[random.Random] = None, + force_update: bool = False, + max_update_fraction: float = 1.0, + code_selection_temperature: float = 0, + trust_ratio: Optional[float] = None, +) -> torch.Tensor: + """ + Update codes using beam search to minimize L2 error in code values (regardless of activations) + :param reference_weight: a target for L2 error, [out_features, in_features] + :param codebooks: look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_size, in_group_size] + :param prev_codes: previous-best integer weight codes, shape: [num_output_groups, num_input_groups, num_codebooks] + :param scales: weight will be multiplied by this factor, shape = [num_output_groups, num_input_groups or 1, 1, 1] + :param dim_rng: a source of randomness to (optionally) shuffle the order in which the beam search runs + None = update dimensions and codebooks in their natural order (0, 1, ..., n) + random.Random(optional_seed) = shuffle dimensions at random, optionally using the specified seed + + :param beam_size: consider up to this many best encoding combinations + :param stochastic_rounding_tau: if positive, each time the algorithm chooses a code, it will have a probability + of replacing it with the second-best choice. If the two best codes increase the error by delta1 and delta2, + then the probability of choosing each code is P_i = delta_i ^ -1/tau / (sum_j_in_choices delta_j ^ -1/tau). + Note that if there is a code that has zero error, the algorithm will choose allways choose such a code + :param chunk_size_bytes: process this many candidates at a time; reduce to save memory + :param force_update: if True, the algorithm will force codes to change even if code is optimal in terms + of mean squared error. By default, the algorithm forces *all* weights to update this way, which may change weights + too much. To limit the numer of updated weights, set max_code_change and trust_ratio. + :param max_update_fraction: the maximum portion of discrete code groups that *can* be updated; + By default, all codes can be updated. If < 1, only this portion of all code groups is allowed to update. + The algorithm selects the codes for update based on the difference between de-quantized and reference_weight. + If there are multiple codebooks, changing any one code responsible for the group counts as code group changed. + Note that small max_code_change also speeds up computation since not all codes need beam search. + If the number of weights do not divide evenly, the algoritm will round the number of updates up. + :param code_selection_temperature: only used if max_code_change > 1; by default, prioritize updating the codes with + the largest delta = ||(reference_weight - quantized_weight) * mask_only_weights_that_depend_on_this_code||_2 . + If temperature > 0, the updated codes are instead *sampled* at random, proportionally to delta^(1/temperature) . + :param trust_ratio: if not None, the algorithm only admits code changes as long as they do not change too much. + Formally, ||new_quantized_weight - prev_quantized_weight|| / ||prev_quantized_weight|| <= trust_ratio + If this is not true, the algorithm will reset some of the new quantized weights to their old values until the + constraint becomes satisfied. The algorithm still prioritizes changes to weights with largest delta (see above). + If code_change_temperature > 0, the algorithm instead samples which weights to change with the same probability. + The algorithm will always allow changing exactly *one* code in excess of trust ratio to ensure that at least + one weight is updated. If both this and max_code_change is set, both these constraints are enforced. + :return: the best quantization codes found within constraints, same shape as prev_codes + + """ + assert 0 < max_update_fraction <= 1 and (trust_ratio is None or trust_ratio > 0) + # reshape references, codes and codebooks so they are no longer group-wise + num_output_groups, num_input_groups, num_codebooks = prev_codes.shape + _num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + + flat_unscaled_reference = reference_weight.reshape( + num_output_groups, out_group_size, num_input_groups, in_group_size + ).permute( + 0, 2, 1, 3 + ) # [num_output_groups, num_input_groups, out_group_size, in_group_size] + if scales is not None: + flat_unscaled_reference = flat_unscaled_reference / scales + # divide by scales; the resulting problem is equivalent to multiplying dequantized weight + flat_unscaled_reference = flat_unscaled_reference.flatten(2, 3).flatten(0, 1) + flat_prev_codes = prev_codes.flatten(0, -2) + flat_codebooks = codebooks.flatten(-2, -1).detach() + dim_order = list(range(num_codebooks)) + if dim_rng is not None: + dim_rng.shuffle(dim_order) + + def _update_flat_codes(_flat_reference, _flat_codes): + """update _flat_codes [num_groups, num_codebooks] to approximate _flat_reference [num_groups, group_size]""" + if num_codebooks == 1 and beam_size == 1 and stochastic_rounding_tau == 0 and not force_update: + # a faster algorithm for a special case of one codebook + return _greedy_find_best_codes( + reference=_flat_reference, + codebook=flat_codebooks[0], + chunk_size_values=chunk_size_bytes // _flat_reference[0, 0].nbytes, + code_dtype=prev_codes.dtype, + ) + else: + return _beam_search_update_codes_groupwise( + reference=_flat_reference, + codebooks=flat_codebooks, + codes=_flat_codes, + beam_size=beam_size, + stochastic_rounding_tau=stochastic_rounding_tau, + force_update=force_update, + chunk_size_values=chunk_size_bytes // _flat_reference[0, 0].nbytes, + dim_order=dim_order, + ) + + def _groupwise_squared_norms(delta: torch.Tensor): + """ + Given a matrix delta [out_features, in_features], compute a tensor [num_output_groups, num_input_groups] that + contains the squared sum of elements of delta from each tile of (out_group_size, in_group_size) values. + """ + return ( + delta.view(delta.shape[0] // out_group_size, out_group_size, delta.shape[1] // in_group_size, in_group_size) + .square() + .sum(dim=(1, 3)) + ) + + flat_indices_to_update = prev_dequantized_weight = None + if max_update_fraction < 1 or trust_ratio is not None: + # precompute ordered code indices to be used for constraints on the number of updates + prev_dequantized_weight = _dequantize_weight(prev_codes, codebooks, scales) + num_codes_to_update = int(math.ceil(max_update_fraction * num_output_groups * num_input_groups)) + difference_with_reference_squared_norms = _groupwise_squared_norms(reference_weight - prev_dequantized_weight) + # ^-- [num_output_groups, num_input_groups] + if code_selection_temperature > 0: + flat_indices_to_update = torch.pow( + difference_with_reference_squared_norms.flatten(), + 0.5 / code_selection_temperature, + # note: temperature is multuplied by 0.5 because sampling is proportional to norms without square + ).multinomial(num_samples=num_codes_to_update, replacement=False) + else: + flat_indices_to_update = torch.topk( + difference_with_reference_squared_norms.flatten(), k=num_codes_to_update, largest=True, sorted=True + ).indices + + if max_update_fraction == 1: + flat_new_codes = _update_flat_codes(flat_unscaled_reference, flat_prev_codes) + else: + flat_new_codes = flat_prev_codes.index_put( # note: this is an out-of-place op that does not modify prev codes + (flat_indices_to_update[:, None], torch.arange(num_codebooks, device=codebooks.device)[None, :]), + _update_flat_codes( + flat_unscaled_reference[flat_indices_to_update], flat_prev_codes[flat_indices_to_update] + ), + ) + + if trust_ratio is not None: + assert isinstance(flat_indices_to_update, torch.Tensor) and isinstance(prev_dequantized_weight, torch.Tensor) + new_dequantized_weight = _dequantize_weight(flat_new_codes.view_as(prev_codes), codebooks, scales) + weight_change_squared_norms = _groupwise_squared_norms(new_dequantized_weight - prev_dequantized_weight) + # ^-- shape: [num_output_groups, num_input_groups] + + flat_ordered_weight_change_squared_norms = weight_change_squared_norms.flatten()[flat_indices_to_update] + flat_ordered_cumulative_norms = flat_ordered_weight_change_squared_norms.cumsum(0).sqrt() + # [num_codes_to_update] + + num_codes_selected = 1 + torch.searchsorted( + flat_ordered_cumulative_norms, trust_ratio * prev_dequantized_weight.norm(), side="left" + ) + truncated_flat_indices_to_update = flat_indices_to_update[:num_codes_selected] # sorted most to least important + flat_new_codes = flat_prev_codes.index_put( # <-- note: this is an out-of-place operation + (truncated_flat_indices_to_update[:, None], torch.arange(num_codebooks, device=codebooks.device)[None, :]), + flat_new_codes[truncated_flat_indices_to_update], + ) + return flat_new_codes.view_as(prev_codes) + + +@maybe_script +def _beam_search_update_codes_groupwise( + reference: torch.Tensor, + codebooks: torch.Tensor, + codes: torch.Tensor, + *, + beam_size: int, + stochastic_rounding_tau: float, + chunk_size_values: int, + dim_order: Optional[List[int]], + force_update: bool, +) -> torch.Tensor: + """ + :param reference: [num_groups, group_size] + :param codes: [num_groups, num_codebooks] + :param codebooks: [num_codebooks, codebook_size, group_size] + :returns: [num_groups, num_codebooks] + """ + if stochastic_rounding_tau > 0: + assert beam_size >= 2, "with stochastic rounding, we need at least 2 hypotheses to choose from" + + prev_codes = codes + device = reference.device + num_groups, group_size = reference.shape + num_codebooks, codebook_size, group_size = codebooks.shape + codebook_offsets = torch.arange(0, num_codebooks * codebook_size, codebook_size, device=device) # [num_codebooks] + original_dequantized_vectors = F.embedding_bag( + codes + codebook_offsets, codebooks.flatten(0, 1), mode="sum" + ) # [num_groups, group_size] + if dim_order is None: + dim_order = list(range(num_codebooks)) + + code_norms_sq = codebooks.square().sum(-1) # [num_codebooks, codebook_size] + beam_codes = codes.clone().unsqueeze(1) # [num_groups, current_beam_size, num_codebooks] + residue = (reference - original_dequantized_vectors).view(num_groups, 1, group_size) + # shape: [num_groups, current_beam_size, group_size] + direction = residue.clone().view(num_groups, group_size) if force_update else torch.empty(0) + + for i, codebook_index in enumerate(dim_order): + current_beam_size = residue.shape[1] + is_last_step = i == len(dim_order) - 1 + # ^-- [num_groups, current_beam_size, group_size] + residue = residue + F.embedding(beam_codes[..., codebook_index], codebooks[codebook_index, ...]) + if beam_size > 1 or stochastic_rounding_tau > 0: + residue_norms_sq = residue.square().sum(-1).unsqueeze(-1) # [num_groups, current beam size, 1] + else: + residue_norms_sq = torch.empty(0, device=device) # when doing greedy search, these are const + + if not is_last_step: + target_num_candidates = beam_size + int(stochastic_rounding_tau > 0) + else: + target_num_candidates = 2 if stochastic_rounding_tau > 0 or force_update else 1 + + flat_best_indices = torch.empty(num_groups, target_num_candidates, device=device, dtype=codes.dtype) + chunk_size_rows = chunk_size_values // (codebook_size * current_beam_size) // 32 + for chunk_start in range(0, num_groups, chunk_size_rows): + chunk_end = min(chunk_start + chunk_size_rows, num_groups) + scores = torch.matmul(residue[chunk_start:chunk_end], codebooks[codebook_index].T) + if beam_size > 1 or stochastic_rounding_tau > 0: + scores = residue_norms_sq[chunk_start:chunk_end] - 2 * scores + code_norms_sq[codebook_index] + else: + scores = -2 * scores + code_norms_sq[codebook_index] # residue norms are const(j) + # ^-- [num_groups_chunk, beam_size, codebook_size] + + flat_best_losses_chunk, flat_best_indices_chunk = torch.topk( + scores.flatten(1, 2), + k=target_num_candidates, + largest=False, + sorted=is_last_step or beam_size > 1 or stochastic_rounding_tau > 0, + ) # [num_groups_chunk, target_num_candidates] + + if stochastic_rounding_tau > 0: + errors = flat_best_losses_chunk.relu().sqrt() # non-squared errors + scores = torch.pow(errors / errors.sum(-1, keepdim=True), -1 / stochastic_rounding_tau) + # ^-- [num_groups_chunk, beam_size + 1] + keep_prob = scores[:, :-1] / (scores[:, :-1] + scores[:, 1:]) # [num_groups, k_best] + keep_prob = torch.where(torch.isinf(scores[:, :-1]), 1.0, keep_prob) + keep = torch.less_equal(torch.rand_like(keep_prob), keep_prob) + flat_best_indices_chunk = torch.where( + keep, flat_best_indices_chunk[:, :-1], flat_best_indices_chunk[:, 1:] + ) + + flat_best_indices[chunk_start:chunk_end] = flat_best_indices_chunk + + arange_num_groups = torch.arange(num_groups, device=device) + best_hypo_source_ids = flat_best_indices // codebook_size + best_hypo_codes = flat_best_indices % codebook_size + beam_codes = beam_codes[arange_num_groups[:, None], best_hypo_source_ids, :] + beam_codes[:, :, codebook_index] = best_hypo_codes.to(beam_codes.dtype) + # ^-- [num_groups, beam_size, num_codebooks] + + if not is_last_step: + residue = residue - F.embedding(beam_codes[..., codebook_index], codebooks[codebook_index, ...]) + + if force_update: + assert beam_codes.shape[1] == 2 + best_codes = beam_codes[:, 0, :] + second_best_codes = beam_codes[:, 1, :] + best_code_changed = torch.ne(best_codes, prev_codes).any(dim=-1) + return torch.where(best_code_changed.unsqueeze(-1), best_codes, second_best_codes) + else: + return beam_codes[:, 0, :] + + +@maybe_script +def _greedy_find_best_codes( + reference: torch.Tensor, codebook: torch.Tensor, chunk_size_values: int, code_dtype: torch.dtype +) -> torch.Tensor: + """ + :param reference: [num_groups, group_size] + :param codebook: [codebook_size, group_size] + :param chunk_size_values: how many values can be materialized in memory simultaneously + :parma code_dtype the dtype of optimal codes returned by this function + :returns: codes [num_groups, 1] + """ + codebook_t = codebook.T.contiguous() + chunk_size = chunk_size_values // len(codebook) + codebook_norms_sq = codebook.square().sum(dim=-1) + new_codes = torch.empty((len(reference),), dtype=code_dtype, device=reference.device) + for chunk_start in range(0, len(reference), chunk_size): + new_codes[chunk_start : chunk_start + chunk_size] = torch.addmm( + codebook_norms_sq[None], reference[chunk_start : chunk_start + chunk_size], codebook_t, alpha=-2 + ).argmin(-1) + return new_codes.unsqueeze(-1) + + +def _find_optimal_codebooks( + reference: torch.Tensor, + codebooks: torch.Tensor, + codes: torch.Tensor, +) -> torch.Tensor: + num_samples = len(reference) + num_codebooks, codebook_size, group_size = codebooks.shape + + # compute optimal codebooks via linsolve + codebook_offsets = torch.arange(num_codebooks, device=codes.device) * codebook_size + code_indicators = torch.sparse_coo_tensor( + indices=torch.stack( + [ + torch.arange(num_samples * num_codebooks, device=codes.device) // num_codebooks, + (codes + codebook_offsets).flatten(), + ], + 0, + ), + values=torch.ones(num_samples * num_codebooks, device=codes.device), + size=(num_samples, num_codebooks * codebook_size), + ) + cooc = (code_indicators.T @ code_indicators).coalesce() + rhs = code_indicators.T @ reference + + try: + cooc = cooc.to_dense() + cooc[torch.arange(len(cooc)), torch.arange(len(cooc))].clamp_min_(1.0) + optimal_codebooks = (torch.linalg.lstsq(cooc, rhs)).solution.reshape(num_codebooks, codebook_size, group_size) + except Exception as e: + print(f"Linsolve failed with {e}") + optimal_codebooks = codebooks + return optimal_codebooks diff --git a/src/beam_search_xtx.py b/src/beam_search_xtx.py new file mode 100644 index 0000000..90c4801 --- /dev/null +++ b/src/beam_search_xtx.py @@ -0,0 +1,361 @@ +""" Beam search that minimizes ||XWref - XWq||^2 w.r.t. Wq codes """ +import random +from typing import Optional, Tuple + +import torch +from torch.nn import functional as F +from tqdm.asyncio import trange + +from src.utils import _dequantize_weight, maybe_script + + +@torch.inference_mode() +def beam_search_optimal_codes( + *, + XTX: torch.Tensor, + reference_weight: torch.Tensor, + codebooks: torch.Tensor, + prev_codes: torch.IntTensor, + scales: Optional[torch.Tensor], + beam_size: int, + dim_rng: Optional[random.Random] = None, + sparsity_regularizer: float = 0, + verbose: bool, +): + """ + :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] + :note: if XTX is divided by dataset size, this function will return *mean* squared error + :param reference_weight: original weight matrix that is being quantized, shape: [out_features, in_features] + :param codebooks: look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_siz, in_group_size] + :param prev_codes: previous-best integer weight codes, shape: [num_out_groups, num_in_groups, num_codebooks] + :param scales: weight will be multiplied by this factor, shape = [num_out_groups, num_in_groups or 1, 1, 1] + :param dim_rng: a source of randomness to (optionally) shuffle the order in which the beam search runs + None = update dimensions and codebooks in their natural order (0, 1, ..., n) + random.Random(optional_seed) = shuffle dimensions at random, optionally using the specified seed + + :param beam_size: consider up to this many best encoding combinations + :param sparsity_regularizer: subtract this value from beam search objective each time you have a zero code somewhere + :param verbose: if True, draw a progressbar and periodically print best loss + :return: best quantization codes found, same shape as prev_codes + + :intuition: the beam search needs to produce weight codes that minimize MSE error + - the codes are of shape [out_features / out_group_size, in_features / in_group_size, num_codebooks] + + Out of those three dimensions, out_features is "independent", i.e. changing code in + one output feature does not increase the MSE error for another feature. Therefore, + beam search for different output features can run in independently in parallel. + + Neither (in_features / in_group_size) nor (num_codebooks) dimension are independent: + - changing the encoding for one feature can compensate the error from encoding another, OBC-style + - for a single weight group, changing code in one codebook can affect the optimal choice in another codebook + Therefore, beam search must go in a double loop over (in_features/in_group_size) and (num_codebooks) dimensions + + This leaves one choice: which dimension used for outer loop, and which one goes is in the inner loop? + Due to the nature of beam search, interactions between dimensions of inner loop will be explored better. + We chose to use (in_features/in_group_size) in the outer loop and (num_codebooks) for the inner loop. + This is based on an intuition from GPTQ: you can get decent performance by quantizing each input unit ... + ... greedily --- GPTQ does not change quantizations for previously quantized features and works fine. + Therefore, we believe that we can also use a greedy approach to compensate error between input features. + In turn, we believe that the codes used to encode the same weights (additively) are more inter-dependent. + This should be treated as an educated guess with no proof and no ablation (as of the time of writing). + + """ + num_out_groups, num_in_groups, num_codebooks = prev_codes.shape + num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + in_features = num_in_groups * in_group_size + out_features = num_out_groups * out_group_size + assert reference_weight.shape == (out_features, in_features) + prev_weight = _dequantize_weight(prev_codes, codebooks, scales) + + # initialize all beam codes as previous codes - so they can be updated during beam search + beam_codes = prev_codes.unsqueeze(0) + # beam_codes shape: [current beam_size, num_out_groups, num_in_groups, num_codebooks], initial beam_size = 1 + beam_weights = prev_weight.unsqueeze(0) + # beam_weights shape: [current beam_size, out_features, in_features], initial beam size = 1 + + beam_losses = ( + _channelwise_squared_error(XTX, prev_weight, reference_weight) + .reshape(1, num_out_groups, out_group_size) + .sum(-1) + ) + # beam_losses shape: [current beam_size, num_out_groups], initial beam_size = 1 + if sparsity_regularizer != 0: + beam_losses = beam_losses - sparsity_regularizer * (prev_codes == 0).sum(dim=(-1, -2))[None, :] + + if verbose: + progressbar = trange(num_in_groups * num_codebooks) + + def _make_range(n: int) -> list: + seq = list(range(n)) + if dim_rng is not None: + dim_rng.shuffle(seq) + return seq + + for input_group_index in _make_range(num_in_groups): + for codebook_index in _make_range(num_codebooks): + ### part 1: compute losses for every possible candidate for one given codebook and input group. + # Currently, we compute errors for all output features in parallel in a vectorized fashion. + best_losses, best_indices = _beam_search_squared_errors( + XTX=XTX, + reference_weight=reference_weight, + codebooks=codebooks, + scales=scales, + beam_losses=beam_losses, + beam_codes=beam_codes, + beam_weights=beam_weights, + input_group_index=input_group_index, + codebook_index=codebook_index, + k_best=beam_size, + sparsity_regularizer=sparsity_regularizer, + ) # [current beam_size, codebook_size, num_out_groups] + + # part 2: select beam_size new best codes and re-arrange beam to account for the fact that ... + # ... sometimes two or more top candidates originate from the same source in previous beam + beam_codes, beam_weights, beam_losses = _beam_search_select_best( + beam_codes=beam_codes, + beam_weights=beam_weights, + codebooks=codebooks, + scales=scales, + input_group_index=input_group_index, + codebook_index=codebook_index, + best_losses=best_losses, + best_indices=best_indices, + beam_size=beam_size, + ) + + if verbose: + progressbar.update() + if (input_group_index * num_codebooks + codebook_index) % verbose != 0: + continue # if update is an integer, compute metrics every (this many) beam search steps + best_loss = beam_losses.min(0).values.sum().item() / out_features + info = f"in_group {input_group_index} / {num_in_groups} " + info += f"| codebook {codebook_index} / {num_codebooks} " + if sparsity_regularizer == 0: + info += f"| loss {best_loss:.10f}" + else: # un-regularize to restore MSE loss, report sparsity rate + num_zero_codes = (beam_codes[0] == 0).sum().item() + best_loss = best_loss + sparsity_regularizer / out_features * num_zero_codes + sparsity = num_zero_codes / prev_codes.numel() + info += f"| loss {best_loss:.5f} | sparse {sparsity * 100:.1f}% |" + + progressbar.desc = info + return beam_codes[0] + + +@maybe_script +def _beam_search_squared_errors( + XTX: torch.Tensor, + reference_weight: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor], + beam_losses: torch.Tensor, + beam_codes: torch.Tensor, + beam_weights: torch.Tensor, + input_group_index: int, + codebook_index: int, + k_best: int, + sparsity_regularizer: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute MSE or sum-of-squared-error losses for all possible ways to replace quantization codes for one input group + and one codebook. Works in parallel for all output-dimension groups. + + :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] + :note: if both XTX *and* beam_loses are divided by dataset size, this function will return mean squared error + :param reference_weight: original weight matrix that is being quantized, shape: [out_features, in_features] + :param codebooks: look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, [num_out_groups, num_in_groups, 1, 1] + + :param beam_losses: sum-of-squared-error for each hypothesis in beam and for each output channel; + shape: [beam_size, num_out_groups] + :param beam_codes: a tensor with best weight codes, shape: [beam_size, num_out_groups, num_in_groups, num_codebooks] + :param beam_weights: a tensor with de-quantized beam_codes, shape: [beam_size, out_features, in_features] + :param input_group_index: an index of one group of in_features that is being re-encoded + :param codebook_index: an index of one codebook for that group of features that is being re-encoded + :return: tuple(Tensor, Tensor) of 3d tensor of shape = [beam_size, k_best, num_out_groups]. + First one is float tensor of losses of k_best lowest square errors for each beam and out_group + Second one is int64 tensor of indices of k_best lowest square errors for each beam and out_group + + :note: The code computes MSE using the square-of-difference expansion + ||X@W.T - sum_i X@(Bi@Ci).T||^2 = ||X@W.T||^2 - 2 + ||sum_i X@Bi@Ci||^2 + where X[nsamples,in_features] is calibration data, W[out_features, in_features] is the reference weight, + C[num_codebooks, codebook_size, in_features] are learned codebooks (Ci has shape [codebook_size, out_features]) + B[num_codebooks, out_features, codebook_size] are one-hot encoded indices (quantization codes) + The formula above uses a single group per output "neuron" and a single group. + The algorithm below generalizes the formula for multiple groups and codebooks. + + Furthermore, the algorithm does not compute the entire formula. Instead, it begins from some baseline loss + and computes the change in loss from changing a single code to every possible altearnative code. + When computing the changed loss, the algorithm only computes the few affected parts of the loss formula above. + """ + num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + beam_size, num_out_groups, num_in_groups, num_codebooks = beam_codes.shape + out_features = num_out_groups * out_group_size + + input_group_slice = slice(input_group_index * in_group_size, (input_group_index + 1) * in_group_size) + + prev_codes_part = beam_codes[:, :, input_group_index, codebook_index] # [beam_size, num_out_groups] + + if scales is not None: + scales_part = scales[:, input_group_index % scales.shape[1], :, :] # [num_out_groups, 1, 1] + else: + scales_part = torch.empty(0, device=XTX.device) + prev_part_dequantized = F.embedding(prev_codes_part, codebooks[codebook_index].flatten(-2, -1)).view( + beam_size, out_features, in_group_size + ) # previous codes de-quantized + + prev_weight_part = prev_part_dequantized + if scales is not None: + prev_weight_part = ( + prev_weight_part.view(beam_size, num_out_groups, out_group_size, in_group_size) + .mul(scales_part) + .view(beam_size, out_features, in_group_size) + ) + + cand_weights = codebooks[codebook_index] # [codebook_size, out_group_size, in_group_size], all replacement codes + + delta_weight_without_part = reference_weight - beam_weights + delta_weight_without_part[:, :, input_group_slice] += prev_weight_part + + # dWTXTX is equivalent to < X @ (W - \sum BiCi except current codebook), X @ SOMETHING > + dWTXTXg = delta_weight_without_part @ XTX[..., input_group_slice] # [beam_size, out_features, in_group_size] + # below: use torch.matmul to compute broadcasted batch matrix multiplication; see matmul docs + + XnewBkC_norms_sq = torch.bmm( + (cand_weights.flatten(0, 1) @ XTX[input_group_slice, input_group_slice]).view( + codebook_size, 1, out_group_size * in_group_size + ), + cand_weights.view(codebook_size, out_group_size * in_group_size, 1), + ).reshape( + codebook_size, 1 + ) # [codebook_size, num_out_groups] + if scales is not None: + XnewBkC_norms_sq = XnewBkC_norms_sq.mul(scales_part.square().reshape(1, num_out_groups)) + + best_losses = torch.empty( + (beam_size, k_best, num_out_groups), dtype=XTX.dtype, device=XTX.device + ) # shape: [beam_size, k_best, num_out_groups] + best_indices = torch.empty( + (beam_size, k_best, num_out_groups), + dtype=torch.int64, + device=XTX.device, + ) + for beam_id in range(beam_size): + dot_products = ( + torch.einsum( + "mg,og->mo", + cand_weights.reshape(codebook_size, out_group_size * in_group_size), + dWTXTXg[beam_id].view(num_out_groups, out_group_size * in_group_size), + ) + .sub_( + torch.einsum( + "og,og->o", + prev_part_dequantized[beam_id].reshape(num_out_groups, out_group_size * in_group_size), + dWTXTXg[beam_id].view(num_out_groups, out_group_size * in_group_size), + ).view(1, num_out_groups) + ) + .view(codebook_size, num_out_groups) + ) + if scales is not None: + dot_products = dot_products.mul_(scales_part.reshape(1, num_out_groups)) + + XoldBkC_norms_sq = torch.bmm( + (prev_weight_part[beam_id] @ XTX[input_group_slice, input_group_slice]).view( + num_out_groups, 1, out_group_size * in_group_size + ), + prev_weight_part[beam_id].view(num_out_groups, out_group_size * in_group_size, 1), + ).reshape(1, num_out_groups) + + # finally, combine them to get MSE + candidate_squared_errors = ( + beam_losses[beam_id, None, :] - 2 * dot_products + XnewBkC_norms_sq - XoldBkC_norms_sq + ) # shape: [codebook_size, num_out_groups] + + if sparsity_regularizer != 0: + candidate_squared_errors += sparsity_regularizer * (prev_codes_part[beam_id] == 0).to(XTX.dtype)[None, :] + candidate_squared_errors[0, :] -= sparsity_regularizer + + best_beam_squared_errors, best_beam_indices = torch.topk( + candidate_squared_errors, k_best, dim=0, largest=False, sorted=False + ) + best_losses[beam_id] = best_beam_squared_errors + best_indices[beam_id] = best_beam_indices + + return best_losses, best_indices + + +@maybe_script +def _beam_search_select_best( + beam_codes: torch.Tensor, + beam_weights: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor], + input_group_index: int, + codebook_index: int, + best_losses: torch.Tensor, + best_indices: torch.Tensor, + beam_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Select top-:beam_size: and reorder beam accordingly, return new beam + :param beam_codes: a tensor with best weight codes, shape: [beam_size, num_out_groups, num_in_groups, num_codebooks] + :param beam_weights: a tensor with de-quantized beam_codes, shape: [beam_size, out_features, in_features] + :param codebooks: a tensor with look-up tables of codes, shape: [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, [num_out_groups, num_in_groups, 1, 1] + + :param input_group_index: an index of one group of in_features that is being re-encoded + :param codebook_index: an index of one codebook for that group of features that is being re-encoded + :param best_losses: a 3d tensor of losses of k_best lowest square errors for each beam and out group, + shape = [beam_size, k_best, num_out_groups] + :param best_indices: a 3d tensor of indices of k_best lowest square errors for each beam and out group, + shape = [beam_size, k_best, num_out_groups] + :param beam_size: how many top hypotheses should be selected + + :returns: new (beam_codes, beam_weights, beam_losses) + """ + dtype = best_losses.dtype + device = best_losses.device + _prev_beam_size, k_best, num_out_groups = best_losses.shape + _prev_beam_size, out_features, in_features = beam_weights.shape + _prev_beam_size, num_out_groups, num_in_groups, num_codebooks = beam_codes.shape + flat_best = best_losses.flatten(0, 1).topk(dim=0, k=beam_size, largest=False) + best_hypo_source_ids = flat_best.indices // k_best + arange_out_groups = torch.arange(num_out_groups, device=device) + best_hypo_codes = best_indices.flatten(0, 1)[flat_best.indices, arange_out_groups].reshape( + beam_size, num_out_groups + ) + # ^-- shape: [beam_size, num_out_groups] + + # reorder beam codes and weights + new_beam_codes = torch.full( + size=(len(best_hypo_codes), num_out_groups, num_in_groups, num_codebooks), + fill_value=-1, + dtype=beam_codes.dtype, + device=device, + ) # [beam_size, num_out_groups, num_in_groups, num_codebooks] + new_beam_weights = torch.empty(len(best_hypo_codes), out_features, in_features, dtype=dtype, device=device) + + for beam_index in range(len(best_hypo_codes)): + new_beam_codes[beam_index, :, ...] = beam_codes[best_hypo_source_ids[beam_index, :], arange_out_groups, ...] + new_beam_codes[beam_index, :, input_group_index, codebook_index] = best_hypo_codes[beam_index, :] + new_beam_weights[beam_index, :, :] = _dequantize_weight(new_beam_codes[beam_index, ...], codebooks, scales) + + # Note: the code above can be further accelerated by 1) vectorzing loop and ... + # ... 2) updating new_beam_weights only for the chosen input group + return new_beam_codes, new_beam_weights, flat_best.values + + +@maybe_script +def _channelwise_squared_error(XTX: torch.Tensor, weight: torch.Tensor, reference_weight: torch.Tensor): + """ + Compute per-channel squared error between X @ weight_or_weights and X @ reference_weight + :param XTX: pairwise products of input features matmul(X.transpose(), X), shape: [in_features, in_features] + :note: if XTX is divided by dataset size, this function will return *mean* squared error + :param weight: predicted/reconstructed weights of shape [*dims, out_features, in_features] + :param reference_weight: reference weight of shape [out_features, in_features] + :return: per-channel squared errors of shape [*dims, out_features] + """ + XW_norm_square = torch.matmul(weight[..., :, None, :], (weight @ XTX)[..., :, :, None]).flatten(-3) + XWreference_norm_square = torch.bmm(reference_weight[:, None, :], (reference_weight @ XTX)[:, :, None]).flatten(-3) + dot_product = torch.matmul((reference_weight @ XTX)[:, None, :], weight[..., :, :, None]).flatten(-3) + return XW_norm_square - 2 * dot_product + XWreference_norm_square diff --git a/src/configurable_adam.py b/src/configurable_adam.py new file mode 100644 index 0000000..2140876 --- /dev/null +++ b/src/configurable_adam.py @@ -0,0 +1,245 @@ +import math +from contextlib import contextmanager +from typing import Iterable, Optional, Tuple, Union + +import torch + +from src.utils import maybe_script + +NO_DATA = torch.empty(0) + + +class ConfigurableAdamW(torch.optim.Optimizer): + r""" + A version of Adam optimizer that supports custom parameter dtypes, amsgrad, lamb or rmsprop on per-group basis. + Adam and Amsgrad based on https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py + Lamb flag based on https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py + This was tested to match Adam and Lamb exactly for torch 2.3.0 (when compute_dtypes are all None) + :param exp_avg_dtype: dtype for storing first moments; only created if betas[0] != 0; defaults to param dtype + :param exp_avg_sq_dtype: dtype for storing second moments; only created if betas[1] != 0; defaults to param dtype + :param v_hat_max_dtype: dtype for storing maximum v_hat; only created if amsgrad=True; defaults to param dtype + :param exp_avg_device: device for storing exp_avg buffers; only created if betas[0]!=0; defaults to param.device + :param exp_avg_sq_device: device for storing exp_avg_sq only created if betas[1]!=0; defaults to param.device + :param v_hat_max_device: device for storing v_hat buffers; only created if amsgrad=True; defaults to param.device + :note: if any of these devices are CPU, they will be prefetched for optimizer step using pinned memory + :param compute_dtype: dtype for optimizer step computation; defaults to param dtype + """ + + def __init__( + self, + params: Iterable[Union[torch.Tensor, dict]], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0, + debias: Optional[bool] = None, + amsgrad: bool = False, + lamb: bool = False, + clamp_value: Optional[float] = None, + compute_dtype: Optional[torch.dtype] = None, + exp_avg_dtype: Optional[torch.dtype] = None, + exp_avg_sq_dtype: Optional[torch.dtype] = None, + v_hat_max_dtype: Optional[torch.dtype] = None, + exp_avg_device: torch.device = None, + exp_avg_sq_device: torch.device = None, + v_hat_max_device: torch.device = None, + ) -> None: + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + debias=debias, + amsgrad=amsgrad, + lamb=lamb, + clamp_value=clamp_value, + compute_dtype=compute_dtype, + exp_avg_dtype=exp_avg_dtype, + exp_avg_sq_dtype=exp_avg_sq_dtype, + v_hat_max_dtype=v_hat_max_dtype, + exp_avg_device=exp_avg_device, + exp_avg_sq_device=exp_avg_sq_device, + v_hat_max_device=v_hat_max_device, + ) + super().__init__(params, defaults) + + def _maybe_init_state(self, param: torch.Tensor, group: dict) -> dict: + state = self.state[param] + if "step" not in state: + state["step"] = 0 + if group["betas"][0] != 0 and "exp_avg" not in state: + pin_memory = group["exp_avg_device"] == torch.device("cpu") + state["exp_avg"] = torch.zeros_like( + param, + dtype=group["exp_avg_dtype"], + memory_format=torch.preserve_format, + device=group["exp_avg_device"], + pin_memory=pin_memory, + ) + if group["betas"][1] not in (0, 1) and "exp_avg_sq" not in state: + pin_memory = group["exp_avg_sq_device"] == torch.device("cpu") + state["exp_avg_sq"] = torch.zeros_like( + param, + dtype=group["exp_avg_sq_dtype"], + memory_format=torch.preserve_format, + device=group["exp_avg_sq_device"], + pin_memory=pin_memory, + ) + if group["amsgrad"] and "v_hat_max" not in state: + pin_memory = group["v_hat_max_device"] == torch.device("cpu") + state["v_hat_max"] = torch.zeros_like( + param, + dtype=group["v_hat_max_dtype"], + memory_format=torch.preserve_format, + device=group["v_hat_max_device"], + pin_memory=pin_memory, + ) + return state + + @torch.no_grad() + def step(self, closure: Optional[callable] = None): + r"""Performs a single optimization step. + Arguments: + closure: A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group, p, state in self.iterate_groups_with_prefetch(): + assert p.grad is not None + assert not p.grad.is_sparse, f"{self} does not support sparse gradients" + grad = p.grad.data + + state["step"] += 1 + beta1, beta2 = group["betas"] + compute_dtype = group.get("compute_dtype") or p.dtype + + if not group["lamb"] and group["weight_decay"] != 0: + p.data = p.data.mul_(1 - group["lr"] * group["weight_decay"]) + # adam weight decay is not scaled by bias correction + + # Decay the first and second moment running average coefficient + update = _inner_adam_step_and_update_statistics( + p, + grad, + state.get("exp_avg", p), + state.get("exp_avg_sq", p), + state.get("v_hat_max", p), + beta1, + beta2, + group["eps"], + group["amsgrad"], + compute_dtype, + ) + + if group["lamb"] and group["weight_decay"] != 0: + update = update.add(p, alpha=group["weight_decay"]) + # lamb weight decay is later multiplied by -lr * trust_ratio * bias_correction + + update_scale = -group["lr"] + # below: to save compute, we update scalar coefficient to account for debias/lamb/.. and multiply once + if group["debias"] if group["debias"] is not None else (not group["lamb"]): + # if not specified, default to True for Adam, False for Lamb + mt_debias = 1.0 / (1 - beta1 ** state["step"]) if beta1 != 0 else 1 + vt_debias = 1.0 / math.sqrt(1 - beta2 ** state["step"]) if beta2 != 0 else 1 + bias_correction = mt_debias / vt_debias + update_scale *= bias_correction + + if group["lamb"]: + weight_norm = torch.norm(p.data.to(compute_dtype)) + update_norm = torch.norm(update) + # note: lamb does not count debiasing when computing trust ratio + if group["clamp_value"] is not None: + weight_norm = torch.clamp_max_(weight_norm, group["clamp_value"]) + if weight_norm == 0 or update_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / update_norm + update_scale *= trust_ratio + + p.data.add_(update, alpha=update_scale) + return loss + + def iterate_groups_with_prefetch(self): + """Iterate parameters and optimizer states; skip parameters that do not require grad""" + flat_params = [ + (group, param) for group, param in _get_flat_param_groups(self.param_groups) if param.grad is not None + ] + + active_group, active_param = flat_params[0] + active_state = self._maybe_init_state(active_param, active_group) + active_state_fetched = _fetch_state_to_device(active_state, active_param.device) + + for next_group, next_param in flat_params[1:] + [(active_group, active_param)]: + next_state = self._maybe_init_state(next_param, next_group) + next_state_fetched = _fetch_state_to_device(next_state, next_param.device) + + yield active_group, active_param, active_state_fetched + + _commit_state_updates(active_state, active_state_fetched) + + active_group, active_param, active_state, active_state_fetched = ( + next_group, + next_param, + next_state, + next_state_fetched, + ) + + +@maybe_script +def _inner_adam_step_and_update_statistics( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + exp_avg_sq: torch.Tensor, + v_hat_max: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + amsgrad: bool, + compute_dtype: torch.dtype, +): + grad = grad.to(compute_dtype, copy=True) + stored_exp_avg, stored_exp_avg_sq, stored_v_hat_max = exp_avg, exp_avg_sq, v_hat_max + if beta1 != 0: + exp_avg = exp_avg.to(compute_dtype).lerp(grad, 1 - beta1) + stored_exp_avg.copy_(exp_avg, non_blocking=True) + update = exp_avg + else: + update = grad.clone() + + if beta2 == 1: + pass + else: + if beta2 == 0: + exp_avg_sq = grad.square() + else: + exp_avg_sq = exp_avg_sq.to(compute_dtype).lerp(grad.square(), (1 - beta2)) + stored_exp_avg_sq.copy_(exp_avg_sq, non_blocking=True) + if amsgrad: + exp_avg_sq = torch.maximum(exp_avg_sq, v_hat_max, out=exp_avg_sq) + stored_v_hat_max.copy_(exp_avg_sq, non_blocking=True) + + update /= exp_avg_sq.sqrt().add(eps) + + return update + + +def _get_flat_param_groups(param_groups): + return [(group, param) for group in param_groups for param in group["params"]] + + +def _fetch_state_to_device(state, device): + fetchable_state_keys = {"exp_avg", "exp_avg_sq", "v_hat_max"}.intersection(state.keys()) + fetched_states = {state_key: state[state_key].to(device, non_blocking=True) for state_key in fetchable_state_keys} + return state | fetched_states + + +def _commit_state_updates(offloaded_states, fetched_states): + fetched_keys = {"exp_avg", "exp_avg_sq", "v_hat_max"} + for state_key in offloaded_states: + if state_key not in fetched_keys: + offloaded_states[state_key] = fetched_states[state_key] + elif offloaded_states[state_key] is not fetched_states[state_key]: + offloaded_states[state_key].copy_(fetched_states[state_key], non_blocking=True) diff --git a/src/datautils.py b/src/datautils.py index c944e72..aa8c56f 100644 --- a/src/datautils.py +++ b/src/datautils.py @@ -1,14 +1,16 @@ import os import random -from typing import Optional +from itertools import chain +from typing import Optional, Sequence -import datasets import numpy as np import torch +import torch.distributed from datasets import load_dataset -from packaging import version +from torch import nn from tqdm import trange -from transformers import AutoTokenizer, LlamaTokenizer +from tqdm.auto import tqdm +from transformers import AutoTokenizer def set_seed(seed: Optional[int]): @@ -248,3 +250,76 @@ def get_loaders( print(f"Loaded data from {name}; {len(data)=} sequences") return data + + +def split_long_texts(inputs: Sequence[str], split_max_length: int): + """Split examples that exceed split_max_length into multiple sub-examples""" + outputs = [] + for index, input_str in enumerate(inputs): + while True: + truncation_index = input_str.find("\n", split_max_length) + if truncation_index == -1: + outputs.append(input_str) + break + outputs.append(input_str[:truncation_index]) + input_str = input_str[truncation_index + 1 :] # continue after \n + return outputs + + +def group_texts(examples: Sequence[Sequence[int]], block_size: int, add_labels: bool = True): + """Group tokenized examples together and split them into blocks of up to block_size tokens""" + # based on https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() + } + if add_labels: + result["labels"] = result["input_ids"].copy() + return result + + +@torch.no_grad() +def evaluate_perplexity( + model: nn.Module, data: torch.Tensor, seqlen: int, device: torch.device, amp_dtype: Optional[torch.dtype] = None +) -> float: + """Perplexity evaluation as per https://github.com/IST-DASLab/gptq (standard among quantization research)""" + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + + inps = [ + data[:, start : start + seqlen] for start in range(0, data.shape[1], seqlen) if start + seqlen < data.shape[1] + ] # ignore last incomplete sequence as in the GPTQ paper + num_sequences_without_padding = len(inps) + + # pad sequences to be divisible by world_size for DDP/FSDP compatibility + num_padding_sequences = -len(inps) % world_size + inps.extend([inps[-1]] * num_padding_sequences) + + total_nll_and_tokens = torch.tensor([0.0, 0.0], dtype=torch.float64, device=device) + total_nll, total_tokens = total_nll_and_tokens[0], total_nll_and_tokens[1] + + for sequence_index, input_ids in enumerate(tqdm(inps, desc="Evaluating perplexity") if rank == 0 else inps): + if sequence_index % world_size != rank: + continue + input_ids = input_ids.to(device) + with torch.cuda.amp.autocast(enabled=amp_dtype is not None, dtype=amp_dtype or torch.float32): + lm_logits = model(input_ids).logits + + if sequence_index < num_sequences_without_padding: + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = input_ids[:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + total_nll += loss.float() * shift_labels.numel() + total_tokens += shift_labels.numel() + + if world_size > 1: + torch.distributed.all_reduce(total_nll_and_tokens, op=torch.distributed.ReduceOp.SUM) + ppl = torch.exp(total_nll / total_tokens) + return ppl.item() diff --git a/src/finetune.py b/src/finetune.py index c71ee38..729bfcf 100644 --- a/src/finetune.py +++ b/src/finetune.py @@ -1,3 +1,4 @@ +"""Utilities for internal **block-wise** finetuning used during initial AQLM calibration""" from __future__ import annotations import warnings diff --git a/src/memory_efficient_loss.py b/src/memory_efficient_loss.py new file mode 100644 index 0000000..9d7bdd0 --- /dev/null +++ b/src/memory_efficient_loss.py @@ -0,0 +1,114 @@ +""" +Utility functions for computing a KL divergence loss without materializing all logits / logprobs simultaneously +""" +import itertools +from typing import Callable, TypeVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +T = TypeVar("T") + + +def compute_kl_divergence_loss_values( + *, + student_hidden_states: torch.Tensor, + student_lm_head: nn.Module, + teacher_hidden_states: torch.Tensor, + teacher_lm_head: nn.Module, + max_tokens_per_chunk: int = 256, + checkpoint_last_chunk: bool = True, + **checkpoint_kwargs, +) -> torch.Tensor: + """ + Compute token-wise KL divergence loss without materializing all logits/logprobs simultaneously + :param student_hidden_states: input hidden states for student head, [batch_size, sequence_length, student_dim] + :param student_lm_head: a token-wise layer (e.g. nn.Linear) mapping from student_dim to logits [vocabulary_size] + :param teacher_hidden_states: input hidden states for teacher head, [batch_size, sequence_length, teacher_dim] + :param teacher_lm_head: a token-wise layer (e.g. nn.Linear) mapping from teacher_dim to logits [vocabulary_size] + :note: teacher is applied to hidden states without no_grad. If required, set requires_grad=False on teacher manually + :param max_tokens_per_chunk: materialize logits logprobs for at most this many tokens at a time + :param checkpoint_kwargs: additional arguments passed to checkpoint (e.g. use_reentrant or determinism_check) + :param checkpoint_last_chunk: if False, do not apply gradient checkpointing to the very last chunk of inputs + since they are the first ones to be re-materialized anyway. Useful if loss is backpropagated immediately. + :returns: token-wise KL loss values of shape [batch_size, sequence_length] + """ + assert student_hidden_states.requires_grad or teacher_hidden_states.requires_grad or not torch.is_grad_enabled() + assert teacher_hidden_states.shape[:-1] == student_hidden_states.shape[:-1] + flat_student_hidden_states = student_hidden_states.flatten(0, -2) + flat_teacher_hidden_states = teacher_hidden_states.flatten(0, -2) + total_tokens = flat_teacher_hidden_states.shape[0] + + loss_values_by_chunk = [] + for chunk_start in range(0, total_tokens, max_tokens_per_chunk): + is_last_chunk = chunk_start + max_tokens_per_chunk >= total_tokens + loss_values_by_chunk.append( + maybe_checkpoint( + _compute_kl_div_from_flat_hidden_states, + flat_student_hidden_states[chunk_start : chunk_start + max_tokens_per_chunk], + student_lm_head, + flat_teacher_hidden_states[chunk_start : chunk_start + max_tokens_per_chunk], + teacher_lm_head, + checkpoint_enabled=torch.is_grad_enabled() and (checkpoint_last_chunk or not is_last_chunk), + **checkpoint_kwargs, + ) + ) + return torch.cat(loss_values_by_chunk).reshape(*student_hidden_states.shape[:2]) + + +def _compute_kl_div_from_flat_hidden_states( + flat_student_hidden_states: torch.Tensor, + student_lm_head: nn.Module, + flat_teacher_hidden_states: torch.Tensor, + teacher_lm_head: nn.Module, +) -> torch.Tensor: + student_logprobs = F.log_softmax(student_lm_head(flat_student_hidden_states), dim=-1) + teacher_logprobs = F.log_softmax(teacher_lm_head(flat_teacher_hidden_states), dim=-1) + return F.kl_div(input=student_logprobs, target=teacher_logprobs, log_target=True, reduction="none").sum(-1) + + +def maybe_checkpoint(func: Callable[[...], T], *inputs, checkpoint_enabled: bool, **checkpoint_kwargs) -> T: + """Execute function normally or with checkpointing, depending on checkpoint_enabled. Forward **checkpoint_kwargs""" + return func(*inputs) if checkpoint_enabled else checkpoint(func, *inputs, **checkpoint_kwargs) + + +def test_kl_divergence( + teacher_hidden_size=2048, + student_hidden_size=1024, + batch_size=2, + seq_length=450, + vocab_size=10_000, + max_tokens_per_chunk=128, +): + """Verify correctness of compute_kl_divergence_loss_values""" + + teacher_lm_head = nn.Linear(teacher_hidden_size, vocab_size) + student_lm_head = nn.Linear(student_hidden_size, vocab_size) + + teacher_hidden_states = torch.randn(batch_size, seq_length, teacher_hidden_size) + student_hidden_states = torch.randn(batch_size, seq_length, student_hidden_size, requires_grad=True) + + ref_loss_values = F.kl_div( + input=F.log_softmax(student_lm_head(student_hidden_states), dim=-1), + target=F.log_softmax(teacher_lm_head(teacher_hidden_states), dim=-1), + log_target=True, + reduction="none", + ).sum(-1) + + for use_reentrant, checkpoint_last_chunk, determinism_check in itertools.product( + (True, False), (True, False), ("default", "none") + ): + loss_values = compute_kl_divergence_loss_values( + student_hidden_states=student_hidden_states, + student_lm_head=student_lm_head, + teacher_hidden_states=teacher_hidden_states, + teacher_lm_head=teacher_lm_head, + max_tokens_per_chunk=max_tokens_per_chunk, + checkpoint_last_chunk=checkpoint_last_chunk, + use_reentrant=use_reentrant, + determinism_check=determinism_check, + ) + assert loss_values.shape == (batch_size, seq_length) + assert torch.allclose(loss_values, ref_loss_values) diff --git a/src/modelutils.py b/src/modelutils.py index 939c29c..3325e31 100644 --- a/src/modelutils.py +++ b/src/modelutils.py @@ -1,13 +1,17 @@ import math import os from contextlib import contextmanager +from typing import Optional import torch import torch.nn as nn import transformers from accelerate import dispatch_model +from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision from transformers import AutoConfig, AutoModelForCausalLM +from src.aq import QuantizedWeight + MODEL_ERROR_MSG = "Unsupported model type {} - only 'llama', 'Yi', 'opt', 'falcon', 'phi3' are supported" FALCON_TYPES = ("falcon", "refinedweb", "refinedwebmodel") LLAMA_LIKE = ("llama", "Yi", "mistral", "mixtral", "gemma", "cohere", "qwen2") @@ -48,7 +52,7 @@ def get_model( dtype = ( AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code).torch_dtype or "auto" ) # force transformers 4.29.2 to follow the same rules as 4.30.x - else: + elif isinstance(dtype, str): dtype = getattr(torch, dtype) model_kwargs = {} @@ -82,7 +86,14 @@ def get_model( return model -def get_model_head(model): +def is_model_for_causal_lm(model: nn.Module): + assert isinstance(model, transformers.PreTrainedModel) + assert len(model.base_model_prefix) > 0 and hasattr(model, model.base_model_prefix) + assert model.get_output_embeddings() is not None + return True + + +def get_model_head_with_norm(model): head = torch.nn.ModuleList() if model.config.model_type in (*LLAMA_LIKE, "phi3"): if model.model.norm is not None: @@ -228,6 +239,10 @@ def load_dequantized_model(model, load_path): print("layer", layer_index) layer = layers[layer_index] quant_layer = torch.load(os.path.join(load_path, str(layer_index) + ".pth"), map_location="cpu") + for module in quant_layer.modules(): + if isinstance(module, QuantizedWeight): + if not hasattr(module, "codes_storage"): + module.codes_storage = None # backwards compatibility layers[layer_index] = load_linear_layers(layer, quant_layer, model) model.load_state_dict(torch.load(os.path.join(load_path, "not_quantized_weights.pt")), strict=False) return model @@ -241,6 +256,11 @@ def load_quantized_model(model, load_path): os.path.join(load_path, str(layer_index) + ".pth"), map_location=model.model.layers[layer_index].input_layernorm.weight.device, ) + for module in model.model.layers[layer_index].modules(): + if isinstance(module, QuantizedWeight): + if not hasattr(module, "codes_storage"): + module.codes_storage = None # backwards compatibility + model.load_state_dict(torch.load(os.path.join(load_path, "not_quantized_weights.pt")), strict=False) return model @@ -254,3 +274,18 @@ def save_not_quantized_weights(model: nn.Module, save_dir: str): name: param for name, param in model.named_parameters() if param not in already_saved_weights } torch.save(not_quantized_weights, os.path.join(save_dir, "not_quantized_weights.pt")) + + +def save_quantized_model(model: transformers.PreTrainedModel, save_dir: str): + """Save dequantized model state in the same format as returned by AQLM calibration (main.py)""" + os.makedirs(save_dir, exist_ok=True) + for layer_index, layer in enumerate(get_layers(model)): + layer_save_path = os.path.join(save_dir, f"{layer_index}.pth") + torch.save(layer, layer_save_path) + save_not_quantized_weights(model, save_dir) + + +def get_layers_prefix(config: transformers.PretrainedConfig) -> str: + if config.model_type in ("llama", "mistral", "mixtral", "gemma"): + return "model.layers" + raise NotImplementedError(f"Can't get layers prefix for {config.model_type}") diff --git a/src/pv_optimizer.py b/src/pv_optimizer.py new file mode 100644 index 0000000..8e13d47 --- /dev/null +++ b/src/pv_optimizer.py @@ -0,0 +1,510 @@ +"""Module containing utilities for straight-through fine-tuning of language models""" +import random +from enum import Enum, auto +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +import torch +import torch.distributed +import torch.nn as nn +from torch.optim.optimizer import StateDict + +from src.aq import QuantizedWeight +from src.configurable_adam import ConfigurableAdamW +from src.pv_utils import YourQuantizedWeightIsInAnotherRank, print_runtime_stats + + +class ParameterRole(Enum): + QUANTIZED_PARAMETER = auto() # entire quantized weight, in a de-quantized form + QUANTIZED_REPRESENTATION_PARAMETER = auto() # part of quantized weight inner parameters, e.g. codebooks or scales + NON_QUANTIZED_PARAMETER = auto() + + +class StraightThroughAdamW(ConfigurableAdamW): + """ + A wrapper for a PyTorch optimizer that can perform updates on quantized and/or de-quantized parameters + :param update_non_quantized_params: how to update parameters that are not directly linked to a QuantizedWeight. + This may include biases, embeddings/heads, normalization layers or parts of the model that were not quantized. + This should be either None (do not update) or a dictionary of optimizer kwargs. In the latter case, these + keyword arguments will be used when configuring optimizer for this specific parameter group. + :param update_codebooks_and_scales: how to update continuous params of QuantizedWeight: codebooks and scales. + This should be either None (do not update) or a dictionary of optimizer kwargs. In the latter case, these + keyword arguments will be used when configuring optimizer for this specific parameter group. + :param update_codes: how to update codes in each QuantizedWeight with beam search and straight-through grad. + This should be either None (do not update codes) or a dictionary of hyperparameter, similarly to above. + :param delta_decay: determines whether to use straight-through estimation, direct optimization or a mixture thereof + - if delta_decay == 1, do not use straight-through estimation. In this regime, the optimizer first updates + de-quantized weights as though they were continuous, then uses modified weights to update codes, codebooks and + scales; at the end of each step, the optimizer overwrites de-quantized weights to a de-quantization of the + possibly updated quantized representations (codes, codebooks, scales). + - if delta_decay == 0, use standard straight-through estimation. In this regime, the optimizer creates + an internal set of straight-through buffers in the shape of de-quantized weights. The optimizer trains these + buffers as though they were continuous; the quantized weights are then updated to minimize the L2 distance to + these straight-through buffers; finally, the optimizer updates de-quantized weights from the quantized versions. + - if delta_decay is between 0 and 1, use penalized straight-through estimation. The optimizer acts as though + using standard straight-through estimation (see delta_decay == 0), but after every step, the straight-through + buffers are set to (1 - delta_decay) * straight_through_buffer + delta_decay * quantized_weight. + + :param max_code_change_per_step: max portion of discrete code groups that can be updated; only affects codes + :param code_trust_ratio: the maximum relative change to quantized weights per step, as a fraction of weight norm; + see details in src/beam_search_l2.py, and in particular, beam_search_optimal_codes docstring. + :param code_selection_temperature: if max_code_change or code_trust_ratio is set, the optimizer will by default + prioritize updating codes with the largest delta = ||dequantized_weight_after_sgd_step - quantized_weight||_2 . + If code_selection_temperature is above 0, it will instead sample codes randomly in proportion to the same + delta ^ (1 / temperature). If temperature is very high, the optimizer will choose codes uniformly at random. + :param force_code_update: if True, beam search will force codes to change even if code is optimal in + terms of mean squared error. By default, the algorithm forces *all* weights to update this way, which may change + weights too much. To limit the numer of updated weights, set max_code_change and trust_ratio. + :param stochastic_rounding_tau: if above 0, use stochastic rounding with this temperature. See aq.py + + :param beam_size: beam search width used only when updating codes. See beam_size in aq.py + + :param straight_through_buffer_dtype: use this dtype when accumulating updates to de-quantized weight matrices + Used only if delta_decay != 1. + + """ + + def __init__( + self, + named_dequantized_params: Dict[str, nn.Parameter], + named_quantized_params: Dict[str, Union[QuantizedWeight, YourQuantizedWeightIsInAnotherRank]], + *, + update_non_quantized_parameters: Optional[dict] = None, + update_codebooks_and_scales: Optional[dict] = None, + update_codes: Optional[dict] = None, + beam_size: int, + delta_decay: float = 1, + max_code_change_per_step: float, + code_trust_ratio: Optional[float] = None, + code_selection_temperature: float = 0, + force_code_update: bool = False, + stochastic_rounding_tau: float = 0, + straight_through_buffer_dtype: Optional[torch.dtype] = None, + verbose: bool = False, + **kwargs, + ): + assert 0 <= delta_decay <= 1 + assert all( + isinstance(qw, (QuantizedWeight, YourQuantizedWeightIsInAnotherRank)) + for qw in named_quantized_params.values() + ) + assert all(name in named_dequantized_params for name in named_quantized_params), "param names mismatch" + + self.sharded = not all(isinstance(qw, QuantizedWeight) for qw in named_quantized_params.values()) + self.is_straight_through = delta_decay != 1 + if verbose and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0): + print(end=f"PV optimizer init:\n\tAre quantized weights sharded? : {self.sharded}.\n") + print(end=f"\tOptimizing {('without', 'with')[self.is_straight_through]} straight-through buffers\n") + param_groups, all_optimized_params = self._select_optimized_parameters( + named_dequantized_params=named_dequantized_params, + named_quantized_params=named_quantized_params, + update_non_quantized_parameters=update_non_quantized_parameters, + update_codebooks_and_scales=update_codebooks_and_scales, + update_codes=update_codes, + straight_through_buffer_dtype=straight_through_buffer_dtype, + ) + + super().__init__(param_groups, **kwargs) + self.ordered_quantized_weight_names = tuple(sorted(named_quantized_params.keys())) + self.optimized_param_to_name = {param: name for name, param in all_optimized_params.items()} + self.quantized_weights_by_name = { + name: qw + for name, qw in named_quantized_params.items() + if isinstance(qw, (QuantizedWeight, YourQuantizedWeightIsInAnotherRank)) + } + self.straight_through_buffer_by_name = ( + { + name: all_optimized_params[name] + for name in self.quantized_weights_by_name.keys() + if name in all_optimized_params + } + if self.is_straight_through + else {} + ) + self.dequantized_weights_by_name = { + name: param for name, param in named_dequantized_params.items() if name in named_quantized_params + } + if self.sharded: + self.sharded_param_sizes_by_rank = _get_sharded_param_sizes_by_rank(named_dequantized_params) + self.target_rank_by_name = { + name: qw.rank if isinstance(qw, YourQuantizedWeightIsInAnotherRank) else torch.distributed.get_rank() + for name, qw in self.quantized_weights_by_name.items() + } + + self.should_update_non_quantized_parameters = update_non_quantized_parameters is not None + self.should_update_codebooks_and_scales = update_codebooks_and_scales is not None + self.should_update_codes = update_codes is not None + + self.delta_decay = delta_decay + self.max_code_change_per_step = max_code_change_per_step + self.code_trust_ratio = code_trust_ratio + self.force_code_update = force_code_update + self.code_selection_temperature = code_selection_temperature + self.stochastic_rounding_tau = stochastic_rounding_tau + self.beam_size = beam_size + self.verbose = verbose + + def _select_optimized_parameters( + self, + named_dequantized_params, + named_quantized_params, + straight_through_buffer_dtype, + update_non_quantized_parameters: Optional[dict], + update_codebooks_and_scales: Optional[dict], + update_codes: Optional[dict], + ) -> Tuple[List[Dict[str, Any]], Dict[str, nn.Parameter]]: + """Choose which version of parameter to optimize: the parameter itself or a straight-through buffer""" + non_quantized_params, quantized_params, quantized_representation_params = dict(), dict(), dict() + for name, param in named_dequantized_params.items(): + if name not in named_quantized_params or isinstance(named_quantized_params[name], torch.Tensor): + non_quantized_params[name] = param + elif isinstance(named_quantized_params[name], QuantizedWeight): + quantized_weight = named_quantized_params[name] + if self.is_straight_through: # create an accumulator for optimizer updates; sharded alongside FSDP + with torch.no_grad(): + dequantized_weight = quantized_weight() + dequantized_weight = nn.Parameter( + dequantized_weight.to(dtype=straight_through_buffer_dtype), + requires_grad=dequantized_weight.requires_grad, + ) + else: + dequantized_weight = param + quantized_params[name] = dequantized_weight + for subparam_name, subparam in quantized_weight.named_parameters(): + full_name = f"{name}.{subparam_name}" + assert full_name not in quantized_representation_params, full_name + quantized_representation_params[full_name] = subparam + elif isinstance(named_quantized_params[name], YourQuantizedWeightIsInAnotherRank): + assert self.sharded # running sharded optimizer, this weight should be optimized by another rank + else: + raise RuntimeError(f"Unxpected quantized param type {type(named_quantized_params[name])}") + + total_params = len(set(non_quantized_params) | set(quantized_params) | set(quantized_representation_params)) + assert total_params == len(non_quantized_params) + len(quantized_params) + len(quantized_representation_params) + param_groups = [] + all_optimized_params = dict() + if update_non_quantized_parameters is not None: + all_optimized_params.update(non_quantized_params) + param_groups.append( + dict( + params=list(non_quantized_params.values()), + role=ParameterRole.NON_QUANTIZED_PARAMETER, + **update_non_quantized_parameters, + ) + ) + if update_codebooks_and_scales is not None: + all_optimized_params.update(quantized_representation_params) + param_groups.append( + dict( + params=list(quantized_representation_params.values()), + role=ParameterRole.QUANTIZED_REPRESENTATION_PARAMETER, + **update_codebooks_and_scales, + ) + ) + if update_codes is not None: + all_optimized_params.update(quantized_params) + param_groups.append( + dict(params=list(quantized_params.values()), role=ParameterRole.QUANTIZED_PARAMETER, **update_codes) + ) + assert len(param_groups) > 0, ( + "Please set at least one of update_codes, update_codebooks_and_scales " "or update_non_quantized_parameters" + ) + return param_groups, all_optimized_params + + def step(self, *args, **kwargs): + with print_runtime_stats("_propagate_grads_to_optimized_parameters", enabled=self.verbose): + self._propagate_grads_to_optimized_parameters() + with print_runtime_stats("super().step", enabled=self.verbose): + original_output = super().step(*args, **kwargs) + with print_runtime_stats("_optimize_quantized_weights", enabled=self.verbose): + self._optimize_quantized_weights() + with print_runtime_stats("_update_dequantized_weights", enabled=self.verbose): + self._update_dequantized_weights() + return original_output + + def _aggregate_gradients_for_dequantized_weights(self): + """collect full parameter gradients from fsdp-sharded parameters, return dict[name -> grad]""" + grad_shards_by_name = dict() + + for name in self.ordered_quantized_weight_names: + if self.dequantized_weights_by_name[name].grad is None: + assert self.dequantized_weights_by_name[name].numel() == 0 + self.dequantized_weights_by_name[name].grad = torch.zeros_like(self.dequantized_weights_by_name[name]) + grad = self.dequantized_weights_by_name[name].grad + assert grad is not None, name + grad_shards_by_name[name] = grad + + if self.sharded: + aggregated_grads_by_name = _aggregate_tensors_by_name( + grad_shards_by_name, + self.sharded_param_sizes_by_rank, + self.target_rank_by_name, + name_order=self.ordered_quantized_weight_names, + ) + else: + aggregated_grads_by_name = grad_shards_by_name + + aggregated_grads_by_name = { + name: grad.view(self.quantized_weights_by_name[name].shape) + for name, grad in aggregated_grads_by_name.items() + } + if self.verbose: + for name, grad in aggregated_grads_by_name.items(): + print(end=f"aggregated grad norm for {name}: {grad.norm().item()}\n") + return aggregated_grads_by_name + + def _aggregate_dequantized_weights(self): + """collect full (possibly optimizer-updated) dequantized weights""" + if not self.sharded: + return self.dequantized_weights_by_name + dequantized_flat_param_shards = { + name: param.data.flatten() for name, param in self.dequantized_weights_by_name.items() + } + flat_aggregated_params_by_name = _aggregate_tensors_by_name( + dequantized_flat_param_shards, + self.sharded_param_sizes_by_rank, + self.target_rank_by_name, + name_order=self.ordered_quantized_weight_names, + ) + aggregated_params_by_name = { + name: param.view(self.quantized_weights_by_name[name].shape) + for name, param in flat_aggregated_params_by_name.items() + } + return aggregated_params_by_name + + @torch.no_grad() + def _propagate_grads_to_optimized_parameters(self): + """Ensure that every optimized parameter receives gradient""" + aggregated_grads_by_name = self._aggregate_gradients_for_dequantized_weights() + for param_group in self.param_groups: + for param in param_group["params"]: + name = self.optimized_param_to_name[param] + if param_group["role"] == ParameterRole.QUANTIZED_PARAMETER: + if self.is_straight_through: + assert param is self.straight_through_buffer_by_name[name] + # pass gradients to straight-through update buffer or (possibly offloaded) quantized parameter + grad_wrt_dequantized_parameter = aggregated_grads_by_name[name] + assert grad_wrt_dequantized_parameter.shape == param.shape + param.grad = grad_wrt_dequantized_parameter.to(dtype=param.dtype, device=param.device) + else: + assert len(self.straight_through_buffer_by_name) == 0, self.straight_through_buffer_by_name + assert param.grad is not None + elif param_group["role"] == ParameterRole.NON_QUANTIZED_PARAMETER: + assert name not in self.dequantized_weights_by_name and name not in self.quantized_weights_by_name + elif param_group["role"] == ParameterRole.QUANTIZED_REPRESENTATION_PARAMETER: + assert name not in self.dequantized_weights_by_name + assert self.should_update_codebooks_and_scales + # gradients w.r.t quantized representation parameters are computed below via backprop + else: + raise RuntimeError(f"Unexpected param role: {param_group['role']}") + + if self.should_update_codebooks_and_scales: + # propagate gradients from dequantized weights to quantization parameters so they can be updated in step; + # if sharded, every rank propagates gradients only for the QuantizedWeight instances owned by this rank + with torch.enable_grad(): + for name, quantized_weight in self.quantized_weights_by_name.items(): + if isinstance(quantized_weight, QuantizedWeight): + quantized_weight.forward().backward(aggregated_grads_by_name[name]) + + @torch.no_grad() + def _optimize_quantized_weights(self): + """Update discrete state representations to approximate straight through buffers""" + # note: if sharded, this only updates the subset of quantized weights that are assigned to local rank + remaining_quantized_weights = { + name: qw for name, qw in self.quantized_weights_by_name.items() if isinstance(qw, QuantizedWeight) + } + if self.is_straight_through: + reference_weights_by_name = self.straight_through_buffer_by_name + else: + reference_weights_by_name = self._aggregate_dequantized_weights() + + for param_group in self.param_groups: + if param_group["role"] == ParameterRole.QUANTIZED_PARAMETER: + for param in param_group["params"]: + # param is either a dequantized weight or a special straight-through buffer (if is_straight_through) + name = self.optimized_param_to_name[param] + quantized_weight = remaining_quantized_weights.pop(name) + reference_weight = reference_weights_by_name[name] + assert reference_weight.shape == quantized_weight.shape, ( + reference_weight.shape, + quantized_weight.shape, + ) + assert isinstance(quantized_weight, QuantizedWeight) + + prev_codes = quantized_weight.get_codes().clone() # [num_output_groups, num_input_groups] + new_codes = quantized_weight.beam_search_update_codes_( + reference_weight=reference_weight, + beam_size=self.beam_size, + stochastic_rounding_tau=self.stochastic_rounding_tau, + max_update_fraction=self.max_code_change_per_step, + force_update=self.force_code_update, + code_selection_temperature=self.code_selection_temperature, + trust_ratio=self.code_trust_ratio, + dim_rng=random.Random(None), + ) # note: this updates quantized_weight codes in-place + if self.delta_decay != 0 and self.is_straight_through: + self.straight_through_buffer_by_name[name][...] = ( + self.delta_decay * quantized_weight() + (1 - self.delta_decay) * reference_weight + ) + # if not is_straight_throuh, param will be properly updated in _update_dequantized_weights + + if self.verbose: + code_change_rate = torch.not_equal(prev_codes, new_codes).any(-1).float().mean().item() + maybe_distributed_msg = "" + if torch.distributed.is_initialized(): + maybe_distributed_msg = f" (rank {torch.distributed.get_rank()})" + maybe_limit_msg = "" + if self.max_code_change_per_step is not None: + maybe_limit_msg = f"(limit {self.max_code_change_per_step})" + maybe_individual_msg = "" + if quantized_weight.num_codebooks > 1: + subcode_change = torch.not_equal(prev_codes, new_codes).float().mean().item() + maybe_individual_msg = f" | overall change {subcode_change:.8f}" + maybe_delta_msg = "" + if self.delta_decay != 1: + _dequantized_weight = quantized_weight() + delta_norm = (reference_weight - _dequantized_weight).norm().item() + relative_error = delta_norm / max(_dequantized_weight.norm().item(), 1e-9) + maybe_delta_msg = ( + f"\t||quantized_weight - optimized_weight|| / ||quantized_weight||" + f" = {relative_error}\n" + ) + print( + end=f"Updated codes for {name}{maybe_distributed_msg}:\n\tFraction of weights with at " + f"least one code change: {code_change_rate:.8f} " + f"{maybe_limit_msg}{maybe_individual_msg}\n{maybe_delta_msg}\n" + ) + assert len(remaining_quantized_weights) == 0 + + @torch.no_grad() + def _update_dequantized_weights(self): + """Assign dequantized weight buffers to latest quantized weights after codebook/scale/code updates""" + own_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + async_ops = list() + for name in self.ordered_quantized_weight_names: + quantized_weight = self.quantized_weights_by_name[name] + dequantized_weight_buffer = self.dequantized_weights_by_name[name] + dequantized_weight_buffer.fill_(float("nan")) # this is to ensure that the update reaches the buffer + + if not self.sharded: + dequantized_weight_buffer[...] = quantized_weight() + + else: + if isinstance(quantized_weight, QuantizedWeight): + new_dequantized_weight = quantized_weight().to(dequantized_weight_buffer.dtype) + shard_sizes: Sequence[int] = self.sharded_param_sizes_by_rank[name] + assert sum(shard_sizes) == new_dequantized_weight.numel() + new_dequantized_weight_parts = new_dequantized_weight.flatten().split_with_sizes(shard_sizes) + for i in range(world_size): + if i != own_rank: + async_ops.append(torch.distributed.isend(new_dequantized_weight_parts[i], dst=i)) + else: + dequantized_weight_buffer.copy_(new_dequantized_weight_parts[i]) + + else: + assert isinstance(quantized_weight, YourQuantizedWeightIsInAnotherRank) + source_rank = self.quantized_weights_by_name[name].rank + async_ops.append(torch.distributed.irecv(dequantized_weight_buffer, src=source_rank)) + for handle in async_ops: + handle.wait() + + def zero_grad(self, set_to_none: bool = True, *args, **kwargs) -> None: + super().zero_grad(set_to_none=set_to_none, *args, **kwargs) + for param in self.dequantized_weights_by_name.values(): + # dequantized weights are not in param_groups, but they still accumulate grads; reset them manually + if set_to_none: + param.grad = None + elif param.grad is not None: + param.grad.zero_() + + def iterate_local_quantized_weights(self) -> Iterator[Tuple[str, QuantizedWeight]]: + """Iterate over (name, QuantizedWeight) pairs for all quantized weights trained by this optimizer and rank""" + for name, quantized_weight in self.quantized_weights_by_name.items(): + if isinstance(quantized_weight, QuantizedWeight): # skip YourQuantizedWeightIsInAnotherRank if sharded + yield name, quantized_weight + + def state_dict(self) -> StateDict: + state_dict = super().state_dict() + assert "quantized_weight_state_dicts" not in state_dict + state_dict["quantized_weight_state_dicts"] = { + name: quantized_weight.state_dict() for name, quantized_weight in self.iterate_local_quantized_weights() + } + state_dict["straight_through_buffers"] = dict(self.straight_through_buffer_by_name) # may be empty + # note: the de-quantized params are not saved here; instead, they are saved with model.state_dict + return state_dict + + def load_state_dict(self, state_dict: StateDict) -> None: + quantized_weight_state_dicts: Dict[str, StateDict] = dict(state_dict.pop("quantized_weight_state_dicts")) + for name, quantized_weight in self.iterate_local_quantized_weights(): + quantized_weight.load_state_dict(quantized_weight_state_dicts.pop(name)) + assert len(quantized_weight_state_dicts) == 0, f"unused keys: {quantized_weight_state_dicts.keys()}" + + straight_through_buffers = state_dict.pop("straight_through_buffers") + assert all(name in straight_through_buffers for name in self.straight_through_buffer_by_name) + for name, loaded_values in straight_through_buffers.items(): + self.straight_through_buffer_by_name[name][...] = loaded_values + super().load_state_dict(state_dict) + + +def _get_sharded_param_sizes_by_rank(named_dequantized_params: Dict[str, torch.Tensor]) -> Dict[str, Sequence[int]]: + """For each parameter name, return a tuple of sizes (numbers of elements) this parameter across all FSDP ranks""" + assert torch.distributed.is_initialized() + own_dequantized_param_shard_size = {name: param.numel() for name, param in named_dequantized_params.items()} + world_size = torch.distributed.get_world_size() + gathered_list = [{} for _ in range(world_size)] + torch.distributed.all_gather_object(gathered_list, own_dequantized_param_shard_size) + assert all(name in sizes_dict for sizes_dict in gathered_list for name in own_dequantized_param_shard_size) + dequantized_param_sizes_by_rank = dict() + for name in named_dequantized_params.keys(): + dequantized_param_sizes_by_rank[name] = [gathered_list[rank][name] for rank in range(world_size)] + return dequantized_param_sizes_by_rank + + +def _aggregate_tensors_by_name( + sharded_tensors_by_name: Dict[str, torch.Tensor], + shard_sizes_by_name: Dict[str, Sequence[int]], + target_rank_by_name: Dict[str, int], + name_order: Optional[Sequence[str]] = None, +) -> Dict[str, torch.Tensor]: + """ + :param sharded_tensors_by_name: a dictionary from string to flat (1d) tensors available on the current shard + :note: the keys should be the same across ranks and go in the same order; if not, use ordered_names + :param shard_sizes_by_name: a dictionary from name to a list of sizes (numel) for this key across ranks + :param target_rank_by_name: a dictionary from name to a rank that this name should be aggregated to + :param name_order: if specified, this defines the order in which devices go over named shards + """ + assert torch.distributed.is_initialized() + own_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + aggregated_tensors_by_name = dict() + async_ops = list() + + for name in sorted(sharded_tensors_by_name.keys()) if name_order is None else name_order: + shard = sharded_tensors_by_name[name] + assert shard.ndim == 1 + destination_rank = target_rank_by_name[name] + shard_sizes: Sequence[int] = shard_sizes_by_name[name] + if destination_rank == own_rank: + total_numel = sum(shard_sizes) + combined_buffer = torch.full((total_numel,), fill_value=torch.nan, dtype=shard.dtype, device=shard.device) + gather_buffers = list(combined_buffer.split_with_sizes(shard_sizes)) + assert all( + part.untyped_storage().data_ptr() == combined_buffer.untyped_storage().data_ptr() + for part in gather_buffers + ) + for i in range(world_size): + if shard_sizes[i] == 0: + continue # optimization: this handles FSDP where some param/grad shards are empty + elif i != own_rank: + async_ops.append(torch.distributed.irecv(gather_buffers[i], src=i)) + else: + gather_buffers[i].copy_(shard) + aggregated_tensors_by_name[name] = combined_buffer + else: + if shard_sizes[own_rank] == 0: + continue + async_ops.append(torch.distributed.isend(shard, destination_rank)) + + for handle in async_ops: + handle.wait() + return aggregated_tensors_by_name diff --git a/src/pv_utils.py b/src/pv_utils.py new file mode 100644 index 0000000..e462fa7 --- /dev/null +++ b/src/pv_utils.py @@ -0,0 +1,203 @@ +import contextlib +import dataclasses +import hashlib +import json +import time +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from typing import Dict, List, Optional, Tuple + +import torch +import transformers +from torch import nn as nn + +from src.aq import QuantizedLinear, QuantizedWeight + + +def infer_module_classes(model: nn.Module, class_name: str) -> Tuple[type[nn.Module], ...]: + """find transformer block classes that should be wrapped with inner FullyShardedDataParallel (auto_wrap_policy)""" + found_module_types = [] + for module in model.modules(): + if module.__class__.__name__ == class_name: + found_module_types.append(type(module)) + if not found_module_types: + raise ValueError(f"Could not find {class_name} among submodules of {model}") + found_module_types = tuple(found_module_types) + assert any(isinstance(module, found_module_types) for module in model.modules()) + return found_module_types + + +def create_dequantized_model( + model: transformers.PreTrainedModel, *, reuse_non_quantized: bool, dequantized_dtype: Optional[torch.dtype] = None +) -> transformers.PreTrainedModel: + """ + Create a version of the model where all QuanizedWeight and derivative layers are de-quantized and cast to dtype. + :param model: model to be dequantized (out-of-place) + :param reuse_non_quantized: if True, any non-quantized parameters and buffers are reused for de-quantized model; + otherwise (default) they are copied and linked in the returned dictionary + :returns: a model (converted out-of-place) and a mapping (dict) from de-quantized to master parameters + """ + memo = dict() # for deepcopy with replacement + master_parameters = dict() + all_quantized_weight_parameters = set() + + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + assert module not in master_parameters and id(module) not in memo, f"{name} is converted more than once" + quantized_weight = module.quantized_weight + + dequantized_module = nn.Linear( + module.in_features, + module.out_features, + bias=module.bias is not None, + dtype=dequantized_dtype if dequantized_dtype is not None else quantized_weight.get_codebooks().dtype, + device=next(quantized_weight.parameters()).device, + ) + with torch.no_grad(): + dequantized_module.weight[...] = quantized_weight() + dequantized_module.weight.requires_grad = any(p.requires_grad for p in quantized_weight.parameters()) + + if module.bias is not None and not reuse_non_quantized: + dequantized_module.bias[...] = module.bias + dequantized_module.bias.requires_grad = dequantized_module.bias.requires_grad + elif module.bias is not None and reuse_non_quantized: + dequantized_module.bias = module.bias + + memo[id(module)] = dequantized_module + master_parameters[f"{name}.weight"] = quantized_weight + if dequantized_module.bias is not module.bias: + master_parameters[f"{name}.bias"] = module.bias + all_quantized_weight_parameters |= set(quantized_weight.parameters()) + assert all( + param in {dequantized_module.weight, dequantized_module.bias} + for param in dequantized_module.parameters() + ) + + for name, param_or_buffer in chain(model.named_parameters(), model.named_buffers()): + if name in master_parameters or param_or_buffer in all_quantized_weight_parameters: + continue # parameter already accounted for in the previous loop + assert name not in master_parameters, name + assert id(param_or_buffer) not in memo, name + if reuse_non_quantized: + new_param_or_buffer = param_or_buffer + elif isinstance(param_or_buffer, nn.Parameter): + new_param_or_buffer = nn.Parameter(param_or_buffer.data.clone(), param_or_buffer.requires_grad) + else: + new_param_or_buffer = param_or_buffer.detach().clone().requires_grad_(param_or_buffer.requires_grad) + if new_param_or_buffer is not param_or_buffer: + master_parameters[name] = new_param_or_buffer + memo[id(param_or_buffer)] = new_param_or_buffer + + dequantized_model = deepcopy(model, memo=memo) + + for name, module in dequantized_model.named_modules(): + assert not isinstance(module, QuantizedWeight), ( + f"Dequantized model should not have quantized weights, " f"but found {name} that is {module}" + ) + if reuse_non_quantized: + assert all(isinstance(master, QuantizedWeight) for master in master_parameters.values()) + verify_dequantized_model(dequantized_model, master_parameters) + return dequantized_model, master_parameters + + +def verify_dequantized_model(dequantized_model: nn.Module, master_parameters: dict): + """Test that the dequantized model parameters still match the dequantized_to_master dictionary""" + unmatched_master_parameters = set(master_parameters.keys()) + for name, param_or_buffer in chain(dequantized_model.named_parameters(), dequantized_model.named_buffers()): + if name not in master_parameters: + continue # non-quantized weight + master_param_or_buffer = master_parameters[name] + assert param_or_buffer.shape == master_param_or_buffer.shape + unmatched_master_parameters.remove(name) + assert len(unmatched_master_parameters) == 0, f"Found unmatched tensors: {unmatched_master_parameters}" + + +def get_original_named_parameters_from_fsdp_module(dequantized_model) -> Dict[str, nn.Parameter]: + return {name.replace("_fsdp_wrapped_module.", ""): param for name, param in dequantized_model.named_parameters()} + + +@contextlib.contextmanager +def print_runtime_stats(operation_name: str, enabled: bool = True, device: Optional[torch.device] = None): + if not enabled: + yield + return + + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if device is None: + device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") + if torch.device.type == "cuda": + torch.cuda.synchronize(device) + start_time = time.perf_counter() + yield + if torch.device.type == "cuda": + torch.cuda.synchronize(device) + maybe_distributed_msg = f"rank {rank} " if torch.distributed.is_initialized() else "" + print(end=f"{maybe_distributed_msg}{operation_name} took {time.perf_counter() - start_time}\n") + + +def split_quantized_weights_between_ranks(quantized_weights: Dict[str, QuantizedWeight], verify_checksums: bool): + """ + Split all quantized weights between ranks in a distributed setup; uses greedy knapsack heuristic. + Note that unlike FSDP, this heuristic will always assign the entire quantized weight to one rank. + + :param quantized_weights: a dictionary [parameter_name] -> QuantizedWeight + :returns: a dictionary similar to quantized weights or pointers to different ranks. + If your rank stores this quantized weight for [name], then returned_dict[name] is quantized_weights[name] + Otherwise, returned_dict[name] = YourQuantizedWeightIsInAnotherRank(rank=where_it_is_stored) + :param verify_checksums: if True, synchronize with other ranks and verify that parameters are split consistently. + If False, do not synchronize, but instead print a hash of checksum for each rank to be verified by the user. + """ + assert torch.distributed.is_initialized() + own_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + all_quantized_weights: Dict[QuantizedWeight, List[str]] = defaultdict(list) + for name, quantized_weight in quantized_weights.items(): + all_quantized_weights[quantized_weight].append(name) + + # order quantized weights in a rank-agnostic way: order by (param size desc, linked param name asc) + def _compute_size(qw: QuantizedWeight) -> float: + return qw.out_features * qw.in_features * qw.estimate_nbits_per_parameter() + + ordered_quantized_weights = sorted( + all_quantized_weights, key=lambda qw: (-_compute_size(qw), min(all_quantized_weights[qw])) + ) + assert len(ordered_quantized_weights) > 0, "internal error: could not find any linked QuantizedWeight in state" + + # split between ranks + quantized_weight_to_rank = dict() + total_size_by_rank = [0 for _ in range(world_size)] + for quantized_weight in ordered_quantized_weights: + least_busy_rank = min(range(world_size), key=lambda rank: total_size_by_rank[rank]) + total_size_by_rank[least_busy_rank] += _compute_size(quantized_weight) + quantized_weight_to_rank[quantized_weight] = least_busy_rank + + checksum = tuple( + (min(all_quantized_weights[qw]), quantized_weight_to_rank[qw], _compute_size(qw)) + for qw in ordered_quantized_weights + ) + if verify_checksums: + checksums = [() for _ in range(world_size)] + torch.distributed.all_gather_object(checksums, checksum) + assert checksums[own_rank] == checksum, (checksums, own_rank, checksum) + assert all(other_checksum == checksum for other_checksum in checksums), checksums + else: + hashing = hashlib.sha256() + hashing.update(json.dumps(checksum).encode()) + print(end=f"Splitting quantized weights, rank {own_rank} checksum hash: {hashing.hexdigest()}\n") + + sharded_quantized_weights = dict() + for name, quantized_weight in list(quantized_weights.items()): + target_rank = quantized_weight_to_rank[quantized_weight] + if target_rank == own_rank: + sharded_quantized_weights[name] = quantized_weight + else: + sharded_quantized_weights[name] = YourQuantizedWeightIsInAnotherRank(target_rank) + return sharded_quantized_weights + + +@dataclasses.dataclass(init=True, frozen=True) +class YourQuantizedWeightIsInAnotherRank: + """This replaces quantized weights that are not held on this rank""" + + rank: int diff --git a/src/utils.py b/src/utils.py index a709ab1..9dd77ed 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,4 @@ +"""Common utility functions for additive quantization""" from __future__ import annotations import contextlib @@ -6,12 +7,14 @@ from typing import Any, Callable, Iterable, Iterator, List, Optional, Sequence, Union import torch +import torch.distributed +from torch import nn +from torch.nn import functional as F ellipsis = type(...) def get_mean_nbits_by_codebook(codes: torch.IntTensor, huffman_group_size: int = 2): - """ Calculates average code length in codebooks. :param codes: codebook codes @@ -58,6 +61,36 @@ def maybe_script(fn: callable) -> callable: return torch.jit.script(fn) if should_script else fn +@maybe_script +def _dequantize_weight( + codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, device=codes.device + ) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] + ) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) + return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + @contextlib.contextmanager def using_tf32(enabled: bool): was_cudnn = torch.backends.cudnn.allow_tf32 @@ -117,3 +150,72 @@ def maybe_get_0th_element(x: Union[Any, Sequence[Any]]) -> Any: def _extract_into_tensor(tensor_list: List[torch.Tensor], indices: Iterable[int], device=None, dtype=None): extracted_items = [maybe_get_0th_element(tensor_list[i]) for i in indices] return torch.cat(extracted_items, dim=0).to(device=device, dtype=dtype) + + +class IntCodes(nn.Module): + """ + A storage for integer codes that makes them compatible with FullyShardedDataParallel, + see https://github.com/pytorch/pytorch/issues/123528 for details + """ + + def __init__(self, codes: torch.tensor, storage_dtype: torch.dtype = torch.float64): + super().__init__() + assert torch.finfo(storage_dtype).bits % torch.iinfo(codes.dtype).bits == 0 + self.dtype, self.shape, self.numel = codes.dtype, codes.shape, codes.numel() + size_ratio = torch.finfo(storage_dtype).bits // torch.iinfo(codes.dtype).bits + codes = F.pad(codes.flatten().clone(), pad=[0, -codes.numel() % size_ratio]) + assert len(codes.untyped_storage()) == codes.nbytes # no offset / stride / tail + self.storage_dtype = storage_dtype + self.data = nn.Parameter( + torch.as_tensor(codes.untyped_storage(), device=codes.device, dtype=storage_dtype), requires_grad=False + ) + + def forward(self): + assert self.data.is_contiguous() and self.data.dtype == self.storage_dtype + byte_offset = self.data.storage_offset() * self.data.nbytes // self.data.numel() + return torch.as_tensor( + self.data.untyped_storage()[byte_offset : byte_offset + self.data.nbytes], + device=self.data.device, + dtype=self.dtype, + )[: self.numel].view(*self.shape) + + +@contextlib.contextmanager +def one_rank_at_a_time(local: bool = False, group_size: int = 1): + """ + In distributed setting, let only group_size processes enter at a time + :param local: if True, the limit is enforced within each host, i.e. distributed hosts can act concurrently + :param group_size: if more than one is specified, + """ + distributed = torch.distributed.is_initialized() + rank = int(os.environ.get("LOCAL_RANK" if local else "RANK", 0)) if distributed else 0 + world_size = int(os.environ.get("LOCAL_WORLD_SIZE" if local else "WORLD_SIZE", 0)) if distributed else 1 + if distributed: + torch.distributed.barrier() + for current_group_index in range(world_size // group_size): + if current_group_index == rank // group_size: + yield + if distributed: + torch.distributed.barrier() + + +@contextlib.contextmanager +def master_rank_first(local: bool, master_rank: int = 0): + distributed = torch.distributed.is_initialized() + rank = int(os.environ.get("LOCAL_RANK" if local else "RANK", 0)) if distributed else 0 + if distributed and rank != master_rank: + torch.distributed.barrier() + yield + if distributed and rank == master_rank: + torch.distributed.barrier() + + +def is_signed(dtype: torch.dtype) -> bool: + """Return True iff an integer dtype is signed""" + try: + return dtype.is_signed + except RuntimeError: # see https://github.com/pytorch/pytorch/issues/125124 + if dtype.is_floating_point: + return torch.finfo(dtype).min < 0 + else: + return torch.iinfo(dtype).min < 0 and torch.iinfo(dtype).max > 0