Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into wangchang/inc3.x
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Jul 10, 2024
2 parents 54c8157 + 3fd99c8 commit b5537bc
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 68 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/script/install_binary.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash

source /intel-extension-for-transformers/.github/workflows/script/change_color.sh

cd /intel-extension-for-transformers
Expand All @@ -10,7 +11,7 @@ git config --global --add safe.directory "*"
git submodule update --init --recursive


$BOLD_YELLOW && echo "---------------- run python setup.py sdist bdist_wheel -------------" && $RESET
$BOLD_YELLOW && echo "---------------- run python setup.py bdist_wheel -------------" && $RESET
python setup.py bdist_wheel


Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit-test-optimize.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
with:
submodules: "recursive"
ref: ${{ matrix.test_branch }}
fetch-tags: true
fetch-depth: 0

- name: Docker Build
run: |
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile_chatbot
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ RUN apt update \
&& apt install -y wget numactl git nvidia-cuda* \
&& apt install -y openssh-server \
&& apt install -y python${PYTHON_VERSION} python3-pip \
&& echo 'LoginGraceTime 0' >> /etc/ssh/sshd_config \# https://ubuntu.com/security/CVE-2024-6387
&& apt clean \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s /usr/bin/python3 /usr/bin/python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ENV COMPOSE_DOCKER_CLI_BUILD=0
# Install torch and intel-extension-for-pytorch 2.1
RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
RUN python3 -m pip install intel-extension-for-pytorch intel-extension-for-transformers optimum
RUN python3 -m pip install git+https://github.com/huggingface/optimum-intel.git@f95dea1ae8966dee4d75d622e7b2468c514ba02d
RUN python3 -m pip install git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85
RUN python3 -m pip install git+https://github.com/bigcode-project/bigcode-evaluation-harness@0d84db85f9ff971fa23a187a3347b7f59af288dc

# Standard requirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ transformers >= 4.35.0
tiktoken #code_gen
neural-compressor
intel_extension_for_pytorch==2.3.0
optimum-intel
git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85
auto-round==0.2
git+https://github.com/bigcode-project/bigcode-evaluation-harness@094c7cc197d13a53c19303865e2056f1c7488ac1
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sentencepiece != 0.1.92
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
torch==2.1.0a0
transformers
optimum-intel
git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85
bitsandbytes #baichuan
transformers_stream_generator
tiktoken #qwen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sentencepiece != 0.1.92
torch==2.3.0+cpu
transformers==4.38.1
intel_extension_for_pytorch==2.3.0
optimum-intel==1.16.1
git+https://github.com/huggingface/optimum-intel.git@50d867c13b22c22eda451ddb67bddb8159670f85
bitsandbytes #baichuan
transformers_stream_generator
tiktoken #qwen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,7 @@

user_model = None

# tokenizer
if config.model_type == "llama":
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(args.model)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

quantization_config = None
if args.woq:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#pragma once
#include <ATen/core/TensorBody.h>
#include <torch/torch.h>
#include "bestla/bestla_storage.h"
#include "../include/dispatcher_utils.hpp"
#include <string.h>
#include <assert.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <chrono>
#include <string>
#include "bestla/bestla_device.h"
#include "bestla/bestla_storage.h"
#include "bestla/bestla_utils.h"
#include "bestla/bestla_parallel.h"
namespace dispatcher_utils {
Expand All @@ -26,6 +27,12 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()->
inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); }
inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); }

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

class qbits_threading {
public:
static bestla::parallel::IThreading* get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
#include "../include/bestla_packq_impl.hpp"

namespace woq {
template <class GemmCore, BTLA_ISA ISA>

template <class proB>
void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
using proB = bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>;
static proB ker;
auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym);
using WType = typename proB::StorageWeight;
WType qpackw(0);
if constexpr (std::is_same_v<WType, bestla::storage::gemm::StorageWeightKBlockNInteger>) {
qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym);
} else {
qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type));
}
if (p->enable_act_shuffle) ker.enableShuffle(&qpackw);
ctx->packw_size = qpackw.mSize;
if (task == WOQ_GET_PACKW_SIZE) return;
Expand All @@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get());
}

template <class GemmCore, BTLA_ISA ISA>
void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip") {
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>>(p, ctx, task);
}
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNFloat<GemmCore, ISA>>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: unsupported bestla packq config, compute_type: " + p->compute_type +
" weight_type: " + p->weight_type);
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
switch (dtype) {
case BTLA_DTYPE::F32:
Expand Down Expand Up @@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
}

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
// TODO(zhe): elegant impl.
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer WOQ in PACKQ");

