-
Notifications
You must be signed in to change notification settings - Fork 541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update model_loader deps and qqq quantization deps #2220
Open
HandH1998
wants to merge
1
commit into
sgl-project:main
Choose a base branch
from
HandH1998:sgl_model_loader
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
"Fp8LinearMethod", | ||
"MarlinLinearMethod", | ||
"GPTQLinearMethod", | ||
"QQQLinearMethod", | ||
] | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/quantization/qqq.py | ||
|
||
import logging | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
from torch.nn.parameter import Parameter | ||
from torchao.ops import marlin_qqq_gemm | ||
from torchao.quantization.utils import dynamically_quantize_per_channel | ||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase | ||
from vllm.model_executor.parameter import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part also needs to be migrated. |
||
BasevLLMParameter, | ||
ChannelQuantScaleParameter, | ||
GroupQuantScaleParameter, | ||
PackedvLLMParameter, | ||
) | ||
|
||
from sglang.srt.layers.quantization.base_config import QuantizationConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
MARLIN_QQQ_TILE = 16 | ||
MARLIN_QQQ_MIN_THREAD_N = 64 | ||
MARLIN_QQQ_MIN_THREAD_K = 128 | ||
MARLIN_QQQ_MAX_PARALLEL = 16 | ||
|
||
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] | ||
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] | ||
MARLIN_QQQ_SUPPORTED_SYM = [True] | ||
|
||
|
||
class QQQConfig(QuantizationConfig): | ||
"""Config class for QQQ | ||
|
||
Reference: https://arxiv.org/pdf/2406.09904 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
weight_bits: int, | ||
group_size: int, | ||
is_sym: bool = True, | ||
) -> None: | ||
self.weight_bits = weight_bits | ||
self.group_size = group_size | ||
self.is_sym = is_sym | ||
|
||
# Verify | ||
if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: | ||
raise ValueError( | ||
f"QQQ does not support weight_bits = {self.weight_bits}. " | ||
f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " | ||
"are supported." | ||
) | ||
if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: | ||
raise ValueError( | ||
f"QQQ does not support group_size = {self.group_size}. " | ||
f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " | ||
"are supported." | ||
) | ||
if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: | ||
raise ValueError( | ||
f"QQQ does not support is_sym = {self.is_sym}. " | ||
f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported." | ||
) | ||
|
||
# 4 Bits packed into 32 bit datatype. | ||
self.pack_factor = 32 // self.weight_bits | ||
|
||
# Tile size used by QQQ kernels. | ||
self.tile_size = MARLIN_QQQ_TILE | ||
|
||
# Min out_features dim | ||
self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N | ||
|
||
# Min in_features dim | ||
self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K | ||
|
||
# Max parallel problems to solve at once (improves large | ||
# batch performance) | ||
self.max_parallel = MARLIN_QQQ_MAX_PARALLEL | ||
|
||
# Permutation length used by the QQQ kernels. | ||
self.perm_len = 1024 | ||
|
||
def __repr__(self) -> str: | ||
return "QQQConfig(weight_bits={}, group_size={})".format( | ||
self.weight_bits, self.group_size | ||
) | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
return "qqq" | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
return [torch.half] | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
return 80 | ||
|
||
@classmethod | ||
def get_config_filenames(cls) -> List[str]: | ||
"""List of filenames to search for in the model directory.""" | ||
return [ | ||
"quant_config.json", | ||
"quantize_config.json", | ||
] | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": | ||
weight_bits = cls.get_from_keys(config, ["wbits"]) | ||
group_size = cls.get_from_keys(config, ["group_size"]) | ||
return cls(weight_bits, group_size) | ||
|
||
def get_quant_method( | ||
self, layer: torch.nn.Module, prefix: str | ||
) -> Optional["QQQLinearMethod"]: | ||
if isinstance(layer, LinearBase): | ||
return QQQLinearMethod(self) | ||
return None | ||
|
||
def get_scaled_act_names(self) -> List[str]: | ||
return [] | ||
|
||
|
||
class QQQLinearMethod(LinearMethodBase): | ||
"""Linear method for QQQ. | ||
|
||
Args: | ||
quant_config: The QQQ quantization config. | ||
""" | ||
|
||
def __init__(self, quant_config: QQQConfig): | ||
self.quant_config = quant_config | ||
|
||
def create_weights( | ||
self, | ||
layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_partition_sizes: List[int], | ||
input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype, | ||
**extra_weight_attrs, | ||
): | ||
weight_loader = extra_weight_attrs["weight_loader"] | ||
if params_dtype != torch.float16: | ||
raise ValueError( | ||
f"The params dtype must be float16, but got {params_dtype}" | ||
) | ||
|
||
# Validate output_size_per_partition | ||
output_size_per_partition = sum(output_partition_sizes) | ||
if output_size_per_partition % self.quant_config.min_n_threads != 0: | ||
raise ValueError( | ||
f"Weight output_size_per_partition = " | ||
f"{output_size_per_partition} is not divisible by " | ||
f"min_n_threads = {self.quant_config.min_n_threads}." | ||
) | ||
if output_size_per_partition % self.quant_config.pack_factor != 0: | ||
raise ValueError( | ||
f"Weight output_size_per_partition = " | ||
f"{output_size_per_partition} is not divisible by " | ||
f"pack_factor = {self.quant_config.pack_factor}." | ||
) | ||
|
||
# Validate input_size_per_partition | ||
if input_size_per_partition % self.quant_config.min_k_threads != 0: | ||
raise ValueError( | ||
f"Weight input_size_per_partition = " | ||
f"{input_size_per_partition} is not divisible by " | ||
f"min_k_threads = {self.quant_config.min_k_threads}." | ||
) | ||
if ( | ||
self.quant_config.group_size != -1 | ||
and input_size_per_partition % self.quant_config.group_size != 0 | ||
): | ||
raise ValueError( | ||
f"Weight input_size_per_partition = " | ||
f"{input_size_per_partition} is not divisible by " | ||
f"group_size = {self.quant_config.group_size}." | ||
) | ||
|
||
# Check that we have at least 4 tiles horizontally in the shard | ||
num_tiles_per_perm = self.quant_config.perm_len // ( | ||
self.quant_config.tile_size**2 | ||
) | ||
if output_size_per_partition % num_tiles_per_perm != 0: | ||
raise ValueError("Each permutation group must reside on the same gpu") | ||
|
||
# Quantized 4Bit weights packed into Int32. | ||
qweight = PackedvLLMParameter( | ||
data=torch.empty( | ||
input_size_per_partition // self.quant_config.tile_size, | ||
output_size_per_partition | ||
* self.quant_config.tile_size | ||
// self.quant_config.pack_factor, | ||
device="cuda", | ||
dtype=torch.int32, | ||
), | ||
input_dim=0, | ||
output_dim=1, | ||
packed_dim=1, | ||
packed_factor=self.quant_config.pack_factor, | ||
marlin_tile_size=self.quant_config.tile_size, | ||
weight_loader=weight_loader, | ||
) | ||
|
||
s_channel = ChannelQuantScaleParameter( | ||
data=torch.empty( | ||
1, | ||
output_size_per_partition, | ||
device="cuda", | ||
dtype=torch.float, | ||
), | ||
weight_loader=weight_loader, | ||
output_dim=1, | ||
) | ||
|
||
if self.quant_config.group_size == -1: | ||
s_group_data = torch.tensor( | ||
[], | ||
device="cuda", | ||
dtype=torch.half, | ||
) | ||
else: | ||
s_group_data = torch.empty( | ||
input_size_per_partition // self.quant_config.group_size, | ||
output_size_per_partition, | ||
device="cuda", | ||
dtype=torch.half, | ||
) | ||
|
||
s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} | ||
|
||
if self.quant_config.group_size == -1: | ||
s_group = BasevLLMParameter(**s_group_attr) | ||
else: | ||
s_group = GroupQuantScaleParameter( | ||
output_dim=1, input_dim=0, **s_group_attr | ||
) | ||
|
||
# Allocate workspace (Used for internal locking mechanism) | ||
max_workspace_size = ( | ||
output_size_per_partition // self.quant_config.min_n_threads | ||
) * self.quant_config.max_parallel | ||
|
||
workspace = BasevLLMParameter( | ||
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), | ||
weight_loader=weight_loader, | ||
) | ||
|
||
layer.register_parameter("B", qweight) | ||
layer.register_parameter("s_channel", s_channel) | ||
layer.register_parameter("s_group", s_group) | ||
layer.register_parameter("workspace", workspace) | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
# required by torch.compile | ||
layer.B = Parameter(layer.B.data, requires_grad=False) | ||
layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) | ||
layer.s_group = Parameter(layer.s_group.data, requires_grad=False) | ||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False) | ||
|
||
def apply( | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
qweight = layer.B | ||
s_ch = layer.s_channel | ||
s_group = layer.s_group | ||
workspace = layer.workspace | ||
|
||
x_2d = x.view(-1, x.shape[-1]) | ||
|
||
size_m = x_2d.shape[0] | ||
size_k = x_2d.shape[1] | ||
size_n = s_ch.shape[1] | ||
|
||
x_int8, s_tok, _ = dynamically_quantize_per_channel( | ||
x_2d, quant_min=-127, quant_max=127, target_dtype=torch.int8 | ||
) | ||
# TODO(HandH1998): As the `dynamically_quantize_per_channel` function in torchao doesn't support defining the `scale_dtype`, | ||
# we have to convert `s_tok` to `torch.float32`, which is required by `marlin_qqq_gemm`. Remove it when torchao supports defining the `scale_dtype`. | ||
s_tok = s_tok.to(torch.float32) | ||
|
||
output_2d = marlin_qqq_gemm( | ||
x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k | ||
) | ||
|
||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) | ||
|
||
if bias is not None: | ||
output.add_(bias) # In-place add | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be due to the version, the current release of torchao (v0.6.1) does not include qqq.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can introduce qqq in the next PR after torchao releases a new version, how about that @HandH1998
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok