Skip to content

Commit

Permalink
separated saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Jan 18, 2024
1 parent c04ca2f commit 8613f57
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 59 deletions.
61 changes: 2 additions & 59 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_model_head,
get_sequential_groups,
)
from src.saving import save_fresh_model
from src.utils import using_tf32
from transformers import PreTrainedModel

Expand Down Expand Up @@ -151,59 +152,6 @@ def forward(self, inp, **kwargs):
return inps, forward_args


def update_config(model: PreTrainedModel, args):
old_config = model.config
old_config_type = type(old_config)
old_model_type = old_config.model_type
new_model_type = f"{old_model_type}_aqlm"

class AqlmConfig(old_config_type):
model_type = new_model_type

def __init__(
self,
aqlm: dict[str, int] = {
"nbits_per_codebook": 16,
"num_codebooks": 1,
"out_group_size": 8,
"in_group_size": 1,
},
**kwargs,
):
super().__init__(**kwargs)
self.aqlm = aqlm

config_dict = old_config.to_dict()
config_dict["auto_map"] = {
f"AutoConfig": f"configuration_{new_model_type}.{old_config.__class__.__name__}",
"AutoModelForCausalLM": f"modeling_{new_model_type}.{model.__class__.__name__}",
}
del config_dict["_name_or_path"]

new_config = AqlmConfig(
{
"nbits_per_codebook": args.nbits_per_codebook,
"num_codebooks": args.num_codebooks,
"out_group_size": args.out_group_size,
"in_group_size": args.in_group_size,
}
)
new_config.update(config_dict)

model.config = new_config
model.__class__.__name__ = model.__class__.__name__ + "_AQLM"


def add_inference_code(model: PreTrainedModel, save_path: os.PathLike):
model_type = model.config.model_type

shutil.copytree(f"./transformers/common", save_path, dirs_exist_ok=True)
if os.path.isdir(f"./transformers/{model_type}"):
shutil.copytree(f"./transformers/{model_type}", save_path, dirs_exist_ok=True)
else:
print(f"No predefined PreTrainedModel exists for {model_type}. You'll have to copy-paste some code yourself.")


@torch.no_grad()
def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):
assert not torch.backends.cuda.matmul.allow_tf32
Expand Down Expand Up @@ -305,12 +253,7 @@ def quantize_aq(model: PreTrainedModel, dataloader: Iterable, args: Namespace):

print("=====================\nFinal stats:")
if args.save:
for (submodule, child_name, quantized_linear) in replaced_linears:
setattr(submodule, child_name, quantized_linear.finalize())

update_config(model, args)
model.save_pretrained(args.save)
add_inference_code(model, args.save)
save_fresh_model(model, replaced_linears, args)

if args.wandb:
wandb.log({"max_cuda_mem_quantize": round(torch.cuda.max_memory_allocated() / 1e9, 2)})
Expand Down
69 changes: 69 additions & 0 deletions src/saving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import shutil

import torch
from torch import nn

from transformers import PreTrainedModel


def update_config(model: PreTrainedModel, args):
old_config = model.config
old_config_type = type(old_config)
old_model_type = old_config.model_type
new_model_type = f"{old_model_type}_aqlm"

class AqlmConfig(old_config_type):
model_type = new_model_type

def __init__(
self,
aqlm: dict[str, int] = {
"nbits_per_codebook": 16,
"num_codebooks": 1,
"out_group_size": 8,
"in_group_size": 1,
},
**kwargs,
):
super().__init__(**kwargs)
self.aqlm = aqlm

config_dict = old_config.to_dict()
config_dict["auto_map"] = {
f"AutoConfig": f"configuration_{new_model_type}.{old_config.__class__.__name__}",
"AutoModelForCausalLM": f"modeling_{new_model_type}.{model.__class__.__name__}",
}
del config_dict["_name_or_path"]

new_config = AqlmConfig(
{
"nbits_per_codebook": args.nbits_per_codebook,
"num_codebooks": args.num_codebooks,
"out_group_size": args.out_group_size,
"in_group_size": args.in_group_size,
}
)
new_config.update(config_dict)

model.config = new_config
model.__class__.__name__ = model.__class__.__name__ + "_AQLM"


def add_inference_code(model: PreTrainedModel, save_path: os.PathLike):
model_type = model.config.model_type

shutil.copytree(f"./transformers/common", save_path, dirs_exist_ok=True)
if os.path.isdir(f"./transformers/{model_type}"):
shutil.copytree(f"./transformers/{model_type}", save_path, dirs_exist_ok=True)
else:
print(f"No predefined PreTrainedModel exists for {model_type}. You'll have to copy-paste some code yourself.")


def save_fresh_model(model: PreTrainedModel, replaced_linears: list[tuple[nn.Module, str, nn.Module]], args):
for (submodule, child_name, quantized_linear) in replaced_linears:
setattr(submodule, child_name, quantized_linear.finalize())

update_config(model, args)
model.save_pretrained(args.save)
add_inference_code(model, args.save)

0 comments on commit 8613f57

Please sign in to comment.