Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft ukernel selection logic #1652

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" ON)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: nocommit

if(TORCHAO_BUILD_KLEIDIAI)
message(STATUS "Building with Arm KleidiAI library")
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@

#include <kai/kai_common.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h>

#ifdef TORCHAO_ENABLE_ARM_I8MM
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
#endif // TORCHAO_ENABLE_ARM_I8MM

#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h>

Expand Down Expand Up @@ -43,14 +50,16 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;

size_t activation_data_size(const Ukernel ukernel, int m, int k) {
size_t activation_data_size(int mr, int kr, int sr, int m, int k) {
auto lhs_packing = get_lhs_packing();
return lhs_packing.get_lhs_packed_size(
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
m, k, mr, kr, sr);
}

void prepare_activation_data(
const Ukernel ukernel,
int mr,
int kr,
int sr,
void* activation_data,
int m,
int k,
Expand All @@ -60,29 +69,31 @@ void prepare_activation_data(
lhs_pack.run_lhs_pack(
m,
k,
ukernel.get_mr(),
ukernel.get_kr(),
ukernel.get_sr(),
mr,
kr,
sr,
/*m_index_start=*/0,
activations,
/*lhs_stride=*/k * sizeof(float),
activation_data);
}

size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) {
size_t weight_data_size(int nr, int kr, int sr, int n, int k, int group_size) {
auto rhs_pack = get_rhs_packing();
return rhs_pack.get_rhs_packed_size(
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
nr,
kr,
sr,
group_size,
kai_datatype::kai_dt_bf16);
}

void prepare_weight_data(
const Ukernel ukernel,
int nr,
int kr,
int sr,
void* weight_data,
int n,
int k,
Expand Down Expand Up @@ -134,9 +145,9 @@ void prepare_weight_data(
/*groups=*/1,
n,
k,
ukernel.get_nr(),
ukernel.get_kr(),
ukernel.get_sr(),
nr,
kr,
sr,
group_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()),
/*rhs_stride=*/roundup(k, 2) / 2,
Expand All @@ -148,5 +159,99 @@ void prepare_weight_data(
/*qparams=*/&qparams);
}


size_t get_preferred_alignement() {
return 16;
}


#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \
return weight_data_size(nr, kr, sr, n, k, group_size); \
} \
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \
void* weight_data, \
int n, \
int k, \
int group_size, \
const int8_t* weight_qvals, \
const float* weight_scales, \
const int8_t* weight_zeros, \
const float* bias) { \
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \
}
Comment on lines +168 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torture!

Suggested change
#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \
return weight_data_size(nr, kr, sr, n, k, group_size); \
} \
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \
void* weight_data, \
int n, \
int k, \
int group_size, \
const int8_t* weight_qvals, \
const float* weight_scales, \
const int8_t* weight_zeros, \
const float* bias) { \
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \
}
#define DEFINE_WEIGHT_DATA_FN(nr, kr, sr)
\
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \
return weight_data_size(nr, kr, sr, n, k, group_size); \
} \
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \
void* weight_data, \
int n, \
int k, \
int group_size, \
const int8_t* weight_qvals, \
const float* weight_scales, \
const int8_t* weight_zeros, \
const float* bias) { \
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \
}


#define DEFINE_ACTIVATION_DATA_FNS(mr, kr, sr) \
size_t activation_data_size_mr##mr##_kr##kr##_sr##sr(int m, int k, int group_size) { \
(void)group_size; \
return activation_data_size(mr, kr, sr, m, k); \
} \
void prepare_activation_data_mr##mr##_kr##kr##_sr##sr(void* activation_data, int m, int k, int group_size, const float* activations) { \
(void)group_size; \
prepare_activation_data(mr, kr, sr, activation_data, m, k, activations); \
}

// TODO: first and suffix need to be better, e.g., parametrized by mr, nr, etc
// But I don't quite follow the naming convention for KleidiAI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming convention - kai_matmul_<fused_ops>_<dst_info>_<lhs_info>_<rhs_info>_<mr x nr x kacc>_<technology>_<feature>_<instruction>

