-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
draft ukernel selection logic #1652
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,6 +15,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <kai/kai_common.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#ifdef TORCHAO_ENABLE_ARM_I8MM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#endif // TORCHAO_ENABLE_ARM_I8MM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -43,14 +50,16 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t activation_data_size(const Ukernel ukernel, int m, int k) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t activation_data_size(int mr, int kr, int sr, int m, int k) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto lhs_packing = get_lhs_packing(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return lhs_packing.get_lhs_packed_size( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
m, k, mr, kr, sr); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void prepare_activation_data( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const Ukernel ukernel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int mr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int kr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int sr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void* activation_data, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int m, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -60,29 +69,31 @@ void prepare_activation_data( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
lhs_pack.run_lhs_pack( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
m, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_mr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_kr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_sr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*m_index_start=*/0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
activations, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*lhs_stride=*/k * sizeof(float), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
activation_data); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t weight_data_size(int nr, int kr, int sr, int n, int k, int group_size) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto rhs_pack = get_rhs_packing(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return rhs_pack.get_rhs_packed_size( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_nr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_kr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_sr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
group_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kai_datatype::kai_dt_bf16); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void prepare_weight_data( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const Ukernel ukernel, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int nr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int kr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int sr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void* weight_data, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int n, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -134,9 +145,9 @@ void prepare_weight_data( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*groups=*/1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_nr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_kr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ukernel.get_sr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
kr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
group_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*rhs_stride=*/roundup(k, 2) / 2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -148,5 +159,99 @@ void prepare_weight_data( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*qparams=*/&qparams); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t get_preferred_alignement() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return 16; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return weight_data_size(nr, kr, sr, n, k, group_size); \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void* weight_data, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int n, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int k, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int group_size, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int8_t* weight_qvals, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const float* weight_scales, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const int8_t* weight_zeros, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const float* bias) { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+168
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Torture!
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#define DEFINE_ACTIVATION_DATA_FNS(mr, kr, sr) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t activation_data_size_mr##mr##_kr##kr##_sr##sr(int m, int k, int group_size) { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(void)group_size; \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return activation_data_size(mr, kr, sr, m, k); \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void prepare_activation_data_mr##mr##_kr##kr##_sr##sr(void* activation_data, int m, int k, int group_size, const float* activations) { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(void)group_size; \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
prepare_activation_data(mr, kr, sr, activation_data, m, k, activations); \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// TODO: first and suffix need to be better, e.g., parametrized by mr, nr, etc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// But I don't quite follow the naming convention for KleidiAI | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naming convention - |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#define DEFINE_KERNEL_FNS(first, suffix) \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
namespace impl_##suffix { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const Ukernel get_ukernel() { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return Ukernel{ \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+196
to
+200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing in first ir lhs as Also
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.run_matmul = kai_run_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}; \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void kernel( \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
float32_t* output, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int output_m_stride, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int m, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int n, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int k, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int group_size, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const void* weight_data, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
const void* activation_data, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
float clamp_min, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
float clamp_max) { \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
get_ukernel().run_matmul( \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
m, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
k, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
group_size, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
activation_data, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
weight_data, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output, \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*dst_stride_row=*/ output_m_stride * sizeof(float), \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*dst_stride_col=*/ sizeof(float), \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*clamp_min=*/std::numeric_limits<float>::lowest(), \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
/*clamp_max=*/std::numeric_limits<float>::max() \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
); \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @digantdesai this file is my draft reworking of the kleidiai integration. Weight packing and activation functions are no longer in isa kernel-specific namespaces because many kernels share the same routines. Kernel functions and uconfigs are defined using macros. I would like DEFINE_KERNEL_FNS to be defined by things like mr, nr, instruction (dotprod/i8mm), but I don't follow follow the kleidi naming convention. So now it is indexed by first/suffix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good. Alternative to this would be to code gen these wrappers at compile-time but this is clean enough. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DEFINE_ACTIVATION_DATA_FNS(/*mr*/1, /*kr*/16, /*sr*/2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DEFINE_KERNEL_FNS(1x8, 8x8_1x8x32_neon_dotprod) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DEFINE_KERNEL_FNS(1x8, 4x8_1x4x32_neon_dotprod) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#ifdef TORCHAO_ENABLE_ARM_I8MM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DEFINE_KERNEL_FNS(4x8, 4x8_8x4x32_neon_i8mm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DEFINE_KERNEL_FNS(4x8, 8x8_4x8x32_neon_i8mm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#endif // TORCHAO_ENABLE_ARM_I8MM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#undef DEFINE_WEIGHT_DATA_FNS | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#undef DEFINE_ACTIVATION_DATA_FNS | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#undef DEFINE_KERNEL_FNS | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} // namespace torchao::kernels::cpu::aarch64::kleidi |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,23 @@ cmake_minimum_required(VERSION 3.19) | |
|
||
include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake) | ||
|
||
add_compile_options(-Wno-unused-function -Wno-unused-variable) # For some reason cpuinfo package has unused functions/variables | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix it upstream? |
||
|
||
include(FetchContent) | ||
FetchContent_Declare(cpuinfo | ||
GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git | ||
GIT_TAG main) # need main for benchmark::benchmark | ||
FetchContent_MakeAvailable( | ||
cpuinfo) | ||
|
||
find_package(Torch REQUIRED) | ||
add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT | ||
linear_8bit_act_xbit_weight.cpp | ||
op_linear_8bit_act_xbit_weight_aten.cpp | ||
) | ||
target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp) | ||
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64) | ||
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo) | ||
target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") | ||
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") | ||
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1) | ||
|
@@ -37,4 +47,5 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) | |
target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) | ||
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") | ||
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64) | ||
target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo) | ||
endif() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: nocommit