From e01afcb443304f1587d3816b9439f6c17f01edb3 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sun, 2 Feb 2025 19:42:20 -0800 Subject: [PATCH 1/6] draft ukenrel selection logic --- .../CMakeLists.txt | 11 ++ .../kernel_selector.h | 141 ++++++++++++++++++ .../linear_8bit_act_xbit_weight.cpp | 16 +- .../linear_8bit_act_xbit_weight.h | 13 +- .../op_linear_8bit_act_xbit_weight-impl.h | 89 ++--------- .../packed_weights_header.h | 28 ++++ 6 files changed, 210 insertions(+), 88 deletions(-) create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt index 91fcf60621..72d6affdc1 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt @@ -8,6 +8,15 @@ cmake_minimum_required(VERSION 3.19) include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake) +add_compile_options(-Wno-unused-function -Wno-unused-variable) # For some reason cpuinfo package has unused functions/variables + +include(FetchContent) +FetchContent_Declare(cpuinfo + GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git + GIT_TAG main) # need main for benchmark::benchmark +FetchContent_MakeAvailable( + cpuinfo) + find_package(Torch REQUIRED) add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT linear_8bit_act_xbit_weight.cpp @@ -15,6 +24,7 @@ add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT ) target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64) +target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo) target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1) @@ -37,4 +47,5 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64) + target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo) endif() diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h new file mode 100644 index 0000000000..90372d9262 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -0,0 +1,141 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +namespace { +// If needed in future, we can add uarch to the UKernelConfigCacheKey if needed +using UKernelConfigCacheKey = torchao::ops::PackedWeightsHeader; +struct UKernelConfigCacheKeyHash { + std::size_t operator()(const UKernelConfigCacheKey& k) const { + std::size_t hash = std::hash()(static_cast(k.format)); + for (int i = 0; i < k.params.size(); i++) { + hash ^= std::hash()(k.params[i]); + } + return hash; + } +}; +using UKernelConfigCacheType = std::unordered_map; +} + +template +void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int version) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_header_universal(weight_nbit, has_weight_zeros, has_bias, nr, kr); + + if (cpuinfo_has_arm_neon_dot()) { + if (nr == 8 && kr == 16) { + ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::activation_data_size, + /*preferred_activation_data_alignment*/16, + &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::prepare_activation_data, + &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::weight_data_size, + /*preferred_weight_data_alignment*/16, + &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::prepare_weight_data, + /*nr*/8, + {{{/*mr*/1, &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::kernel}}} + }; + return; + } + } + + throw std::runtime_error("Cannot register ukernel_config for packing format ukernel because no implementation is available on this platform"); +} + + +template +void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsHeader header) { + auto it = ukernel_config_cache.find(header); + if (it != ukernel_config_cache.end()) { + throw std::runtime_error("UKernel config already registered"); + } + + switch (header.format) { + case torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal: { + auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params(header); + if (packing_params.weight_nbit != weight_nbit) { + throw std::runtime_error("Packed weights are not in the expected format"); + } + if (packing_params.has_weight_zeros != has_weight_zeros) { + throw std::runtime_error("Packed weights are not in the expected format"); + } + if (packing_params.has_bias) { + register_ukernel_config_universal(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version); + } else { + register_ukernel_config_universal(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version); + } + break; + } + default: + throw std::runtime_error("No implementation for packed weights format"); + } + + it = ukernel_config_cache.find(header); + if (it == ukernel_config_cache.end()) { + throw std::runtime_error("UKernel config did not register"); + } +} + + +template +torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) { + static UKernelConfigCacheType ukernel_config_cache; + + // Check cache + auto it = ukernel_config_cache.find(header); + if (it != ukernel_config_cache.end()) { + std::cout << "UKERNEL CONFIG FROM CACHE: " << std::endl; + return it->second; + } + + std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl; + register_ukernel_config(ukernel_config_cache, header); + it = ukernel_config_cache.find(header); + assert(it != ukernel_config_cache.end()); + auto config = it->second; + return config; +} + +// TODO: make packing format and header separate concepts +// Header is a serialized packing format +template +torchao::ops::PackedWeightsHeader select_header(std::optional target = std::nullopt) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + + // Select KleidiAI header + if (!target || *target == "kleidi_ai") { + if (weight_nbit == 4 && !has_weight_zeros) { + } + } + + // Select universal header + if (!target || *target == "universal") { + if (cpuinfo_has_arm_neon_dot()) { + return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_header_universal(weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16, /*version*/1); + } + } + + throw std::runtime_error("No header was selected"); +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 4130d72e32..caadac2cbc 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -98,7 +98,7 @@ LinearTilingParams get_default_linear_tiling_params( TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); tiling_params.mc_by_mr = 1; - int mc = tiling_params.mc_by_mr * ukernel_config.mr; + int mc = tiling_params.mc_by_mr * ukernel_config.kernels[0].mr; int num_mc_panels = (m + mc - 1) / mc; int numerator = n * num_mc_panels; @@ -128,7 +128,7 @@ get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( int k, int group_size) { return ukernel_config.activation_data_size_fn( - tiling_params.mc_by_mr * ukernel_config.mr, k, group_size); + tiling_params.mc_by_mr * ukernel_config.kernels[0].mr, k, group_size); } inline size_t @@ -158,7 +158,7 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( float clamp_min, float clamp_max) { int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); + int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.kernels[0].mr); int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; @@ -184,7 +184,7 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int output_offset = m_idx * n + n_idx; int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.kernel_fn( + ukernel_config.kernels[0].kernel_fn( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, @@ -214,10 +214,10 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( const float* activations, float clamp_min, float clamp_max) { - int mr = ukernel_config.mr; + int mr = ukernel_config.kernels[0].mr; int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int mc = std::min(m, tiling_params.mc_by_mr * mr); + int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; @@ -254,7 +254,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( int output_offset = m_idx * n + n_idx; int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.kernel_fn( + ukernel_config.kernels[0].kernel_fn( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index bcf9446f1b..7645216564 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -40,6 +41,11 @@ struct UKernelConfig { const void* activation_data, float clamp_min, float clamp_max); + + struct kernel { + int mr{0}; + kernel_fn_type kernel_fn{nullptr}; + }; activation_data_size_fn_type activation_data_size_fn{nullptr}; // preferred_activation_data_alignment is only a preferred alignment for @@ -57,11 +63,10 @@ struct UKernelConfig { size_t preferred_weight_data_alignment{0}; prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - kernel_fn_type kernel_fn{nullptr}; - int mr{0}; + // kernel_fn_type kernel_fn{nullptr}; + // int mr{0}; int nr{0}; - - torchao::ops::PackedWeightsHeader packed_weights_header; + std::array kernels; }; // Pack weight functions diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 52c3bbae12..ec0d3aa217 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -13,72 +13,13 @@ #include #include #include +#include #include #include #include namespace { -// This selects a UkernelConfig based on the packed weight header -template -inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config(torchao::ops::PackedWeightsHeader header) { - torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config; - - switch (header.format) { -#if defined(__aarch64__) || defined(__ARM_NEON) - case torchao::ops::PackedWeightsFormat:: - linear_8bit_act_xbit_weight_universal: - namespace ukernel - = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - - // Check packing params match the kernel - TORCHAO_CHECK( - header == - torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal( - weight_nbit, - has_weight_zeros, - has_bias, - /*nr=*/8, - /*kr=*/16), - "Packing params do not match what kernel supports"); - - config.packed_weights_header = header; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel:: - prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - return config; - break; -#endif // defined(__aarch64__) || defined(__ARM_NEON) - default: - TORCHAO_CHECK(false, "Unsupported packed weights format"); - } -} - -template -inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config() { - auto header = torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal( - weight_nbit, has_weight_zeros, has_bias, /*nr=*/8, /*kr=*/16); - return get_ukernel_config( - header); -} - #ifdef USE_ATEN template Tensor pack_weights_cpu( @@ -127,11 +68,10 @@ Tensor pack_weights_cpu( using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config< - weight_nbit, - has_weight_zeros, - has_bias, - false /*has_clamp*/>(); + +auto header = select_header(); +auto ukernel_config = select_ukernel_config(header); + auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( ukernel_config, n, /*target_panels_per_thread=*/1); @@ -139,7 +79,7 @@ Tensor pack_weights_cpu( get_packed_weight_data_size(ukernel_config, n, k, group_size); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); - ukernel_config.packed_weights_header.write( + header.write( packed_weights.mutable_data_ptr()); pack_weight_data_operator( ukernel_config, @@ -206,11 +146,10 @@ Tensor pack_weights_meta( using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config< - weight_nbit, - has_weight_zeros, - has_bias, - false /*has_clamp*/>(); + + auto header = select_header(); + auto ukernel_config = select_ukernel_config(header); + auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); @@ -311,17 +250,14 @@ Tensor linear_out_cpu( auto header = torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); - auto ukernel_config = get_ukernel_config< - weight_nbit, - has_weight_zeros /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>(header); + auto ukernel_config = select_ukernel_config(header); auto linear_tiling_params = get_default_linear_tiling_params( ukernel_config, m, n, /*target_tiles_per_thread=*/5); + auto linear_scheduling_policy = LinearTileSchedulingPolicy::single_mc_parallel_nc; @@ -332,6 +268,7 @@ Tensor linear_out_cpu( m, k, group_size); + std::vector activation_data_buffer(activation_data_buffer_size); linear_operator( diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h index d86a429461..9a0a1416b2 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h @@ -35,4 +35,32 @@ inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( 0}); } +struct UniversalPackingParams { + int version; + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + int nr; + int kr; +}; + +inline UniversalPackingParams get_universal_packing_params(torchao::ops::PackedWeightsHeader header) { + if (header.format != torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal) { + throw std::runtime_error("Packed weights are not in universal packing format."); + } + return UniversalPackingParams{ + header.params[0], + header.params[1], + static_cast(header.params[2]), + static_cast(header.params[3]), + header.params[4], + header.params[5], + }; +} + + + + + + } // namespace torchao::ops::linear_8bit_act_xbit_weight From eaa21f2338fcfa369709422d7a5d8ef3734b37a2 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:16:15 -0800 Subject: [PATCH 2/6] up --- torchao/experimental/CMakeLists.txt | 2 +- .../embedding_xbit/op_embedding_xbit-impl.h | 24 +-- .../embedding_xbit/packed_weights_header.h | 6 +- .../kernel_selector.h | 200 +++++++++++++++--- .../op_linear_8bit_act_xbit_weight-impl.h | 39 ++-- .../packed_weights_header.h | 67 +++++- .../experimental/ops/packed_weights_header.h | 50 +++-- 7 files changed, 293 insertions(+), 95 deletions(-) diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index a90cc5884a..5dc42aad01 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -22,7 +22,7 @@ if(NOT TORCHAO_INCLUDE_DIRS) set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..) endif() -option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) +option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" ON) if(TORCHAO_BUILD_KLEIDIAI) message(STATUS "Building with Arm KleidiAI library") add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h index 777ec740ca..bf4316a963 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h @@ -37,16 +37,16 @@ void check_embedding_inputs( int packed_embedding_dim = (embedding_dim * weight_nbit) / 8; TORCHAO_CHECK( packed_weight_qvals.size(0) == - (torchao::ops::PackedWeightsHeader::size() + + (torchao::ops::PackedWeightsFormat::serialized_size() + (num_embeddings * packed_embedding_dim)), "packed_weight_qvals is not the correct size"); - // Check header - auto header = torchao::ops::PackedWeightsHeader::read( + // Check packed_weights_format + auto packed_weights_format = torchao::ops::PackedWeightsFormat::deserialize( packed_weight_qvals.const_data_ptr()); TORCHAO_CHECK( - header == - torchao::ops::embedding_xbit::get_packed_weights_header_universal( + packed_weights_format == + torchao::ops::embedding_xbit::get_packed_weights_format_universal( weight_nbit, /*min_value_chunk_size=*/32, /*max_value_chunk_size=*/128), @@ -151,7 +151,7 @@ Tensor embedding_out_cpu( embedding_dim, group_size, packed_weight_qvals.const_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), + torchao::ops::PackedWeightsFormat::serialized_size(), weight_scales.const_data_ptr(), weight_zeros_ptr, index); @@ -222,23 +222,23 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) { weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8"); auto out = torch::empty( - torchao::ops::PackedWeightsHeader::size() + + torchao::ops::PackedWeightsFormat::serialized_size() + (num_embeddings * packed_embedding_dim)) .to(torch::kInt8); - auto header = - torchao::ops::embedding_xbit::get_packed_weights_header_universal( + auto packed_weights_format = + torchao::ops::embedding_xbit::get_packed_weights_format_universal( weight_nbit, /*min_value_chunk_size=*/32, /*max_value_chunk_size=*/128); - header.write(out.mutable_data_ptr()); + packed_weights_format.serialize(out.mutable_data_ptr()); torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) { #if defined(__aarch64__) || defined(__ARM_NEON) torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals< weight_nbit>( out.mutable_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), + torchao::ops::PackedWeightsFormat::serialized_size(), embedding_dim, weight_qvals.const_data_ptr(), idx); @@ -261,7 +261,7 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) { embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack"); int packed_embedding_dim = embedding_dim * weight_nbit / 8; return torch::empty( - torchao::ops::PackedWeightsHeader::size() + + torchao::ops::PackedWeightsFormat::serialized_size() + (num_embeddings * packed_embedding_dim)) .to("meta"); } diff --git a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h b/torchao/experimental/ops/embedding_xbit/packed_weights_header.h index 935ee3bfbd..f2c8de6613 100644 --- a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h +++ b/torchao/experimental/ops/embedding_xbit/packed_weights_header.h @@ -10,13 +10,13 @@ namespace torchao::ops::embedding_xbit { -inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( +inline torchao::ops::PackedWeightsFormat get_packed_weights_format_universal( int weight_nbit, int min_value_chunk_size, int max_value_chunk_size, int version = 1) { - return torchao::ops::PackedWeightsHeader( - torchao::ops::PackedWeightsFormat::embedding_xbit_universal, + return torchao::ops::PackedWeightsFormat( + torchao::ops::PackedWeightsType::embedding_xbit_universal, {version, weight_nbit, min_value_chunk_size, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 90372d9262..31cd38ac6f 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -17,21 +17,68 @@ #include #include +#if defined(TORCHAO_ENABLE_KLEIDI) +#include +#include +#if defined (TORCHAO_ENABLE_ARM_I8MM) +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM +#endif // TORCHAO_ENABLE_KLEIDI + + +// #if defined(TORCHAO_ENABLE_KLEIDI) + +// enum kai_kernel_id { +// dotprod_1x4x32 = 0, +// dotprod_1x8x32, +// i8mm_4x8x32, +// i8mm_8x4x32 +// }; + +// #define KAI_GEN_UKERNEL(kernel_ns) \ +// namespace kernel = kernel_ns; \ +// auto uk = kernel::get_ukernel(); \ +// config.mr = uk.get_m_step(); \ +// config.nr = uk.get_n_step(); \ +// config.activation_data_size_fn = &kernel::activation_data_size; \ +// config.weight_data_size_fn = &kernel::weight_data_size; \ +// config.preferred_activation_data_alignment = kernel::get_preferred_alignement(); \ +// config.preferred_weight_data_alignment = kernel::get_preferred_alignement(); \ +// config.prepare_activation_data_fn = &kernel::prepare_activation_data; \ +// config.prepare_weight_data_fn = &kernel::prepare_weight_data; \ +// config.kernel_fn = &kernel::kernel; \ + +// template +// UKernelConfig get_ukernel_config_kleidi() { +// UKernelConfig config; +// #if defined (TORCHAO_ENABLE_ARM_I8MM) +// if constexpr (kernel_id == i8mm_4x8x32) { +// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); +// return config; +// } +// if constexpr (kernel_id == i8mm_8x4x32) { +// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); +// return config; +// } +// #endif // TORCHAO_ENABLE_ARM_I8MM +// if constexpr (kernel_id == dotprod_1x8x32) { +// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); +// return config; +// } +// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); +// return config; +// } + +// #endif // TORCHAO_ENABLE_KLEIDI + + + namespace torchao::ops::linear_8bit_act_xbit_weight { namespace { -// If needed in future, we can add uarch to the UKernelConfigCacheKey if needed -using UKernelConfigCacheKey = torchao::ops::PackedWeightsHeader; -struct UKernelConfigCacheKeyHash { - std::size_t operator()(const UKernelConfigCacheKey& k) const { - std::size_t hash = std::hash()(static_cast(k.format)); - for (int i = 0; i < k.params.size(); i++) { - hash ^= std::hash()(k.params[i]); - } - return hash; - } -}; -using UKernelConfigCacheType = std::unordered_map; +using UKernelConfigCacheKey = torchao::ops::PackedWeightsFormat; +using UKernelConfigCacheType = std::unordered_map; } template @@ -39,19 +86,20 @@ void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_ca if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_header_universal(weight_nbit, has_weight_zeros, has_bias, nr, kr); + UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, nr, kr); if (cpuinfo_has_arm_neon_dot()) { if (nr == 8 && kr == 16) { + namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::activation_data_size, + &kernel::activation_data_size, /*preferred_activation_data_alignment*/16, - &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::prepare_activation_data, - &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::weight_data_size, + &kernel::prepare_activation_data, + &kernel::weight_data_size, /*preferred_weight_data_alignment*/16, - &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::prepare_weight_data, + &kernel::prepare_weight_data, /*nr*/8, - {{{/*mr*/1, &torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::kernel}}} + {{{/*mr*/1, &kernel::kernel}}} }; return; } @@ -60,17 +108,90 @@ void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_ca throw std::runtime_error("Cannot register ukernel_config for packing format ukernel because no implementation is available on this platform"); } +template +void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int sr) { + std::cout << "register_ukernel_config_kleidi_ai" << std::endl; + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + + // TODO: make better + UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, has_bias, nr, kr, sr); + + #if defined (TORCHAO_ENABLE_ARM_I8MM) + if (cpuinfo_has_arm_i8mm()) { + namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; \ + auto uk = kernel::get_ukernel(); + ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + &kernel::activation_data_size, + kernel::get_preferred_alignement(), + &kernel::prepare_activation_data, + &kernel::weight_data_size, + kernel::get_preferred_alignement(), + &kernel::prepare_weight_data, + /*nr*/static_cast(uk.get_n_step()), + {{{/*mr*/static_cast(uk.get_m_step()), &kernel::kernel}}} + }; + return; + } + #endif // TORCHAO_ENABLE_ARM_I8MM + + + if (cpuinfo_has_arm_neon_dot()) { + if (nr == 8 && kr == 16 && sr == 2) { + namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; + auto uk = kernel::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + &kernel::activation_data_size, + kernel::get_preferred_alignement(), + &kernel::prepare_activation_data, + &kernel::weight_data_size, + kernel::get_preferred_alignement(), + &kernel::prepare_weight_data, + /*nr*/static_cast(uk.get_n_step()), + {{{/*mr*/static_cast(uk.get_m_step()), &kernel::kernel}}} + }; + return; + } + + if (nr == 4 && kr == 8 && sr == 2) { + namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; + auto uk = kernel::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + &kernel::activation_data_size, + kernel::get_preferred_alignement(), + &kernel::prepare_activation_data, + &kernel::weight_data_size, + kernel::get_preferred_alignement(), + &kernel::prepare_weight_data, + /*nr*/static_cast(uk.get_n_step()), + {{{/*mr*/static_cast(uk.get_m_step()), &kernel::kernel}}} + }; + return; + } + } + + +throw std::runtime_error("Cannot register ukernel_config for packing format kleidi_ai because no implementation is available on this platform"); +} + template -void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsHeader header) { - auto it = ukernel_config_cache.find(header); +void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsFormat format) { + auto it = ukernel_config_cache.find(format); if (it != ukernel_config_cache.end()) { throw std::runtime_error("UKernel config already registered"); } - switch (header.format) { - case torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal: { - auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params(header); + switch (format.type) { + case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: { + auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params(format); if (packing_params.weight_nbit != weight_nbit) { throw std::runtime_error("Packed weights are not in the expected format"); } @@ -84,11 +205,17 @@ void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torch } break; } + case torchao::ops::PackedWeightsType::kleidi_ai: { + auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_kleidi_ai_packing_params(format); + assert (packing_params.has_bias == true); + register_ukernel_config_kleidi_ai(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.sr); + break; + } default: throw std::runtime_error("No implementation for packed weights format"); } - it = ukernel_config_cache.find(header); + it = ukernel_config_cache.find(format); if (it == ukernel_config_cache.end()) { throw std::runtime_error("UKernel config did not register"); } @@ -96,46 +223,49 @@ void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torch template -torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) { +torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) { static UKernelConfigCacheType ukernel_config_cache; // Check cache - auto it = ukernel_config_cache.find(header); + auto it = ukernel_config_cache.find(format); if (it != ukernel_config_cache.end()) { std::cout << "UKERNEL CONFIG FROM CACHE: " << std::endl; return it->second; } std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl; - register_ukernel_config(ukernel_config_cache, header); - it = ukernel_config_cache.find(header); + register_ukernel_config(ukernel_config_cache, format); + it = ukernel_config_cache.find(format); assert(it != ukernel_config_cache.end()); auto config = it->second; return config; } -// TODO: make packing format and header separate concepts -// Header is a serialized packing format +// TODO: make packing format and format separate concepts +// format is a serialized packing format template -torchao::ops::PackedWeightsHeader select_header(std::optional target = std::nullopt) { +torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional target = std::nullopt) { if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - // Select KleidiAI header + // Select KleidiAI format + #if defined(TORCHAO_ENABLE_KLEIDI) if (!target || *target == "kleidi_ai") { if (weight_nbit == 4 && !has_weight_zeros) { + return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2); } } + #endif // defined(TORCHAO_ENABLE_KLEIDI) - // Select universal header + // Select universal format if (!target || *target == "universal") { if (cpuinfo_has_arm_neon_dot()) { - return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_header_universal(weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16, /*version*/1); + return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16, /*version*/1); } } - throw std::runtime_error("No header was selected"); + throw std::runtime_error("No format was selected"); } } // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index ec0d3aa217..41fa54e852 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -69,30 +69,39 @@ Tensor pack_weights_cpu( using namespace torchao::ops::linear_8bit_act_xbit_weight; -auto header = select_header(); -auto ukernel_config = select_ukernel_config(header); +auto packed_weights_format = select_packed_weights_format(); +auto ukernel_config = select_ukernel_config(packed_weights_format); auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( ukernel_config, n, /*target_panels_per_thread=*/1); - auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + auto packed_weight_data_size = torchao::ops::PackedWeightsFormat::serialized_size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); - header.write( + packed_weights_format.serialize( packed_weights.mutable_data_ptr()); + + // TODO: hack to call Kleidi AI kernel + // Need to pipe bias through pack_weight_data_operator + float* bias_ptr = nullptr; + std::vector bias(n, 0.0); + if (packed_weights_format.type == torchao::ops::PackedWeightsType::kleidi_ai) { + bias_ptr = bias.data(); + } + pack_weight_data_operator( ukernel_config, pack_weight_tiling_params, packed_weights.mutable_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), + torchao::ops::PackedWeightsFormat::serialized_size(), n, k, group_size, weight_qvals.const_data_ptr(), weight_scales.const_data_ptr(), weight_zeros_ptr, - /*bias*/ nullptr); + bias_ptr); return packed_weights; } @@ -147,11 +156,11 @@ Tensor pack_weights_meta( using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto header = select_header(); - auto ukernel_config = select_ukernel_config(header); + auto packed_weights_format = select_packed_weights_format(); + auto ukernel_config = select_ukernel_config(packed_weights_format); - auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + auto packed_weight_data_size = torchao::ops::PackedWeightsFormat::serialized_size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); return torch::empty({static_cast(packed_weight_data_size)}) .to("meta"); @@ -245,12 +254,12 @@ Tensor linear_out_cpu( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); #endif // USE_ATEN TORCHAO_CHECK( - packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), - "packed_weights is not big enough to read the header."); - auto header = - torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); + packed_weights.size(0) >= torchao::ops::PackedWeightsFormat::serialized_size(), + "packed_weights is not big enough to read the packed_weights_format."); + auto packed_weights_format = + torchao::ops::PackedWeightsFormat::deserialize(packed_weights.const_data_ptr()); - auto ukernel_config = select_ukernel_config(header); + auto ukernel_config = select_ukernel_config(packed_weights_format); auto linear_tiling_params = get_default_linear_tiling_params( ukernel_config, @@ -282,7 +291,7 @@ Tensor linear_out_cpu( k, group_size, packed_weights.const_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), + torchao::ops::PackedWeightsFormat::serialized_size(), activations.const_data_ptr(), // Clamp parameters are ignored because config is created from // has_clamp = false diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h index 9a0a1416b2..3b6fa6be02 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h @@ -10,15 +10,15 @@ namespace torchao::ops::linear_8bit_act_xbit_weight { -inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( +inline torchao::ops::PackedWeightsFormat get_packed_weights_format_universal( int weight_nbit, bool has_weight_zeros, bool has_bias, int nr, int kr, int version = 1) { - return torchao::ops::PackedWeightsHeader( - torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal, + return torchao::ops::PackedWeightsFormat( + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, {version, weight_nbit, has_weight_zeros, @@ -44,23 +44,68 @@ struct UniversalPackingParams { int kr; }; -inline UniversalPackingParams get_universal_packing_params(torchao::ops::PackedWeightsHeader header) { - if (header.format != torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal) { +inline UniversalPackingParams get_universal_packing_params(torchao::ops::PackedWeightsFormat format) { + if (format.type != torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal) { throw std::runtime_error("Packed weights are not in universal packing format."); } return UniversalPackingParams{ - header.params[0], - header.params[1], - static_cast(header.params[2]), - static_cast(header.params[3]), - header.params[4], - header.params[5], + format.params[0], + format.params[1], + static_cast(format.params[2]), + static_cast(format.params[3]), + format.params[4], + format.params[5], }; } +inline torchao::ops::PackedWeightsFormat get_packed_weights_format_kleidi_ai( + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + int nr, + int kr, + int sr) { + return torchao::ops::PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, + {weight_nbit, + has_weight_zeros, + has_bias, + nr, + kr, + sr, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0}); +} +struct KleidiAIPackingParams { + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + int nr; + int kr; + int sr; +}; +inline KleidiAIPackingParams get_kleidi_ai_packing_params(torchao::ops::PackedWeightsFormat format) { + if (format.type != torchao::ops::PackedWeightsType::kleidi_ai) { + throw std::runtime_error("Packed weights are not in kleidi_ai packing format."); + } + return KleidiAIPackingParams{ + format.params[0], + static_cast(format.params[1]), + static_cast(format.params[2]), + format.params[3], + format.params[4], + format.params[5] + }; +} } // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/experimental/ops/packed_weights_header.h index 7184da4b46..1c8e02872e 100644 --- a/torchao/experimental/ops/packed_weights_header.h +++ b/torchao/experimental/ops/packed_weights_header.h @@ -12,53 +12,54 @@ namespace torchao::ops { -enum class PackedWeightsFormat : uint32_t { +enum class PackedWeightsType : uint32_t { unknown = 0, linear_8bit_act_xbit_weight_universal = 1, - embedding_xbit_universal = 2 + embedding_xbit_universal = 2, + kleidi_ai = 3 }; -class PackedWeightsHeader { +class PackedWeightsFormat { public: using params_type = std::array; const static int magic = 6712; - PackedWeightsFormat format; + PackedWeightsType type; - // 14 bytes of format specific params + // 14 bytes of type specific params params_type params; - PackedWeightsHeader( - PackedWeightsFormat format = PackedWeightsFormat::unknown, + PackedWeightsFormat( + PackedWeightsType type = PackedWeightsType::unknown, params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - : format{format}, params{params} {} + : type{type}, params{params} {} - inline static constexpr int size() { - static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); + inline static constexpr int serialized_size() { + static_assert(sizeof(magic) + sizeof(type) + sizeof(params) == 64); return 64; } - inline void write(void* packed_weights) const { + inline void serialize(void* packed_weights) const { auto header = reinterpret_cast(packed_weights); header[0] = magic; - header[1] = static_cast(format); + header[1] = static_cast(type); for (int i = 0; i < params.size(); i++) { header[i + 2] = params[i]; } } - static PackedWeightsHeader read(const void* packed_weights) { + static PackedWeightsFormat deserialize(const void* packed_weights) { auto header = reinterpret_cast(packed_weights); - assert(header[0] == PackedWeightsHeader::magic); + assert(header[0] == PackedWeightsFormat::magic); params_type params; for (int i = 0; i < params.size(); i++) { params[i] = header[i + 2]; } - return PackedWeightsHeader( - static_cast(header[1]), params); + return PackedWeightsFormat( + static_cast(header[1]), params); } - bool operator==(const PackedWeightsHeader& other) const { - if (format != other.format) { + bool operator==(const PackedWeightsFormat& other) const { + if (type != other.type) { return false; } for (int i = 0; i < params.size(); i++) { @@ -71,3 +72,16 @@ class PackedWeightsHeader { }; } // namespace torchao::ops + +namespace std { + template <> + struct hash { + std::size_t operator()(const torchao::ops::PackedWeightsFormat& f) const { + std::size_t hash = std::hash()(static_cast(f.type)); + for (int i = 0; i < f.params.size(); i++) { + hash ^= std::hash()(f.params[i]); + } + return hash; + }; +}; +} From 0032ca8107af2d1736f4de211cf2a5b1ec9b8e97 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 3 Feb 2025 22:50:29 -0800 Subject: [PATCH 3/6] up --- .../kernel_selector.h | 142 ++++++---------- .../linear_8bit_act_xbit_weight.cpp | 39 ++--- .../linear_8bit_act_xbit_weight.h | 45 +++-- ...8_dynamic_activation_intx_weight_layout.py | 160 ++++++++++++++++++ 4 files changed, 255 insertions(+), 131 deletions(-) create mode 100644 torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 31cd38ac6f..82b9f4bfb9 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -26,54 +26,6 @@ #endif // TORCHAO_ENABLE_ARM_I8MM #endif // TORCHAO_ENABLE_KLEIDI - -// #if defined(TORCHAO_ENABLE_KLEIDI) - -// enum kai_kernel_id { -// dotprod_1x4x32 = 0, -// dotprod_1x8x32, -// i8mm_4x8x32, -// i8mm_8x4x32 -// }; - -// #define KAI_GEN_UKERNEL(kernel_ns) \ -// namespace kernel = kernel_ns; \ -// auto uk = kernel::get_ukernel(); \ -// config.mr = uk.get_m_step(); \ -// config.nr = uk.get_n_step(); \ -// config.activation_data_size_fn = &kernel::activation_data_size; \ -// config.weight_data_size_fn = &kernel::weight_data_size; \ -// config.preferred_activation_data_alignment = kernel::get_preferred_alignement(); \ -// config.preferred_weight_data_alignment = kernel::get_preferred_alignement(); \ -// config.prepare_activation_data_fn = &kernel::prepare_activation_data; \ -// config.prepare_weight_data_fn = &kernel::prepare_weight_data; \ -// config.kernel_fn = &kernel::kernel; \ - -// template -// UKernelConfig get_ukernel_config_kleidi() { -// UKernelConfig config; -// #if defined (TORCHAO_ENABLE_ARM_I8MM) -// if constexpr (kernel_id == i8mm_4x8x32) { -// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); -// return config; -// } -// if constexpr (kernel_id == i8mm_8x4x32) { -// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); -// return config; -// } -// #endif // TORCHAO_ENABLE_ARM_I8MM -// if constexpr (kernel_id == dotprod_1x8x32) { -// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); -// return config; -// } -// KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); -// return config; -// } - -// #endif // TORCHAO_ENABLE_KLEIDI - - - namespace torchao::ops::linear_8bit_act_xbit_weight { namespace { @@ -92,14 +44,22 @@ void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_ca if (nr == 8 && kr == 16) { namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - &kernel::activation_data_size, - /*preferred_activation_data_alignment*/16, - &kernel::prepare_activation_data, - &kernel::weight_data_size, - /*preferred_weight_data_alignment*/16, - &kernel::prepare_weight_data, - /*nr*/8, - {{{/*mr*/1, &kernel::kernel}}} + /*preferred_alignment*/16, + /*weight_packing*/ + { + /*nr*/8, + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data + }, + /*kernels*/ + {{ + { + /*mr*/1, + /*activation_data_size_fn*/&kernel::activation_data_size, + /*prepare_activation_data_fn*/&kernel::prepare_activation_data, + /*kernel*/&kernel::kernel + } + }} }; return; } @@ -120,18 +80,33 @@ void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_ca #if defined (TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; \ + if (nr == 8 && kr == 16 && sr == 2) { + namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; auto uk = kernel::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - &kernel::activation_data_size, - kernel::get_preferred_alignement(), - &kernel::prepare_activation_data, - &kernel::weight_data_size, - kernel::get_preferred_alignement(), - &kernel::prepare_weight_data, + /*preferred_alignment*/16, + /*weight_packing*/ + { /*nr*/static_cast(uk.get_n_step()), - {{{/*mr*/static_cast(uk.get_m_step()), &kernel::kernel}}} + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data + }, + /*kernels*/ + {{ + { + /*mr*/static_cast(uk.get_m_step()), + /*activation_data_size_fn*/&kernel::activation_data_size, + /*prepare_activation_data_fn*/&kernel::prepare_activation_data, + /*kernel*/&kernel::kernel + } + }} }; + return; + } return; } #endif // TORCHAO_ENABLE_ARM_I8MM @@ -144,35 +119,30 @@ void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_ca assert (nr == uk.get_nr()); assert (kr == uk.get_kr()); assert (sr == uk.get_sr()); + ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - &kernel::activation_data_size, - kernel::get_preferred_alignement(), - &kernel::prepare_activation_data, - &kernel::weight_data_size, - kernel::get_preferred_alignement(), - &kernel::prepare_weight_data, + /*preferred_alignment*/16, + /*weight_packing*/ + { /*nr*/static_cast(uk.get_n_step()), - {{{/*mr*/static_cast(uk.get_m_step()), &kernel::kernel}}} + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data + }, + /*kernels*/ + {{ + { + /*mr*/static_cast(uk.get_m_step()), + /*activation_data_size_fn*/&kernel::activation_data_size, + /*prepare_activation_data_fn*/&kernel::prepare_activation_data, + /*kernel*/&kernel::kernel + } + }} }; return; } if (nr == 4 && kr == 8 && sr == 2) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; - auto uk = kernel::get_ukernel(); - assert (nr == uk.get_nr()); - assert (kr == uk.get_kr()); - assert (sr == uk.get_sr()); - ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - &kernel::activation_data_size, - kernel::get_preferred_alignement(), - &kernel::prepare_activation_data, - &kernel::weight_data_size, - kernel::get_preferred_alignement(), - &kernel::prepare_weight_data, - /*nr*/static_cast(uk.get_n_step()), - {{{/*mr*/static_cast(uk.get_m_step()), &kernel::kernel}}} - }; + // TODO return; } } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index caadac2cbc..1065a922dc 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -23,7 +23,7 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1"); PackWeightDataTilingParams tiling_params; - int nr = ukernel_config.nr; + int nr = ukernel_config.weight_packing.nr; int num_threads = torchao::get_num_threads(); int numerator = n; int denominator = num_threads * target_panels_per_thread; @@ -33,7 +33,7 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( assert(nc >= 1); // Replace nc with the next number nr divides - nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; + nc = ((nc + nr - 1) / nr) * nr; tiling_params.nc_by_nr = nc / nr; return tiling_params; @@ -55,8 +55,8 @@ void pack_weight_data_operator( TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); - int nr = ukernel_config.nr; - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int nr = ukernel_config.weight_packing.nr; + int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.weight_packing.nr); int num_nc_panels = (n + nc - 1) / nc; torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { @@ -65,12 +65,12 @@ void pack_weight_data_operator( int nc_tile_size = std::min(nc, n - n_idx); int weight_data_offset = - (n_idx / nr) * ukernel_config.weight_data_size_fn(nr, k, group_size); + (n_idx / nr) * ukernel_config.weight_packing.weight_data_size_fn(nr, k, group_size); int weight_qvals_offset = n_idx * k; int weight_scales_and_zeros_offset = (n_idx * k / group_size); int bias_offset = n_idx; - ukernel_config.prepare_weight_data_fn( + ukernel_config.weight_packing.prepare_weight_data_fn( (char*)weight_data + weight_data_offset, /*n=*/nc_tile_size, k, @@ -109,9 +109,10 @@ LinearTilingParams get_default_linear_tiling_params( assert(nc >= 1); // Replace nc with next number nr divides - nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; - assert(nc % ukernel_config.nr == 0); - tiling_params.nc_by_nr = nc / ukernel_config.nr; + int nr = ukernel_config.weight_packing.nr; + nc = ((nc + nr - 1) / nr) * nr; + assert(nc % nr == 0); + tiling_params.nc_by_nr = nc / nr; assert(tiling_params.mc_by_mr >= 1); assert(tiling_params.nc_by_nr >= 1); @@ -127,7 +128,7 @@ get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( int m, int k, int group_size) { - return ukernel_config.activation_data_size_fn( + return ukernel_config.kernels[0].activation_data_size_fn( tiling_params.mc_by_mr * ukernel_config.kernels[0].mr, k, group_size); } @@ -138,7 +139,7 @@ get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_n int m, int k, int group_size) { - return ukernel_config.activation_data_size_fn(m, k, group_size); + return ukernel_config.kernels[0].activation_data_size_fn(m, k, group_size); } inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( @@ -157,19 +158,19 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( // Ignored if has_clamp = false float clamp_min, float clamp_max) { - int nr = ukernel_config.nr; + int nr = ukernel_config.weight_packing.nr; int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.kernels[0].mr); - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; size_t weight_data_size = - ukernel_config.weight_data_size_fn(nr, k, group_size); + ukernel_config.weight_packing.weight_data_size_fn(nr, k, group_size); for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) { int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.prepare_activation_data_fn( + ukernel_config.kernels[0].prepare_activation_data_fn( activation_data_buffer, /*m=*/mc_tile_size, k, @@ -215,16 +216,16 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( float clamp_min, float clamp_max) { int mr = ukernel_config.kernels[0].mr; - int nr = ukernel_config.nr; + int nr = ukernel_config.weight_packing.nr; int mc = std::min(m, tiling_params.mc_by_mr * mr); int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; size_t weight_data_size = - ukernel_config.weight_data_size_fn(nr, k, group_size); + ukernel_config.weight_packing.weight_data_size_fn(nr, k, group_size); size_t activation_data_size = - ukernel_config.activation_data_size_fn(mr, k, group_size); + ukernel_config.kernels[0].activation_data_size_fn(mr, k, group_size); torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) { int mc_tile_idx = idx; @@ -233,7 +234,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( int activations_offset = m_idx * k; int activation_data_offset = (m_idx / mr) * activation_data_size; - ukernel_config.prepare_activation_data_fn( + ukernel_config.kernels[0].prepare_activation_data_fn( activation_data_buffer + activation_data_offset, /*m=*/mc_tile_size, k, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index 7645216564..8367b54c7c 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -42,31 +42,24 @@ struct UKernelConfig { float clamp_min, float clamp_max); - struct kernel { + struct weight_packing_config { + int nr{0}; + weight_data_size_fn_type weight_data_size_fn{nullptr}; + prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; + }; + struct kernel_config { int mr{0}; + activation_data_size_fn_type activation_data_size_fn{nullptr}; + prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; kernel_fn_type kernel_fn{nullptr}; }; - - activation_data_size_fn_type activation_data_size_fn{nullptr}; - // preferred_activation_data_alignment is only a preferred alignment for - // performance reasons. Integration surfaces are not required to - // respect this alignment, and the ukernel must behave correctly no matter - // how the prepared_activation_data byte-array is aligned - size_t preferred_activation_data_alignment{0}; - prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; - - weight_data_size_fn_type weight_data_size_fn{nullptr}; - // weight_data_alignment is only a preferred alignment for - // performance reasons. Integration surfaces are not required to - // respect this alignment, and the ukernel must behave correctly no matter - // how the prepared_weight_data byte-array is aligned - size_t preferred_weight_data_alignment{0}; - prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - - // kernel_fn_type kernel_fn{nullptr}; - // int mr{0}; - int nr{0}; - std::array kernels; + + // preferred_alignment for activation and weight data + // Integration surfaces are not required to respect this alignment, and the + // ukernel must behave correctly no matter how buffers are aligned + size_t preferred_alignment{0}; + weight_packing_config weight_packing; + std::array kernels; }; // Pack weight functions @@ -84,12 +77,12 @@ inline size_t get_packed_weight_data_size( int n, int k, int group_size) { - return ukernel_config.weight_data_size_fn(n, k, group_size); + return ukernel_config.weight_packing.weight_data_size_fn(n, k, group_size); } inline size_t get_preferred_packed_weight_data_alignment( const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_weight_data_alignment; + return ukernel_config.preferred_alignment; } void pack_weight_data_operator( @@ -131,9 +124,9 @@ size_t get_activation_data_buffer_size( int k, int group_size); -inline size_t get_preferred_activation_data_buffer_alignment( +inline size_t activation_data_buffer_alignment( const UKernelConfig& ukernel_config) { - return ukernel_config.preferred_activation_data_alignment; + return ukernel_config.preferred_alignment; } void linear_operator( diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py new file mode 100644 index 0000000000..b2688ed00c --- /dev/null +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import tempfile +import unittest + +import torch + +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass + + +class TestPackedLinearInt8DynamicActivationIntxWeightLayout(unittest.TestCase): + def test_accuracy(self): + """ + Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + its results to the results of a reference model that uses PlainLayout() + """ + granularity = PerGroup(128) + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for weight_dtype in [ + torch.int1, + torch.int2, + torch.int3, + torch.int4, + torch.int5, + torch.int6, + torch.int7, + torch.int8, + ]: + for has_weight_zeros in [True, False]: + print( + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}" + ) + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # default + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PlainLayout(), + ), + ) + + with torch.no_grad(): + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + num_mismatch_at_low_tol = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + # print(f"actual_val={actual_val}, expected_val={expected_val}") + # self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + + self.assertTrue(torch.abs(actual_val - expected_val) < 0.05) + + # self.assertTrue(torch.allclose(actual_val, expected_val, atol=5e-1, rtol=1)) + + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + # Assert at most 5% of entries are not close at a low tolerance + # self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + def test_export_compile_aoti(self): + """ + Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with + torch.export.export, torch.compile, and AOTI. + """ + granularity = PerRow() + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 + weight_dtype = torch.int4 + has_weight_zeros = True + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, k3, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + + print("Quantizing model") + quantize_( + model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + ), + ) + eager_results = model(activations) + + unwrapped_model = copy.deepcopy(model) + unwrap_tensor_subclass(model) + + print("Exporting quantized model") + exported = torch.export.export(model, (activations,), strict=True) + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + print("Compiling quantized model") + compiled = torch.compile(unwrapped_model) + with torch.no_grad(): + compiled_results = compiled(activations) + self.assertTrue(torch.allclose(eager_results, compiled_results)) + + with tempfile.TemporaryDirectory() as tmpdirname: + package_path = f"{tmpdirname}/model.pt2" + print("Exporting quantized model with AOTI") + torch._inductor.aoti_compile_and_package( + exported, package_path=package_path + ) + + print("Running quantized model in AOTI") + fn = torch._inductor.aoti_load_package(package_path) + aoti_results = fn(activations) + self.assertTrue(torch.allclose(eager_results, aoti_results)) + + +if __name__ == "__main__": + unittest.main() From f7f43bd9414b025ff17dcb4dba7b715003528979 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 4 Feb 2025 21:38:39 -0800 Subject: [PATCH 4/6] up --- .../kernel_selector.h | 305 +++++++++++------- .../op_linear_8bit_act_xbit_weight-impl.h | 1 - .../packed_weights_header.h | 111 ------- 3 files changed, 189 insertions(+), 228 deletions(-) delete mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 82b9f4bfb9..874840d6da 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,7 +6,6 @@ #pragma once #include -#include #include #if defined(__aarch64__) || defined(__ARM_NEON) @@ -28,105 +27,200 @@ namespace torchao::ops::linear_8bit_act_xbit_weight { -namespace { -using UKernelConfigCacheKey = torchao::ops::PackedWeightsFormat; -using UKernelConfigCacheType = std::unordered_map; -} +struct UniversalPackedWeightsFormat { + int version; + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + int nr; + int kr; + + static UniversalPackedWeightsFormat from_packed_weights_format(torchao::ops::PackedWeightsFormat format) { + if (format.type != torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal) { + throw std::runtime_error("Packed weights are not in universal packing format."); + } + return UniversalPackedWeightsFormat{ + format.params[0], + format.params[1], + static_cast(format.params[2]), + static_cast(format.params[3]), + format.params[4], + format.params[5], + }; + } + inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const { + return torchao::ops::PackedWeightsFormat( + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + { + version, + weight_nbit, + has_weight_zeros, + has_bias, + nr, + kr + }); + } +}; + +struct KleidiAIPackedWeightsFormat { + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + int nr; + int kr; + int sr; + + static KleidiAIPackedWeightsFormat from_packed_weights_format(torchao::ops::PackedWeightsFormat format) { + if (format.type != torchao::ops::PackedWeightsType::kleidi_ai) { + throw std::runtime_error("Packed weights are not in kleidi_ai packing format."); + } + return KleidiAIPackedWeightsFormat{ + format.params[0], + static_cast(format.params[1]), + static_cast(format.params[2]), + format.params[3], + format.params[4], + format.params[5] + }; + } + inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const { + return torchao::ops::PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, + {weight_nbit, + has_weight_zeros, + has_bias, + nr, + kr, + sr}); + } +}; + +struct UKernelConfigRegistrationTable { + private: + std::unordered_map registration_table_; + public: + void register_ukernel_config(torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) { + if (registration_table_.find(format) != registration_table_.end()) { + throw std::runtime_error("UKernelConfig is already registered for this format"); + } + registration_table_[format] = config; + } + std::optional get_ukernel_config(torchao::ops::PackedWeightsFormat format) const { + auto it = registration_table_.find(format); + if (it == registration_table_.end()) { + return std::nullopt; + } + return it->second; + } +}; template -void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int version) { +void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) { if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, nr, kr); + auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format); + if (universal_format.weight_nbit != weight_nbit) { + throw std::runtime_error("Packed weights are not in the expected format"); + } + if (universal_format.has_weight_zeros != has_weight_zeros) { + throw std::runtime_error("Packed weights are not in the expected format"); + } + if (universal_format.has_bias != has_bias) { + throw std::runtime_error("Packed weights are not in the expected format"); + } - if (cpuinfo_has_arm_neon_dot()) { - if (nr == 8 && kr == 16) { + if (universal_format.nr == 8 && universal_format.kr == 16) { + #if defined(__aarch64__) || defined(__ARM_NEON) + if (cpuinfo_has_arm_neon_dot()) { namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/16, - /*weight_packing*/ - { - /*nr*/8, - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data - }, - /*kernels*/ - {{ - { - /*mr*/1, - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + table.register_ukernel_config( + format, + torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + /*preferred_alignment*/16, + /*weight_packing*/ + { + /*nr*/8, + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data + }, + /*kernels*/ + {{ + { + /*mr*/1, + /*activation_data_size_fn*/&kernel::activation_data_size, + /*prepare_activation_data_fn*/&kernel::prepare_activation_data, + /*kernel*/&kernel::kernel + } + }} + } + ); + return; } - }} - }; - return; - } + #endif // defined(__aarch64__) || defined(__ARM_NEON) } - - throw std::runtime_error("Cannot register ukernel_config for packing format ukernel because no implementation is available on this platform"); } template -void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int sr) { +void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) { std::cout << "register_ukernel_config_kleidi_ai" << std::endl; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - // TODO: make better - UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, has_bias, nr, kr, sr); + auto kleidi_ai_format = KleidiAIPackedWeightsFormat::from_packed_weights_format(format); + int nr = kleidi_ai_format.nr; + int kr = kleidi_ai_format.kr; + int sr = kleidi_ai_format.sr; - #if defined (TORCHAO_ENABLE_ARM_I8MM) - if (cpuinfo_has_arm_i8mm()) { - if (nr == 8 && kr == 16 && sr == 2) { + if (nr == 8 && kr == 16 && sr == 2) { + #if defined (TORCHAO_ENABLE_ARM_I8MM) + if (cpuinfo_has_arm_i8mm()) { namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; auto uk = kernel::get_ukernel(); assert (nr == uk.get_nr()); assert (kr == uk.get_kr()); assert (sr == uk.get_sr()); - - ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/16, - /*weight_packing*/ - { - /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data - }, - /*kernels*/ - {{ - { - /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + table.register_ukernel_config( + format, + torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + /*preferred_alignment*/16, + /*weight_packing*/ + { + /*nr*/static_cast(uk.get_n_step()), + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data + }, + /*kernels*/ + {{ + { + /*mr*/static_cast(uk.get_m_step()), + /*activation_data_size_fn*/&kernel::activation_data_size, + /*prepare_activation_data_fn*/&kernel::prepare_activation_data, + /*kernel*/&kernel::kernel + } + }} } - }} - }; + ); return; - } - return; - } - #endif // TORCHAO_ENABLE_ARM_I8MM - - - if (cpuinfo_has_arm_neon_dot()) { - if (nr == 8 && kr == 16 && sr == 2) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - auto uk = kernel::get_ukernel(); - assert (nr == uk.get_nr()); - assert (kr == uk.get_kr()); - assert (sr == uk.get_sr()); + } + #endif // TORCHAO_ENABLE_ARM_I8MM - ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + if (cpuinfo_has_arm_neon_dot()) { + namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; + auto uk = kernel::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + table.register_ukernel_config( + format, + torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ /*preferred_alignment*/16, /*weight_packing*/ { - /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data + /*nr*/static_cast(uk.get_n_step()), + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data }, /*kernels*/ {{ @@ -136,57 +230,37 @@ void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_ca /*prepare_activation_data_fn*/&kernel::prepare_activation_data, /*kernel*/&kernel::kernel } - }} - }; - return; - } - - if (nr == 4 && kr == 8 && sr == 2) { - // TODO - return; + }} } + ); + return; + } } - - -throw std::runtime_error("Cannot register ukernel_config for packing format kleidi_ai because no implementation is available on this platform"); } template -void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsFormat format) { - auto it = ukernel_config_cache.find(format); - if (it != ukernel_config_cache.end()) { - throw std::runtime_error("UKernel config already registered"); - } - +void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) { switch (format.type) { case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: { - auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params(format); - if (packing_params.weight_nbit != weight_nbit) { - throw std::runtime_error("Packed weights are not in the expected format"); - } - if (packing_params.has_weight_zeros != has_weight_zeros) { - throw std::runtime_error("Packed weights are not in the expected format"); - } - if (packing_params.has_bias) { - register_ukernel_config_universal(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version); + auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format); + if (universal_format.has_bias) { + register_ukernel_config_universal(table, format); } else { - register_ukernel_config_universal(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version); + register_ukernel_config_universal(table, format); } break; } case torchao::ops::PackedWeightsType::kleidi_ai: { - auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_kleidi_ai_packing_params(format); - assert (packing_params.has_bias == true); - register_ukernel_config_kleidi_ai(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.sr); + register_ukernel_config_kleidi_ai(table, format); break; } default: throw std::runtime_error("No implementation for packed weights format"); } - it = ukernel_config_cache.find(format); - if (it == ukernel_config_cache.end()) { + auto config = table.get_ukernel_config(format); + if (!config.has_value()) { throw std::runtime_error("UKernel config did not register"); } } @@ -194,21 +268,20 @@ void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torch template torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) { - static UKernelConfigCacheType ukernel_config_cache; + static UKernelConfigRegistrationTable table; - // Check cache - auto it = ukernel_config_cache.find(format); - if (it != ukernel_config_cache.end()) { - std::cout << "UKERNEL CONFIG FROM CACHE: " << std::endl; - return it->second; + auto ukernel = table.get_ukernel_config(format); + if (ukernel.has_value()) { + std::cout << "FOUND UKERNEL CONFIG IN CACHE" << std::endl; + return ukernel.value(); } std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl; - register_ukernel_config(ukernel_config_cache, format); - it = ukernel_config_cache.find(format); - assert(it != ukernel_config_cache.end()); - auto config = it->second; - return config; + register_ukernel_config(table, format); + + ukernel = table.get_ukernel_config(format); + assert(ukernel.has_value()); + return ukernel.value(); } // TODO: make packing format and format separate concepts @@ -223,7 +296,7 @@ torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional #include -#include #include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h deleted file mode 100644 index 3b6fa6be02..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -inline torchao::ops::PackedWeightsFormat get_packed_weights_format_universal( - int weight_nbit, - bool has_weight_zeros, - bool has_bias, - int nr, - int kr, - int version = 1) { - return torchao::ops::PackedWeightsFormat( - torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, - {version, - weight_nbit, - has_weight_zeros, - has_bias, - nr, - kr, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0}); -} - -struct UniversalPackingParams { - int version; - int weight_nbit; - bool has_weight_zeros; - bool has_bias; - int nr; - int kr; -}; - -inline UniversalPackingParams get_universal_packing_params(torchao::ops::PackedWeightsFormat format) { - if (format.type != torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal) { - throw std::runtime_error("Packed weights are not in universal packing format."); - } - return UniversalPackingParams{ - format.params[0], - format.params[1], - static_cast(format.params[2]), - static_cast(format.params[3]), - format.params[4], - format.params[5], - }; -} - - -inline torchao::ops::PackedWeightsFormat get_packed_weights_format_kleidi_ai( - int weight_nbit, - bool has_weight_zeros, - bool has_bias, - int nr, - int kr, - int sr) { - return torchao::ops::PackedWeightsFormat( - torchao::ops::PackedWeightsType::kleidi_ai, - {weight_nbit, - has_weight_zeros, - has_bias, - nr, - kr, - sr, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0}); -} - -struct KleidiAIPackingParams { - int weight_nbit; - bool has_weight_zeros; - bool has_bias; - int nr; - int kr; - int sr; -}; - -inline KleidiAIPackingParams get_kleidi_ai_packing_params(torchao::ops::PackedWeightsFormat format) { - if (format.type != torchao::ops::PackedWeightsType::kleidi_ai) { - throw std::runtime_error("Packed weights are not in kleidi_ai packing format."); - } - return KleidiAIPackingParams{ - format.params[0], - static_cast(format.params[1]), - static_cast(format.params[2]), - format.params[3], - format.params[4], - format.params[5] - }; -} - - -} // namespace torchao::ops::linear_8bit_act_xbit_weight From 5c45936174e74653cba599e14a8935de0db5eacf Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 6 Feb 2025 17:02:47 -0800 Subject: [PATCH 5/6] up --- .../kernel_selector.h | 157 +++++++++++++----- 1 file changed, 114 insertions(+), 43 deletions(-) diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 874840d6da..eeb455bfc4 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -28,7 +28,6 @@ namespace torchao::ops::linear_8bit_act_xbit_weight { struct UniversalPackedWeightsFormat { - int version; int weight_nbit; bool has_weight_zeros; bool has_bias; @@ -41,18 +40,16 @@ struct UniversalPackedWeightsFormat { } return UniversalPackedWeightsFormat{ format.params[0], - format.params[1], + static_cast(format.params[1]), static_cast(format.params[2]), - static_cast(format.params[3]), + format.params[3], format.params[4], - format.params[5], }; } inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const { return torchao::ops::PackedWeightsFormat( torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, { - version, weight_nbit, has_weight_zeros, has_bias, @@ -97,16 +94,24 @@ struct KleidiAIPackedWeightsFormat { struct UKernelConfigRegistrationTable { private: - std::unordered_map registration_table_; + using Key = std::pair; + struct KeyHasher { + std::size_t operator()(const Key& k) const { + return std::hash()(k.first) ^ std::hash()(static_cast(k.second)); + } + }; + std::unordered_map registration_table_; public: - void register_ukernel_config(torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) { - if (registration_table_.find(format) != registration_table_.end()) { + void register_ukernel_config(torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) { + auto key = std::make_pair(format, uarch); + if (registration_table_.find(key) != registration_table_.end()) { throw std::runtime_error("UKernelConfig is already registered for this format"); } - registration_table_[format] = config; + registration_table_[key] = config; } - std::optional get_ukernel_config(torchao::ops::PackedWeightsFormat format) const { - auto it = registration_table_.find(format); + std::optional get_ukernel_config(torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) const { + auto key = std::make_pair(format, uarch); + auto it = registration_table_.find(key); if (it == registration_table_.end()) { return std::nullopt; } @@ -115,19 +120,30 @@ struct UKernelConfigRegistrationTable { }; template -void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) { +void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) { + std::cout << "Calling register_ukernel_config_universal" << std::endl; // TODO: remove + if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format); if (universal_format.weight_nbit != weight_nbit) { - throw std::runtime_error("Packed weights are not in the expected format"); + throw std::runtime_error( + "Kernel expects weight_nbit=" + std::to_string(weight_nbit) + + ", but packed_weights have weight_nbit=" + std::to_string(universal_format.weight_nbit) + ); } if (universal_format.has_weight_zeros != has_weight_zeros) { - throw std::runtime_error("Packed weights are not in the expected format"); + throw std::runtime_error( + "Kernel expects has_weight_zeros=" + std::to_string(has_weight_zeros) + + ", but packed_weights have has_weight_zeros=" + std::to_string(universal_format.has_weight_zeros) + ); } if (universal_format.has_bias != has_bias) { - throw std::runtime_error("Packed weights are not in the expected format"); + throw std::runtime_error( + "Kernel expects has_bias=" + std::to_string(has_bias) + + ", but packed_weights have has_bias=" + std::to_string(universal_format.has_bias) + ); } if (universal_format.nr == 8 && universal_format.kr == 16) { @@ -136,6 +152,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, to namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; table.register_ukernel_config( format, + uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ /*preferred_alignment*/16, /*weight_packing*/ @@ -161,9 +178,11 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, to } } -template -void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) { - std::cout << "register_ukernel_config_kleidi_ai" << std::endl; +template +void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) { +#ifdef TORCHAO_ENABLE_KLEIDI + std::cout << "Calling register_ukernel_config_kleidi_ai" << std::endl; // TODO: remove + if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } @@ -172,6 +191,23 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to int nr = kleidi_ai_format.nr; int kr = kleidi_ai_format.kr; int sr = kleidi_ai_format.sr; + if (kleidi_ai_format.weight_nbit != weight_nbit) { + throw std::runtime_error( + "Kernel expects weight_nbit=" + std::to_string(weight_nbit) + + ", but packed_weights have weight_nbit=" + std::to_string(kleidi_ai_format.weight_nbit) + ); + } + if (kleidi_ai_format.has_weight_zeros != has_weight_zeros) { + throw std::runtime_error( + "Kernel expects has_weight_zeros=" + std::to_string(has_weight_zeros) + + ", but packed_weights have has_weight_zeros=" + std::to_string(kleidi_ai_format.has_weight_zeros) + ); + } + if (kleidi_ai_format.has_bias != true) { + throw std::runtime_error( + "Kernel expects has_bias=true, but packed_weights have has_bias=" + std::to_string(kleidi_ai_format.has_bias) + ); + } if (nr == 8 && kr == 16 && sr == 2) { #if defined (TORCHAO_ENABLE_ARM_I8MM) @@ -183,8 +219,9 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to assert (sr == uk.get_sr()); table.register_ukernel_config( format, + uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/16, + /*preferred_alignment*/kernel::get_preferred_alignement(), /*weight_packing*/ { /*nr*/static_cast(uk.get_n_step()), @@ -214,8 +251,9 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to assert (sr == uk.get_sr()); table.register_ukernel_config( format, + uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/16, + /*preferred_alignment*/kernel::get_preferred_alignement(), /*weight_packing*/ { /*nr*/static_cast(uk.get_n_step()), @@ -236,32 +274,66 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to return; } } + + if (nr == 4 && kr == 16 && sr == 2) { + if (cpuinfo_has_arm_neon_dot()) { + namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; + auto uk = kernel::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + table.register_ukernel_config( + format, + uarch, + torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + /*preferred_alignment*/kernel::get_preferred_alignement(), + /*weight_packing*/ + { + /*nr*/static_cast(uk.get_n_step()), + /*weight_data_size_fn*/&kernel::weight_data_size, + /*prepare_weight_data_fn*/&kernel::prepare_weight_data + }, + /*kernels*/ + {{ + { + /*mr*/static_cast(uk.get_m_step()), + /*activation_data_size_fn*/&kernel::activation_data_size, + /*prepare_activation_data_fn*/&kernel::prepare_activation_data, + /*kernel*/&kernel::kernel + } + }} + } + ); + return; + } + } +#endif // TORCHAO_ENABLE_KLEIDI } template -void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) { +void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) { switch (format.type) { case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: { auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format); if (universal_format.has_bias) { - register_ukernel_config_universal(table, format); + register_ukernel_config_universal(table, format, uarch); } else { - register_ukernel_config_universal(table, format); + register_ukernel_config_universal(table, format, uarch); } break; } case torchao::ops::PackedWeightsType::kleidi_ai: { - register_ukernel_config_kleidi_ai(table, format); + register_ukernel_config_kleidi_ai(table, format, uarch); break; } default: - throw std::runtime_error("No implementation for packed weights format"); + throw std::runtime_error("No registration available for packed_weights_type=" + std::to_string(static_cast(format.type))); } - auto config = table.get_ukernel_config(format); + auto config = table.get_ukernel_config(format, uarch); if (!config.has_value()) { - throw std::runtime_error("UKernel config did not register"); + throw std::runtime_error("ukernel_config did not register"); } } @@ -270,32 +342,33 @@ template torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) { static UKernelConfigRegistrationTable table; - auto ukernel = table.get_ukernel_config(format); + // In future, we can populate this with the current thread's uarch + // That will require that select_ukernel_config be called in the lambda + // instead of before it on the main thread + // Note, cpuinfo_get_current_core() is not currently implemeted outside of linux + // XNNPACK often uses non-core specific logic like cpuinfo_get_core(0)->uarch in configs + auto uarch = cpuinfo_uarch_unknown; + auto ukernel = table.get_ukernel_config(format, uarch); if (ukernel.has_value()) { - std::cout << "FOUND UKERNEL CONFIG IN CACHE" << std::endl; + std::cout << "Found ukernel_config in cache" << std::endl; // TODO: remove cout return ukernel.value(); } - std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl; - register_ukernel_config(table, format); + std::cout << "Registering ukernel config" << std::endl; // TODO: remove cout + register_ukernel_config(table, format, uarch); - ukernel = table.get_ukernel_config(format); + ukernel = table.get_ukernel_config(format, uarch); assert(ukernel.has_value()); return ukernel.value(); } -// TODO: make packing format and format separate concepts -// format is a serialized packing format + template torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional target = std::nullopt) { - if (!cpuinfo_initialize()) { - throw std::runtime_error("Failed to initialize cpuinfo!"); - } - // Select KleidiAI format #if defined(TORCHAO_ENABLE_KLEIDI) if (!target || *target == "kleidi_ai") { - if (weight_nbit == 4 && !has_weight_zeros) { + if constexpr (weight_nbit == 4 && (!has_weight_zeros)) { // TODO: add has_bias here return KleidiAIPackedWeightsFormat({weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2}).to_packed_weights_format(); } } @@ -303,12 +376,10 @@ torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional Date: Thu, 6 Feb 2025 21:56:40 -0800 Subject: [PATCH 6/6] change kleidiai interface --- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 133 ++++++++++++++++-- .../kernel_selector.h | 91 ++++++------ 2 files changed, 161 insertions(+), 63 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 167ccc47df..92569db4bb 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -15,6 +15,13 @@ #include #include +#include +#include + +#ifdef TORCHAO_ENABLE_ARM_I8MM +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM #include @@ -43,14 +50,16 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -size_t activation_data_size(const Ukernel ukernel, int m, int k) { +size_t activation_data_size(int mr, int kr, int sr, int m, int k) { auto lhs_packing = get_lhs_packing(); return lhs_packing.get_lhs_packed_size( - m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); + m, k, mr, kr, sr); } void prepare_activation_data( - const Ukernel ukernel, + int mr, + int kr, + int sr, void* activation_data, int m, int k, @@ -60,29 +69,31 @@ void prepare_activation_data( lhs_pack.run_lhs_pack( m, k, - ukernel.get_mr(), - ukernel.get_kr(), - ukernel.get_sr(), + mr, + kr, + sr, /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), activation_data); } -size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { +size_t weight_data_size(int nr, int kr, int sr, int n, int k, int group_size) { auto rhs_pack = get_rhs_packing(); return rhs_pack.get_rhs_packed_size( n, k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), + nr, + kr, + sr, group_size, kai_datatype::kai_dt_bf16); } void prepare_weight_data( - const Ukernel ukernel, + int nr, + int kr, + int sr, void* weight_data, int n, int k, @@ -134,9 +145,9 @@ void prepare_weight_data( /*groups=*/1, n, k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), + nr, + kr, + sr, group_size, /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), /*rhs_stride=*/roundup(k, 2) / 2, @@ -148,5 +159,99 @@ void prepare_weight_data( /*qparams=*/&qparams); } + +size_t get_preferred_alignement() { + return 16; +} + + +#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \ + size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \ + return weight_data_size(nr, kr, sr, n, k, group_size); \ + } \ + void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \ + void* weight_data, \ + int n, \ + int k, \ + int group_size, \ + const int8_t* weight_qvals, \ + const float* weight_scales, \ + const int8_t* weight_zeros, \ + const float* bias) { \ + prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \ + } + +#define DEFINE_ACTIVATION_DATA_FNS(mr, kr, sr) \ + size_t activation_data_size_mr##mr##_kr##kr##_sr##sr(int m, int k, int group_size) { \ + (void)group_size; \ + return activation_data_size(mr, kr, sr, m, k); \ + } \ + void prepare_activation_data_mr##mr##_kr##kr##_sr##sr(void* activation_data, int m, int k, int group_size, const float* activations) { \ + (void)group_size; \ + prepare_activation_data(mr, kr, sr, activation_data, m, k, activations); \ + } + +// TODO: first and suffix need to be better, e.g., parametrized by mr, nr, etc +// But I don't quite follow the naming convention for KleidiAI +#define DEFINE_KERNEL_FNS(first, suffix) \ + namespace impl_##suffix { \ + const Ukernel get_ukernel() { \ + return Ukernel{ \ + .get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix \ + }; \ + } \ + void kernel( \ + float32_t* output, \ + int output_m_stride, \ + int m, \ + int n, \ + int k, \ + int group_size, \ + const void* weight_data, \ + const void* activation_data, \ + float clamp_min, \ + float clamp_max) { \ + get_ukernel().run_matmul( \ + m, \ + n, \ + k, \ + group_size, \ + activation_data, \ + weight_data, \ + output, \ + /*dst_stride_row=*/ output_m_stride * sizeof(float), \ + /*dst_stride_col=*/ sizeof(float), \ + /*clamp_min=*/std::numeric_limits::lowest(), \ + /*clamp_max=*/std::numeric_limits::max() \ + ); \ + } \ + } + + + +DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2) +DEFINE_ACTIVATION_DATA_FNS(/*mr*/1, /*kr*/16, /*sr*/2) +DEFINE_KERNEL_FNS(1x8, 8x8_1x8x32_neon_dotprod) +DEFINE_KERNEL_FNS(1x8, 4x8_1x4x32_neon_dotprod) + +#ifdef TORCHAO_ENABLE_ARM_I8MM +DEFINE_KERNEL_FNS(4x8, 4x8_8x4x32_neon_i8mm) +DEFINE_KERNEL_FNS(4x8, 8x8_4x8x32_neon_i8mm) +#endif // TORCHAO_ENABLE_ARM_I8MM + +#undef DEFINE_WEIGHT_DATA_FNS +#undef DEFINE_ACTIVATION_DATA_FNS +#undef DEFINE_KERNEL_FNS + } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index eeb455bfc4..d380c9e564 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -17,12 +17,7 @@ #include #if defined(TORCHAO_ENABLE_KLEIDI) -#include -#include -#if defined (TORCHAO_ENABLE_ARM_I8MM) -#include -#include -#endif // TORCHAO_ENABLE_ARM_I8MM +#include #endif // TORCHAO_ENABLE_KLEIDI namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -208,44 +203,43 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to "Kernel expects has_bias=true, but packed_weights have has_bias=" + std::to_string(kleidi_ai_format.has_bias) ); } + namespace op = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p; if (nr == 8 && kr == 16 && sr == 2) { #if defined (TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; - auto uk = kernel::get_ukernel(); - assert (nr == uk.get_nr()); - assert (kr == uk.get_kr()); - assert (sr == uk.get_sr()); - table.register_ukernel_config( - format, - uarch, - torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/kernel::get_preferred_alignement(), - /*weight_packing*/ - { - /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data - }, - /*kernels*/ - {{ + auto uk = op::8x8_4x8x32_neon_i8mm::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + table.register_ukernel_config( + format, + uarch, + torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + /*preferred_alignment*/op::get_preferred_alignement(), + /*weight_packing*/ + { + /*nr*/static_cast(uk.get_n_step()), + /*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2, + /*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2 + }, + /*kernels*/ + {{ { /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + /*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2, + /*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2, + /*kernel*/&op::8x8_4x8x32_neon_i8mm::kernel } - }} - } - ); - return; + }} + } + ); + return; } #endif // TORCHAO_ENABLE_ARM_I8MM if (cpuinfo_has_arm_neon_dot()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - auto uk = kernel::get_ukernel(); + auto uk = op::impl_8x8_1x8x32_neon_dotprod::get_ukernel(); assert (nr == uk.get_nr()); assert (kr == uk.get_kr()); assert (sr == uk.get_sr()); @@ -253,20 +247,20 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to format, uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/kernel::get_preferred_alignement(), + /*preferred_alignment*/op::get_preferred_alignement(), /*weight_packing*/ { /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data + /*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2, + /*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2 }, /*kernels*/ {{ { /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + /*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2, + /*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2, + /*kernel*/&op::impl_8x8_1x8x32_neon_dotprod::kernel } }} } @@ -274,11 +268,10 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to return; } } - + if (nr == 4 && kr == 16 && sr == 2) { if (cpuinfo_has_arm_neon_dot()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; - auto uk = kernel::get_ukernel(); + auto uk = op::impl_4x8_1x4x32_neon_dotprod::get_ukernel(); assert (nr == uk.get_nr()); assert (kr == uk.get_kr()); assert (sr == uk.get_sr()); @@ -286,26 +279,26 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to format, uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/kernel::get_preferred_alignement(), + /*preferred_alignment*/op::get_preferred_alignement(), /*weight_packing*/ { /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data + /*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2, + /*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2 }, /*kernels*/ {{ { /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + /*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2, + /*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2, + /*kernel*/&op::impl_4x8_1x4x32_neon_dotprod::kernel } }} } ); return; - } + } } #endif // TORCHAO_ENABLE_KLEIDI }