#define DEFINE_KERNEL_FNS(first, suffix) \
namespace impl_##suffix { \
const Ukernel get_ukernel() { \
return Ukernel{ \
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
Comment on lines +196 to +200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing in first ir lhs as qai8dxp1x8 instead of 1x8 is better for (1) meaningful, (2) can cover channel wise 4b quant i.e. QC4W as well.

Also suffix should be three different things, rhs + output tile x kacc + isa, where rhs is not 8x8 but qsi4c32p4x8.

Suggested change
#define DEFINE_KERNEL_FNS(first, suffix) \
namespace impl_##suffix { \
const Ukernel get_ukernel() { \
return Ukernel{ \
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
#define DEFINE_KLEIDI_KERNEL_FN(lhs, suffix) \
namespace impl_##suffix { \
const Ukernel get_ukernel() { \
return Ukernel{ \
.get_m_step = kai_get_m_step_matmul_clamp_f32_##lhs##_##suffix, \

.get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \
.run_matmul = kai_run_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix \
}; \
} \
void kernel( \
float32_t* output, \
int output_m_stride, \
int m, \
int n, \
int k, \
int group_size, \
const void* weight_data, \
const void* activation_data, \
float clamp_min, \
float clamp_max) { \
get_ukernel().run_matmul( \
m, \
n, \
k, \
group_size, \
activation_data, \
weight_data, \
output, \
/*dst_stride_row=*/ output_m_stride * sizeof(float), \
/*dst_stride_col=*/ sizeof(float), \
/*clamp_min=*/std::numeric_limits<float>::lowest(), \
/*clamp_max=*/std::numeric_limits<float>::max() \
); \
} \
}



DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai this file is my draft reworking of the kleidiai integration. Weight packing and activation functions are no longer in isa kernel-specific namespaces because many kernels share the same routines.

Kernel functions and uconfigs are defined using macros. I would like DEFINE_KERNEL_FNS to be defined by things like mr, nr, instruction (dotprod/i8mm), but I don't follow follow the kleidi naming convention. So now it is indexed by first/suffix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Alternative to this would be to code gen these wrappers at compile-time but this is clean enough.

DEFINE_ACTIVATION_DATA_FNS(/*mr*/1, /*kr*/16, /*sr*/2)
DEFINE_KERNEL_FNS(1x8, 8x8_1x8x32_neon_dotprod)
DEFINE_KERNEL_FNS(1x8, 4x8_1x4x32_neon_dotprod)

#ifdef TORCHAO_ENABLE_ARM_I8MM
DEFINE_KERNEL_FNS(4x8, 4x8_8x4x32_neon_i8mm)
DEFINE_KERNEL_FNS(4x8, 8x8_4x8x32_neon_i8mm)
#endif // TORCHAO_ENABLE_ARM_I8MM

#undef DEFINE_WEIGHT_DATA_FNS
#undef DEFINE_ACTIVATION_DATA_FNS
#undef DEFINE_KERNEL_FNS

} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
} // namespace torchao::kernels::cpu::aarch64::kleidi
24 changes: 12 additions & 12 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ void check_embedding_inputs(
int packed_embedding_dim = (embedding_dim * weight_nbit) / 8;
TORCHAO_CHECK(
packed_weight_qvals.size(0) ==
(torchao::ops::PackedWeightsHeader::size() +
(torchao::ops::PackedWeightsFormat::serialized_size() +
(num_embeddings * packed_embedding_dim)),
"packed_weight_qvals is not the correct size");

// Check header
auto header = torchao::ops::PackedWeightsHeader::read(
// Check packed_weights_format
auto packed_weights_format = torchao::ops::PackedWeightsFormat::deserialize(
packed_weight_qvals.const_data_ptr());
TORCHAO_CHECK(
header ==
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
packed_weights_format ==
torchao::ops::embedding_xbit::get_packed_weights_format_universal(
weight_nbit,
/*min_value_chunk_size=*/32,
/*max_value_chunk_size=*/128),
Expand Down Expand Up @@ -151,7 +151,7 @@ Tensor embedding_out_cpu(
embedding_dim,
group_size,
packed_weight_qvals.const_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
torchao::ops::PackedWeightsFormat::serialized_size(),
weight_scales.const_data_ptr<float>(),
weight_zeros_ptr,
index);
Expand Down Expand Up @@ -222,23 +222,23 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");

auto out = torch::empty(
torchao::ops::PackedWeightsHeader::size() +
torchao::ops::PackedWeightsFormat::serialized_size() +
(num_embeddings * packed_embedding_dim))
.to(torch::kInt8);

auto header =
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
auto packed_weights_format =
torchao::ops::embedding_xbit::get_packed_weights_format_universal(
weight_nbit,
/*min_value_chunk_size=*/32,
/*max_value_chunk_size=*/128);
header.write(out.mutable_data_ptr());
packed_weights_format.serialize(out.mutable_data_ptr());

torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
#if defined(__aarch64__) || defined(__ARM_NEON)
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
weight_nbit>(
out.mutable_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
torchao::ops::PackedWeightsFormat::serialized_size(),
embedding_dim,
weight_qvals.const_data_ptr<int8_t>(),
idx);
Expand All @@ -261,7 +261,7 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) {
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
return torch::empty(
torchao::ops::PackedWeightsHeader::size() +
torchao::ops::PackedWeightsFormat::serialized_size() +
(num_embeddings * packed_embedding_dim))
.to("meta");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

namespace torchao::ops::embedding_xbit {

inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
inline torchao::ops::PackedWeightsFormat get_packed_weights_format_universal(
int weight_nbit,
int min_value_chunk_size,
int max_value_chunk_size,
int version = 1) {
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::embedding_xbit_universal,
return torchao::ops::PackedWeightsFormat(
torchao::ops::PackedWeightsType::embedding_xbit_universal,
{version,
weight_nbit,
min_value_chunk_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@ cmake_minimum_required(VERSION 3.19)

include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake)

add_compile_options(-Wno-unused-function -Wno-unused-variable) # For some reason cpuinfo package has unused functions/variables
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix it upstream?


include(FetchContent)
FetchContent_Declare(cpuinfo
GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git
GIT_TAG main) # need main for benchmark::benchmark
FetchContent_MakeAvailable(
cpuinfo)

find_package(Torch REQUIRED)
add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT
linear_8bit_act_xbit_weight.cpp
op_linear_8bit_act_xbit_weight_aten.cpp
)
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo)
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1)
Expand All @@ -37,4 +47,5 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS)
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}")
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64)
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo)
endif()
Loading
Loading