-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
merge PV-Tuning into AQLM main #110
Changes from all commits
aa8898a
fb498f4
15c91ff
e90cbce
798ff52
2330ced
bbe38f7
e86ca81
d1b3637
cd05b9a
2649aa4
60b4775
def623b
64c218b
1c4c3cb
811ba05
82b4e99
3a0fe20
f116fb2
ac91cee
4547546
f491c54
9417a5a
a996b21
34084b7
ab1ee81
cbc5170
a5971a0
74ec292
68583cf
c480907
bd18254
5c78297
8dbcc41
11cf170
a41a36a
469fffa
d4c3916
e9bf34b
ffbbac9
56e64f8
99f2589
aeaf762
f1a4a53
4d192e9
dbf965e
bf3015c
190e918
0669da8
ffd1b5e
a8f1eb3
43e21d4
0ee923e
a7e1781
4ffcaac
000876a
2291a45
2ed53db
fb0c575
36ac8c3
e99fae3
e49e4aa
3b73ff2
4a2b6bc
9134599
9240cff
68cf048
6d50d2d
3b2a79d
d8fdda0
54ad0fd
d80937e
6766dba
22d1a9a
ced4d75
8cd1c9b
562dac6
dede8ee
84829b8
b0a5e54
c20e006
bfe3e1b
9178240
a0c7781
186d31c
0c31066
de9a781
baeab44
8fdb672
d23fe57
85e7b50
70fef4e
08b1cb5
6501760
9b7439c
bbcbfc3
6307a74
b7113b1
9487c07
a0950be
7d160ef
2a525c6
26deb68
571ec70
2eb0887
2cb76a7
e3c84be
da885f6
73b559e
d9aa905
fbf8ac4
e171e95
11d8caf
628a901
10d0054
c8cda39
6454e6d
d5c2bd4
ea38bc1
2059e65
b03eaf4
40dfdd5
33e2756
a635338
ad07d46
db878d2
4eda2a7
25b0ce0
c6202fe
10de380
df785f1
2a0de19
97a1e80
96ae954
0d88e10
ec901be
c493f23
cb00245
be22b19
9148f55
03094c0
25e3f87
3982c51
43a458c
7968d4e
fe22221
de6d79f
4fc4b09
3c8e911
19e453d
1d6aeb4
592795b
08f43f5
4a55546
8668c50
a9d00cf
662b097
a7d7cc6
1c7134c
b7a9709
b175829
1d46873
576e1f5
f2ad37b
83d0259
5caa0be
66f8d50
295ec08
aa1e4db
62234b2
7b63472
03e3d0c
794176d
b2c909c
8ce8002
d32364a
c7155db
f1811e3
cb119e2
e8fde28
f5f9da3
b671d7f
52f013b
cf114d8
a49b5bc
159e937
f933d95
9b50b8f
a9f674e
316a300
c50c76d
30d1a1e
9b39ad3
dcdd9c9
80299f7
c6fd897
e0dfb63
2ecbed5
24e43a8
a5db0a7
a5bd5c4
bd0b0b8
49f67dc
d4c8356
611c60f
8576b58
69cf0b0
3d48441
08170dd
364f486
99a85e5
136906f
dce3641
5dc7ac6
46ed052
f659b83
1e77444
31c98e2
b1b3030
8a33bcb
e537735
ec546b5
bebf959
1d84233
fd43b8c
ff26df3
f114362
bbacb28
c4120de
d16b8dd
c79e5c8
3c840dc
71ac701
fc1f2a3
d72f0bb
8d14a5c
20e5870
1f887c4
d40720a
f0934ef
be828a0
18a51da
2278bb2
247a54b
534fa4b
da9e6c8
3954b4f
06db959
a29e728
d3b2cfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
""" | ||
This abomination converts between one of several quantized model formats to the same format as returned by main.py . | ||
This code exists because we failed to produce a single data format for quantized model. | ||
We should eventually switch to saving all models in the same data format. Once we do, this file should be deleted. | ||
""" | ||
import argparse | ||
import os | ||
import warnings | ||
from copy import deepcopy | ||
|
||
import torch | ||
import transformers.models | ||
from torch import nn | ||
|
||
from src.aq import QuantizedLinear, QuantizedWeight | ||
from src.modelutils import get_model, save_quantized_model | ||
from src.utils import is_signed | ||
|
||
|
||
def load_quantized_model_with_old_pickle(base_model_name: str, quantized_model_name: str, **kwargs): | ||
"""Hacky way to allow compatibility between old *pickled* layers and new transformers""" | ||
# because patching it for the fourth time is better than writing a proper saver once >.< | ||
import transformers.activations | ||
|
||
if not hasattr(transformers.activations, "SiLUActivation"): | ||
transformers.activations.SiLUActivation = deepcopy(torch.nn.SiLU) | ||
transformers.activations.SiLUActivation.inplace = False | ||
# https://github.com/huggingface/transformers/issues/28496 | ||
if not hasattr(transformers.models.llama.modeling_llama.LlamaAttention, "attention_dropout"): | ||
transformers.models.llama.modeling_llama.LlamaAttention.attention_dropout = 0 | ||
quantized_model = get_model(base_model_name, None, **kwargs) | ||
quantized_model_src = get_model(base_model_name, quantized_model_name, **kwargs) | ||
for module in quantized_model_src.modules(): | ||
if isinstance(module, QuantizedWeight) and not hasattr(module, "codes_storage"): | ||
module.codes_storage = None # backwards compatibility with older pickled snapshots | ||
|
||
lut = {} | ||
for name, module in quantized_model_src.named_modules(): | ||
for child_name, child_module in module.named_children(): | ||
if isinstance(child_module, QuantizedWeight): | ||
lut[name + "." + child_name] = child_module | ||
print(f"found {len(lut)} quantized weight matrices") | ||
for name, module in quantized_model.named_modules(): | ||
for child_name, child_module in module.named_children(): | ||
if name + "." + child_name + ".quantized_weight" in lut: | ||
quantized_weight = lut.pop(name + "." + child_name + ".quantized_weight") | ||
assert isinstance(child_module, nn.Linear) | ||
setattr(module, child_name, QuantizedLinear(quantized_weight, bias=child_module.bias)) | ||
assert not lut, list(lut.keys()) | ||
quantized_model.load_state_dict(quantized_model_src.state_dict()) | ||
warnings.warn("You should be ashamed of yourself.") | ||
return quantized_model | ||
|
||
|
||
import functools | ||
|
||
|
||
def rsetattr(obj, attr, val): | ||
pre, _, post = attr.rpartition(".") | ||
return setattr(rgetattr(obj, pre) if pre else obj, post, val) | ||
|
||
|
||
def rgetattr(obj, attr, *args): | ||
def _getattr(obj, attr): | ||
return getattr(obj, attr, *args) | ||
|
||
return functools.reduce(_getattr, [obj] + attr.split(".")) | ||
|
||
|
||
def load_quantized_model_from_fdsp_checkpoint(base_model_name: str, fsdp_checkpoint_path: str, **kwargs): | ||
original_model = get_model(base_model_name, None, **kwargs) | ||
|
||
state_filenames = os.listdir(fsdp_checkpoint_path) | ||
|
||
non_quant_fname = "non_quantized_state_dict.pth" | ||
non_quant_path = os.path.join(fsdp_checkpoint_path, non_quant_fname) | ||
non_quant_states = torch.load(non_quant_path) | ||
|
||
incomp_keys = original_model.load_state_dict(non_quant_states, strict=False) | ||
assert not incomp_keys.unexpected_keys | ||
|
||
missing_keys = list() | ||
for module_name, module in original_model.named_modules(): | ||
if not isinstance(module, nn.Linear): | ||
continue | ||
|
||
assert not module.bias | ||
state_fname = f"{module_name}.weight.pth" | ||
|
||
if state_fname not in state_filenames: | ||
missing_keys.append(module_name) | ||
continue | ||
|
||
state_path = os.path.join(fsdp_checkpoint_path, state_fname) | ||
quantized_weight = torch.load(state_path, map_location="cpu") | ||
quantized_linear = QuantizedLinear(quantized_weight, bias=None) | ||
rsetattr(original_model, module_name, quantized_linear) | ||
|
||
return original_model | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(add_help=True) | ||
parser.add_argument( | ||
"--base_model", | ||
type=str, | ||
required=True, | ||
help="path or name of the teacher model", | ||
) | ||
parser.add_argument( | ||
"--quantized_model", | ||
type=str, | ||
required=True, | ||
help="path to quantized model", | ||
) | ||
parser.add_argument( | ||
"--load_dtype", | ||
type=str, | ||
default="auto", | ||
choices=["auto", "float16", "float32", "bfloat16"], | ||
help="dtype to load the model in", | ||
) | ||
parser.add_argument( | ||
"--code_dtype", | ||
type=str, | ||
default=None, | ||
help="if specified, cast quantized layers' codes to this dtype; default = keep loaded dtype", | ||
) | ||
parser.add_argument( | ||
"--p_finetuned_state_dict", | ||
type=str, | ||
default=None, | ||
help="path to quantized model state dict saved by the old FSDP finetuning code", | ||
) | ||
parser.add_argument( | ||
"--pv_fsdp_dir", | ||
type=str, | ||
default=None, | ||
help="path to quantized model state dict saved by the old FSDP finetuning code", | ||
) | ||
parser.add_argument( | ||
"--monkeypatch_old_pickle", | ||
action="store_true", | ||
help="If set, load quantized_model in a hacky way that allows pickled models with older transformers/torch.", | ||
) | ||
parser.add_argument( | ||
"--attn_implementation", | ||
type=str, | ||
default=None, | ||
justheuristic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2", | ||
) | ||
parser.add_argument( | ||
"--trust_remote_code", | ||
action="store_true", | ||
help="Whether to trust remote code when loading base model.", | ||
) | ||
parser.add_argument("--save", type=str, required=True, help="Save the converted quantized model here") | ||
|
||
args = parser.parse_args() | ||
assert args.p_finetuned_state_dict or args.pv_fsdp_dir, "either one of those must be specified" | ||
print(f"{args.p_finetuned_state_dict=}, {args.pv_fsdp_dir=}") | ||
assert (args.p_finetuned_state_dict is not None) != (args.pv_fsdp_dir is not None) | ||
|
||
args.load_dtype = getattr(torch, args.load_dtype) if args.load_dtype != "auto" else "auto" | ||
args.code_dtype = getattr(torch, args.code_dtype) if args.code_dtype is not None else None | ||
|
||
if not args.monkeypatch_old_pickle: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the fact that there is 3 way to load the model, that depends on user choosing correct argument, to put it mildly - not great. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. As a small consolation, using the prmary way with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we ultimately decided that this is a grave, but currently necessary evil |
||
quantized_model = get_model( | ||
args.base_model, | ||
args.quantized_model, | ||
dtype=args.load_dtype, | ||
trust_remote_code=args.trust_remote_code, | ||
attn_implementation=args.attn_implementation, | ||
) | ||
elif args.p_finetuned_state_dict: | ||
quantized_model = load_quantized_model_with_old_pickle( | ||
args.base_model, | ||
args.quantized_model, | ||
dtype=args.load_dtype, | ||
trust_remote_code=args.trust_remote_code, | ||
attn_implementation=args.attn_implementation, | ||
) | ||
elif args.pv_fsdp_dir: | ||
quantized_model = load_quantized_model_from_fdsp_checkpoint( | ||
args.base_model, | ||
args.pv_fsdp_dir, | ||
dtype=args.load_dtype, | ||
trust_remote_code=args.trust_remote_code, | ||
) | ||
|
||
for module in quantized_model.modules(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose this happens precisely, because there is 3 way to save the model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a bit unclear. Please elaborate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolution: this is a temporary backwards compatibility patch for users tht train with previous main branch and finetune afterwards. It should be deleted in the nearest PR after 2-4 weeks |
||
if isinstance(module, QuantizedWeight): | ||
if not hasattr(module, "codes_storage"): | ||
module.codes_storage = None | ||
if module.codes is None: | ||
module.unwrap_codes_() | ||
assert module.codes is not None | ||
if args.code_dtype is not None: | ||
assert module.nbits_per_codebook <= torch.iinfo(args.code_dtype).bits - is_signed(args.code_dtype) | ||
module.codes = nn.Parameter(module.codes.to(args.code_dtype), requires_grad=module.codes.requires_grad) | ||
|
||
if args.p_finetuned_state_dict is not None: | ||
state_dict = torch.load(args.p_finetuned_state_dict, map_location="cpu") | ||
state_dict = {k: v for k, v in state_dict.items() if not k.endswith(".codes_storage.data")} | ||
status = quantized_model.load_state_dict(state_dict, strict=False) | ||
assert all(key.endswith("codes") for key in status.missing_keys) | ||
assert not status.unexpected_keys | ||
del state_dict, status # note: in this case, it is okay not to load codes since P step does not change them | ||
|
||
save_quantized_model(quantized_model, args.save) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this file now? Why can't we save models in the same data format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file allows backward compatibility with models quantized with older version