Skip to content

Commit

Permalink
Merge branch 'habana_main' into michalkuligowski-qwen-sync_after_weig…
Browse files Browse the repository at this point in the history
…htload
  • Loading branch information
michalkuligowski authored Jan 10, 2025
2 parents 467b333 + 73aaf71 commit 20a4450
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,12 @@ def apply_fp8_linear(

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
if current_platform.is_hpu():
#hpu does not support torch._scaled_mm (SW-197036)
output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight,
False, None, input.dtype,
x_scale, weight_scale, None,
False)
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
Expand Down

0 comments on commit 20a4450

Please sign in to comment.