Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental][Kleidi] Add GEMM operator tests #1638

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA
include(FetchContent)
# KleidiAI is an open-source library that provides optimized
# performance-critical routines, also known as micro-kernels, for artificial
# intelligence (AI) workloads tailored for Arm® CPUs.
# intelligence (AI) workloads tailored for Arm® CPUs.
FetchContent_Declare(kleidiai
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this
GIT_TAG v1.2.0)
FetchContent_MakeAvailable(kleidiai)

# Temporarily exposing this to the parent scope until we wire
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_row=*/output_m_stride * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/ n * sizeof(float),
/*dst_stride_row=*/ output_m_stride * sizeof(float),
/*dst_stride_col=*/ sizeof(float),
clamp_min,
clamp_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_row=*/output_m_stride * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_row=*/output_m_stride * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
Expand Down
22 changes: 22 additions & 0 deletions torchao/experimental/ops/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,34 @@ if(TORCHAO_BUILD_KLEIDIAI)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()

if(TORCHAO_BUILD_ARM_I8MM)
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
endif()

if (ANDROID_ABI)
# We are cross compiling, delay test discovery till runtime
set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST)
endif()

include_directories(${TORCHAO_INCLUDE_DIRS})

set(TORCHAO_PARALLEL_BACKEND "test_dummy")
add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

include(${TORCHAO_ROOT}/Utils.cmake)

if (ANDROID_ABI)
# Given where we are today this is sufficent. But needs to be revisited.
# This is also needed for native builds, but keeping it only for cross builds
# for now given the hacky nature.
file(GLOB DOTPROD_SRC_FILES test*.cpp)
message(SRC_FILES: ${DOTPROD_SRC_FILES})
set_property(SOURCE
${DOTPROD_SRC_FILES}
APPEND_STRING PROPERTY
COMPILE_FLAGS " -march=armv8.2-a+dotprod ")
endif()

add_executable(
test_linear_8bit_act_xbit_weight
test_linear_8bit_act_xbit_weight.cpp
Expand Down
41 changes: 39 additions & 2 deletions torchao/experimental/ops/tests/build_and_run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,57 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

target=${1:-"native"}
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests

IS_ARM64=0
BUILD_ARM_I8MM=0
EXTRA_ARGS=""
if [[ "${target}" == "android" ]]; then
if [[ -z ${ANDROID_NDK} ]]; then
echo "Need to set ANDROID_NDK env variable to build for Android";
exit 1;
fi
android_abi=arm64-v8a
android_platform=28 # must be >=28 for aligned_alloc
IS_ARM64=1
BUILD_ARM_I8MM=1 # Hardcoded for now
CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android}
toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake"
if [[ -z ${toolchain_file} ]]; then
echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}"
exit 1;
fi
EXTRA_ARGS="\
-DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \
-DANDROID_ABI=${android_abi} \
-DANDROID_PLATFORM=${android_platform}
"
echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}"
fi

hash arch; retval=$?
if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then
IS_ARM64=1
fi

export CMAKE_OUT=/tmp/cmake-out/torchao/tests
cmake \
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
${EXTRA_ARGS} \
-DCMAKE_BUILD_TYPE=Debug \
-DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \
-DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \
-S . \
-B ${CMAKE_OUT}

cmake --build ${CMAKE_OUT}

echo "Successfully built tests."

if [[ "${target}" != "native" ]]; then
echo "Skip running tests when cross compiling.";
exit 0;
fi

# Run
${CMAKE_OUT}/test_linear_8bit_act_xbit_weight
128 changes: 128 additions & 0 deletions torchao/experimental/ops/tests/generate_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Simple script to generate test cases for the torchao ops
from string import Template


def add_test_string(kernel, m, n, k, g, has_bias, has_clamp):
name = f"m{m}xn{n}xk{k}xg{g}{'_bias' if has_bias else ''}{'_clamp' if has_clamp else ''}"
d = {
"name": name,
"kernel": kernel,
"m": m,
"n": n,
"k": k,
"g": g,
"has_bias": "true" if has_bias else "false",
"has_clamp": "true" if has_clamp else "false",
}

