diff --git a/.github/workflows/script/install_binary.sh b/.github/workflows/script/install_binary.sh index 7bca0d4d2f3..a33a6607b1b 100644 --- a/.github/workflows/script/install_binary.sh +++ b/.github/workflows/script/install_binary.sh @@ -1,4 +1,5 @@ #!/bin/bash + source /intel-extension-for-transformers/.github/workflows/script/change_color.sh cd /intel-extension-for-transformers @@ -10,7 +11,7 @@ git config --global --add safe.directory "*" git submodule update --init --recursive -$BOLD_YELLOW && echo "---------------- run python setup.py sdist bdist_wheel -------------" && $RESET +$BOLD_YELLOW && echo "---------------- run python setup.py bdist_wheel -------------" && $RESET python setup.py bdist_wheel diff --git a/.github/workflows/unit-test-optimize.yml b/.github/workflows/unit-test-optimize.yml index 4d11947d92c..6399df03878 100644 --- a/.github/workflows/unit-test-optimize.yml +++ b/.github/workflows/unit-test-optimize.yml @@ -67,7 +67,7 @@ jobs: with: submodules: "recursive" ref: ${{ matrix.test_branch }} - fetch-tags: true + fetch-depth: 0 - name: Docker Build run: | diff --git a/docker/Dockerfile_chatbot b/docker/Dockerfile_chatbot index 3bdb8cf9e05..2908acf8b69 100644 --- a/docker/Dockerfile_chatbot +++ b/docker/Dockerfile_chatbot @@ -153,6 +153,7 @@ RUN apt update \ && apt install -y wget numactl git nvidia-cuda* \ && apt install -y openssh-server \ && apt install -y python${PYTHON_VERSION} python3-pip \ + && echo 'LoginGraceTime 0' >> /etc/ssh/sshd_config \# https://ubuntu.com/security/CVE-2024-6387 && apt clean \ && rm -rf /var/lib/apt/lists/* RUN ln -s /usr/bin/python3 /usr/bin/python diff --git a/examples/huggingface/pytorch/code-generation/quantization/Dockerfile-multiple b/examples/huggingface/pytorch/code-generation/quantization/Dockerfile-multiple index a903d161614..d58375c38b4 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/Dockerfile-multiple +++ b/examples/huggingface/pytorch/code-generation/quantization/Dockerfile-multiple @@ -61,7 +61,7 @@ ENV COMPOSE_DOCKER_CLI_BUILD=0 # Install torch and intel-extension-for-pytorch 2.1 RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu RUN python3 -m pip install intel-extension-for-pytorch intel-extension-for-transformers optimum -RUN python3 -m pip install git+https://github.com/huggingface/optimum-intel.git@f95dea1ae8966dee4d75d622e7b2468c514ba02d +RUN python3 -m pip install git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85 RUN python3 -m pip install git+https://github.com/bigcode-project/bigcode-evaluation-harness@0d84db85f9ff971fa23a187a3347b7f59af288dc # Standard requirements diff --git a/examples/huggingface/pytorch/code-generation/quantization/requirements.txt b/examples/huggingface/pytorch/code-generation/quantization/requirements.txt index 5f02605d7d0..455eccd2b26 100644 --- a/examples/huggingface/pytorch/code-generation/quantization/requirements.txt +++ b/examples/huggingface/pytorch/code-generation/quantization/requirements.txt @@ -10,6 +10,6 @@ transformers >= 4.35.0 tiktoken #code_gen neural-compressor intel_extension_for_pytorch==2.3.0 -optimum-intel +git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85 auto-round==0.2 git+https://github.com/bigcode-project/bigcode-evaluation-harness@094c7cc197d13a53c19303865e2056f1c7488ac1 diff --git a/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt b/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt index 65980710432..1b000e0c61b 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt +++ b/examples/huggingface/pytorch/text-generation/quantization/requirements_GPU.txt @@ -7,7 +7,7 @@ sentencepiece != 0.1.92 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ torch==2.1.0a0 transformers -optimum-intel +git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85 bitsandbytes #baichuan transformers_stream_generator tiktoken #qwen diff --git a/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt b/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt index 047f65d091a..02655339b5d 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt +++ b/examples/huggingface/pytorch/text-generation/quantization/requirements_sq.txt @@ -7,7 +7,7 @@ sentencepiece != 0.1.92 torch==2.3.0+cpu transformers==4.38.1 intel_extension_for_pytorch==2.3.0 -optimum-intel==1.16.1 +git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85 bitsandbytes #baichuan transformers_stream_generator tiktoken #qwen diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index 9363b45cf5f..d54dd5f127f 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -142,12 +142,7 @@ user_model = None -# tokenizer -if config.model_type == "llama": - from transformers import LlamaTokenizer - tokenizer = LlamaTokenizer.from_pretrained(args.model) -else: - tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) +tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) quantization_config = None if args.woq: diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp index 1f33cfc663b..784c512220f 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp @@ -14,7 +14,6 @@ #pragma once #include #include -#include "bestla/bestla_storage.h" #include "../include/dispatcher_utils.hpp" #include #include diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp index 8a0c99b3b3a..05a8c718b26 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp @@ -16,6 +16,7 @@ #include #include #include "bestla/bestla_device.h" +#include "bestla/bestla_storage.h" #include "bestla/bestla_utils.h" #include "bestla/bestla_parallel.h" namespace dispatcher_utils { @@ -26,6 +27,12 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()-> inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); } inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); } +template +constexpr bool is_int8_cmpt_gemmcore() { + return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || + GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v>; +} + class qbits_threading { public: static bestla::parallel::IThreading* get() { diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp index 399deaba7e0..cf6889a9f15 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp @@ -16,12 +16,19 @@ #include "../include/bestla_packq_impl.hpp" namespace woq { -template + +template void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { - using proB = bestla::prologue_b::gemm::WeightKBlockNInteger; static proB ker; - auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type), - scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym); + using WType = typename proB::StorageWeight; + WType qpackw(0); + if constexpr (std::is_same_v) { + qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type), + scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym); + } else { + qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type), + scale2bestladt_map.at(p->scale_type)); + } if (p->enable_act_shuffle) ker.enableShuffle(&qpackw); ctx->packw_size = qpackw.mSize; if (task == WOQ_GET_PACKW_SIZE) return; @@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx p->asym ? ctx->zp->data_ptr() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get()); } +template +void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { + if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || + p->weight_type == "int2_clip") { + return execute_qpack>(p, ctx, task); + } + if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") { + TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization."); + return execute_qpack>(p, ctx, task); + } + TORCH_CHECK(false, "Qbits: unsupported bestla packq config, compute_type: " + p->compute_type + + " weight_type: " + p->weight_type); +} + std::string get_dtype_str(BTLA_DTYPE dtype) { switch (dtype) { case BTLA_DTYPE::F32: @@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) { } void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { - // TODO(zhe): elegant impl. - TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || - p->weight_type == "int2_clip", - "Qbits: only support Integer WOQ in PACKQ"); - if (p->compute_type == "int8") { + TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || + p->weight_type == "int2_clip", + "Qbits: only support Integer weight-type with int8 compute-type"); if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); + return parse_prob, BTLA_ISA::AMX_INT8>(p, ctx, task); } if (dispatcher_utils::check_avx512_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX512_VNNI>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX512_VNNI>(p, ctx, task); } if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX_VNNI>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX_VNNI>(p, ctx, task); } if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize, ", ISA support avx2:", dispatcher_utils::check_avx2()); } if (p->compute_type == "fp32") { if (dispatcher_utils::check_avx512f()) { - return execute_qpack, BTLA_ISA::AVX512F>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX512F>(p, ctx, task); } if (dispatcher_utils::check_avx2()) { - return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32"); } if (p->compute_type == "bf16") { if (dispatcher_utils::check_amx()) { - return execute_qpack, BTLA_ISA::AMX_BF16>(p, ctx, task); + return parse_prob, BTLA_ISA::AMX_BF16>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16"); } diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp index f9864ddece0..c04e652a4aa 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp @@ -43,12 +43,6 @@ concept quant_PrologueA = requires { requires !std::is_same_v; }; -template -constexpr bool is_int8_cmpt_gemmcore() { - return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || - GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v>; -} - template void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start(); @@ -133,7 +127,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) { using StorageWeight = typename Launcher::PrologueB::StorageWeight; size_t asym_size = 0, shuf_size = 0; int8_t* tmpbuf = nullptr; - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { using Parallel = bestla::parallel::gemm::SchedulerKBlockS; bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize); StorageWeight* packedw = dynamic_cast(ctx->deseries_wei); @@ -236,7 +230,7 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) { template class PrologueB, template class PrologueA, template class Epilogue> void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) { - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { using Launcher = bestla::wrapper::gemm::LauncherIntKBlock; return execute_task(p, ctx); } else { @@ -260,7 +254,7 @@ template class Pro void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { using namespace bestla::prologue_a::gemm; if (p->src_dt == dispatcher_utils::QBITS_FP32) { - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { return parse_store( p, ctx); } else { @@ -269,7 +263,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { } } if (p->src_dt == dispatcher_utils::QBITS_BF16) { - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { return parse_store( p, ctx); } else { @@ -289,7 +283,7 @@ void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" || p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") { TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization."); - if constexpr (!is_int8_cmpt_gemmcore()) + if constexpr (!dispatcher_utils::is_int8_cmpt_gemmcore()) return parse_activation(p, ctx); } TORCH_CHECK(false, diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index eb18e88baab..8d43d29dde6 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -20,7 +20,7 @@ import logging import math import os - +from ....tools.utils import _ipex_version from accelerate import init_empty_weights from datasets import load_dataset from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear @@ -308,14 +308,12 @@ def _replace_linear( scale_dtype=quantization_config.scale_dtype, blocksize=quantization_config.group_size, scheme=quantization_config.scheme, - compression_dtype=getattr( - module, "compression_dtype", torch.int8 - ), - compression_dim=getattr(module, "compression_dim", 0), + compression_dtype=getattr(module, "compression_dtype", + torch.int8 if _ipex_version < "2.3.10" else torch.int32), + compression_dim=getattr(module, "compression_dim", 0 if _ipex_version < "2.3.10" else 1), device=device, - use_optimum_format=getattr( - module, "use_optimum_format", False - ), + use_optimum_format=getattr(module, "use_optimum_format", + False if _ipex_version < "2.3.10" else True), ) if quantization_config.quant_method.value == "gptq": g_idx = getattr( @@ -340,6 +338,17 @@ def _replace_linear( quantization_config.compute_dtype ), device=torch.device(device), + ) if _ipex_version < "2.3.10" else torch.ones( + ( + math.ceil( + in_features / quantization_config.group_size + ), + out_features, + ), + dtype=convert_dtype_str2torch( + quantization_config.compute_dtype + ), + device=torch.device(device), ) ), module.qzeros if hasattr(module, "qzeros") else None, @@ -392,11 +401,13 @@ def _replace_linear( else: if not hasattr(module, "qweight"): n_pack = ( - 8 // quantization_config.bits + (8 if _ipex_version < "2.3.10" else 32) + // DTYPE_BITS_MAPPING[quantization_config.weight_dtype] ) weight = torch.zeros( - (math.ceil(out_features / n_pack), in_features), - dtype=torch.int8, + (math.ceil(out_features / n_pack), in_features) if _ipex_version < "2.3.10" else + (math.ceil(in_features / n_pack), out_features), + dtype=torch.int8 if _ipex_version < "2.3.10" else torch.int32, device=torch.device(device), ) model._modules[name].set_weights_bias( @@ -677,7 +688,7 @@ def convert_to_quantized_model(model, config, device="cpu"): use_optimum_format=False, scale_dtype=convert_dtype_str2torch(config.scale_dtype), device="xpu", - ) + ) if _ipex_version < "2.3.10" else inc_model.export_compressed_model(use_optimum_format=True, device="xpu") q_model = replace_linear(model, None, None, config, device=device) else: diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index f691f951b15..708f34326c0 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -184,7 +184,7 @@ def convert_model_to_public(model): # reorder weight and scales if they have been transposed if model.device == "xpu" or (isinstance(model.device, torch.device) and model.device.type == "xpu"): for name, module in model.named_modules(): - if isinstance(module, WeightOnlyQuantizedLinear): + if isinstance(module, WeightOnlyQuantizedLinear) and not module.use_optimum_format: if module.weight_transposed: module.qweight.data = module.qweight.t_().contiguous() module.scales.data = module.scales.t_().contiguous() @@ -1768,7 +1768,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # weight dtype is higher priority than bits in config.json when both existed. if quantization_config.weight_dtype is None: if quantization_config.bits == 4: - quantization_config.weight_dtype = "int4_clip" + if use_xpu: + quantization_config.weight_dtype = "int4_fullrange" + else: + quantization_config.weight_dtype = "int4_clip" logger.info( "{} quantization weight_dtype is used due to bits is 4 in config.json.".format( quantization_config.weight_dtype @@ -1825,7 +1828,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "fp4_e2m1", "fp4_e2m1_bnb", "nf4", - "int4_fullrange", ]: model = build_woq_model(model, quantization_config) else: @@ -1938,18 +1940,14 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): # Set model in evaluation mode to deactivate DropOut modules by default model.eval() - if ( - quantization_config.weight_dtype - not in [ - "fp8_e5m2", - "fp8_e4m3", - "nf4", - "fp4_e2m1", - "fp4_e2m1_bnb", - "int4_fullrange", - ] - and not quantization_config.use_ipex - ): + + if quantization_config.weight_dtype not in [ + "fp8_e5m2", + "fp8_e4m3", + "nf4", + "fp4_e2m1", + "fp4_e2m1_bnb", + ] and not quantization_config.use_ipex: model = replace_linear( model, quantization_config=quantization_config, diff --git a/tests/requirements.txt b/tests/requirements.txt index 4dbac7def89..d2c2dca3f74 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -7,6 +7,7 @@ datasets==2.16.1 einops evaluate gguf +git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85 git+https://github.com/intel/neural-compressor.git git+https://github.com/intel/neural-speed.git intel-extension-for-pytorch==2.3.0 @@ -16,7 +17,6 @@ mlflow nlpaug==1.1.9 onnx onnxruntime -optimum-intel==1.16.1 peft==0.6.2 py-cpuinfo sacremoses