Skip to content

Commit

Permalink
Remove all references to habana_quantization_toolkit for 1.18 (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
tthakkal authored Oct 18, 2024
1 parent 21c13ff commit 46b14e6
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 73 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ curl -N 127.0.0.1:8080/generate_stream \
## Running TGI with FP8 Precision
TGI-Gaudi supports FP8 precision inference with INC (Intel Neural Compressor) and HQT (Habana Quantization Toolkit). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command. From TGI-Gaudi 2.0.4 release, INC is used by default for quantization. HQT will be removed in future releases. To use HQT, disable INC by setting `-e USE_INC=0` in docker command.
TGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html). FP8 inference can be run by setting QUANT_CONFIG environment variable in the docker command.
To run FP8 Inference:
Expand Down
38 changes: 28 additions & 10 deletions server/text_generation_server/habana_quantization_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.

import os
import habana_frameworks.torch as htorch

quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
Expand All @@ -10,18 +11,35 @@
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault(
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")


def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce

for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)


def setup_quantization(model):
if is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model


def prepare_model_for_quantization(model):
if is_quantization_enabled:
if os.getenv("USE_INC", "1") != "0":
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
else:
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model
if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2", "gemma"]:
patch_scoped_linear_all_reduce(model)
from neural_compressor.torch.quantization import FP8Config, convert

config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
return model
29 changes: 3 additions & 26 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def __init__(
model = self.get_deepspeed_model(
model_id, dtype, revision
)
model = self.prepare_model_for_quantization(model)
model = hq_env.prepare_model_for_quantization(model)
else:
get_repo_root(model_id)

Expand All @@ -648,7 +648,7 @@ def __init__(
trust_remote_code=trust_remote_code,
**model_kwargs
)
model = self.prepare_model_for_quantization(model)
model = hq_env.prepare_model_for_quantization(model)
model = model.eval().to(device)

self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
Expand All @@ -667,7 +667,7 @@ def __init__(
"TORCH COMPILE", f'Torch compiling of model')
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})

model = self.setup_quantization(model)
model = hq_env.setup_quantization(model)

if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
raise ValueError(f"Model type {model.config.model_type} is not supported!")
Expand Down Expand Up @@ -799,29 +799,6 @@ def get_rope_scaling(self) -> Optional[Dict]:
'type': rope_scaling, 'factor': float(rope_factor)
}

def setup_quantization(self, model):
if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model

def prepare_model_for_quantization(self, model):
if hq_env.is_quantization_enabled:
if model.config.model_type == "llama":
self.patch_scoped_linear_all_reduce(model)
model = hq_env.prepare_model_for_quantization(model)
return model

def patch_scoped_linear_all_reduce(self, model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
self.patch_scoped_linear_all_reduce(module)

@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
Expand Down
42 changes: 6 additions & 36 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_tokenized(
device: torch.device,
is_warmup: bool = False,
) -> "VlmCausalLMBatch":

dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]

Expand Down Expand Up @@ -536,7 +536,7 @@ def __init__(
model = self.get_deepspeed_model(
model_class, model_id, dtype, revision
)
model = self.prepare_model_for_quantization(model)
model = hq_env.prepare_model_for_quantization(model)
else:
get_repo_root(model_id)

Expand All @@ -555,7 +555,7 @@ def __init__(
trust_remote_code=trust_remote_code,
**model_kwargs
)
model = self.prepare_model_for_quantization(model)
model = hq_env.prepare_model_for_quantization(model)
model = model.eval().to(device)

self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
Expand All @@ -565,13 +565,13 @@ def __init__(
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else:
if LAZY_MODE == 0:
if LAZY_MODE == 0:
# It is said that "keep_input_mutations" is safe for inference to be done
dbg_trace(
"TORCH COMPILE", f'Torch compiling of model')
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})

model = self.setup_quantization(model)
model = hq_env.setup_quantization(model)

if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
raise ValueError(f"Model type {model.config.model_type} is not supported!")
Expand Down Expand Up @@ -703,36 +703,6 @@ def get_rope_scaling(self) -> Optional[Dict]:
'type': rope_scaling, 'factor': float(rope_factor)
}

def setup_quantization(self, model):
if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
htorch.core.quantization._check_params_as_const(model)
htorch.core.hpu_initialize(model)
return model

def prepare_model_for_quantization(self, model):
if hq_env.is_quantization_enabled:
if model.config.model_type == "llama":
self.patch_scoped_linear_all_reduce(model)
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model

def finish_quantization_measurements(self, model):
if hq_env.is_quantization_enabled:
import habana_quantization_toolkit
habana_quantization_toolkit.finish_measurements(self.model)
return model

def patch_scoped_linear_all_reduce(self, model):
from deepspeed.module_inject.layers import LinearAllreduce
from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce
for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
self.patch_scoped_linear_all_reduce(module)

def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

Expand Down Expand Up @@ -906,7 +876,7 @@ def generate_token(
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
)
elif all([req.stopping_criteria.max_new_tokens == 1 for req in batch.requests]):
# Don't schedule next forward if max_new_tokens for all requests equals 1
# Don't schedule next forward if max_new_tokens for all requests equals 1
# - we've already generated the first and only needed token in the prefill phase
pass
else:
Expand Down

0 comments on commit 46b14e6

Please sign in to comment.