-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable quant model support #1074
Merged
Merged
Changes from all commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
3888824
enable IPEXModelForSeq2SeqLM
jiqing-feng f9fa807
set static cache
jiqing-feng 202df43
add tests for IPEXModelForSeq2SeqLM
jiqing-feng 4488073
add docs
jiqing-feng 16fecf8
fix readme
jiqing-feng de501f4
Merge branch 'main' into text2text
jiqing-feng 4225bf0
refactor compile
jiqing-feng 2ac7ecf
fix check
jiqing-feng 24b988c
fix ruff check
jiqing-feng 5c4f9a1
Merge branch 'huggingface:main' into text2text
jiqing-feng 46b93a4
enable quantized model
jiqing-feng 82d39ce
add bnb test
jiqing-feng 7dc08da
add bnb tests in yaml
jiqing-feng 30027ff
fix tests
jiqing-feng 314db04
disable bnb tests
jiqing-feng 87656ca
fix gpt2
jiqing-feng 9a7e931
Merge branch 'main' into quant
jiqing-feng b0cec9c
set actual device
jiqing-feng 94cf35d
assign device when convert class
jiqing-feng 9af46d1
fix class init
jiqing-feng 18b2a6a
fix ipex attn init
jiqing-feng 9f6db33
rm set device on config
jiqing-feng 6d8a969
fix format
jiqing-feng dd811f9
fix mlp class init
jiqing-feng d91eefb
Merge branch 'huggingface:main' into quant
jiqing-feng f094cad
Merge branch 'main' into quant
jiqing-feng dab4a78
add use_cache param when init generation config
jiqing-feng 6bf3b8b
fix gpt2 quant model
jiqing-feng 356d51d
fix falcon linear fusion
jiqing-feng d1eee87
fix falcon
jiqing-feng 3aece6a
Merge branch 'huggingface:main' into quant
jiqing-feng 57e3c27
enable awq model test
jiqing-feng 8f6ba5c
fix install
jiqing-feng 8870714
fix install
jiqing-feng 5828fc0
fix install
jiqing-feng c616d57
fix install
jiqing-feng e1715b8
fix install
jiqing-feng e88faf2
fix install
jiqing-feng 882f2b2
fix install
jiqing-feng d8208c7
fix install
jiqing-feng 80b9ccb
fix install
jiqing-feng f05fb2f
fix install
jiqing-feng bd8e870
fix install
jiqing-feng e471be3
fix install
jiqing-feng 96f4622
fix install
jiqing-feng 32bf0a1
fix install
jiqing-feng ad3467b
fix install
jiqing-feng 4a21d26
fix install
jiqing-feng fb8002c
fix install
jiqing-feng 2ad2371
fix install
jiqing-feng 3c2ddef
enable bnb test
jiqing-feng 757ea8c
remove useless device
jiqing-feng 0a6ab0f
update python to 3.10 on test_ipex
jiqing-feng 8c4884b
Apply suggestions from code review
IlyasMoutawwakil 7fa23a5
install autoawq
jiqing-feng 5386bbe
install wheel
jiqing-feng c5f5d16
fix install autoawq
jiqing-feng 41513f0
rm autoawq
jiqing-feng 8e1caa2
rebase
jiqing-feng f73c08d
fix concat qkv
jiqing-feng f51777b
fix format
jiqing-feng f64b251
fix qwen patch
jiqing-feng 778bf15
fix bias
jiqing-feng 6ba1895
rm autoawq test
jiqing-feng 88dba29
fix style
jiqing-feng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
from transformers.models.bert.modeling_bert import BertIntermediate | ||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel | ||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model | ||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model | ||
from transformers.models.llama.modeling_llama import ( | ||
LlamaDecoderLayer, | ||
LlamaModel, | ||
|
@@ -32,13 +32,11 @@ | |
|
||
from .modeling_utils import ( | ||
_IPEX_MINIMUM_VERSION_FOR_PATCHING, | ||
_IPEXGPT2MLP, | ||
_falcon_model_forward, | ||
_gpt2_block_forward, | ||
_gpt2_model_forward, | ||
_ipex_rms_layer_norm_forward, | ||
_IPEXFalconDecoderLayer, | ||
_IPEXGPT2Attention, | ||
_IPEXGPT2Block, | ||
_IPEXIntermediate, | ||
_IPEXLlamaDecoderLayer, | ||
_IPEXQwen2DecoderLayer, | ||
|
@@ -66,12 +64,12 @@ def convert_functions(m, target_m, new_function_name, new_function): | |
convert_functions(sub_m, target_m, new_function_name, new_function) | ||
|
||
|
||
def convert_class(m, target_m, new_class, config=None): | ||
def convert_class(m, target_m, new_class, device, config): | ||
for name, sub_m in m.named_children(): | ||
if isinstance(sub_m, target_m): | ||
new_m = new_class(sub_m, config) | ||
new_m = new_class(sub_m, device, config) | ||
setattr(m, name, new_m) | ||
convert_class(sub_m, target_m, new_class, config) | ||
convert_class(sub_m, target_m, new_class, device, config) | ||
|
||
|
||
def patch_op(m, target_m, new_op_name, new_op): | ||
|
@@ -89,7 +87,7 @@ def _patch_llama_model(model): | |
""" | ||
convert_functions(model, LlamaModel, "forward", _llama_model_forward) | ||
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward) | ||
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) | ||
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config) | ||
return model | ||
|
||
|
||
|
@@ -105,21 +103,20 @@ def _patch_falcon_model(model): | |
setattr(model.config, "num_key_value_heads", num_key_value_heads) | ||
convert_functions(model, FalconModel, "forward", _falcon_model_forward) | ||
replace_customized_linear_with_linear(model) | ||
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config) | ||
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config) | ||
return model | ||
|
||
|
||
def _patch_gpt2_model(model): | ||
""" | ||
Patch gpt2 model: | ||
1. Use IPEX paged attention | ||
2. Linear fusion with (Linear + Add) | ||
""" | ||
num_key_value_heads = model.config.num_attention_heads | ||
setattr(model.config, "num_key_value_heads", num_key_value_heads) | ||
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) | ||
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) | ||
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) | ||
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) | ||
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config) | ||
Comment on lines
-120
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why no longer patching the mlp and attention here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because they are in the _IPEXGPT2Block |
||
return model | ||
|
||
|
||
|
@@ -131,7 +128,7 @@ def _patch_qwen2_model(model): | |
""" | ||
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward) | ||
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward) | ||
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config) | ||
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config) | ||
return model | ||
|
||
|
||
|
@@ -140,7 +137,7 @@ def _patch_bert_model(model): | |
Patch bert model: | ||
1. Linear fusion with Linear + Gelu | ||
""" | ||
convert_class(model, BertIntermediate, _IPEXIntermediate) | ||
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config) | ||
return model | ||
|
||
|
||
|
@@ -149,7 +146,7 @@ def _patch_vit_model(model): | |
Patch vit model: | ||
1. Linear fusion with Linear + Gelu | ||
""" | ||
convert_class(model, ViTIntermediate, _IPEXIntermediate) | ||
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config) | ||
return model | ||
|
||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no luck with autoawq ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the autoawq installation have some issues, I will figure out why this env cannot install autoawq. But the tests passed in my local env which have the autoawq installed.