diff --git a/code/hqq/__init__.py b/code/hqq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/code/hqq/models/__init__.py b/code/hqq/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/code/hqq/models/base.py b/code/hqq/models/base.py new file mode 100644 index 0000000..9313426 --- /dev/null +++ b/code/hqq/models/base.py @@ -0,0 +1,187 @@ +#Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 +##################################################### + +import torch +import gc, os +from tqdm import tqdm +from abc import abstractmethod + +from huggingface_hub import snapshot_download +from ..quantize.core import HQQLinear + +def cleanup(): + torch.cuda.empty_cache() + gc.collect() + +def fix_path(path): + if(len(path)==0): return path + return path + '/' if (path[-1]!='/') else path + +#Base patching class. Patching defines how nn.Linear and other layers are replaced via a patching function. +class BasePatch(): + #Override these OR override the main patch_model() function + ############################################ + #This method iterates through layers of the model that are NOT nn.Linear and processes them via new_nodule = patch_fct(module, params) + @classmethod + def patch_nonlinearlayers(cls, model, patch_fct, verbose=True): + pass + + #This method iterates through layers of the model that are nn.Linear and processes them via new_nodule = patch_fct(module, params) + @classmethod + def patch_linearlayers(cls, base_model, patch_fct, patch_params, verbose=True): + pass + ############################################ + #These tags are used to specfiy parameters of the patching in patch_linearlayers() + @classmethod + def get_linear_tags(cls): + return [] + + #Autmatically name modules. This is very important to save/load the weights + @classmethod + def autoname_modules(cls, model): + for name, module in model.named_modules(): + module.name = name + + #Freeze all layers + @classmethod + def freeze_model(cls, model): + for param in model.parameters(): + param.requires_grad = False + try: + for param in model.model.parameters(): + param.requires_grad = False + except: + pass + + #Main patching function + @classmethod + def patch_model(cls, model, patch_nonlinear_fct, patch_linear_fct, patch_params, verbose=True): + model.eval() + cls.freeze_model(model) + cls.patch_nonlinearlayers(model, patch_nonlinear_fct, verbose=verbose) + cls.patch_linearlayers(model, patch_linear_fct, patch_params, verbose=verbose) + cls.autoname_modules(model) + cleanup() + + +class BaseHQQModel: + #Override these + ############################################ + #This method creates and empty model based on the specfied architecture + @abstractmethod + def create_model(self): + pass + + #This method saves the model architecture only without inculding the weights (for example to a config.json) + @abstractmethod + def cache_model(cls, model, save_dir): + pass + ############################################ + + @classmethod + def get_config_file(cls, save_dir): + return fix_path(save_dir) + 'config.json' + + @classmethod + def get_weight_file(cls, save_dir): + return fix_path(save_dir) + 'qmodel.pt' + + @classmethod + def get_ignore_layers(cls, model): + return [] + + @classmethod + def save_weights(cls, weights, save_dir): + torch.save(weights, cls.get_weight_file(save_dir)) + + @classmethod + def load_weights(cls, save_dir): + return torch.load(cls.get_weight_file(save_dir)) + + @classmethod + def quantize_model(cls, model, quant_config): + #Use the same quantization config for all linear layers. Use None to skip quantizing a specfic layer. + patch_params = dict([(k, quant_config) for k in cls.get_linear_tags()]) + + #We replace the nn.Linear layers with HQQLinear + def _patch_linear(linear_layer, quant_config): + return HQQLinear(linear_layer, quant_config) if (quant_config is not None) else linear_layer + + cls.patch_model(model, lambda l: l.half().cuda(), _patch_linear, patch_params) + + @classmethod + def save_quantized(cls, model, save_dir, verbose=False): + #Save config + cls.cache_model(model, save_dir) + + #Save weights + weights = {} + ignore_keys = cls.get_ignore_layers(model) + for name, module in model.named_modules(): + if(name in ignore_keys): continue + try: + state_dict = module.state_dict() + if(len(state_dict)>0): + weights[name] = dict(state_dict) + except Exception as error: + if(verbose): + print('Skipping', name) + + cls.save_weights(weights, save_dir) + + @classmethod + def try_snapshot_download(cls, save_dir_or_hub, cache_dir=''): + save_dir = fix_path(cache_dir) + save_dir_or_hub + + if(os.path.exists(save_dir)==False): + save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir) + save_dir = fix_path(save_dir) + + #Check + if(os.path.exists(cls.get_weight_file(save_dir))==False): + raise Exception('Weight file missing. Check your cache directory.') + if(os.path.exists(cls.get_config_file(save_dir))==False): + raise Exception('Config file missing. Check your cache directory.') + + return save_dir + + @classmethod + def from_quantized(cls, save_dir_or_hub, cache_dir=''): + #Get directory path + save_dir = cls.try_snapshot_download(save_dir_or_hub, cache_dir) + + #Load model from config + model = cls.create_model(save_dir) + + #Name the layers + cls.autoname_modules(model) + + #Load weights + try: + weights = cls.load_weights(save_dir) + except Exception as error: + print("Failed to load the weights", error) + return + + #load_state_dict() doesn't work with modules initialized with init_empty_weights(), so we need to do this manually + @torch.no_grad() + def _load_module(module, params=None): + if(module.name not in weights): + return module.half().cuda() + + state_dict = weights[module.name] + if(('W_q' in state_dict) and ('meta' in state_dict)): + module = HQQLinear(linear_layer=None, quant_config=None) + module.load_state_dict(state_dict) + else: + for key in state_dict: + setattr(module, key, torch.nn.Parameter(state_dict[key])) + + return module + + cls.patch_model(model, _load_module, _load_module, dict([(k, None) for k in cls.get_linear_tags()])) + + return model + + + diff --git a/code/hqq/models/llama.py b/code/hqq/models/llama.py new file mode 100644 index 0000000..5bd3844 --- /dev/null +++ b/code/hqq/models/llama.py @@ -0,0 +1,65 @@ +from .base import * + +from tqdm import tqdm +from accelerate import init_empty_weights +import transformers + +#Patch LLama functions +class LLamaPatch(BasePatch): + #These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers + @classmethod + def get_linear_tags(cls): + return ['self_attn.q_proj', + 'self_attn.k_proj', + 'self_attn.v_proj', + 'self_attn.o_proj', + 'mlp.gate_proj' , + 'mlp.up_proj' , + 'mlp.down_proj' ] + + @classmethod + def patch_nonlinearlayers(cls, model, patch_fct, verbose=True): + base_model = model.model + model.lm_head = patch_fct(model.lm_head) + base_model.embed_tokens = patch_fct(base_model.embed_tokens) + base_model.norm = patch_fct(base_model.norm) + + layers = base_model.layers + for i in tqdm(range(len(base_model.layers)), disable=not verbose): + layers[i].self_attn.rotary_emb = patch_fct(layers[i].self_attn.rotary_emb) + layers[i].mlp.act_fn = patch_fct(layers[i].mlp.act_fn) + layers[i].input_layernorm = patch_fct(layers[i].input_layernorm) + layers[i].post_attention_layernorm = patch_fct(layers[i].post_attention_layernorm) + + @classmethod + def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True): + base_model = model.model + layers = base_model.layers + for i in tqdm(range(len(layers)), disable=not verbose): + layers[i].self_attn.q_proj = patch_fct(layers[i].self_attn.q_proj, patch_params['self_attn.q_proj']) + layers[i].self_attn.k_proj = patch_fct(layers[i].self_attn.k_proj, patch_params['self_attn.k_proj']) + layers[i].self_attn.v_proj = patch_fct(layers[i].self_attn.v_proj, patch_params['self_attn.v_proj']) + layers[i].self_attn.o_proj = patch_fct(layers[i].self_attn.o_proj, patch_params['self_attn.o_proj']) + layers[i].mlp.gate_proj = patch_fct(layers[i].mlp.gate_proj, patch_params['mlp.gate_proj']) + layers[i].mlp.up_proj = patch_fct(layers[i].mlp.up_proj, patch_params['mlp.up_proj']) + layers[i].mlp.down_proj = patch_fct(layers[i].mlp.down_proj, patch_params['mlp.down_proj']) + + +class LlamaHQQ(LLamaPatch, BaseHQQModel): + #layers to ignore when saving the weights + @classmethod + def get_ignore_layers(cls, model): + return ['', 'model', 'model.layers'] + ['model.layers.' + str(i) for i in range(len(model.model.layers))] + + #Save model architecture + @classmethod + def cache_model(cls, model, save_dir): + model.config.save_pretrained(save_dir) + + #Create empty model + @classmethod + def create_model(cls, save_dir): + config = transformers.AutoConfig.from_pretrained(cls.get_config_file(save_dir)) + with init_empty_weights(): + model = transformers.LlamaForCausalLM(config) + return model diff --git a/code/hqq/models/vit.py b/code/hqq/models/vit.py new file mode 100644 index 0000000..c179883 --- /dev/null +++ b/code/hqq/models/vit.py @@ -0,0 +1,65 @@ +from .base import * + +from tqdm import tqdm +import timm, json, os + +#Patch ViT functions +class VitPatch(BasePatch): + #These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers + @classmethod + def get_linear_tags(cls): + return ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'] + + @classmethod + def freeze_model(cls, model): + for param in model.parameters(): + param.requires_grad = False + + @classmethod + def patch_nonlinearlayers(cls, model, patch_fct, verbose=True): + model.patch_embed.proj = patch_fct(model.patch_embed.proj) + model.patch_embed.norm = patch_fct(model.patch_embed.norm) + model.norm_pre = patch_fct(model.norm_pre) + model.norm = patch_fct(model.norm) + model.head = patch_fct(model.head) + model.cls_token.data = patch_fct(model.cls_token.data) + model.pos_embed.data = patch_fct(model.pos_embed.data) + + for i in tqdm(range(len(model.blocks)), disable=not verbose): + model.blocks[i].norm1 = patch_fct(model.blocks[i].norm1) + model.blocks[i].norm2 = patch_fct(model.blocks[i].norm2) + + @classmethod + def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True): + for i in tqdm(range(len(model.blocks))): + model.blocks[i].attn.qkv = patch_fct(model.blocks[i].attn.qkv, patch_params['attn.qkv']) + model.blocks[i].attn.proj = patch_fct(model.blocks[i].attn.proj, patch_params['attn.proj']) + model.blocks[i].mlp.fc1 = patch_fct(model.blocks[i].mlp.fc1, patch_params['mlp.fc1']) + model.blocks[i].mlp.fc2 = patch_fct(model.blocks[i].mlp.fc2, patch_params['mlp.fc2']) + + +class ViTHQQ(VitPatch, BaseHQQModel): + #layers to ignore when saving the weights + @classmethod + def get_ignore_layers(cls, model): + return ['', 'model', 'model.blocks'] + ['model.blocks.' + str(i) for i in range(len(model.blocks))] + + #Save model architecture + @classmethod + def cache_model(cls, model, save_dir): + try: + os.makedirs(save_dir, exist_ok=True) + except Exception as error: + print(error) + + with open(cls.get_config_file(save_dir), "w") as file: + json.dump(model.default_cfg, file) + + #Create empty model + @classmethod + def create_model(cls, save_dir): + with open(cls.get_config_file(save_dir), "r") as file: + config = json.load(file) + + model = timm.create_model(config['architecture'] + '.' + config['tag'], pretrained=True) + return model diff --git a/code/hqq/quantize/__init__.py b/code/hqq/quantize/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/code/hqq/quantize/core.py b/code/hqq/quantize/core.py new file mode 100644 index 0000000..426ea87 --- /dev/null +++ b/code/hqq/quantize/core.py @@ -0,0 +1,375 @@ +#Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 +##################################################### + +import torch +import numpy as np +from tqdm import tqdm + +import gc +def cleanup(): + torch.cuda.empty_cache() + gc.collect() + +#Proximal solver || W - dequantize(quantize(W))||_p^p +@torch.inference_mode() +def optimize_weights_proximal(tensor, scale, zero, min_max, axis=0, device='cuda', opt_params={'lp_norm':0.7, 'beta':1e1, 'kappa':1.01, 'iters':20}, verbose=False): + lp_norm, beta, kappa, iters = opt_params['lp_norm'], opt_params['beta'], opt_params['kappa'], opt_params['iters'] + + dtype = torch.float16 if (device=='cuda') else torch.float32 + W_f = tensor.to(dtype).to(device) + scale = scale.to(dtype).to(device) + zero = zero.to(dtype).to(device) + + if(lp_norm==1): + shrink_op = lambda x, beta: torch.sign(x)*torch.nn.functional.relu(torch.abs(x) - 1./beta) + else: + shrink_op = lambda x, beta,p=lp_norm: torch.sign(x)*torch.nn.functional.relu(torch.abs(x) - (1./beta)*torch.pow(torch.abs(x), p-1)) + + best_error = 1e4 + for i in range(iters): + W_q = torch.round(W_f*scale + zero).clamp(min_max[0], min_max[1]) + W_r = (W_q - zero)/scale + W_e = shrink_op(W_f - W_r, beta) + zero = torch.mean(W_q - (W_f - W_e)*scale, axis=axis, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(W_f - W_r).mean()) + if(verbose): + print(i, np.round(current_error, 6)) + if(current_error < best_error): + best_error = current_error + else: + break + + scale = scale.to(tensor.device) + zero = zero.to(tensor.device) + del W_f, W_q, W_r, W_e + torch.cuda.empty_cache() + + return scale, zero + +#SGD solver || W - dequantize(quantize(W))||_1 (p=1 only) +def optimize_weights_autograd(tensor, scale, zero, min_max, axis=0, device='cuda', opt_params={'lr':2e-3, 'iters':2500}, verbose=False): + W_f = tensor.to(device) + params = {} + params['scale'] = torch.nn.Parameter(scale.float().to(device), requires_grad=True) + params['zero'] = torch.nn.Parameter(zero.float().to(device), requires_grad=True) + optimizer = torch.optim.AdamW([params[k] for k in params], lr=opt_params['lr'], betas=(0.9, 0.99), eps=1e-06, weight_decay=0.) + + def _loss_fct(output, target): + return torch.mean(torch.abs(target - output)) #L1 + + def _fake_quant(): + #Quantize + W_q = torch.round(W_f*params['scale'] + params['zero']).clamp(min_max[0], min_max[1]) + #Dequantize + W_r = (W_q - params['zero'])/params['scale'] + return W_r + + with torch.no_grad(): + _init_loss = _loss_fct(_fake_quant(), W_f).item() + + def _step(): + optimizer.zero_grad() + loss = _loss_fct(_fake_quant(), W_f) + loss.backward() + optimizer.step() + return np.round(loss.item(), 10) + + for i in range(opt_params['iters']): + l = _step() + if(verbose and (i%100)==0): print(i, l) + + with torch.no_grad(): + _final_loss = _loss_fct(_fake_quant(), W_f).item() + + if(_final_loss<_init_loss): + for k in params: params[k] = params[k].data.detach().to(tensor.device) + else: + if(verbose): print('optimization failed...') + params = {'scale':scale, 'zero':zero} + + del W_f + torch.cuda.empty_cache() + return params['scale'], params['zero'] + +def is_divisible(val1, val2): + return int(val2*np.ceil(val1/val2))==val1 + +def make_multiple(val, multiple): + return int(multiple*np.ceil(val/float(multiple))) + +def zero_pad_row(tensor, num_rows, dtype=None): + out = torch.zeros([num_rows, tensor.shape[1]], device=tensor.device, dtype=tensor.dtype if (dtype is None) else dtype) + out[:len(tensor)] = tensor + return W_q + +class BitPack: + @staticmethod + def pack_8bit_u8(W_q): + return W_q.to(torch.uint8) + + @staticmethod + def unpack_8bit_u8(W_q): + return W_q + + @staticmethod + def pack_4bit_u8(W_q): #uint8 > uint8/2 + W_q = W_q.to(torch.uint8) + _step = int(len(W_q)/2) + return (W_q[:_step] << 4) | W_q[_step:] + + @staticmethod + def unpack_4bit_u8(W_q): #uint8/2 > uint8 + return torch.cat([(W_q & 0b11110000) >> 4, W_q & 0b00001111], axis=0) + + @staticmethod + def pack_2bit_u8(W_q): #uint8 > uint8/4 + W_q = W_q.to(torch.uint8) + _step = int(len(W_q)/4) + return (W_q[:_step] << 6 | W_q[_step:2*_step] << 4 | W_q[2*_step:3*_step] << 2 | W_q[3*_step:] ) + + @staticmethod + def unpack_2bit_u8(W_q): + return torch.cat([(W_q & 0b11000000) >> 6, (W_q & 0b00110000) >> 4, (W_q & 0b00001100) >> 2, W_q & 0b00000011], axis=0) + + #int32 bit packing + ################### + @staticmethod + def pack_3bit_32(W_q_in): + W_q = torch.zeros([int(10*np.ceil(W_q_in.shape[0]/10.)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32) + W_q[:len(W_q_in)] = W_q_in + _step = int(len(W_q)/10) + W_q = (W_q[:_step] << 27) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 21) | (W_q[_step*3:_step*4] << 18) | (W_q[_step*4:_step*5] << 15) | (W_q[_step*5:_step*6] << 12) | (W_q[_step*6:_step*7] << 9) | (W_q[7*_step:_step*8] << 6) | (W_q[_step*8:_step*9] << 3) | (W_q[_step*9:]) + return W_q + + @staticmethod + def unpack_3bit_32(W_q): + return torch.cat([((W_q & 0b00111000000000000000000000000000) >> 27), + ((W_q & 0b00000111000000000000000000000000) >> 24), + ((W_q & 0b00000000111000000000000000000000) >> 21), + ((W_q & 0b00000000000111000000000000000000) >> 18), + ((W_q & 0b00000000000000111000000000000000) >> 15), + ((W_q & 0b00000000000000000111000000000000) >> 12), + ((W_q & 0b00000000000000000000111000000000) >> 9), + ((W_q & 0b00000000000000000000000111000000) >> 6), + ((W_q & 0b00000000000000000000000000111000) >> 3), + ((W_q & 0b00000000000000000000000000000111))], axis=0) + + @staticmethod + def pack_3bit2bit_u8(W_q): + assert is_divisible(len(W_q),3), "Input should have shape[0] divisble by 3 to use mixed 3-2bit bit packing" + _step = int(len(W_q)/3) + return (W_q[:_step] << 6 | W_q[1*_step:2*_step] << 3 | W_q[2*_step:] ) + + @staticmethod + def unpack_3bit2bit_u8(W_q): + return torch.cat([(W_q & 0b11100000) >> 6, (W_q & 0b00011100) >> 3, W_q & 0b00000011], axis=0) + + @staticmethod + def pack_4bit_32(W_q): + W_q = W_q.to(torch.int32) + _step = int(len(W_q)/8) + W_q = (W_q[:_step] << 28) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 20) | (W_q[_step*3:_step*4] << 16) | (W_q[_step*4:_step*5] << 12) | (W_q[_step*5:_step*6] << 8) | (W_q[_step*6:_step*7] << 4) | (W_q[_step*7:]) + return W_q + + @staticmethod + def unpack_4bit_32(W_q): + return torch.cat([((W_q & 0b11110000000000000000000000000000) >> 28), + ((W_q & 0b00001111000000000000000000000000) >> 24), + ((W_q & 0b00000000111100000000000000000000) >> 20), + ((W_q & 0b00000000000011110000000000000000) >> 16), + ((W_q & 0b00000000000000001111000000000000) >> 12), + ((W_q & 0b00000000000000000000111100000000) >> 8), + ((W_q & 0b00000000000000000000000011110000) >> 4), + ((W_q & 0b00000000000000000000000000001111))], axis=0) + + +class Quantizer: + SUPPORTED_BITS = [8, 4, 3, 2] + optimize_weights = optimize_weights_proximal + + bit_to_packing = {8:'8bit_u8', 4:'4bit_u8', 3:'3bit_32', 2:'2bit_u8'} + + pack = {'8bit_u8':BitPack.pack_8bit_u8, + '4bit_u8':BitPack.pack_4bit_u8, + '3bit_32':BitPack.pack_3bit_32, + '2bit_u8':BitPack.pack_2bit_u8} + + unpack = {'8bit_u8':BitPack.unpack_8bit_u8, + '4bit_u8':BitPack.unpack_4bit_u8, + '3bit_32':BitPack.unpack_3bit_32, + '2bit_u8':BitPack.unpack_2bit_u8} + + @classmethod + def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0): + assert nbits in Quantizer.SUPPORTED_BITS, "nbits=" + str(nbits) + " not supported." + assert axis in [0, 1], "axis should be either 0 or 1" + if(group_size is not None): + assert is_divisible(tensor.shape[axis], group_size), "group_size should be divisble by the tensor dimension" + + W = tensor.float() + shape = W.shape + + #Reshape for grouping + if((group_size is not None) and channel_wise): + W = W.reshape([-1, group_size]) if (axis==1) else W.reshape([group_size, -1]) + + #Get min/max values + if(channel_wise==False): + _min, _max = W.min(), W.max() + optimize = False + else: + _min = W.min(axis=axis, keepdim=True)[0] + _max = W.max(axis=axis, keepdim=True)[0] + + max_v = 2**nbits - 1 + min_v = 0 + min_max = [min_v, max_v] + + #Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on. + scale = (max_v/(_max - _min)).clamp(max=2e4) #clamp to avoid half-precision problems + zero = -_min*scale + + #Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14 + if(round_zero): zero = torch.round(zero) + + #Fine-tune weights + if(optimize): scale, zero = Quantizer.optimize_weights(tensor=W, scale=scale, zero=zero, min_max=min_max, axis=axis) + + #Quantize + W_q = torch.round(W*scale + zero).clamp(min_max[0], min_max[1]) + + #Store meta-data (we invert the scale for dequantization) + meta = {'nbits':nbits, 'group_size':group_size, 'shape':shape, 'scale':1./scale, 'zero':zero, 'axis':axis, 'packing':Quantizer.bit_to_packing[nbits]} + + #Pack bits + W_q = Quantizer.pack[meta['packing']](W_q) + + #cleanup + del W, _min, _max + torch.cuda.empty_cache() + + return W_q, meta + + @classmethod + def dequantize(cls, W_q, meta): + W_q_p = Quantizer.unpack[meta['packing']](W_q).half() + if((meta['group_size'] is not None) and (meta['nbits']==3)): + W_q_p = W_q_p[:meta['group_size']] if (meta['axis']==0) else W_q_p[:,:meta['group_size']] + W_r = ((W_q_p - meta['zero'])*meta['scale']).reshape(meta['shape']) + del W_q_p + return W_r + + @classmethod + def to_inplace(cls, W_q, meta, device): + W_q = W_q.to(device).contiguous() + for key in meta: + if(type(meta[key])==torch.Tensor): + meta[key] = (meta[key].half() if meta[key].dtype==torch.float32 else meta[key]).to(device).contiguous() + return W_q, meta + + @classmethod + def to_ooplace(cls, W_q, meta, device): + W_q_c = W_q.to(device).contiguous() + meta_c = {} + for key in meta: + if(type(meta[key])==torch.Tensor): + meta_c[key] = (meta[key].half() if meta[key].dtype==torch.float32 else meta[key]).to(device).contiguous() + else: + meta_c[key] = meta[key] + return W_q_c, meta_c + + @classmethod + def cuda(cls, W_q, meta): + return Quantizer.to_inplace(W_q, meta, device='cuda') + + @classmethod + def cpu(cls, W_q, meta): + return Quantizer.to_ooplace(W_q, meta, device='cpu') + +#Main linear layer +class HQQLinear(torch.nn.Module): + def __init__(self, linear_layer, quant_config, del_orig=True): + super().__init__() + self.ready = False + self.in_gpu = False + self.quant_config = quant_config + if(linear_layer is not None): + self.quantize(linear_layer.weight.data, **quant_config) + self.bias = None if (linear_layer.bias==None) else linear_layer.bias.half().cuda() + if(del_orig): del linear_layer + torch.cuda.empty_cache() + + def cuda(self): + if(self.in_gpu): return + self.W_q, self.meta = Quantizer.cuda(self.W_q, self.meta) + if(self.meta['quant_scale']): + self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cuda(self.meta['scale_q'], self.meta['meta_scale']) + if(self.meta['quant_zero']): + self.meta['zero_q'] , self.meta['meta_zero'] = Quantizer.cuda(self.meta['zero_q'], self.meta['meta_zero']) + self.in_gpu = True + + def to(self, device): + pass + + def half(self): + return self + + def state_dict(self): + return {'W_q':self.W_q, 'meta':self.meta, 'bias':self.bias} + + def load_state_dict(self, state_dict): + self.W_q = state_dict['W_q'] + self.meta = state_dict['meta'] + self.bias = state_dict['bias'] if ('bias' in state_dict) else None + self.in_gpu = self.W_q.device.type == 'cuda' + if(self.in_gpu==False): self.cuda() + self.ready = True + + def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params): + quant_scale = scale_quant_params is not None + quant_zero = zero_quant_params is not None + + #Quantize + W_q , meta = Quantizer.quantize(W, **weight_quant_params) + meta.update({'quant_scale':quant_scale, 'quant_zero':quant_zero}) + if(meta['quant_scale']): + meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale'] + if(meta['quant_zero']): + meta['zero_q'], meta['meta_zero'] = Quantizer.quantize(meta['zero'], **zero_quant_params); del meta['zero'] + + self.W_q = W_q + self.meta = meta + self.cuda() + self.ready = True + + @torch.inference_mode() + def dequantize(self): + assert self.ready, "model was not quantized" + W_q, meta = self.W_q, self.meta + del_keys = [] + if(meta['quant_scale']): + meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') + if(meta['quant_zero']): + meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') + W_est = Quantizer.dequantize(W_q, meta) + #Cleanup + for key in del_keys: del meta[key] + return W_est + + @torch.no_grad() + def forward(self, x): + W_est = self.dequantize() + out = torch.matmul(x, W_est.t()) + if(self.bias!=None): out += self.bias + del W_est + return out + +def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False): + assert nbits in Quantizer.SUPPORTED_BITS, "nbits value not supported. Check Quantizer.SUPPORTED_BITS." + assert is_divisible(group_size, 8), "Invalid group_size param: the value should be a multiple of 8." + weight_quant_params = {'nbits':nbits,'channel_wise':True, 'group_size':group_size, 'optimize':True, 'round_zero':True if nbits==4 else False} + scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None + zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None + return {'weight_quant_params':weight_quant_params, 'scale_quant_params':scale_quant_params, 'zero_quant_params':zero_quant_params} diff --git a/code/llama2_benchmark/eval_model.py b/code/llama2_benchmark/eval_model.py new file mode 100644 index 0000000..2f7829a --- /dev/null +++ b/code/llama2_benchmark/eval_model.py @@ -0,0 +1,52 @@ +from datasets import load_dataset +import torch, time +import numpy as np +from tqdm import tqdm + +import gc +def cleanup(): + torch.cuda.empty_cache() + gc.collect() + +#Adapted from https://huggingface.co/transformers/v4.2.2/perplexity.html +def eval_wikitext2(model, tokenizer, max_length=1024, stride=512, verbose=True): + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') + + encodings['input_ids'] = encodings['input_ids'].to('cuda') + + lls, t = [], [] + for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): + begin_loc = max(i + stride - max_length, 0) + end_loc = min(i + stride, encodings['input_ids'].size(1)) + trg_len = end_loc - i + input_ids = encodings['input_ids'][:,begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:,:-trg_len] = -100 #ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + torch.cuda.synchronize() + t2 = time.time() + t.append((t2-t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = np.round(float(torch.exp(torch.stack(lls).sum() / end_loc)), 4) + pred_time = np.round(np.mean(t), 3) + if(verbose): + print('perplexity', ppl) + print('time', str(pred_time) + ' sec') + + del encodings + cleanup() + + return {'perplexity':ppl, 'prediction_time':pred_time} + diff --git a/code/llama2_benchmark/quant_llama2_awq_demo.py b/code/llama2_benchmark/quant_llama2_awq_demo.py new file mode 100644 index 0000000..e8370c4 --- /dev/null +++ b/code/llama2_benchmark/quant_llama2_awq_demo.py @@ -0,0 +1,39 @@ +import torch, transformers + +#Settings +###################################################################################### +hf_auth = None #HuggingFace token +cache_path = '' #cache directory to store data + +#Chose a model +model_id = "meta-llama/Llama-2-7b-hf" +#model_id = "meta-llama/Llama-2-13b-hf" +#model_id = "meta-llama/Llama-2-70b-hf" + +#AWQ settings +###################################################################################### +from awq import AutoAWQForCausalLM +import gc, time + +# Load model +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth) +model = AutoAWQForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path, resume_download=True) + +#quant_config = {"w_bit": 4, "q_group_size": 128, "zero_point": True, 'version':'GEMM'} +quant_config = {"w_bit": 4, "q_group_size": 64, "zero_point": True, 'version':'GEMM'} + +t1 = time.time() +model.quantize(tokenizer, quant_config=quant_config) +t2 = time.time() +print('Took ' + str(t2-t1) + ' seconds to quantize the model with AWQ') + +model = model.cuda() +torch.cuda.empty_cache() +gc.collect() + +#Evaluate the quantized model +###################################################################################### +from eval_model import eval_wikitext2 + +eval_wikitext2(model, tokenizer, verbose=True) + diff --git a/code/llama2_benchmark/quant_llama2_gptq_demo.py b/code/llama2_benchmark/quant_llama2_gptq_demo.py new file mode 100644 index 0000000..f4fdd48 --- /dev/null +++ b/code/llama2_benchmark/quant_llama2_gptq_demo.py @@ -0,0 +1,71 @@ +import torch, transformers + +#Important: limit the number of threads otherwise the process will hang for a long time +#num_threads=32; +#OMP_NUM_THREADS=$num_threads OPENBLAS_NUM_THREADS=$num_threads MKL_NUM_THREADS=$num_threads VECLIB_MAXIMUM_THREADS=$num_threads NUMEXPR_NUM_THREADS=$num_threads CUDA_VISIBLE_DEVICES=0 ipython3 + +#Settings +###################################################################################### +hf_auth = None #HuggingFace token +cache_path = '' #cache directory to store data + +#Chose a model +model_id = "meta-llama/Llama-2-7b-hf" +#model_id = "meta-llama/Llama-2-13b-hf" +#model_id = "meta-llama/Llama-2-70b-hf" + +#GPTQ settings +###################################################################################### +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +import logging, gc, time +from tqdm import tqdm + +logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S") + +#Adapted from: https://towardsdatascience.com/4-bit-quantization-with-gptq-36b0f4f02c34 +def prepare_model(model, tokenizer, n_samples=1024, max_tokens=512, use_triton=True): + # Load data and tokenize examples + from datasets import load_dataset + import random + data = load_dataset("allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz", split=f"train[:{n_samples}]", cache_dir=cache_path) + tokenized_data = torch.cat([tokenizer(data[i]['text'], return_tensors='pt').input_ids for i in tqdm(range(len(data)))], axis=-1) #~536K tokens + + # Format tokenized examples + random.seed(1) + examples_ids = [] + for _ in range(n_samples): + i = random.randint(0, tokenized_data.shape[1] - max_tokens - 1) + j = i + max_tokens + input_ids = tokenized_data[:, i:j] + attention_mask = torch.ones_like(input_ids) + examples_ids.append({'input_ids': input_ids, 'attention_mask': attention_mask}) + + print('Using ' + str(len(examples_ids)) + ' samples for calibration.') + model.quantize(examples_ids, batch_size=1, use_triton=use_triton) + model = model.cuda(); + with torch.no_grad(): x = model(input_ids.to('cuda')); + del examples_ids, x + torch.cuda.empty_cache() + gc.collect() + return model + +#quantize_config = BaseQuantizeConfig(bits=8, group_size=128, damp_percent=0.01, desc_act=False); use_triton=True; +#quantize_config = BaseQuantizeConfig(bits=4, group_size=128, damp_percent=0.01, desc_act=False); use_triton=True; +quantize_config = BaseQuantizeConfig(bits=4, group_size=64, damp_percent=0.01, desc_act=False); use_triton=True; +#quantize_config = BaseQuantizeConfig(bits=3, group_size=128, damp_percent=0.01, desc_act=False); use_triton=False; +#quantize_config = BaseQuantizeConfig(bits=3, group_size=64, damp_percent=0.01, desc_act=False); use_triton=False; +#quantize_config = BaseQuantizeConfig(bits=2, group_size=64, damp_percent=0.01, desc_act=False); use_triton=True; + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth) +model = AutoGPTQForCausalLM.from_pretrained(model_id, quantize_config, use_auth_token=hf_auth, cache_dir=cache_path) +t1 = time.time() +model = prepare_model(model, tokenizer, use_triton=use_triton) +t2 = time.time() +print('Took ' + str(t2-t1) + ' seconds to quantize the model with GPTQ') + +#Evaluate the quantized model +###################################################################################### +from eval_model import eval_wikitext2 + +eval_wikitext2(model, tokenizer, verbose=True) + diff --git a/code/llama2_benchmark/quant_llama2_hqq_demo.py b/code/llama2_benchmark/quant_llama2_hqq_demo.py new file mode 100644 index 0000000..a917c82 --- /dev/null +++ b/code/llama2_benchmark/quant_llama2_hqq_demo.py @@ -0,0 +1,36 @@ +import torch, transformers + +#Settings +###################################################################################### +hf_auth = None #HuggingFace token +cache_path = '' #cache directory to store data + +#Chose a model +model_id = "meta-llama/Llama-2-7b-hf" +#model_id = "meta-llama/Llama-2-13b-hf" +#model_id = "meta-llama/Llama-2-70b-hf" + +#Load model on the CPU +###################################################################################### +model = transformers.AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=hf_auth, cache_dir=cache_path) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth) + +#Quantize the model +###################################################################################### +from hqq.quantize.core import hqq_base_quant_config +from hqq.models.llama import LlamaHQQ + +#quant_config = hqq_base_quant_config(nbits=8, group_size=128) +quant_config = hqq_base_quant_config(nbits=4, group_size=64) +#quant_config = hqq_base_quant_config(nbits=3, group_size=64) +#quant_config = hqq_base_quant_config(nbits=2, group_size=16) +#quant_config = hqq_base_quant_config(nbits=2, group_size=16, quant_scale=True) #scale is quantized to 8-bit/g=128 + +#quantize_model(model, quant_config=quant_config) +LlamaHQQ.quantize_model(model, quant_config=quant_config) + +# #Evaluate the quantized model +###################################################################################### +from eval_model import eval_wikitext2 +eval_wikitext2(model, tokenizer, verbose=True) + diff --git a/code/setup.py b/code/setup.py new file mode 100644 index 0000000..8cc67e0 --- /dev/null +++ b/code/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='hqq', + version='1.0.0', + description='Half-Quadratic Quantization (HQQ)', + url='https://github.com/mobiusml/hqq/tree/main/code', + author='Dr. Hicham Badri', + author_email='hicham@mobiuslabs.com', + license='Apache 2', + packages=['hqq', 'hqq/models', 'hqq/quantize'], + install_requires=['numpy>=1.24.4','tqdm>=4.64.1', 'torch>=2.0.1'], +) diff --git a/code/vit_example/vit_example.py b/code/vit_example/vit_example.py new file mode 100644 index 0000000..15f4c5b --- /dev/null +++ b/code/vit_example/vit_example.py @@ -0,0 +1,59 @@ +import numpy as np +import timm, torch + +from hqq.quantize.core import hqq_base_quant_config +from hqq.models.vit import ViTHQQ + +#Model ID +model_id = 'vit_large_patch14_clip_224.laion2b' + +#Load model (on CPU) +model = timm.create_model(model_id, pretrained=True) + +#Quantize +quant_config = hqq_base_quant_config(nbits=4, group_size=64) +ViTHQQ.quantize_model(model, quant_config=quant_config) + +############################################################### +# #Save model +# save_dir = "repo/" + model_id +# ViTHQQ.save_quantized(model, save_dir=save_dir) + +# #Load model +# model = ViTHQQ.from_quantized(save_dir) +############################################################### + +#Load reference model to compare with +model_ref = timm.create_model(model_id, pretrained=True) +model_ref = model_ref.half().cuda() +model_ref.eval(); + +#Pre-processing +mean_clip = np.array([0.4815, 0.4578, 0.4082], 'float32') +std_clip = np.array([0.2686, 0.2613, 0.2758], 'float32') +def normalize_images_clip(data_np_in, BCHW=True): + data_np = torch.from_numpy(data_np_in).float() if(type(data_np_in)==np.ndarray) else data_np_in.float() + + data_np /= 255. + for i in range(3): + data_np[...,i] -= mean_clip[i] + data_np[...,i] /= std_clip[i] + if(BCHW): + data_np = data_np.swapaxes(2, 3).swapaxes(1, 2) + return data_np + +############################################################### +#Compare the compressed model with the original +x = np.random.rand(16, 224, 224, 3) +x = normalize_images_clip(x).half().cuda() + +with torch.no_grad(): + y1 = model(x) + y1 /= torch.norm(y1, p=2, dim=-1, keepdim=True) + +with torch.no_grad(): + y2 = model_ref(x) + y2 /= torch.norm(y2, p=2, dim=-1, keepdim=True) + +#We want the dot product to be as close as possible to 1 +print('Average dot-product score', float(torch.diag(torch.matmul(y1, y2.t())).mean())) #~0.9736328125