test_template = Template(
"""
TEST(test_linear_8bit_act_xbit_weight, Kleidi_${kernel}_${name}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
$has_bias /*has_bias*/,
$has_clamp /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/$m, /*n=*/$n, /*k=*/$k, /*group_size=*/$g, &ukernel_config);
}
"""
)

return [test_template.safe_substitute(d)]


def get_test_block(kernel):
# Assuming given kleidi kernel can run with all these test cases
tests = []
# GEMV, m == 1
## subtile
tests += add_test_string(kernel, 1, 2 * 1, 32, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 2, 32, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 3, 32, 32, True, False)
tests += add_test_string(kernel, 1, 2 * 2, 32, 32, True, True)
tests += add_test_string(kernel, 1, 2 * 3, 32, 32, False, True)
## larger: n - must be multiple of 2
tests += add_test_string(kernel, 1, 2 * 11, 32, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False)
tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True)
tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False)
## larger: k, g - must be multiple of 32
tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False)
tests += add_test_string(kernel, 1, 2 * 13, 64, 64, False, True)
tests += add_test_string(kernel, 1, 2 * 17, 128, 64, False, False)

# GEMM, m > 1
## subtile
tests += add_test_string(kernel, 2, 2 * 1, 32, 32, False, False)
tests += add_test_string(kernel, 2, 2 * 2, 32, 32, False, False)
tests += add_test_string(kernel, 3, 2 * 3, 32, 32, True, False)
tests += add_test_string(kernel, 4, 2 * 4, 32, 32, True, True)
tests += add_test_string(kernel, 3, 2 * 3, 32, 32, False, True)
## larger: m
tests += add_test_string(kernel, 31, 2 * 1, 32, 32, False, False)
tests += add_test_string(kernel, 32, 2 * 2, 32, 32, False, False)
tests += add_test_string(kernel, 33, 2 * 3, 32, 32, True, False)
tests += add_test_string(kernel, 34, 2 * 4, 32, 32, True, True)
tests += add_test_string(kernel, 35, 2 * 3, 32, 32, False, True)
## larger: n - must be multiple of 2
tests += add_test_string(kernel, 7, 2 * 11, 32, 32, False, False)
tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False)
tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True)
tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False)
## larger: k, g - must be multiple of 32
tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False)
tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False)
tests += add_test_string(kernel, 29, 2 * 13, 64, 64, False, True)
tests += add_test_string(kernel, 101, 2 * 17, 128, 64, False, False)

return "".join(tests)


def main():
kleidi_template = Template(
"""
/*****************/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Little late but consider putting a header suggesting this is autogenerated by this particular script and how

Copy link
Contributor Author

@digantdesai digantdesai Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is there, see line 106 in this file.

As per how, script dumps c++ code on stdout right now, and then manual copy-pasta 🍝
Added this as a note in this commit FWIW. We should improve this, but on the back burner I guess.

// ${kernel} tests
/*****************/
${prologue}
${tests}
${epilogue}
"""
)

kleidi_kernels = [
"dotprod_1x4x32",
"dotprod_1x8x32",
"i8mm_4x8x32",
"i8mm_8x4x32",
]

print("/* Generated by generate_tests.py */")
print("/* Do not modify */")
print()
print("#if defined(TORCHAO_ENABLE_KLEIDI)")
for kernel in kleidi_kernels:
prologue, epilogue = "", ""
if "i8mm" in kernel:
prologue = "#if defined(TORCHAO_ENABLE_ARM_I8MM)"
epilogue = "#endif // TORCHAO_ENABLE_ARM_I8MM"
tests = get_test_block(kernel)
d = {
"prologue": prologue,
"kernel": kernel,
"tests": tests,
"epilogue": epilogue,
}

print(kleidi_template.safe_substitute(d))
print("#endif // TORCHAO_ENABLE_KLEIDI")


if __name__ == "__main__":
main()
Loading
Loading