Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge PV-Tuning into AQLM main #110

Merged
merged 542 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
542 commits
Select commit Hold shift + click to select a range
aa8898a
fix
dvmazur May 8, 2024
fb498f4
ензщ
justheuristic May 8, 2024
15c91ff
fix
dvmazur May 8, 2024
e90cbce
yet another fix
dvmazur May 8, 2024
798ff52
m
dvmazur May 8, 2024
2330ced
basic pre-shard version
justheuristic May 8, 2024
bbe38f7
typo
justheuristic May 8, 2024
e86ca81
\n
justheuristic May 8, 2024
d1b3637
\n
justheuristic May 8, 2024
cd05b9a
relative error
justheuristic May 8, 2024
2649aa4
TODOs
justheuristic May 8, 2024
60b4775
one less todo
justheuristic May 8, 2024
def623b
one less todo
justheuristic May 8, 2024
64c218b
typo
justheuristic May 8, 2024
1c4c3cb
typo
justheuristic May 8, 2024
811ba05
eps
justheuristic May 8, 2024
82b4e99
order names
justheuristic May 8, 2024
3a0fe20
add prefetch
dvmazur May 8, 2024
f116fb2
pv experimental
justheuristic May 9, 2024
ac91cee
pv experimental
justheuristic May 9, 2024
4547546
typo
justheuristic May 9, 2024
f491c54
load correctly
justheuristic May 9, 2024
9417a5a
load correctly
justheuristic May 9, 2024
a996b21
load correctly
justheuristic May 9, 2024
34084b7
debugprint
justheuristic May 9, 2024
ab1ee81
args.dtype
justheuristic May 9, 2024
cbc5170
debugprint
justheuristic May 9, 2024
a5971a0
compare exact checksums
justheuristic May 9, 2024
74ec292
debugprint
justheuristic May 9, 2024
68583cf
debugprint
justheuristic May 9, 2024
c480907
debugprint
justheuristic May 9, 2024
bd18254
debugprint
justheuristic May 9, 2024
5c78297
debugprint
justheuristic May 9, 2024
8dbcc41
debugprint
justheuristic May 9, 2024
11cf170
debugprint
justheuristic May 9, 2024
a41a36a
debugprint
justheuristic May 9, 2024
469fffa
debugprint
justheuristic May 9, 2024
d4c3916
debugprint
justheuristic May 9, 2024
e9bf34b
debugprint
justheuristic May 9, 2024
ffbbac9
debugprint
justheuristic May 9, 2024
56e64f8
debugprint
justheuristic May 9, 2024
99f2589
debugprint
justheuristic May 9, 2024
aeaf762
debugprint
justheuristic May 9, 2024
f1a4a53
debugprint
justheuristic May 9, 2024
4d192e9
debugprint
justheuristic May 9, 2024
dbf965e
intentionally break size
justheuristic May 9, 2024
bf3015c
intentionally break size
justheuristic May 9, 2024
190e918
intentionally break size
justheuristic May 9, 2024
0669da8
intentionally break size
justheuristic May 9, 2024
ffd1b5e
intentionally break size
justheuristic May 9, 2024
a8f1eb3
intentionally break size
justheuristic May 9, 2024
43e21d4
intentionally break size
justheuristic May 9, 2024
0ee923e
intentionally break size
justheuristic May 9, 2024
a7e1781
intentionally break size
justheuristic May 9, 2024
4ffcaac
intentionally break size
justheuristic May 9, 2024
000876a
intentionally break size
justheuristic May 9, 2024
2291a45
intentionally break size
justheuristic May 9, 2024
2ed53db
intentionally break size
justheuristic May 9, 2024
fb0c575
intentionally break size
justheuristic May 9, 2024
36ac8c3
intentionally break size
justheuristic May 9, 2024
e99fae3
better debugprint
justheuristic May 9, 2024
e49e4aa
add /n
justheuristic May 9, 2024
3b73ff2
verbose optimzier
justheuristic May 9, 2024
4a2b6bc
remove debugprints
justheuristic May 9, 2024
9134599
remove debugprints
justheuristic May 9, 2024
9240cff
remove debugprints
justheuristic May 9, 2024
68cf048
print timings
justheuristic May 9, 2024
6d50d2d
store shard sizes directly
justheuristic May 9, 2024
3b2a79d
remove debug tool
justheuristic May 9, 2024
d8fdda0
init more than one rank at a time; add docstr
justheuristic May 9, 2024
54ad0fd
rm all_reduce_coalesced
justheuristic May 9, 2024
d80937e
option to store code statistics in 16bit precision
justheuristic May 9, 2024
6766dba
option to store code statistics in 16bit precision
justheuristic May 9, 2024
22d1a9a
report timings
justheuristic May 9, 2024
ced4d75
report timings
justheuristic May 9, 2024
8cd1c9b
report timings
justheuristic May 9, 2024
562dac6
report timings
justheuristic May 9, 2024
dede8ee
remove unused line
justheuristic May 9, 2024
84829b8
set dtypes explicitly
justheuristic May 9, 2024
b0a5e54
set dtypes explicitly
justheuristic May 9, 2024
c20e006
bash typing into submission
justheuristic May 9, 2024
bfe3e1b
bash typing into submission
justheuristic May 9, 2024
9178240
state dict saving
justheuristic May 9, 2024
a0c7781
handle saving/loading
justheuristic May 9, 2024
186d31c
handle saving/loading
justheuristic May 9, 2024
0c31066
handle saving/loading
justheuristic May 9, 2024
de9a781
comment
justheuristic May 9, 2024
baeab44
beam size
justheuristic May 10, 2024
8fdb672
fix saving
justheuristic May 10, 2024
d23fe57
fix saving
justheuristic May 10, 2024
85e7b50
fix saving
justheuristic May 10, 2024
70fef4e
check for num_codebooks>1 instead of >0
justheuristic May 10, 2024
08b1cb5
remove extra \n from logs
justheuristic May 10, 2024
6501760
fix zero cost states
dvmazur May 10, 2024
9b7439c
deal with None grad
dvmazur May 10, 2024
bbcbfc3
style
dvmazur May 10, 2024
6307a74
Update src/configurable_adam.py
dvmazur May 10, 2024
b7113b1
Update src/configurable_adam.py
dvmazur May 10, 2024
9487c07
Update src/configurable_adam.py
dvmazur May 10, 2024
a0950be
Update src/configurable_adam.py
dvmazur May 10, 2024
7d160ef
style & optimizations
dvmazur May 10, 2024
2a525c6
Merge branch 'pv-offload-opt-states' of github.com:Vahe1994/GPTAQ int…
dvmazur May 10, 2024
26deb68
Update src/configurable_adam.py
justheuristic May 10, 2024
571ec70
use lerp where possible
justheuristic May 10, 2024
2eb0887
fix zero const amsgrad=False
dvmazur May 10, 2024
2cb76a7
Merge branch 'pv-offload-opt-states' of github.com:Vahe1994/GPTAQ int…
dvmazur May 10, 2024
e3c84be
Merge pull request #30 from Vahe1994/pv-offload-opt-states
justheuristic May 10, 2024
da885f6
check for num_codebooks>1 instead of >0
justheuristic May 10, 2024
73b559e
check for num_codebooks>1 instead of >0
justheuristic May 10, 2024
d9aa905
add p step finetuning
galqiwi May 10, 2024
fbf8ac4
same init method
justheuristic May 10, 2024
e171e95
undo
justheuristic May 10, 2024
11d8caf
remove load_quantized_state_dict
justheuristic May 10, 2024
628a901
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
10d0054
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
c8cda39
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
6454e6d
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
d5c2bd4
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
ea38bc1
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
2059e65
move new modelutils to a separate file; create an utility script that…
justheuristic May 10, 2024
b03eaf4
add params to set st buffer dtype, debias, lamb, amsgrad; print args …
justheuristic May 10, 2024
40dfdd5
fix debias default behavior
justheuristic May 10, 2024
33e2756
fix debiasing
justheuristic May 10, 2024
a635338
dataset config name
justheuristic May 11, 2024
ad07d46
dataset config name
justheuristic May 11, 2024
db878d2
download_num_workers
justheuristic May 11, 2024
4eda2a7
typo
justheuristic May 11, 2024
25b0ce0
preprocessing_keep_in_memory
justheuristic May 12, 2024
c6202fe
fix bug that prevents codes from learning when code_dtype.bits < 32
justheuristic May 12, 2024
10de380
fix bug that prevents codes from learning when code_dtype.bits < 32
justheuristic May 12, 2024
df785f1
fix bug that prevents codes from learning when code_dtype.bits < 32
justheuristic May 14, 2024
2a0de19
move profiler to utils
justheuristic May 14, 2024
97a1e80
support both hard limit and trust ratio on the number of code updates
justheuristic May 14, 2024
96ae954
remove _select_updates_with_highest_priority
justheuristic May 14, 2024
0d88e10
add docstrings for non-straight-through options
justheuristic May 14, 2024
ec901be
fix beam search update
justheuristic May 14, 2024
c493f23
fix beam search update
justheuristic May 14, 2024
cb00245
converted
dvmazur May 14, 2024
be22b19
fix ceil
justheuristic May 14, 2024
9148f55
fix indexing
justheuristic May 14, 2024
03094c0
test
dvmazur May 14, 2024
25e3f87
fix
dvmazur May 14, 2024
3982c51
better print
justheuristic May 14, 2024
43a458c
8sd
justheuristic May 14, 2024
7968d4e
beta2 095
justheuristic May 14, 2024
fe22221
typo
justheuristic May 15, 2024
de6d79f
typo
justheuristic May 15, 2024
4fc4b09
better printing
justheuristic May 15, 2024
3c8e911
set empty grads to zero
justheuristic May 15, 2024
19e453d
aggregate master params
justheuristic May 15, 2024
1d6aeb4
options to update (or not) the quantized components
justheuristic May 15, 2024
592795b
typo
justheuristic May 15, 2024
08f43f5
typo
justheuristic May 15, 2024
4a55546
view
justheuristic May 15, 2024
8668c50
check shape
justheuristic May 15, 2024
a9d00cf
select chunk
justheuristic May 15, 2024
662b097
print
justheuristic May 15, 2024
a7d7cc6
fix delta update
justheuristic May 15, 2024
1c7134c
Update convert_legacy_model_format.py
justheuristic May 15, 2024
b7a9709
Update convert_legacy_model_format.py
justheuristic May 15, 2024
b175829
Merge pull request #31 from Vahe1994/checkpoint-loading
justheuristic May 15, 2024
1d46873
fix delta update
justheuristic May 16, 2024
576e1f5
fix delta update
justheuristic May 16, 2024
f2ad37b
fix delta update
justheuristic May 16, 2024
83d0259
fix direction
justheuristic May 16, 2024
5caa0be
formatting
justheuristic May 16, 2024
66f8d50
comment
justheuristic May 16, 2024
295ec08
formatting
justheuristic May 16, 2024
aa1e4db
directional vectors
justheuristic May 16, 2024
62234b2
greater equal
justheuristic May 16, 2024
7b63472
strictly greater
justheuristic May 16, 2024
03e3d0c
strictly greater
justheuristic May 16, 2024
794176d
force code update
justheuristic May 16, 2024
b2c909c
force code update
justheuristic May 16, 2024
8ce8002
Merge branch 'pv' into pv2
justheuristic May 17, 2024
d32364a
Update convert_legacy_model_format.py
dvmazur May 18, 2024
c7155db
debug
justheuristic May 19, 2024
f1811e3
debug
justheuristic May 19, 2024
cb119e2
extend parameter description / help
justheuristic May 19, 2024
e8fde28
remove debugprints
justheuristic May 19, 2024
f5f9da3
Merge remote-tracking branch 'origin/pv' into pv2
justheuristic May 19, 2024
b671d7f
default
justheuristic May 19, 2024
52f013b
reconfigure MixedPrecision
justheuristic May 19, 2024
cf114d8
readme
justheuristic May 23, 2024
a49b5bc
announce models
justheuristic May 27, 2024
159e937
Update to actual requirements.txt for PV-T
justheuristic May 30, 2024
f933d95
Merge branch 'main' into pv-updated
justheuristic Jun 11, 2024
9b50b8f
bake in defaults
justheuristic Jun 11, 2024
a9f674e
fix typing
justheuristic Jun 11, 2024
316a300
fix typo
justheuristic Jun 11, 2024
c50c76d
split finetune_fsdp into three files
justheuristic Jun 11, 2024
30d1a1e
Merge branch 'pv-tuning' into pv-updated
justheuristic Jun 11, 2024
9b39ad3
undo extract args
justheuristic Jun 11, 2024
dcdd9c9
undo extract saveload
justheuristic Jun 11, 2024
80299f7
undo extract saveload
justheuristic Jun 11, 2024
c6fd897
undo
justheuristic Jun 11, 2024
e0dfb63
undo
justheuristic Jun 11, 2024
2ecbed5
auto-wrap IntCodes
justheuristic Jun 11, 2024
24e43a8
move param split to load_student_model
justheuristic Jun 11, 2024
a5db0a7
move param split to load_student_model
justheuristic Jun 11, 2024
a5bd5c4
move param split to load_student_model
justheuristic Jun 11, 2024
bd0b0b8
preserve parameter order
justheuristic Jun 12, 2024
49f67dc
option to offload teacher to CPU
justheuristic Jun 13, 2024
d4c8356
For review: memory efficient KL loss (#105)
justheuristic Jun 19, 2024
611c60f
pep
justheuristic Jun 19, 2024
8576b58
Merge remote-tracking branch 'origin/pv-updated' into pv-updated
justheuristic Jun 19, 2024
69cf0b0
better naming
justheuristic Jun 19, 2024
3d48441
move stuff around
justheuristic Jun 19, 2024
08170dd
fix
justheuristic Jun 19, 2024
364f486
refactor
justheuristic Jun 19, 2024
99a85e5
refactor
justheuristic Jun 19, 2024
136906f
refactor
justheuristic Jun 19, 2024
dce3641
check params
justheuristic Jun 19, 2024
5dc7ac6
check params
justheuristic Jun 19, 2024
46ed052
simpler wrapper
justheuristic Jun 19, 2024
f659b83
typo
justheuristic Jun 19, 2024
1e77444
typing
justheuristic Jun 20, 2024
31c98e2
deduplicate world size
justheuristic Jun 20, 2024
b1b3030
better autowrap
justheuristic Jun 24, 2024
8a33bcb
undo change
justheuristic Jun 24, 2024
e537735
inline wrapper
justheuristic Jun 24, 2024
ec546b5
offload optimizer
justheuristic Jun 24, 2024
bebf959
wrap properly
justheuristic Jun 24, 2024
1d84233
try fix eval
justheuristic Jun 25, 2024
fd43b8c
fix dtypes
justheuristic Jun 25, 2024
ff26df3
wrap explicitly
justheuristic Jun 25, 2024
f114362
explicit forward_prefetch
justheuristic Jun 25, 2024
bbacb28
explain the tricky bit
justheuristic Jun 25, 2024
c4120de
non-lazy init
justheuristic Jun 25, 2024
d16b8dd
return_tensors="pt"
justheuristic Jun 25, 2024
c79e5c8
fix
justheuristic Jun 25, 2024
3c840dc
trigger lazy init
justheuristic Jun 25, 2024
71ac701
rename
justheuristic Jun 25, 2024
fc1f2a3
rename
justheuristic Jun 25, 2024
d72f0bb
no_grad instead of inference_mode
justheuristic Jun 25, 2024
8d14a5c
offload_student_params
justheuristic Jun 28, 2024
20e5870
manually cast student to bf16
justheuristic Jun 29, 2024
1f887c4
typo
justheuristic Jun 29, 2024
d40720a
support --embed_dtype
justheuristic Jul 2, 2024
f0934ef
fix error found by @yaldashbz
justheuristic Jul 2, 2024
be828a0
review
galqiwi Jul 2, 2024
18a51da
rollback
justheuristic Jul 2, 2024
2278bb2
minimize diff
justheuristic Jul 2, 2024
247a54b
reduce diff
justheuristic Jul 2, 2024
534fa4b
isort
justheuristic Jul 2, 2024
da9e6c8
Merge remote-tracking branch 'origin/main' into pv-updated
justheuristic Jul 2, 2024
3954b4f
fix init for PV when amp_dtype is used
justheuristic Jul 10, 2024
06db959
typo
justheuristic Jul 10, 2024
a29e728
uncomment
justheuristic Jul 11, 2024
d3b2cfd
blacked
justheuristic Aug 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions aq_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight:
assert isinstance(args.devices, (list, tuple)) and len(args.devices) >= 1, f"Found devices = {args.devices}"
assert args.devices[0] == self.device, (args.devices[0], self.XTX.device)
self.quantized_weight = QuantizedWeight(
XTX=self.XTX.to(device=self.device, dtype=torch.float32),
reference_weight=self.layer.weight.detach().to(device=self.device, dtype=torch.float32),
out_group_size=args.out_group_size,
in_group_size=args.in_group_size,
Expand Down Expand Up @@ -165,7 +164,7 @@ def _replace_and_beam_search(self, params_to_replace: nn.ParameterDict, selectio
)
reference_weight = self.layer.weight.detach()[out_channel_selection].to(dtype)
return self.quantized_weight.beam_search_update_codes_(
self.XTX.to(dtype), reference_weight, selection=selection, **kwargs
XTX=self.XTX.to(dtype), reference_weight=reference_weight, selection=selection, **kwargs
).clone()

@torch.no_grad()
Expand All @@ -177,12 +176,15 @@ def beam_search_update_codes_(
seed: Optional[int] = None,
**kwargs,
):
"""Update self.quantized_weight.codes in-place via beam search"""
"""Update quantized_weight codes in-place via beam search"""
if len(devices) == 1: # single device
assert replicas is None
dtype = self.quantized_weight.codebooks.dtype
self.quantized_weight.beam_search_update_codes_(
self.XTX.to(dtype), self.layer.weight.detach().to(dtype), dim_rng=random.Random(seed), **kwargs
XTX=self.XTX.to(dtype),
reference_weight=self.layer.weight.detach().to(dtype),
dim_rng=random.Random(seed),
**kwargs,
)
else:
assert replicas[0] is self
Expand All @@ -203,7 +205,7 @@ def beam_search_update_codes_(
)
# gather all code parts and assign them to each replica
for device, replica in zip(devices, replicas):
replica.quantized_weight.codes[...] = Gather.apply(device, 0, *new_code_parts_by_replica)
replica.quantized_weight.set_codes(Gather.apply(device, 0, *new_code_parts_by_replica))


def replace_parameter_(module: nn.Module, name: str, new_value: torch.Tensor):
Expand Down
214 changes: 214 additions & 0 deletions convert_legacy_model_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this file now? Why can't we save models in the same data format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file allows backward compatibility with models quantized with older version

This abomination converts between one of several quantized model formats to the same format as returned by main.py .
This code exists because we failed to produce a single data format for quantized model.
We should eventually switch to saving all models in the same data format. Once we do, this file should be deleted.
"""
import argparse
import os
import warnings
from copy import deepcopy

import torch
import transformers.models
from torch import nn

from src.aq import QuantizedLinear, QuantizedWeight
from src.modelutils import get_model, save_quantized_model
from src.utils import is_signed


def load_quantized_model_with_old_pickle(base_model_name: str, quantized_model_name: str, **kwargs):
"""Hacky way to allow compatibility between old *pickled* layers and new transformers"""
# because patching it for the fourth time is better than writing a proper saver once >.<
import transformers.activations

if not hasattr(transformers.activations, "SiLUActivation"):
transformers.activations.SiLUActivation = deepcopy(torch.nn.SiLU)
transformers.activations.SiLUActivation.inplace = False
# https://github.com/huggingface/transformers/issues/28496
if not hasattr(transformers.models.llama.modeling_llama.LlamaAttention, "attention_dropout"):
transformers.models.llama.modeling_llama.LlamaAttention.attention_dropout = 0
quantized_model = get_model(base_model_name, None, **kwargs)
quantized_model_src = get_model(base_model_name, quantized_model_name, **kwargs)
for module in quantized_model_src.modules():
if isinstance(module, QuantizedWeight) and not hasattr(module, "codes_storage"):
module.codes_storage = None # backwards compatibility with older pickled snapshots

lut = {}
for name, module in quantized_model_src.named_modules():
for child_name, child_module in module.named_children():
if isinstance(child_module, QuantizedWeight):
lut[name + "." + child_name] = child_module
print(f"found {len(lut)} quantized weight matrices")
for name, module in quantized_model.named_modules():
for child_name, child_module in module.named_children():
if name + "." + child_name + ".quantized_weight" in lut:
quantized_weight = lut.pop(name + "." + child_name + ".quantized_weight")
assert isinstance(child_module, nn.Linear)
setattr(module, child_name, QuantizedLinear(quantized_weight, bias=child_module.bias))
assert not lut, list(lut.keys())
quantized_model.load_state_dict(quantized_model_src.state_dict())
warnings.warn("You should be ashamed of yourself.")
return quantized_model


import functools


def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)


def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split("."))


def load_quantized_model_from_fdsp_checkpoint(base_model_name: str, fsdp_checkpoint_path: str, **kwargs):
original_model = get_model(base_model_name, None, **kwargs)

state_filenames = os.listdir(fsdp_checkpoint_path)

non_quant_fname = "non_quantized_state_dict.pth"
non_quant_path = os.path.join(fsdp_checkpoint_path, non_quant_fname)
non_quant_states = torch.load(non_quant_path)

incomp_keys = original_model.load_state_dict(non_quant_states, strict=False)
assert not incomp_keys.unexpected_keys

missing_keys = list()
for module_name, module in original_model.named_modules():
if not isinstance(module, nn.Linear):
continue

assert not module.bias
state_fname = f"{module_name}.weight.pth"

if state_fname not in state_filenames:
missing_keys.append(module_name)
continue

state_path = os.path.join(fsdp_checkpoint_path, state_fname)
quantized_weight = torch.load(state_path, map_location="cpu")
quantized_linear = QuantizedLinear(quantized_weight, bias=None)
rsetattr(original_model, module_name, quantized_linear)

return original_model


def main():
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--base_model",
type=str,
required=True,
help="path or name of the teacher model",
)
parser.add_argument(
"--quantized_model",
type=str,
required=True,
help="path to quantized model",
)
parser.add_argument(
"--load_dtype",
type=str,
default="auto",
choices=["auto", "float16", "float32", "bfloat16"],
help="dtype to load the model in",
)
parser.add_argument(
"--code_dtype",
type=str,
default=None,
help="if specified, cast quantized layers' codes to this dtype; default = keep loaded dtype",
)
parser.add_argument(
"--p_finetuned_state_dict",
type=str,
default=None,
help="path to quantized model state dict saved by the old FSDP finetuning code",
)
parser.add_argument(
"--pv_fsdp_dir",
type=str,
default=None,
help="path to quantized model state dict saved by the old FSDP finetuning code",
)
parser.add_argument(
"--monkeypatch_old_pickle",
action="store_true",
help="If set, load quantized_model in a hacky way that allows pickled models with older transformers/torch.",
)
parser.add_argument(
"--attn_implementation",
type=str,
default=None,
help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Whether to trust remote code when loading base model.",
)
parser.add_argument("--save", type=str, required=True, help="Save the converted quantized model here")

args = parser.parse_args()
assert args.p_finetuned_state_dict or args.pv_fsdp_dir, "either one of those must be specified"
print(f"{args.p_finetuned_state_dict=}, {args.pv_fsdp_dir=}")
assert (args.p_finetuned_state_dict is not None) != (args.pv_fsdp_dir is not None)

args.load_dtype = getattr(torch, args.load_dtype) if args.load_dtype != "auto" else "auto"
args.code_dtype = getattr(torch, args.code_dtype) if args.code_dtype is not None else None

if not args.monkeypatch_old_pickle:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fact that there is 3 way to load the model, that depends on user choosing correct argument, to put it mildly - not great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. As a small consolation, using the prmary way with --monkeypatch_old_pickle will always work, it's just not the most efficient way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we ultimately decided that this is a grave, but currently necessary evil

quantized_model = get_model(
args.base_model,
args.quantized_model,
dtype=args.load_dtype,
trust_remote_code=args.trust_remote_code,
attn_implementation=args.attn_implementation,
)
elif args.p_finetuned_state_dict:
quantized_model = load_quantized_model_with_old_pickle(
args.base_model,
args.quantized_model,
dtype=args.load_dtype,
trust_remote_code=args.trust_remote_code,
attn_implementation=args.attn_implementation,
)
elif args.pv_fsdp_dir:
quantized_model = load_quantized_model_from_fdsp_checkpoint(
args.base_model,
args.pv_fsdp_dir,
dtype=args.load_dtype,
trust_remote_code=args.trust_remote_code,
)

for module in quantized_model.modules():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this happens precisely, because there is 3 way to save the model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit unclear. Please elaborate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolution: this is a temporary backwards compatibility patch for users tht train with previous main branch and finetune afterwards. It should be deleted in the nearest PR after 2-4 weeks

if isinstance(module, QuantizedWeight):
if not hasattr(module, "codes_storage"):
module.codes_storage = None
if module.codes is None:
module.unwrap_codes_()
assert module.codes is not None
if args.code_dtype is not None:
assert module.nbits_per_codebook <= torch.iinfo(args.code_dtype).bits - is_signed(args.code_dtype)
module.codes = nn.Parameter(module.codes.to(args.code_dtype), requires_grad=module.codes.requires_grad)

if args.p_finetuned_state_dict is not None:
state_dict = torch.load(args.p_finetuned_state_dict, map_location="cpu")
state_dict = {k: v for k, v in state_dict.items() if not k.endswith(".codes_storage.data")}
status = quantized_model.load_state_dict(state_dict, strict=False)
assert all(key.endswith("codes") for key in status.missing_keys)
assert not status.unexpected_keys
del state_dict, status # note: in this case, it is okay not to load codes since P step does not change them

save_quantized_model(quantized_model, args.save)


if __name__ == "__main__":
main()
Loading
Loading