diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index c024af6b..97ef7923 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -223,7 +223,7 @@ def prompt_attention( if query_heads != kv_heads: key = repeat_kv(key, int(query_heads // kv_heads)) value = repeat_kv(value, int(query_heads // kv_heads)) - softmax_mode = 'None' + softmax_mode = 'fast' recompute_mode = True attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True, scale, softmax_mode, recompute_mode,