Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable quant model support #1074

Merged
merged 65 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
3888824
enable IPEXModelForSeq2SeqLM
jiqing-feng Dec 9, 2024
f9fa807
set static cache
jiqing-feng Dec 9, 2024
202df43
add tests for IPEXModelForSeq2SeqLM
jiqing-feng Dec 9, 2024
4488073
add docs
jiqing-feng Dec 9, 2024
16fecf8
fix readme
jiqing-feng Dec 9, 2024
de501f4
Merge branch 'main' into text2text
jiqing-feng Dec 10, 2024
4225bf0
refactor compile
jiqing-feng Dec 11, 2024
2ac7ecf
fix check
jiqing-feng Dec 11, 2024
24b988c
fix ruff check
jiqing-feng Dec 11, 2024
5c4f9a1
Merge branch 'huggingface:main' into text2text
jiqing-feng Dec 16, 2024
46b93a4
enable quantized model
jiqing-feng Dec 16, 2024
82d39ce
add bnb test
jiqing-feng Dec 16, 2024
7dc08da
add bnb tests in yaml
jiqing-feng Dec 16, 2024
30027ff
fix tests
jiqing-feng Dec 16, 2024
314db04
disable bnb tests
jiqing-feng Dec 16, 2024
87656ca
fix gpt2
jiqing-feng Dec 16, 2024
9a7e931
Merge branch 'main' into quant
jiqing-feng Dec 18, 2024
b0cec9c
set actual device
jiqing-feng Dec 18, 2024
94cf35d
assign device when convert class
jiqing-feng Dec 18, 2024
9af46d1
fix class init
jiqing-feng Dec 18, 2024
18b2a6a
fix ipex attn init
jiqing-feng Dec 18, 2024
9f6db33
rm set device on config
jiqing-feng Dec 18, 2024
6d8a969
fix format
jiqing-feng Dec 18, 2024
dd811f9
fix mlp class init
jiqing-feng Dec 18, 2024
d91eefb
Merge branch 'huggingface:main' into quant
jiqing-feng Jan 14, 2025
f094cad
Merge branch 'main' into quant
jiqing-feng Jan 21, 2025
dab4a78
add use_cache param when init generation config
jiqing-feng Jan 21, 2025
6bf3b8b
fix gpt2 quant model
jiqing-feng Jan 21, 2025
356d51d
fix falcon linear fusion
jiqing-feng Jan 22, 2025
d1eee87
fix falcon
jiqing-feng Jan 22, 2025
3aece6a
Merge branch 'huggingface:main' into quant
jiqing-feng Feb 7, 2025
57e3c27
enable awq model test
jiqing-feng Feb 7, 2025
8f6ba5c
fix install
jiqing-feng Feb 7, 2025
8870714
fix install
jiqing-feng Feb 7, 2025
5828fc0
fix install
jiqing-feng Feb 7, 2025
c616d57
fix install
jiqing-feng Feb 7, 2025
e1715b8
fix install
jiqing-feng Feb 7, 2025
e88faf2
fix install
jiqing-feng Feb 7, 2025
882f2b2
fix install
jiqing-feng Feb 7, 2025
d8208c7
fix install
jiqing-feng Feb 7, 2025
80b9ccb
fix install
jiqing-feng Feb 7, 2025
f05fb2f
fix install
jiqing-feng Feb 7, 2025
bd8e870
fix install
jiqing-feng Feb 7, 2025
e471be3
fix install
jiqing-feng Feb 7, 2025
96f4622
fix install
jiqing-feng Feb 7, 2025
32bf0a1
fix install
jiqing-feng Feb 7, 2025
ad3467b
fix install
jiqing-feng Feb 7, 2025
4a21d26
fix install
jiqing-feng Feb 7, 2025
fb8002c
fix install
jiqing-feng Feb 7, 2025
2ad2371
fix install
jiqing-feng Feb 7, 2025
3c2ddef
enable bnb test
jiqing-feng Feb 7, 2025
757ea8c
remove useless device
jiqing-feng Feb 7, 2025
0a6ab0f
update python to 3.10 on test_ipex
jiqing-feng Feb 11, 2025
8c4884b
Apply suggestions from code review
IlyasMoutawwakil Feb 11, 2025
7fa23a5
install autoawq
jiqing-feng Feb 11, 2025
5386bbe
install wheel
jiqing-feng Feb 11, 2025
c5f5d16
fix install autoawq
jiqing-feng Feb 11, 2025
41513f0
rm autoawq
jiqing-feng Feb 11, 2025
8e1caa2
rebase
jiqing-feng Feb 12, 2025
f73c08d
fix concat qkv
jiqing-feng Feb 12, 2025
f51777b
fix format
jiqing-feng Feb 12, 2025
f64b251
fix qwen patch
jiqing-feng Feb 12, 2025
778bf15
fix bias
jiqing-feng Feb 12, 2025
6ba1895
rm autoawq test
jiqing-feng Feb 12, 2025
88dba29
fix style
jiqing-feng Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: "3.10"

- name: Install dependencies
run: |
pip install --upgrade pip
pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }}

- name: Install bitsandbytes
run: |
git clone --branch multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
cd bitsandbytes
pip install .
Comment on lines +41 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no luck with autoawq ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the autoawq installation have some issues, I will figure out why this env cannot install autoawq. But the tests passed in my local env which have the autoawq installed.


- name: Assert versions
run: |
python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))"
Expand Down
27 changes: 12 additions & 15 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -32,13 +32,11 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
_IPEXGPT2Block,
_IPEXIntermediate,
_IPEXLlamaDecoderLayer,
_IPEXQwen2DecoderLayer,
Expand Down Expand Up @@ -66,12 +64,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
convert_functions(sub_m, target_m, new_function_name, new_function)


def convert_class(m, target_m, new_class, config=None):
def convert_class(m, target_m, new_class, device, config):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config)
new_m = new_class(sub_m, device, config)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config)
convert_class(sub_m, target_m, new_class, device, config)


def patch_op(m, target_m, new_op_name, new_op):
Expand All @@ -89,7 +87,7 @@ def _patch_llama_model(model):
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config)
return model


Expand All @@ -105,21 +103,20 @@ def _patch_falcon_model(model):
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config)
return model


def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Use IPEX paged attention
2. Linear fusion with (Linear + Add)
"""
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
Comment on lines -120 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why no longer patching the mlp and attention here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because they are in the _IPEXGPT2Block

return model


Expand All @@ -131,7 +128,7 @@ def _patch_qwen2_model(model):
"""
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config)
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config)
return model


Expand All @@ -140,7 +137,7 @@ def _patch_bert_model(model):
Patch bert model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, BertIntermediate, _IPEXIntermediate)
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config)
return model


Expand All @@ -149,7 +146,7 @@ def _patch_vit_model(model):
Patch vit model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, ViTIntermediate, _IPEXIntermediate)
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config)
return model


Expand Down
Loading
Loading