Skip to content

Commit

Permalink
Fixup some imports
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Feb 4, 2025
1 parent 17b5f57 commit 27decc5
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationType

from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader

try:
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError:
else:
marlin_kernels = None


Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
)
from text_generation_server.utils.log import log_once

try:
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError:
else:
marlin_kernels = None

try:
# TODO: needs to be ported over to MoE and used on CUDA.
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
except ImportError:
w8a8_block_fp8_matmul = None
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/layers/marlin/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
_check_marlin_kernels,
permute_scales,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel

try:
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError:
else:
marlin_kernels = None


Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/layers/marlin/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader

try:
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError:
else:
marlin_kernels = None


try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/layers/marlin/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import torch.nn as nn

from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader

try:
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError:
else:
marlin_kernels = None


Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/marlin/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel

try:
if SYSTEM == "cuda":
marlin_kernels = load_kernel(
module="quantization", repo_id="kernels-community/quantization"
)
except ImportError:
else:
marlin_kernels = None

try:
Expand Down

0 comments on commit 27decc5

Please sign in to comment.