Skip to content

Commit

Permalink
Revert "[inductor][cpp] Add FlexAttention support for CPU inference (p…
Browse files Browse the repository at this point in the history
…ytorch#141453)"

This reverts commit 7edbde3.

Reverted pytorch#141453 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it is failing periodic NO_AVX2 ([comment](pytorch#141453 (comment)))
  • Loading branch information
pytorchmergebot committed Dec 9, 2024
1 parent a108b28 commit 7101dcf
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 1,752 deletions.
84 changes: 14 additions & 70 deletions aten/src/ATen/native/CPUBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,9 +1125,6 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
if (dtype == ScalarType::Half) {
static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16;
return fp16_support;
} else if (dtype == ScalarType::Float) {
static bool fp32_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx2;
return fp32_support;
} else if (dtype == ScalarType::BFloat16) {
static bool bf16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core;
return bf16_support;
Expand Down Expand Up @@ -1195,29 +1192,18 @@ void brgemm(
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const float* A,
const float* B,
float* C,
bool is_vnni) {

TORCH_CHECK(!is_vnni,
"Float Brgemm does not support vnni layout.");

const at::Half* A,
const at::Half* B,
float* C) {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (Brgemm::device_check(ScalarType::Float)) {
Brgemm::call<float, float, float>(
if (Brgemm::device_check(ScalarType::Half)) {
Brgemm::call<at::Half, at::Half, float>(
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
return;
}
#endif
// fallback path
auto beta = add_C ? 1 : 0;
gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
N, M, K, 1,
B, ld_b, A, ld_a,
beta, C, ld_c);
TORCH_CHECK(false,
"Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported");
}

void brgemm(
Expand All @@ -1230,64 +1216,22 @@ void brgemm(
const bool add_C,
const at::BFloat16* A,
const at::BFloat16* B,
float* C,
bool is_vnni) {
float* C) {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (is_vnni && Brgemm::device_check(ScalarType::BFloat16)) {
if (Brgemm::device_check(ScalarType::BFloat16)) {
Brgemm::call<at::BFloat16, at::BFloat16, float>(
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
return;
}
#endif
// fallback path
TORCH_CHECK(!is_vnni,
"BFloat16 Brgemm VNNI format is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported");
auto beta = add_C ? 1 : 0;
gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
N, M, K, 1,
B, ld_b, A, ld_a,
beta, C, ld_c);
}

void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const at::Half* A,
const at::Half* B,
float* C,
bool is_vnni) {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (is_vnni && Brgemm::device_check(ScalarType::Half)) {
Brgemm::call<at::Half, at::Half, float>(
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
return;
}
#endif
// fallback path
TORCH_CHECK(!is_vnni,
"Half Brgemm VNNI format is only supported on X64 when oneDNN ukernel is enabled and `amx_fp16` is supported");
auto beta = add_C ? 1 : 0;
gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
N, M, K, 1,
B, ld_b, A, ld_a,
beta, C, ld_c);
TORCH_CHECK(false,
"BFloat16 Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512 is supported");
}

void brgemm_release(bool is_vnni) {
void brgemm_release() {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (is_vnni) {
dnnl::ukernel::brgemm::release_hw_context();
Brgemm::get_current() = nullptr;
}
dnnl::ukernel::brgemm::release_hw_context();
Brgemm::get_current() = nullptr;
#endif
}

Expand Down
23 changes: 4 additions & 19 deletions aten/src/ATen/native/CPUBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ TORCH_API void brgemm(
const bool add_C,
const at::Half* A,
const at::Half* B,
float* C,
bool is_vnni = true);
float* C);

TORCH_API void brgemm(
int64_t M,
Expand All @@ -217,24 +216,10 @@ TORCH_API void brgemm(
const bool add_C,
const at::BFloat16* A,
const at::BFloat16* B,
float* C,
bool is_vnni = true);

TORCH_API void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const float* A,
const float* B,
float* C,
bool is_vnni = false);
float* C);

// Release brgemm hardware context
TORCH_API void brgemm_release(bool is_vnni = true);
TORCH_API void brgemm_release();

// Pack B matrix to get better performance if needed
void pack(
Expand All @@ -248,6 +233,6 @@ void pack(
void* out);

// Whether pack is supported in the platform.
TORCH_API bool could_pack(ScalarType dt_in);
bool could_pack(ScalarType dt_in);

} // namespace at::native::cpublas
Loading

0 comments on commit 7101dcf

Please sign in to comment.