Skip to content

Commit

Permalink
undoing src and main
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Jan 28, 2024
1 parent 48d6ddf commit 51d664b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 87 deletions.
20 changes: 15 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
from tqdm import trange
from tqdm.auto import trange
from transformers import PreTrainedModel

from aq_engine import AQEngine
from src.aq import QuantizedLinear
Expand All @@ -22,9 +23,7 @@
get_model_head,
get_sequential_groups,
)
from src.saving import save_fresh_model
from src.utils import using_tf32
from transformers import PreTrainedModel

try:
import wandb
Expand Down Expand Up @@ -161,7 +160,6 @@ def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):
model.config.use_cache = False

quantizers = {}
replaced_linears = [] # [(submodule, child_name, new_linear), ...]
overall_bits = 0
model_number_of_params = 0
layers = get_layers(model)
Expand Down Expand Up @@ -206,7 +204,6 @@ def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):
for child_name, child_module in submodule.named_children():
if child_module is aq_handlers[sublayer_name].layer:
setattr(submodule, child_name, new_linear)
replaced_linears.append((submodule, child_name, new_linear))
found_original = True # note: do not break to handle tied layers

assert found_original, f"could not find {sublayer_name}"
Expand All @@ -224,6 +221,11 @@ def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):
layer = finetune_groupwise(layer=layer, inps=inps, outs=outs, args=args, **forward_args)
layer = layer.to(dtype=layer_dtype_original)
print("FINISHED FINETUNING")
if args.save:
os.makedirs(args.save, exist_ok=True)
layer_save_path = os.path.join(args.save, f"{layer_index}.pth")
print(f"Saving layer {layer_index}... to {layer_save_path}")
torch.save(layer, layer_save_path)

if len(args.devices) == 1:
assert len(inps) == len(outs) == 1
Expand Down Expand Up @@ -251,7 +253,15 @@ def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):

print("=====================\nFinal stats:")
if args.save:
save_fresh_model(model, replaced_linears, args)
torch.save(vars(args), args.save + "/args.pt")
already_saved_weights = set()
for layer in get_layers(model):
for param in layer.parameters():
already_saved_weights.add(param)
not_quantized_weights = {
name: param for name, param in model.named_parameters() if param not in already_saved_weights
}
torch.save(not_quantized_weights, args.save + "/not_quantized_weights.pt")

if args.wandb:
wandb.log({"max_cuda_mem_quantize": round(torch.cuda.max_memory_allocated() / 1e9, 2)})
Expand Down
59 changes: 31 additions & 28 deletions src/aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch.nn.functional as F
from tqdm.auto import trange

from src.inference import FinalizedQuantizedLinear
from src.kmeans import find_nearest_cluster, fit_faiss_kmeans, fit_kmeans, fit_kmeans_1d
from src.utils import _dequantize_weight, ellipsis, maybe_script, pack_int_data
from src.utils import ellipsis, maybe_script


class QuantizedLinear(nn.Module):
Expand All @@ -23,32 +22,6 @@ def forward(self, input: torch.Tensor):
# TODO[aqlm] this can be optimized! maybe integrate with QuantizedLinear?
return F.linear(input, self.quantized_weight(), self.bias)

@torch.no_grad()
def finalize(self) -> FinalizedQuantizedLinear:
assert self.quantized_weight is not None, "Can't finalize an unprocessed layer"
finalized_quantized_linear = FinalizedQuantizedLinear(
self.in_features,
self.out_features,
self.quantized_weight.in_group_size,
self.quantized_weight.out_group_size,
self.quantized_weight.num_codebooks,
self.quantized_weight.nbits_per_codebook,
self.bias is not None,
device=self.quantized_weight.codes.device,
)

state_dict = {
"codes": pack_int_data(self.quantized_weight.codes.clone(), self.quantized_weight.nbits_per_codebook),
"codebooks": self.quantized_weight.get_codebooks().clone(),
"scales": self.quantized_weight.get_scales().clone(),
}
if self.bias is not None:
state_dict["bias"] = self.bias.clone()

finalized_quantized_linear.load_state_dict(state_dict)

return finalized_quantized_linear


class QuantizedWeight(nn.Module):
EPS = 1e-9
Expand Down Expand Up @@ -389,6 +362,36 @@ def _make_range(n: int) -> list:
return beam_codes[0]


@maybe_script
def _dequantize_weight(
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]

reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])


@maybe_script
def _beam_search_squared_errors(
XTX: torch.Tensor,
Expand Down
54 changes: 0 additions & 54 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Callable, Iterator, Optional, Sequence

import torch
import torch.nn.functional as F

ellipsis = type(...)

Expand Down Expand Up @@ -50,29 +49,6 @@ def get_mean_nbits_by_codebook(codes: torch.IntTensor, huffman_group_size: int =
return mean_code_lengths


def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")


@torch.inference_mode()
def pack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
data[data >= 2 ** (nbits - 1)] -= 2**nbits
return data.to(get_int_dtype(nbits))


@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)


@functools.lru_cache()
def maybe_script(fn: callable) -> callable:
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
Expand Down Expand Up @@ -127,33 +103,3 @@ def iterate_minibatches(
prev_batch = batch if isinstance(batch, (list, tuple)) and len(tensors) > 1 else batch[0]
del batch
yield prev_batch


@maybe_script
def _dequantize_weight(
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]

reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])

0 comments on commit 51d664b

Please sign in to comment.