Skip to content

Support INT8 SDPA template for CPU #2148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions test/prototype/inductor/test_int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,14 @@ def _check_common(
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
"torchao.scaled_dot_product_int8",
source_code,
self.assertTrue(
any(
op_name in source_code
for op_name in [
"scaled_dot_product_int8",
"cpp_fused_quantize_per_tensor",
]
)
)

# some tests configured with very low dropout where we still want to check equality
Expand Down
27 changes: 15 additions & 12 deletions torchao/csrc/cpu/int8_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ namespace torchao {

namespace {

inline double calculate_scale(
inline c10::SymFloat calculate_scale(
const at::Tensor& query,
double scale) {
return scale == 0.0 ? 1.0 / std::sqrt(query.size(-1)) : scale;
std::optional<double> scale) {
const auto softmax_scale = scale.has_value()
? scale.value()
: (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
return c10::SymFloat(softmax_scale);
}

#ifdef CPU_CAPABILITY_AVX512
Expand Down Expand Up @@ -736,7 +739,7 @@ sdpa_int8_fused_kernel_impl(
double dropout_p,
bool is_causal,
std::optional<at::Tensor> attention_mask,
double scale,
std::optional<double> scale,
float q_scale,
int32_t q_zp,
float k_scale,
Expand All @@ -758,7 +761,7 @@ sdpa_int8_fused_kernel_impl(
at::Tensor value = v.transpose(1, 2);

using accum_t = float;
accum_t scaling_factor = calculate_scale(query, scale);
accum_t scaling_factor = calculate_scale(query, scale).expect_float();
int block_64 = 64;
auto u8_dt = at::ScalarType::Byte;

Expand Down Expand Up @@ -1103,7 +1106,7 @@ sdpa_int8_fused_kernel_impl(
at::native::cpublas::brgemm(
qSplitSize, block_64, av_gemm_K,
av_gemm_K, // lda
rndHeadSize, //block_64, //ldb
rndHeadSize, //ldb
rndHeadSize, //ldc
s != 0,
qk_reduced_data + s * qk_reduce_strideL,
Expand Down Expand Up @@ -1164,7 +1167,7 @@ sdpa_int8_fused_kernel_impl(
double dropout_p,
bool is_causal,
std::optional<at::Tensor> attention_mask,
double scale,
std::optional<double> scale,
float q_scale,
int32_t q_zp,
float k_scale,
Expand All @@ -1186,7 +1189,7 @@ sdpa_int8_fused_kernel_impl(
at::Tensor value = v.transpose(1, 2);

using accum_t = float;
accum_t scaling_factor = calculate_scale(query, scale);
accum_t scaling_factor = calculate_scale(query, scale).expect_float();
int block_64 = 64;
auto u8_dt = at::ScalarType::Byte;

Expand Down Expand Up @@ -1631,7 +1634,7 @@ sdpa_int8_fused_kernel_impl(
double dropout_p,
bool is_causal,
std::optional<at::Tensor> attn_mask,
double scale,
std::optional<double> scale,
float q_scale,
int32_t q_zp,
float k_scale,
Expand Down Expand Up @@ -1689,7 +1692,7 @@ void sdpa_int8_fused_kernel(
double dropout_p,
bool is_causal,
std::optional<at::Tensor> attn_mask,
double scale,
std::optional<double> scale,
float q_scale,
int32_t q_zp,
float k_scale,
Expand Down Expand Up @@ -1796,7 +1799,7 @@ at::Tensor sdpa_int8_math_kernel(
double dropout_p,
bool is_causal,
std::optional<at::Tensor> attn_mask,
double scale,
std::optional<double> scale,
float q_scale,
int32_t q_zp,
float k_scale,
Expand Down Expand Up @@ -1839,7 +1842,7 @@ at::Tensor _scaled_dot_product_int8_cpu(
std::optional<at::Tensor> attn_mask,
double dropout_p,
bool is_causal,
double scale,
std::optional<double> scale,
double q_scale,
int64_t q_zp,
double k_scale,
Expand Down
6 changes: 3 additions & 3 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
tags=[torch._C.Tag.needs_fixed_stride_order],
)
lib.define(
"scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float scale=0.0, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor"
"scaled_dot_product_int8(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, float? scale=None, float q_scale=1.0, int q_zp=0, float k_scale=1.0, int k_zp=0, float v_scale=1.0, int v_zp=0, float a_scale=1.0, int a_zp=0, float o_scale=1.0, int o_zp=0) -> Tensor"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need to change the OP schema? It seems not related to this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's an additional modification.
This is to keep consistent with the SDPA schema in pytorch func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor.

)


Expand Down Expand Up @@ -169,7 +169,7 @@ def scaled_dot_product_int8(
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
scale: Optional[float] = None,
q_scale: float = 1.0,
q_zp: int = 0,
k_scale: float = 1.0,
Expand Down Expand Up @@ -235,7 +235,7 @@ def _(
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = 0.0,
scale: Optional[float] = None,
q_scale: float = 1.0,
q_zp: int = 0,
k_scale: float = 1.0,
Expand Down
5 changes: 5 additions & 0 deletions torchao/prototype/inductor/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .cpp_int8_sdpa_template import CppInt8SdpaTemplate

__all__ = [
"CppInt8SdpaTemplate",
]
Loading
Loading