-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
962 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
#Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 | ||
##################################################### | ||
|
||
import torch | ||
import gc, os | ||
from tqdm import tqdm | ||
from abc import abstractmethod | ||
|
||
from huggingface_hub import snapshot_download | ||
from ..quantize.core import HQQLinear | ||
|
||
def cleanup(): | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
def fix_path(path): | ||
if(len(path)==0): return path | ||
return path + '/' if (path[-1]!='/') else path | ||
|
||
#Base patching class. Patching defines how nn.Linear and other layers are replaced via a patching function. | ||
class BasePatch(): | ||
#Override these OR override the main patch_model() function | ||
############################################ | ||
#This method iterates through layers of the model that are NOT nn.Linear and processes them via new_nodule = patch_fct(module, params) | ||
@classmethod | ||
def patch_nonlinearlayers(cls, model, patch_fct, verbose=True): | ||
pass | ||
|
||
#This method iterates through layers of the model that are nn.Linear and processes them via new_nodule = patch_fct(module, params) | ||
@classmethod | ||
def patch_linearlayers(cls, base_model, patch_fct, patch_params, verbose=True): | ||
pass | ||
############################################ | ||
#These tags are used to specfiy parameters of the patching in patch_linearlayers() | ||
@classmethod | ||
def get_linear_tags(cls): | ||
return [] | ||
|
||
#Autmatically name modules. This is very important to save/load the weights | ||
@classmethod | ||
def autoname_modules(cls, model): | ||
for name, module in model.named_modules(): | ||
module.name = name | ||
|
||
#Freeze all layers | ||
@classmethod | ||
def freeze_model(cls, model): | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
try: | ||
for param in model.model.parameters(): | ||
param.requires_grad = False | ||
except: | ||
pass | ||
|
||
#Main patching function | ||
@classmethod | ||
def patch_model(cls, model, patch_nonlinear_fct, patch_linear_fct, patch_params, verbose=True): | ||
model.eval() | ||
cls.freeze_model(model) | ||
cls.patch_nonlinearlayers(model, patch_nonlinear_fct, verbose=verbose) | ||
cls.patch_linearlayers(model, patch_linear_fct, patch_params, verbose=verbose) | ||
cls.autoname_modules(model) | ||
cleanup() | ||
|
||
|
||
class BaseHQQModel: | ||
#Override these | ||
############################################ | ||
#This method creates and empty model based on the specfied architecture | ||
@abstractmethod | ||
def create_model(self): | ||
pass | ||
|
||
#This method saves the model architecture only without inculding the weights (for example to a config.json) | ||
@abstractmethod | ||
def cache_model(cls, model, save_dir): | ||
pass | ||
############################################ | ||
|
||
@classmethod | ||
def get_config_file(cls, save_dir): | ||
return fix_path(save_dir) + 'config.json' | ||
|
||
@classmethod | ||
def get_weight_file(cls, save_dir): | ||
return fix_path(save_dir) + 'qmodel.pt' | ||
|
||
@classmethod | ||
def get_ignore_layers(cls, model): | ||
return [] | ||
|
||
@classmethod | ||
def save_weights(cls, weights, save_dir): | ||
torch.save(weights, cls.get_weight_file(save_dir)) | ||
|
||
@classmethod | ||
def load_weights(cls, save_dir): | ||
return torch.load(cls.get_weight_file(save_dir)) | ||
|
||
@classmethod | ||
def quantize_model(cls, model, quant_config): | ||
#Use the same quantization config for all linear layers. Use None to skip quantizing a specfic layer. | ||
patch_params = dict([(k, quant_config) for k in cls.get_linear_tags()]) | ||
|
||
#We replace the nn.Linear layers with HQQLinear | ||
def _patch_linear(linear_layer, quant_config): | ||
return HQQLinear(linear_layer, quant_config) if (quant_config is not None) else linear_layer | ||
|
||
cls.patch_model(model, lambda l: l.half().cuda(), _patch_linear, patch_params) | ||
|
||
@classmethod | ||
def save_quantized(cls, model, save_dir, verbose=False): | ||
#Save config | ||
cls.cache_model(model, save_dir) | ||
|
||
#Save weights | ||
weights = {} | ||
ignore_keys = cls.get_ignore_layers(model) | ||
for name, module in model.named_modules(): | ||
if(name in ignore_keys): continue | ||
try: | ||
state_dict = module.state_dict() | ||
if(len(state_dict)>0): | ||
weights[name] = dict(state_dict) | ||
except Exception as error: | ||
if(verbose): | ||
print('Skipping', name) | ||
|
||
cls.save_weights(weights, save_dir) | ||
|
||
@classmethod | ||
def try_snapshot_download(cls, save_dir_or_hub, cache_dir=''): | ||
save_dir = fix_path(cache_dir) + save_dir_or_hub | ||
|
||
if(os.path.exists(save_dir)==False): | ||
save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir) | ||
save_dir = fix_path(save_dir) | ||
|
||
#Check | ||
if(os.path.exists(cls.get_weight_file(save_dir))==False): | ||
raise Exception('Weight file missing. Check your cache directory.') | ||
if(os.path.exists(cls.get_config_file(save_dir))==False): | ||
raise Exception('Config file missing. Check your cache directory.') | ||
|
||
return save_dir | ||
|
||
@classmethod | ||
def from_quantized(cls, save_dir_or_hub, cache_dir=''): | ||
#Get directory path | ||
save_dir = cls.try_snapshot_download(save_dir_or_hub, cache_dir) | ||
|
||
#Load model from config | ||
model = cls.create_model(save_dir) | ||
|
||
#Name the layers | ||
cls.autoname_modules(model) | ||
|
||
#Load weights | ||
try: | ||
weights = cls.load_weights(save_dir) | ||
except Exception as error: | ||
print("Failed to load the weights", error) | ||
return | ||
|
||
#load_state_dict() doesn't work with modules initialized with init_empty_weights(), so we need to do this manually | ||
@torch.no_grad() | ||
def _load_module(module, params=None): | ||
if(module.name not in weights): | ||
return module.half().cuda() | ||
|
||
state_dict = weights[module.name] | ||
if(('W_q' in state_dict) and ('meta' in state_dict)): | ||
module = HQQLinear(linear_layer=None, quant_config=None) | ||
module.load_state_dict(state_dict) | ||
else: | ||
for key in state_dict: | ||
setattr(module, key, torch.nn.Parameter(state_dict[key])) | ||
|
||
return module | ||
|
||
cls.patch_model(model, _load_module, _load_module, dict([(k, None) for k in cls.get_linear_tags()])) | ||
|
||
return model | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from .base import * | ||
|
||
from tqdm import tqdm | ||
from accelerate import init_empty_weights | ||
import transformers | ||
|
||
#Patch LLama functions | ||
class LLamaPatch(BasePatch): | ||
#These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers | ||
@classmethod | ||
def get_linear_tags(cls): | ||
return ['self_attn.q_proj', | ||
'self_attn.k_proj', | ||
'self_attn.v_proj', | ||
'self_attn.o_proj', | ||
'mlp.gate_proj' , | ||
'mlp.up_proj' , | ||
'mlp.down_proj' ] | ||
|
||
@classmethod | ||
def patch_nonlinearlayers(cls, model, patch_fct, verbose=True): | ||
base_model = model.model | ||
model.lm_head = patch_fct(model.lm_head) | ||
base_model.embed_tokens = patch_fct(base_model.embed_tokens) | ||
base_model.norm = patch_fct(base_model.norm) | ||
|
||
layers = base_model.layers | ||
for i in tqdm(range(len(base_model.layers)), disable=not verbose): | ||
layers[i].self_attn.rotary_emb = patch_fct(layers[i].self_attn.rotary_emb) | ||
layers[i].mlp.act_fn = patch_fct(layers[i].mlp.act_fn) | ||
layers[i].input_layernorm = patch_fct(layers[i].input_layernorm) | ||
layers[i].post_attention_layernorm = patch_fct(layers[i].post_attention_layernorm) | ||
|
||
@classmethod | ||
def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True): | ||
base_model = model.model | ||
layers = base_model.layers | ||
for i in tqdm(range(len(layers)), disable=not verbose): | ||
layers[i].self_attn.q_proj = patch_fct(layers[i].self_attn.q_proj, patch_params['self_attn.q_proj']) | ||
layers[i].self_attn.k_proj = patch_fct(layers[i].self_attn.k_proj, patch_params['self_attn.k_proj']) | ||
layers[i].self_attn.v_proj = patch_fct(layers[i].self_attn.v_proj, patch_params['self_attn.v_proj']) | ||
layers[i].self_attn.o_proj = patch_fct(layers[i].self_attn.o_proj, patch_params['self_attn.o_proj']) | ||
layers[i].mlp.gate_proj = patch_fct(layers[i].mlp.gate_proj, patch_params['mlp.gate_proj']) | ||
layers[i].mlp.up_proj = patch_fct(layers[i].mlp.up_proj, patch_params['mlp.up_proj']) | ||
layers[i].mlp.down_proj = patch_fct(layers[i].mlp.down_proj, patch_params['mlp.down_proj']) | ||
|
||
|
||
class LlamaHQQ(LLamaPatch, BaseHQQModel): | ||
#layers to ignore when saving the weights | ||
@classmethod | ||
def get_ignore_layers(cls, model): | ||
return ['', 'model', 'model.layers'] + ['model.layers.' + str(i) for i in range(len(model.model.layers))] | ||
|
||
#Save model architecture | ||
@classmethod | ||
def cache_model(cls, model, save_dir): | ||
model.config.save_pretrained(save_dir) | ||
|
||
#Create empty model | ||
@classmethod | ||
def create_model(cls, save_dir): | ||
config = transformers.AutoConfig.from_pretrained(cls.get_config_file(save_dir)) | ||
with init_empty_weights(): | ||
model = transformers.LlamaForCausalLM(config) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from .base import * | ||
|
||
from tqdm import tqdm | ||
import timm, json, os | ||
|
||
#Patch ViT functions | ||
class VitPatch(BasePatch): | ||
#These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers | ||
@classmethod | ||
def get_linear_tags(cls): | ||
return ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'] | ||
|
||
@classmethod | ||
def freeze_model(cls, model): | ||
for param in model.parameters(): | ||
param.requires_grad = False | ||
|
||
@classmethod | ||
def patch_nonlinearlayers(cls, model, patch_fct, verbose=True): | ||
model.patch_embed.proj = patch_fct(model.patch_embed.proj) | ||
model.patch_embed.norm = patch_fct(model.patch_embed.norm) | ||
model.norm_pre = patch_fct(model.norm_pre) | ||
model.norm = patch_fct(model.norm) | ||
model.head = patch_fct(model.head) | ||
model.cls_token.data = patch_fct(model.cls_token.data) | ||
model.pos_embed.data = patch_fct(model.pos_embed.data) | ||
|
||
for i in tqdm(range(len(model.blocks)), disable=not verbose): | ||
model.blocks[i].norm1 = patch_fct(model.blocks[i].norm1) | ||
model.blocks[i].norm2 = patch_fct(model.blocks[i].norm2) | ||
|
||
@classmethod | ||
def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True): | ||
for i in tqdm(range(len(model.blocks))): | ||
model.blocks[i].attn.qkv = patch_fct(model.blocks[i].attn.qkv, patch_params['attn.qkv']) | ||
model.blocks[i].attn.proj = patch_fct(model.blocks[i].attn.proj, patch_params['attn.proj']) | ||
model.blocks[i].mlp.fc1 = patch_fct(model.blocks[i].mlp.fc1, patch_params['mlp.fc1']) | ||
model.blocks[i].mlp.fc2 = patch_fct(model.blocks[i].mlp.fc2, patch_params['mlp.fc2']) | ||
|
||
|
||
class ViTHQQ(VitPatch, BaseHQQModel): | ||
#layers to ignore when saving the weights | ||
@classmethod | ||
def get_ignore_layers(cls, model): | ||
return ['', 'model', 'model.blocks'] + ['model.blocks.' + str(i) for i in range(len(model.blocks))] | ||
|
||
#Save model architecture | ||
@classmethod | ||
def cache_model(cls, model, save_dir): | ||
try: | ||
os.makedirs(save_dir, exist_ok=True) | ||
except Exception as error: | ||
print(error) | ||
|
||
with open(cls.get_config_file(save_dir), "w") as file: | ||
json.dump(model.default_cfg, file) | ||
|
||
#Create empty model | ||
@classmethod | ||
def create_model(cls, save_dir): | ||
with open(cls.get_config_file(save_dir), "r") as file: | ||
config = json.load(file) | ||
|
||
model = timm.create_model(config['architecture'] + '.' + config['tag'], pretrained=True) | ||
return model |
Empty file.
Oops, something went wrong.