if (p->compute_type == "int8") {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer weight-type with int8 compute-type");
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
}
if (dispatcher_utils::check_avx512_vnni() &&
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize,
", ISA support avx2:", dispatcher_utils::check_avx2());
}
if (p->compute_type == "fp32") {
if (dispatcher_utils::check_avx512f()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
return parse_prob<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
}
if (dispatcher_utils::check_avx2()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
return parse_prob<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32");
}
if (p->compute_type == "bf16") {
if (dispatcher_utils::check_amx()) {
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
return parse_prob<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ concept quant_PrologueA = requires {
requires !std::is_same_v<T, bestla::utils::bf16>;
};

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

template <class Launcher>
void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start();
Expand Down Expand Up @@ -133,7 +127,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) {
using StorageWeight = typename Launcher::PrologueB::StorageWeight;
size_t asym_size = 0, shuf_size = 0;
int8_t* tmpbuf = nullptr;
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
using Parallel = bestla::parallel::gemm::SchedulerKBlockS<GemmCore>;
bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize);
StorageWeight* packedw = dynamic_cast<StorageWeight*>(ctx->deseries_wei);
Expand Down Expand Up @@ -236,7 +230,7 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) {
template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class PrologueB,
template <class _T, BTLA_ISA> class PrologueA, template <BTLA_ISA> class Epilogue>
void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
using Launcher = bestla::wrapper::gemm::LauncherIntKBlock<GemmCore::ISA, GemmCore, PrologueA, PrologueB, Epilogue>;
return execute_task<TASK, Launcher>(p, ctx);
} else {
Expand All @@ -260,7 +254,7 @@ template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class Pro
void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
using namespace bestla::prologue_a::gemm;
if (p->src_dt == dispatcher_utils::QBITS_FP32) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeF32, dispatcher_utils::QBITS_FP32>(
p, ctx);
} else {
Expand All @@ -269,7 +263,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
}
}
if (p->src_dt == dispatcher_utils::QBITS_BF16) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeBf16, dispatcher_utils::QBITS_BF16>(
p, ctx);
} else {
Expand All @@ -289,7 +283,7 @@ void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" ||
p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
if constexpr (!is_int8_cmpt_gemmcore<GemmCore>())
if constexpr (!dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>())
return parse_activation<TASK, GemmCore, WeightKBlockNFloat>(p, ctx);
}
TORCH_CHECK(false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import math
import os

from ....tools.utils import _ipex_version
from accelerate import init_empty_weights
from datasets import load_dataset
from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear
Expand Down Expand Up @@ -308,14 +308,12 @@ def _replace_linear(
scale_dtype=quantization_config.scale_dtype,
blocksize=quantization_config.group_size,
scheme=quantization_config.scheme,
compression_dtype=getattr(
module, "compression_dtype", torch.int8
),
compression_dim=getattr(module, "compression_dim", 0),
compression_dtype=getattr(module, "compression_dtype",
torch.int8 if _ipex_version < "2.3.10" else torch.int32),
compression_dim=getattr(module, "compression_dim", 0 if _ipex_version < "2.3.10" else 1),
device=device,
use_optimum_format=getattr(
module, "use_optimum_format", False
),
use_optimum_format=getattr(module, "use_optimum_format",
False if _ipex_version < "2.3.10" else True),
)
if quantization_config.quant_method.value == "gptq":
g_idx = getattr(
Expand All @@ -340,6 +338,17 @@ def _replace_linear(
quantization_config.compute_dtype
),
device=torch.device(device),
) if _ipex_version < "2.3.10" else torch.ones(
(
math.ceil(
in_features / quantization_config.group_size
),
out_features,
),
dtype=convert_dtype_str2torch(
quantization_config.compute_dtype
),
device=torch.device(device),
)
),
module.qzeros if hasattr(module, "qzeros") else None,
Expand Down Expand Up @@ -392,11 +401,13 @@ def _replace_linear(
else:
if not hasattr(module, "qweight"):
n_pack = (
8 // quantization_config.bits
(8 if _ipex_version < "2.3.10" else 32)
// DTYPE_BITS_MAPPING[quantization_config.weight_dtype]
)
weight = torch.zeros(
(math.ceil(out_features / n_pack), in_features),
dtype=torch.int8,
(math.ceil(out_features / n_pack), in_features) if _ipex_version < "2.3.10" else
(math.ceil(in_features / n_pack), out_features),
dtype=torch.int8 if _ipex_version < "2.3.10" else torch.int32,
device=torch.device(device),
)
model._modules[name].set_weights_bias(
Expand Down Expand Up @@ -677,7 +688,7 @@ def convert_to_quantized_model(model, config, device="cpu"):
use_optimum_format=False,
scale_dtype=convert_dtype_str2torch(config.scale_dtype),
device="xpu",
)
) if _ipex_version < "2.3.10" else inc_model.export_compressed_model(use_optimum_format=True, device="xpu")

q_model = replace_linear(model, None, None, config, device=device)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def convert_model_to_public(model):
# reorder weight and scales if they have been transposed
if model.device == "xpu" or (isinstance(model.device, torch.device) and model.device.type == "xpu"):
for name, module in model.named_modules():
if isinstance(module, WeightOnlyQuantizedLinear):
if isinstance(module, WeightOnlyQuantizedLinear) and not module.use_optimum_format:
if module.weight_transposed:
module.qweight.data = module.qweight.t_().contiguous()
module.scales.data = module.scales.t_().contiguous()
Expand Down Expand Up @@ -1768,7 +1768,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# weight dtype is higher priority than bits in config.json when both existed.
if quantization_config.weight_dtype is None:
if quantization_config.bits == 4:
quantization_config.weight_dtype = "int4_clip"
if use_xpu:
quantization_config.weight_dtype = "int4_fullrange"
else:
quantization_config.weight_dtype = "int4_clip"
logger.info(
"{} quantization weight_dtype is used due to bits is 4 in config.json.".format(
quantization_config.weight_dtype
Expand Down Expand Up @@ -1825,7 +1828,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"fp4_e2m1",
"fp4_e2m1_bnb",
"nf4",
"int4_fullrange",
]:
model = build_woq_model(model, quantization_config)
else:
Expand Down Expand Up @@ -1938,18 +1940,14 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]):

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if (
quantization_config.weight_dtype
not in [
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4_e2m1",
"fp4_e2m1_bnb",
"int4_fullrange",
]
and not quantization_config.use_ipex
):

if quantization_config.weight_dtype not in [
"fp8_e5m2",
"fp8_e4m3",
"nf4",
"fp4_e2m1",
"fp4_e2m1_bnb",
] and not quantization_config.use_ipex:
model = replace_linear(
model,
quantization_config=quantization_config,
Expand Down
Loading

0 comments on commit b5537bc

Please sign in to comment.