diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py index 9363b45cf5f..d54dd5f127f 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_gpu_woq.py @@ -142,12 +142,7 @@ user_model = None -# tokenizer -if config.model_type == "llama": - from transformers import LlamaTokenizer - tokenizer = LlamaTokenizer.from_pretrained(args.model) -else: - tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) +tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) quantization_config = None if args.woq: diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index 81bf61879e2..f912135db1a 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -20,9 +20,8 @@ import gc import math import os -from ...utils import CpuInfo +from ....tools.utils import _ipex_version from accelerate import init_empty_weights -from datasets import load_dataset from neural_compressor import quantization from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear from neural_compressor.utils.utility import LazyImport @@ -31,7 +30,6 @@ is_ipex_available, is_autoround_available, ) -from transformers import AutoTokenizer if is_ipex_available(): import intel_extension_for_pytorch as ipex @@ -273,10 +271,12 @@ def _replace_linear( scale_dtype=quantization_config.scale_dtype, blocksize=quantization_config.group_size, scheme=quantization_config.scheme, - compression_dtype=getattr(module, "compression_dtype", torch.int8), - compression_dim=getattr(module, "compression_dim", 0), + compression_dtype=getattr(module, "compression_dtype", + torch.int8 if _ipex_version < "2.3.10" else torch.int32), + compression_dim=getattr(module, "compression_dim", 0 if _ipex_version < "2.3.10" else 1), device=device, - use_optimum_format=getattr(module, "use_optimum_format", False), + use_optimum_format=getattr(module, "use_optimum_format", + False if _ipex_version < "2.3.10" else True), ) if quantization_config.quant_method.value == "gptq": g_idx = getattr(module, "g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) @@ -297,6 +297,17 @@ def _replace_linear( quantization_config.compute_dtype ), device=torch.device(device), + ) if _ipex_version < "2.3.10" else torch.ones( + ( + math.ceil( + in_features / quantization_config.group_size + ), + out_features, + ), + dtype=convert_dtype_str2torch( + quantization_config.compute_dtype + ), + device=torch.device(device), ) ), module.qzeros if hasattr(module, "qzeros") else None, @@ -348,11 +359,13 @@ def _replace_linear( else: if not hasattr(module, "qweight"): n_pack = ( - 8 // DTYPE_BITS_MAPPING[quantization_config.weight_dtype] + (8 if _ipex_version < "2.3.10" else 32) + // DTYPE_BITS_MAPPING[quantization_config.weight_dtype] ) weight = torch.zeros( - (math.ceil(out_features / n_pack), in_features), - dtype=torch.int8, + (math.ceil(out_features / n_pack), in_features) if _ipex_version < "2.3.10" else + (math.ceil(in_features / n_pack), out_features), + dtype=torch.int8 if _ipex_version < "2.3.10" else torch.int32, device=torch.device(device), ) model._modules[name].set_weights_bias( @@ -592,7 +605,7 @@ def default_calib_func(model): use_optimum_format=False, scale_dtype=convert_dtype_str2torch(config.scale_dtype), device="xpu", - ) + ) if _ipex_version < "2.3.10" else inc_model.export_compressed_model(use_optimum_format=True, device="xpu") q_model = replace_linear(model, None, None, config, device=device) else: diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index a5be8cdc519..1314e464eff 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -182,7 +182,7 @@ def convert_model_to_public(model): # reorder weight and scales if they have been transposed if model.device == "xpu" or (isinstance(model.device, torch.device) and model.device.type == "xpu"): for name, module in model.named_modules(): - if isinstance(module, WeightOnlyQuantizedLinear): + if isinstance(module, WeightOnlyQuantizedLinear) and not module.use_optimum_format: if module.weight_transposed: module.qweight.data = module.qweight.t_().contiguous() module.scales.data = module.scales.t_().contiguous() @@ -198,6 +198,7 @@ def convert_model_to_public(model): ]: model = recover_export_model(model) + def make_contiguous(model): for param in model.parameters(): if param.data.ndimension() > 1: @@ -1871,7 +1872,10 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # weight dtype is higher priority than bits in config.json when both existed. if quantization_config.weight_dtype is None: if quantization_config.bits == 4: - quantization_config.weight_dtype = "int4_clip" + if use_xpu: + quantization_config.weight_dtype = "int4_fullrange" + else: + quantization_config.weight_dtype = "int4_clip" logger.info( "{} quantization weight_dtype is used due to bits is 4 in config.json.".format( quantization_config.weight_dtype) @@ -1917,7 +1921,6 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "fp4_e2m1", "fp4_e2m1_bnb", "nf4", - "int4_fullrange", ]: model = build_woq_model(model, quantization_config) else: @@ -2025,7 +2028,6 @@ def replace_ipex_cpu_woq_linear(model, current_name=[]): "nf4", "fp4_e2m1", "fp4_e2m1_bnb", - "int4_fullrange", ] and not quantization_config.use_ipex: model = replace_linear( model,