Skip to content

Commit

Permalink
Use FP8 KV cache when specified by compressed-tensors
Browse files Browse the repository at this point in the history
The compressed-tensors configuration can specify the configuration of
the KV cache as well. Use an FP8 KV cache when the configuration tells
us to do so (all other options and types are ignored for now).
  • Loading branch information
danieldk committed Nov 20, 2024
1 parent bd6e8b3 commit af231e1
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables

from compressed_tensors.compressors.model_compressors.model_compressor import (
QuantizationConfig,
)
from compressed_tensors.quantization import QuantizationType
from pydantic import ValidationError
import torch
import enum
import os
Expand All @@ -23,6 +28,7 @@
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.globals import ATTENTION
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
Expand Down Expand Up @@ -367,7 +373,8 @@ def get_model(
model_type = config_dict.get("model_type", None)

quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
if quantization_config is None:
quantization_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
Expand All @@ -381,12 +388,9 @@ def get_model(
logger.info, "Auto selecting quantization method compressed-tensors"
)
quantize = "compressed-tensors"

else:
log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# `compression_config` renamed to `quantization_config`; support retained for backward compatibility.
log_master(logger.info, "Auto selecting quantization method compressed-tensors")
quantize = "compressed-tensors"

if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
Expand All @@ -408,8 +412,31 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")

compressed_tensors_config = None
if quantize == "compressed-tensors":
try:
compressed_tensors_config = QuantizationConfig.model_validate(
quantization_config
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e

if kv_cache_dtype is None:
kv_cache_dtype = dtype
kv_cache_scheme = (
compressed_tensors_config.kv_cache_scheme
if isinstance(compressed_tensors_config, QuantizationConfig)
else None
)
if (
kv_cache_scheme is not None
and kv_cache_scheme.type == QuantizationType.FLOAT
and kv_cache_scheme.num_bits == 8
and SYSTEM == "cuda"
and ATTENTION == "flashinfer"
):
kv_cache_dtype = torch.float8_e4m3fn
else:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
Expand Down

0 comments on commit af231e1

Please sign in to comment.