diff --git a/.github/workflows/release_packages.yml b/.github/workflows/release_packages.yml index ba97d161..fb77cd8a 100644 --- a/.github/workflows/release_packages.yml +++ b/.github/workflows/release_packages.yml @@ -1,57 +1,28 @@ name: Release Packages on: - push: - tags: - - 'v[0-9]+.[0-9]+.[0-9]+' + push: + tags: + - 'v[0-9]+.[0-9]+.[0-9]+' + workflow_dispatch: # Needed to create release and upload assets permissions: contents: write jobs: - build-deb: - runs-on: [self-hosted, Linux, X64] - container: - image: registry-1.docker.io/dashinfer/dev-ubuntu-22.04-x86:v1 - defaults: - run: - shell: bash -l {0} - steps: - - name: Check out code - uses: actions/checkout@v3 - - - name: Pull LFS - run: | - git lfs pull - - - name: Build deb package - run: | - git fetch --tags - TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1)) - VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//') - source activate ds_py - AS_RELEASE_VERSION=$VERSION_NUMBER \ - AS_PLATFORM="x86" \ - AS_BUILD_PACKAGE=ON \ - bash build.sh - - - name: Upload deb package - uses: actions/upload-artifact@v3 - with: - name: dashinfer-deb - path: build/*.deb - - build-rpm: + build-tgz: strategy: matrix: arch: [X64, ARM64] - image: ["dev-centos7-x86:v1", "dev-alinux-arm:v1"] + image: ["dev-centos7-x86:v2", "dev-centos7-cu124:v1", "dev-centos8-arm:v2"] exclude: - arch: X64 - image: "dev-alinux-arm:v1" + image: "dev-centos8-arm:v2" + - arch: ARM64 + image: "dev-centos7-x86:v2" - arch: ARM64 - image: "dev-centos7-x86:v1" + image: "dev-centos7-cu124:v1" runs-on: [self-hosted, Linux, "${{ matrix.arch }}"] container: image: registry-1.docker.io/dashinfer/${{ matrix.image }} @@ -67,40 +38,60 @@ jobs: uses: actions/checkout@v3 with: lfs: true - + - name: Pull LFS run: | + git lfs install --force git lfs pull - - - name: Build rpm package + + - name: Init submodule + run: | + git submodule init + git submodule update + + - name: Build tgz package run: | git fetch --tags TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1)) VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//') - source /opt/rh/devtoolset-7/enable - source activate ds_py - AS_RELEASE_VERSION=$VERSION_NUMBER \ - AS_PLATFORM=$( [[ "${{ matrix.arch }}" = "X64" ]] && echo "x86" || echo "armclang" ) \ - AS_BUILD_PACKAGE=ON \ + source /root/.bashrc + + if command -v nvcc &> /dev/null + then + export AS_PLATFORM="cuda" + export AS_CUDA_SM="'70;75;80;86;89;90a'" + else + # export ENABLE_MULTINUMA="ON" + if [[ "${{ matrix.arch }}" == "ARM64" ]]; then + export AS_PLATFORM="armclang" + else + export AS_PLATFORM="x86" + fi + fi + + export AS_RELEASE_VERSION=$VERSION_NUMBER + export AS_BUILD_PACKAGE=ON bash build.sh - - name: Upload rpm package + - name: Upload tgz package uses: actions/upload-artifact@v3 with: - name: dashinfer-rpm-${{ matrix.arch }} - path: build/*.rpm - + name: dashinfer-tgz-${{ matrix.arch }} + path: build/*.tar.gz + build-wheels: strategy: matrix: arch: [X64, ARM64] - image: ["dev-manylinux-x86:v1", "dev-manylinux-arm:v1"] + image: ["dev-centos7-x86:v2", "dev-centos7-cu124:v1", "dev-centos8-arm:v2"] exclude: - arch: X64 - image: "dev-manylinux-arm:v1" + image: "dev-centos8-arm:v2" - arch: ARM64 - image: "dev-manylinux-x86:v1" + image: "dev-centos7-x86:v2" + - arch: ARM64 + image: "dev-centos7-cu124:v1" runs-on: [self-hosted, Linux, "${{ matrix.arch }}"] container: image: registry-1.docker.io/dashinfer/${{ matrix.image }} @@ -114,12 +105,31 @@ jobs: with: lfs: true + - name: Pull LFS + run: | + git lfs install --force + git lfs pull + + - name: Init submodule + run: | + git submodule init + git submodule update + - name: Build manylinux wheels run: | git fetch --tags TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1)) + source /root/.bashrc VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//') - AS_RELEASE_VERSION=$VERSION_NUMBER bash scripts/release/python_manylinux_build.sh + export AS_RELEASE_VERSION=$VERSION_NUMBER + + if command -v nvcc &> /dev/null + then + export AS_CUDA_SM="'70;75;80;86;89;90a'" + bash scripts/release/python_manylinux_build_cuda.sh + else + bash scripts/release/python_manylinux_build.sh + fi - name: Upload wheels uses: actions/upload-artifact@v3 @@ -127,56 +137,50 @@ jobs: name: python-manylinux-wheels-${{ matrix.arch }} path: python/wheelhouse/*-manylinux*.whl - test: - strategy: - matrix: - arch: [X64, ARM64] - image: ["test-ubuntu-x86:v1", "test-centos-arm:v1"] - exclude: - - arch: X64 - image: "test-centos-arm:v1" - - arch: ARM64 - image: "test-ubuntu-x86:v1" - runs-on: [self-hosted, Linux, "${{ matrix.arch }}"] - container: - image: registry-1.docker.io/dashinfer/${{ matrix.image }} - volumes: - - /mnt/data0/models/modelscope:/github/home/.cache/modelscope - options: "--ipc=host --cap-add SYS_NICE --cap-add SYS_PTRACE" - needs: build-wheels - steps: - - name: Check out code - uses: actions/checkout@v3 + # test: + # strategy: + # matrix: + # arch: [X64, ARM64] + # image: ["test-ubuntu-x86:v1", "test-centos-arm:v1"] + # exclude: + # - arch: X64 + # image: "test-centos-arm:v1" + # - arch: ARM64 + # image: "test-ubuntu-x86:v1" + # runs-on: [self-hosted, Linux, "${{ matrix.arch }}"] + # container: + # image: registry-1.docker.io/dashinfer/${{ matrix.image }} + # volumes: + # - /mnt/data0/models/modelscope:/github/home/.cache/modelscope + # options: "--ipc=host --cap-add SYS_NICE --cap-add SYS_PTRACE" + # needs: build-wheels + # steps: + # - name: Check out code + # uses: actions/checkout@v3 - - name: Download wheels - uses: actions/download-artifact@v3 - with: - name: python-manylinux-wheels-${{ matrix.arch }} - path: python/wheelhouse + # - name: Download wheels + # uses: actions/download-artifact@v3 + # with: + # name: python-manylinux-wheels-${{ matrix.arch }} + # path: python/wheelhouse - - name: Test manylinux wheels - run: | - TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1)) - VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//') - AS_RELEASE_VERSION=$VERSION_NUMBER bash scripts/release/python_manylinux_test.sh + # - name: Test manylinux wheels + # run: | + # TAG_NAME=$(git describe --tags $(git rev-list --tags --max-count=1)) + # VERSION_NUMBER=$(echo "$TAG_NAME" | sed 's/^v//') + # AS_RELEASE_VERSION=$VERSION_NUMBER bash scripts/release/python_manylinux_test.sh publish: runs-on: [self-hosted, Linux] - needs: [build-deb, build-rpm, test] + needs: [build-tgz, build-wheels] strategy: matrix: arch: [X64, ARM64] steps: - - name: Download deb packages + - name: Download tgz packages uses: actions/download-artifact@v3 with: - name: dashinfer-deb - path: release/ - - - name: Download rpm packages - uses: actions/download-artifact@v3 - with: - name: dashinfer-rpm-${{ matrix.arch }} + name: dashinfer-tgz-${{ matrix.arch }} path: release/ - name: Download python wheels @@ -189,6 +193,3 @@ jobs: uses: softprops/action-gh-release@v2 with: files: release/* - - - diff --git a/.gitignore b/.gitignore index 21cb64f2..91d613ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,29 @@ +tests/cpp/model/testcase +tests/cpp/operator/testcase +tests/python/custom_model +tests/testcase build/ python/build/ +ossutil_output/ +__pycache__/ +.ccls +*.qdrep +*.qdstrm +*.h5 +.ccls-cache/ +*.log +compile_commands.json python/dist/ -python/dashinfer.egg-info/ -python/dashinfer.egg-info +python/pyhie.egg-info/ +python/pyhie_allspark.egg-info +*.ascache +*.lock third_party/from_source/*.o -__pycache__/ +third_party/from_source/openssl/* +.idea/ +.vscode/ +*.nsys-rep +log* +*.csv +#*.sh +*.as* diff --git a/CMakeLists.txt b/CMakeLists.txt index 5369d32b..e330c148 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,13 +9,18 @@ endif() string(REGEX REPLACE "-rc[0-9]+" "" STRIPED_VERSION_STRING ${project_version_in_env}) set(project_version_in_env ${STRIPED_VERSION_STRING}) -message("Build AllSpark with version:${project_version_in_env}") +message("Build DashInfer with version: ${project_version_in_env}") project(DashInfer LANGUAGES C CXX VERSION ${project_version_in_env}) include(GNUInstallDirs) set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}-${PROJECT_VERSION} CACHE STRING "Force modify install dir" FORCE) -message(STATUS "CMAKE_INSTALL_PREFIX:${CMAKE_INSTALL_PREFIX} CPACK_PACKAGE_DEVICE_NAME:${CPACK_PACKAGE_DEVICE_NAME}") +message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") +if (BUILD_PYTHON) + # building manylinux pkg need this setting to find local libflash-attn.so + set(CMAKE_INSTALL_RPATH "$ORIGIN") + set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) +endif() if (NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, @@ -63,7 +68,7 @@ option(ENABLE_CUSPARSELT "build with CUSPARSELT lib" OFF) option(BUILD_UTEST "build with unit test" ON) option(BUILD_EXAMPLE "build with examples" ON) option(BUILD_PYTHON "build with python api" ON) -option(PACKAGE_RPM "package with rpm " ON) +option(BUILD_PACKAGE "build cpp package" OFF) option(MEM_CHECK "check memory" OFF) option(LOCK_CHECK "check deadlock" OFF) option(ALWAYS_READ_LOAD_MODEL "load and parse model via every read" OFF) @@ -212,42 +217,44 @@ if (BUILD_PYTHON) add_subdirectory(python) endif() - - -if (PACKAGE_RPM) -set(CPACK_SYSTEM_NAME "alios7") -if(CONFIG_HOST_CPU_TYPE STREQUAL "ARM") - set(CPACK_SYSTEM_ARCHITECTURE "aarch64") -else() - set(CPACK_SYSTEM_ARCHITECTURE "x86_64") -endif() - -if (ENABLE_CUDA) - if(ENABLE_NV_STATIC_LIB) - set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-static") +if (BUILD_PACKAGE) + # config system arch + if(CONFIG_HOST_CPU_TYPE STREQUAL "ARM") + set(CPACK_SYSTEM_ARCHITECTURE "aarch64") else() - set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-shared") + set(CPACK_SYSTEM_ARCHITECTURE "x86_64") endif() -else() - set(CPACK_PACKAGE_DEVICE_NAME "cpu") -endif() - -set(CPACK_PACKAGE_NAME "DashInfer") -set(CPACK_PACKAGE_VERSION ${project_version_in_env}) -set(CPACK_PACKAGE_VENDOR "Alibaba Tongyi") -set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "DashInfer AllSpark is a LLM inference engine.") -set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR}) -set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR}) -set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH}) -set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") -set(CPACK_RESOURCE_FILE_README "${CMAKE_CURRENT_SOURCE_DIR}/README.md") + if (ENABLE_CUDA) + if(ENABLE_NV_STATIC_LIB) + set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-static") + else() + set(CPACK_PACKAGE_DEVICE_NAME "cuda-${CUDA_VERSION}-shared") + endif() + else() + if (ENABLE_MULTINUMA) + set(CPACK_PACKAGE_DEVICE_NAME "cpu-multinuma") + else() + set(CPACK_PACKAGE_DEVICE_NAME "cpu") + endif() + endif() -set(CPACK_PACKAGING_INSTALL_PREFIX "") -set(CPACK_RPM_PACKAGE_RELOCATABLE ON) - -set(CPACK_PACKAGE_FILE_NAME "${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}.${CPACK_PACKAGE_DEVICE_NAME}.${CPACK_SYSTEM_NAME}.${CPACK_SYSTEM_ARCHITECTURE}") -include(CPack) + set(CPACK_PACKAGE_NAME "DashInfer") + set(CPACK_PACKAGE_VENDOR "Alibaba Tongyi") + set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "DashInfer AllSpark is a LLM inference engine.") + set(CPACK_PACKAGE_VERSION ${project_version_in_env}) + set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR}) + set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR}) + set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH}) + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") + set(CPACK_RESOURCE_FILE_README "${CMAKE_CURRENT_SOURCE_DIR}/README.md") + set(CPACK_PACKAGING_INSTALL_PREFIX "") + set(CPACK_GENERATOR "TGZ") + set(CPACK_THREADS 16) + + set(CPACK_PACKAGE_FILE_NAME "${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}.${CPACK_PACKAGE_DEVICE_NAME}.${CPACK_SYSTEM_ARCHITECTURE}") + + INCLUDE(CPack) endif() #install diff --git a/build.sh b/build.sh index 08346009..e9a02502 100755 --- a/build.sh +++ b/build.sh @@ -20,7 +20,7 @@ NCCL_VERSION="${AS_NCCL_VERSION:-2.23.4}" system_nv_lib="${AS_SYSTEM_NV_LIB:-OFF}" build_type="${AS_BUILD_TYPE:-Release}" cuda_static="${AS_CUDA_STATIC:-OFF}" -rpm_package="${AS_RPM_PACKAGE:-OFF}" +build_package="${AS_BUILD_PACKAGE:-ON}" enable_glibcxx11_abi="${AS_CXX11_ABI:-OFF}" enable_span_attn="${ENABLE_SPAN_ATTENTION:-ON}" enable_multinuma="${ENABLE_MULTINUMA:-OFF}" @@ -81,6 +81,7 @@ export PATH=`pwd`/bin:$PATH if [ "${with_platform,,}" == "cuda" ]; then cmake .. \ -DCMAKE_BUILD_TYPE=${build_type} \ + -DBUILD_PACKAGE=${build_package} \ -DCONFIG_ACCELERATOR_TYPE=CUDA \ -DCONFIG_HOST_CPU_TYPE=X86 \ -DNCCL_VERSION=${NCCL_VERSION} \ @@ -97,9 +98,11 @@ if [ "${with_platform,,}" == "cuda" ]; then elif [ "${with_platform,,}" == "x86" ]; then cmake .. \ -DCMAKE_BUILD_TYPE=${build_type} \ + -DBUILD_PACKAGE=${build_package} \ -DCONFIG_ACCELERATOR_TYPE=NONE \ -DCONFIG_HOST_CPU_TYPE=X86 \ -DENABLE_GLIBCXX11_ABI=${enable_glibcxx11_abi} \ + -DBUILD_PYTHON=OFF \ -DALLSPARK_CBLAS=MKL \ -DENABLE_CUDA=OFF \ -DENABLE_SPAN_ATTENTION=OFF \ @@ -108,10 +111,12 @@ elif [ "${with_platform,,}" == "x86" ]; then elif [ "${with_platform,,}" == "armclang" ]; then cmake .. \ -DCMAKE_BUILD_TYPE=${build_type} \ + -DBUILD_PACKAGE=${build_package} \ -DCONFIG_ACCELERATOR_TYPE=NONE \ -DCONFIG_HOST_CPU_TYPE=ARM \ -DENABLE_BLADE_AUTH=${enable_blade_auth} \ -DENABLE_GLIBCXX11_ABI=${enable_glibcxx11_abi} \ + -DBUILD_PYTHON=OFF \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DENABLE_ARMCL=ON \ -DALLSPARK_CBLAS=BLIS \ @@ -134,7 +139,7 @@ make -j16 && make install if [ $? -eq 0 ]; then - if [ ${rpm_package} == "ON" ]; then + if [ ${build_package} == "ON" ]; then make package fi else diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake index fb6bf875..85044546 100644 --- a/cmake/FindNCCL.cmake +++ b/cmake/FindNCCL.cmake @@ -5,24 +5,9 @@ if (USE_SYSTEM_NV_LIB) return() endif() include(FindPackageHandleStandardArgs) -include(FetchContent) -set(NCCL_VERSION - "2.11.4" - CACHE STRING "NCCL VERSION") -set(NCCL_URL https://github.com/NVIDIA/nccl/archive/refs/tags/v${NCCL_VERSION}-1.tar.gz) -set(NCCL_PROJECT "extern_nccl") -FetchContent_Declare(${NCCL_PROJECT} URL ${NCCL_URL}) -message(STATUS "Fetch NCCL from ${NCCL_URL}") -FetchContent_MakeAvailable(${NCCL_PROJECT}) - -set(NCCL_ROOT_DIR - "${${NCCL_PROJECT}_SOURCE_DIR}" - CACHE PATH "NVIDIA NCCL") -message(STATUS "NCCL_ROOT_DIR : ${NCCL_ROOT_DIR}") find_path( NCCL_INCLUDE_DIR nccl.h - HINTS ${NCCL_ROOT_DIR} PATH_SUFFIXES cuda/include include nccl-${NCCL_VERSION}-cuda-${CUDA_VERSION}/include) @@ -35,7 +20,6 @@ endif() message("find nccl with ${NCCL_LIBNAME}") find_library( AS_NCCL_LIBRARY ${NCCL_LIBNAME} - HINTS ${NCCL_ROOT_DIR} PATH_SUFFIXES lib lib64 nccl-${NCCL_VERSION}-cuda-${CUDA_VERSION}/lib64) if(ENABLE_NV_STATIC_LIB) @@ -51,14 +35,11 @@ set_property(TARGET CUDA::${NCCL_LIBNAME} PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${NCCL_INCLUDE_DIR}) # install nccl - if(NOT ENABLE_NV_STATIC_LIB) get_filename_component(NCCL_LIB_DIR ${AS_NCCL_LIBRARY} DIRECTORY) -install(DIRECTORY ${NCCL_LIB_DIR}/ - DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} - USE_SOURCE_PERMISSIONS FILES_MATCHING - PATTERN "*nccl.so*" -) +file(GLOB NCCL_LIBS ${NCCL_LIB_DIR}/*nccl.so*) +install(FILES ${NCCL_LIBS} + DESTINATION ${CMAKE_INSTALL_LIBDIR}) endif() @@ -66,5 +47,5 @@ find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR AS_NCCL_LIBRARY) if(NCCL_FOUND) - message(STATUS "Found NCCL: success , library path : ${AS_NCCL_LIBRARY}") + message(STATUS "Found NCCL: success, library path : ${AS_NCCL_LIBRARY}") endif() diff --git a/cmake/flash-attention.cmake b/cmake/flash-attention.cmake index 01dd532d..db646521 100644 --- a/cmake/flash-attention.cmake +++ b/cmake/flash-attention.cmake @@ -90,7 +90,7 @@ if (FLASHATTN_USE_STATIC_LIB) else() add_library(flash-attention::flash-attn SHARED IMPORTED) install(FILES ${FLASHATTN_LIBRARY_PATH}/libflash-attn.so - DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) + DESTINATION ${CMAKE_INSTALL_LIBDIR}) message(STATUS "libflash-attn.so installing path: ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}") endif() diff --git a/cmake/install.cmake b/cmake/install.cmake index 1c2552e2..d24c330b 100644 --- a/cmake/install.cmake +++ b/cmake/install.cmake @@ -1,44 +1,38 @@ # add install target -SET_TARGET_PROPERTIES(allspark_framework PROPERTIES INSTALL_RPATH "$ORIGIN") install(DIRECTORY ${PROJECT_SOURCE_DIR}/csrc/interface/ - DESTINATION include/allspark/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/allspark USE_SOURCE_PERMISSIONS FILES_MATCHING PATTERN "*.h" ) -if (NOT BUILD_PYTHON) - install(TARGETS allspark_framework_static DESTINATION ${CMAKE_INSTALL_DIR}) -endif() +install(TARGETS allspark_framework DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install(TARGETS allspark_framework_static DESTINATION ${CMAKE_INSTALL_LIBDIR}) if (ENABLE_MULTINUMA) - install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/orterun - DESTINATION bin - RENAME mpirun) - install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/allspark_daemon - DESTINATION bin - RENAME allspark_daemon) - SET_TARGET_PROPERTIES(allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN") - install(TARGETS allspark_client DESTINATION ${CMAKE_INSTALL_DIR}) install(DIRECTORY ${PROJECT_SOURCE_DIR}/csrc/service/ - DESTINATION include/allspark/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/allspark USE_SOURCE_PERMISSIONS FILES_MATCHING PATTERN "allspark_client.h") + install(TARGETS allspark_client DESTINATION ${CMAKE_INSTALL_LIBDIR}) + install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/orterun + DESTINATION ${CMAKE_INSTALL_BINDIR} + RENAME mpirun) + install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/bin/allspark_daemon + DESTINATION ${CMAKE_INSTALL_BINDIR} + RENAME allspark_daemon) endif() if (BUILD_PYTHON) - if (PYTHON_LIB_DIRS) - if(NOT ENABLE_NV_STATIC_LIB) - install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*" PATTERN "libnccl.*" EXCLUDE) - else() - install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*") - endif() - if (ENABLE_MULTINUMA) - install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/bin DESTINATION ${PYTHON_LIB_DIRS} USE_SOURCE_PERMISSIONS FILES_MATCHING PATTERN "*") - SET_TARGET_PROPERTIES(_allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_INSTALL_LIBDIR}") - install(TARGETS _allspark_client DESTINATION ${PYTHON_LIB_DIRS}) - endif() - SET_TARGET_PROPERTIES(_allspark PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_INSTALL_LIBDIR}") - install(TARGETS _allspark DESTINATION ${PYTHON_LIB_DIRS}) +if (PYTHON_LIB_DIRS) + if(NOT ENABLE_NV_STATIC_LIB) + install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*.so" PATTERN "libnccl.*" EXCLUDE) + else() + install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR} DESTINATION ${PYTHON_LIB_DIRS} FILES_MATCHING PATTERN "*.so") + endif() + + if (ENABLE_MULTINUMA) + install(DIRECTORY ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR} DESTINATION ${PYTHON_LIB_DIRS} USE_SOURCE_PERMISSIONS FILES_MATCHING PATTERN "*") + install(TARGETS _allspark_client DESTINATION ${PYTHON_LIB_DIRS}) endif() -else() - install(TARGETS allspark_framework DESTINATION ${CMAKE_INSTALL_DIR}) + install(TARGETS _allspark DESTINATION ${PYTHON_LIB_DIRS}) +endif() endif() diff --git a/csrc/core/kernel/cuda/moe_ppu/moe_ppu_kernel.cu b/csrc/core/kernel/cuda/moe_ppu/moe_ppu_kernel.cu deleted file mode 100644 index b485dbde..00000000 --- a/csrc/core/kernel/cuda/moe_ppu/moe_ppu_kernel.cu +++ /dev/null @@ -1,2826 +0,0 @@ -/*! - * Copyright (c) Alibaba, Inc. and its affiliates. - * @file moe_ppu_kernel.cu - */ -#include -#include -#include -#include - -#include - -#include "../cuda_kernel.h" -#include "allspark.pb.h" -#include "moe_ppu_kernel.h" -namespace allspark { -namespace cuda { -struct alignas(16) BatchInfo { - uint32_t batchId; - uint32_t m; - uint32_t ctaYOffset; - uint32_t COffset; -}; - -__device__ __forceinline__ uint32_t SmemU32Addr(const void* smemptr) { - uint32_t u32addr; - asm("{.reg .u64 u64addr;\n" - " cvta.to.shared.u64 u64addr, %1;\n" - " cvt.u32.u64 %0, u64addr; }\n" - : "=r"(u32addr) - : "l"(smemptr)); - return u32addr; -} - -__device__ __forceinline__ void LdgSts32(const uint32_t& smemAddr, - const void* gmemPtr, bool guard) { - asm volatile( - "{.reg.pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p cp.async.ca.shared.global [%0], [%1], 4;}\n" - : - : "r"(smemAddr), "l"(gmemPtr), "r"((int)guard)); -} - -__device__ __forceinline__ void LdgSts32(const uint32_t& smemAddr, - const void* gmemPtr, - const uint32_t& srcSize, bool guard) { - asm volatile( - "{.reg.pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " @p cp.async.ca.shared.global [%0], [%1], 4, %2;}\n" - : - : "r"(smemAddr), "l"(gmemPtr), "r"(srcSize), "r"((int)guard)); -} - -__device__ __forceinline__ void LdgSts64(const uint32_t& smemAddr, - const void* gmemPtr, bool guard) { - asm volatile( - "{.reg.pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p cp.async.ca.shared.global [%0], [%1], 8;}\n" - : - : "r"(smemAddr), "l"(gmemPtr), "r"((int)guard)); -} - -__device__ __forceinline__ void LdgSts64(const uint32_t& smemAddr, - const void* gmemPtr, - const uint32_t& srcSize, bool guard) { - asm volatile( - "{.reg.pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " @p cp.async.ca.shared.global [%0], [%1], 8, %2;}\n" - : - : "r"(smemAddr), "l"(gmemPtr), "r"(srcSize), "r"((int)guard)); -} - -__device__ __forceinline__ void LdgSts128(const uint32_t& smemAddr, - const void* gmemPtr, bool guard) { - asm volatile( - "{.reg.pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p cp.async.ca.shared.global [%0], [%1], 16;}\n" - : - : "r"(smemAddr), "l"(gmemPtr), "r"((int)guard)); -} - -__device__ __forceinline__ void LdgSts128(const uint32_t& smemAddr, - const void* gmemPtr, - const uint32_t& srcSize, bool guard) { - asm volatile( - "{.reg.pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " @p cp.async.ca.shared.global [%0], [%1], 16, %2;}\n" - : - : "r"(smemAddr), "l"(gmemPtr), "r"(srcSize), "r"((int)guard)); -} - -__device__ __forceinline__ void LdgStsGroupCommit() { - asm volatile("cp.async.commit_group;\n"); -} - -template -__device__ __forceinline__ void LdgStsGroupWait() { - asm volatile("cp.async.wait_group %0;\n" : : "n"(N)); -} - -template -__device__ __forceinline__ void Ldsm4(T& r0, T& r1, T& r2, T& r3, - const uint32_t& addr) { - static_assert(sizeof(T) == 4, "Ldsm4: invalid T"); - asm volatile( - "alippu.ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, " - "[%4];\n" - : "=r"(reinterpret_cast(r0)), - "=r"(reinterpret_cast(r1)), - "=r"(reinterpret_cast(r2)), - "=r"(reinterpret_cast(r3)) - : "r"(addr)); -} - -template -__device__ __forceinline__ void Ldsm4Trans(T& r0, T& r1, T& r2, T& r3, - const uint32_t& addr) { - static_assert(sizeof(T) == 4, "Ldsm4Trans: invalid T"); - asm volatile( - "alippu.ldmatrix.sync.aligned.m16n16.x1.trans.shared.b16 {%0, %1, %2, " - "%3}, [%4];\n" - : "=r"(reinterpret_cast(r0)), - "=r"(reinterpret_cast(r1)), - "=r"(reinterpret_cast(r2)), - "=r"(reinterpret_cast(r3)) - : "r"(addr)); -} - -template -__device__ __forceinline__ void Hmma161616F32(T (&d)[8], const T (&a)[4], - const T (&b)[4]) { - static_assert(sizeof(T) == 4, "Hmma161616F32: invalid T"); - asm volatile ( - "alippu.mma.sync.aligned.m16n16k16.row.col.f32.f16.f16.f32" - " {%0, %1, %2, %3, %4, %5, %6, %7}," - " {%8, %9, %10, %11}," - " {%12, %13, %14, %15}," - " {%0, %1, %2, %3, %4, %5, %6, %7};\n" - : "+r"(reinterpret_cast(d[0])), - "+r"(reinterpret_cast(d[1])) - "+r"(reinterpret_cast(d[2])) - "+r"(reinterpret_cast(d[3])) - "+r"(reinterpret_cast(d[4])) - "+r"(reinterpret_cast(d[5])) - "+r"(reinterpret_cast(d[6])) - "+r"(reinterpret_cast(d[7])) - : "r"(reinterpret_cast(a[0])), - "r"(reinterpret_cast(a[1])), - "r"(reinterpret_cast(a[2])), - "r"(reinterpret_cast(a[3])), - "r"(reinterpret_cast(b[0])), - "r"(reinterpret_cast(b[1])), - "r"(reinterpret_cast(b[2])), - "r"(reinterpret_cast(b[3])) - ); -} - -template -__device__ __forceinline__ void Stg64(const T& r0, const T& r1, const void* ptr, - bool guard) { - static_assert(sizeof(T) == 4, "Stg64: invalid T"); - asm volatile( - "{.reg .pred p;\n" - " setp.ne.b32 p, %1, 0;\n" - " @p st.global.v2.b32 [%0], {%2, %3};}\n" - : - : "l"(ptr), "r"((int)guard), "r"(reinterpret_cast(r0)), - "r"(reinterpret_cast(r1))); -} - -/** - * m_tile: 128 - * n_tile: 256 - * k_tile: 32x5 - * warp_tile: 64x64 - * CTA: 2x4 warps - * smem size: 120KB - */ -__device__ __forceinline__ void hgemm_f32_m128n256_k32x5_hmma161616_ldg8_loop( - const half* A, const half* B, const uint32_t* matARowIdx, half* C, - char* smem, const uint32_t& m, const uint32_t& n, const uint32_t& k, - const uint32_t& tileIdX, const uint32_t& tileIdY, - const uint32_t& BLdgStep) { - uint32_t warpId = threadIdx.x / 32; - uint32_t laneId = threadIdx.x % 32; - - uint32_t matARowId[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int mIdx = tileIdY * 128 + threadIdx.x / 4 + i * 64; - if (mIdx < m) { - asm("ld.global.ca.b32 %0, [%1];" - : "=r"(matARowId[i]) - : "l"(matARowIdx + mIdx)); - } else { - // map the out-of-bound threads to row0 of matrixA, - // to avoid predicated ld instructions - matARowId[i] = 0; - } - } - - const char* ALdgPtr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - ALdgPtr[i] = reinterpret_cast(A + matARowId[i] * k + - threadIdx.x % 4 * 8); - } - const char* BLdgPtr = reinterpret_cast( - B + (threadIdx.x / 8) * n + tileIdX * 256 + (threadIdx.x % 8) * 8); - - // LdgGuard to avoid LDG out of bound - uint32_t BLdgGuard = 0; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int nIdx = tileIdX * 256 + (threadIdx.x % 8) * 8 + i * 64; - if (nIdx < n) { - BLdgGuard |= (1U << i); - } - } - - uint32_t ASmemAddr = SmemU32Addr(smem); - uint32_t BSmemAddr = SmemU32Addr(smem + 128 * 32 * sizeof(half)); - - uint32_t AStsAddr = - ASmemAddr + - sizeof(half) * ((threadIdx.x % 4) * (128 * 8) + - ((threadIdx.x / 4) ^ (threadIdx.x % 4 * 2)) * 8); - uint32_t BStsAddr = - BSmemAddr + - sizeof(half) * ((threadIdx.x / 8) * 256 + - ((threadIdx.x % 8) ^ (threadIdx.x / 8 % 8)) * 8); - - // ATile lds addr - uint32_t ALdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = laneId / 8 % 2 + i * 2; - int row = (laneId / 16 * 8 + laneId % 8) ^ (col * 2); - ALdsAddr[i] = ASmemAddr + sizeof(half) * (col * 128 * 8 + - (warpId / 4) * 64 * 8 + row * 8); - } - - // BTile lds addr - uint32_t BLdsAddr[4]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int col = (laneId / 8 % 2 + i * 2) ^ (laneId % 8); - int row = laneId / 16 * 8 + laneId % 8; - BLdsAddr[i] = - BSmemAddr + sizeof(half) * (row * 256 + (warpId % 4) * 64 + col * 8); - } - - uint32_t kTiles = (k + 31) / 32; - - // load 1'st tile to shared memory - { - uint32_t firstKTile = k - (kTiles * 32 - 32); - uint32_t ASrcSize = threadIdx.x % 4 * 8 < firstKTile ? 16 : 0; - uint32_t BSrcSize = threadIdx.x / 8 < firstKTile ? 16 : 0; - -#pragma unroll - for (int i = 0; i < 2; ++i) { - LdgSts128(AStsAddr + i * 64 * 8 * sizeof(half), ALdgPtr[i], ASrcSize, - true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), BSrcSize, - (BLdgGuard & (1u << i)) != 0); - } - LdgStsGroupCommit(); - -// ldg pointer for the next tile -#pragma unroll - for (int i = 0; i < 2; ++i) { - ALdgPtr[i] += firstKTile * sizeof(half); - } - BLdgPtr += firstKTile * n * sizeof(half); - } - -// load 2'st to (N-stages - 1) tiles to shared memory -#pragma unroll - for (int prefetchIter = 1; prefetchIter < 4; ++prefetchIter) { - if (prefetchIter < kTiles) { -#pragma unroll - for (int i = 0; i < 2; ++i) { - LdgSts128( - AStsAddr + prefetchIter * 1024 * 24 + i * 64 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + prefetchIter * 1024 * 24 + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1u << i)) != 0); - } - -// ldg pointer for the next tile -#pragma unroll - for (int i = 0; i < 2; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - } - LdgStsGroupCommit(); - } - - // wait for the 1'st tile - LdgStsGroupWait<3>(); - __syncthreads(); - - // smem double buffer offset - uint32_t ldsOffset = 0; - uint32_t stsOffset = 96 * 1024; - - // A, B and C register fragment - uint32_t AFrag[2][4][4]; - uint32_t BFrag[2][4][4]; - uint32_t CFrag[4][4][8]; -#pragma unroll - for (int i = 0; i < 4; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { -#pragma unroll - for (int p = 0; p < 8; ++p) { - CFrag[i][j][p] = 0; - } - } - } - -// load 1'st fragment -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4(AFrag[0][i][0], AFrag[0][i][1], AFrag[0][i][2], AFrag[0][i][3], - ALdsAddr[0] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[0][i][0], BFrag[0][i][1], BFrag[0][i][2], BFrag[0][i][3], - BLdsAddr[i] + ldsOffset); - } - - if (tileIdX * 256 + 256 <= n) { - // matrixB CTA tile is full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 96 * 1024 ? ldsOffset + 24 * 1024 : 0; - stsOffset = stsOffset < 96 * 1024 ? stsOffset + 24 * 1024 : 0; - -// ldg pointer for the next tile -#pragma unroll - for (int i = 0; i < 2; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 4; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { -#pragma unroll - for (int i = 0; i < 2; ++i) { - LdgSts128(AStsAddr + stsOffset + i * 64 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), true); - } - LdgStsGroupCommit(); - } - } - } - } else { - // matrixB CTA tile is not full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 96 * 1024 ? ldsOffset + 24 * 1024 : 0; - stsOffset = stsOffset < 96 * 1024 ? stsOffset + 24 * 1024 : 0; - -// ldg pointer for the next tile -#pragma unroll - for (int i = 0; i < 2; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 4; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { -#pragma unroll - for (int i = 0; i < 2; ++i) { - LdgSts128(AStsAddr + stsOffset + i * 64 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1U << i)) != 0); - } - LdgStsGroupCommit(); - } - } - } - } - - // k-tiles loop without prefetch - for (; kTiles > 0; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 96 * 1024 ? ldsOffset + 24 * 1024 : 0; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4( - AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 4; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], BFrag[kFrag % 2][j]); - } - } - - // dummy LdgStsGroupCommit to make LdgStsGroupWait work - if (kFrag == 0) { - LdgStsGroupCommit(); - } - } - } - - uint32_t CStsIdxX = warpId % 4 * 64 + laneId % 4; - uint32_t CStsIdxY = warpId / 4 * 32 + laneId / 4; - uint32_t* CStsPtr = - reinterpret_cast(smem) + CStsIdxY * 260 + CStsIdxX; - const float4* CLdsPtr = reinterpret_cast(smem) + - threadIdx.x / 128 * 32 * 65 + - threadIdx.x % 128 / 64 * 65 + threadIdx.x % 64; - - uint32_t mIdx = - tileIdY * 128 + threadIdx.x / 128 * 64 + threadIdx.x % 128 / 64; - uint32_t nIdx = tileIdX * 256 + threadIdx.x % 64 * 4; - - half* CStgPtr = C + mIdx * n + nIdx; - bool nGuard = nIdx < n; - -#pragma unroll - for (int stgIter = 0; stgIter < 2; ++stgIter) { - // C_tile sts - __syncthreads(); -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - CStsPtr[i * 16 * 260 + j * 16] = CFrag[stgIter * 2 + i][j][0]; - CStsPtr[i * 16 * 260 + j * 16 + 4] = CFrag[stgIter * 2 + i][j][1]; - CStsPtr[i * 16 * 260 + j * 16 + 8] = CFrag[stgIter * 2 + i][j][2]; - CStsPtr[i * 16 * 260 + j * 16 + 12] = CFrag[stgIter * 2 + i][j][3]; - - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16] = CFrag[stgIter * 2 + i][j][4]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 4] = - CFrag[stgIter * 2 + i][j][5]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 8] = - CFrag[stgIter * 2 + i][j][6]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 12] = - CFrag[stgIter * 2 + i][j][7]; - } - } - __syncthreads(); - - // lds - float4 CLdsReg[16]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - CLdsReg[i] = CLdsPtr[i * 2 * 65]; - } - - half2 CStgReg[16][2]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - asm("{.reg .b16 h0, h1, h2, h3;\n" - " cvt.rn.f16.f32 h0, %2;\n" - " cvt.rn.f16.f32 h1, %3;\n" - " cvt.rn.f16.f32 h2, %4;\n" - " cvt.rn.f16.f32 h3, %5;\n" - " mov.b32 %0, {h0, h1};\n" - " mov.b32 %1, {h2, h3};}" - : "=r"(reinterpret_cast(CStgReg[i][0])), - "=r"(reinterpret_cast(CStgReg[i][1])) - : "f"(CLdsReg[i].x), "f"(CLdsReg[i].y), "f"(CLdsReg[i].z), - "f"(CLdsReg[i].w)); - } - -// C_tile stg -#pragma unroll - for (int i = 0; i < 16; ++i) { - Stg64(CStgReg[i][0], CStgReg[i][1], CStgPtr + (stgIter * 32 + i * 2) * n, - mIdx + stgIter * 32 + i * 2 < m && nGuard); - } - } -} - -/** - * m_tile: 96 - * n_tile: 256 - * k_tile: 32x5 - * warp_tile: 48x64 - * CTA: 2x4 warps - * smem size: 110KB - */ -__device__ __forceinline__ void hgemm_f32_m96n256_k32x5_hmma161616_ldg4_loop( - const half* A, const half* B, const uint32_t* matARowIdx, half* C, - char* smem, const uint32_t& m, const uint32_t& n, const uint32_t& k, - const uint32_t& tileIdX, const uint32_t& tileIdY, - const uint32_t& BLdgStep) { - uint32_t warpId = threadIdx.x / 32; - uint32_t laneId = threadIdx.x % 32; - - uint32_t matARowId[3]; -#pragma unroll - for (int i = 0; i < 3; ++i) { - int mIdx = tileIdY * 96 + threadIdx.x / 8 + i * 32; - if (mIdx < m) { - asm("ld.global.ca.b32 %0, [%1];" - : "=r"(matARowId[i]) - : "l"(matARowIdx + mIdx)); - } else { - // map the out-of-bound threads to row0 of matrixA, - // to avoid predicated ld instructions - matARowId[i] = 0; - } - } - - const char* ALdgPtr[3]; -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] = reinterpret_cast(A + matARowId[i] * k + - threadIdx.x % 8 * 4); - } - const char* BLdgPtr = reinterpret_cast( - B + (threadIdx.x / 8) * n + tileIdX * 256 + (threadIdx.x % 8) * 8); - - // LdgGuard to avoid LDG out of bound - uint32_t BLdgGuard = 0; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int nIdx = tileIdX * 256 + (threadIdx.x % 8) * 8 + i * 64; - if (nIdx < n) { - BLdgGuard |= (1U << i); - } - } - - uint32_t ASmemAddr = SmemU32Addr(smem); - uint32_t BSmemAddr = SmemU32Addr(smem + 96 * 32 * sizeof(half)); - - uint32_t AStsAddr = - ASmemAddr + - sizeof(half) * ((threadIdx.x % 8 / 2) * (96 * 8) + - ((threadIdx.x / 8) ^ (threadIdx.x % 8 / 2 * 2)) * 8 + - threadIdx.x % 2 * 4); - uint32_t BStsAddr = - BSmemAddr + - sizeof(half) * ((threadIdx.x / 8) * 256 + - ((threadIdx.x % 8) ^ (threadIdx.x / 8 % 8)) * 8); - - // ATile lds addr - uint32_t ALdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = laneId / 8 % 2 + i * 2; - int row = (laneId / 16 * 8 + laneId % 8) ^ (col * 2); - ALdsAddr[i] = ASmemAddr + sizeof(half) * (col * 96 * 8 + - (warpId / 4) * 48 * 8 + row * 8); - } - - // BTile lds addr - uint32_t BLdsAddr[4]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int col = (laneId / 8 % 2 + i * 2) ^ (laneId % 8); - int row = laneId / 16 * 8 + laneId % 8; - BLdsAddr[i] = - BSmemAddr + sizeof(half) * (row * 256 + (warpId % 4) * 64 + col * 8); - } - - uint32_t kTiles = (k + 31) / 32; - - // load 1'st tile to shared memory - { - uint32_t firstKTile = k - (kTiles * 32 - 32); - uint32_t ASrcSize = threadIdx.x % 8 * 4 < firstKTile ? 8 : 0; - uint32_t BSrcSize = threadIdx.x / 8 < firstKTile ? 16 : 0; - -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts64(AStsAddr + i * 32 * 8 * sizeof(half), ALdgPtr[i], ASrcSize, - true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), BSrcSize, - (BLdgGuard & (1u << i)) != 0); - } - LdgStsGroupCommit(); - -// ldg pointer for the the next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += firstKTile * sizeof(half); - } - BLdgPtr += firstKTile * n * sizeof(half); - } - -// load 2'st to (N-stages - 1) tiles to shared memory -#pragma unroll - for (int prefetchIter = 1; prefetchIter < 4; ++prefetchIter) { - if (prefetchIter < kTiles) { -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts64( - AStsAddr + prefetchIter * 1024 * 22 + i * 32 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + prefetchIter * 1024 * 22 + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1u << i)) != 0); - } - -// ldg pointer for the the next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - } - LdgStsGroupCommit(); - } - - // wait for the 1'st tile - LdgStsGroupWait<3>(); - __syncthreads(); - - // smem double buffer offset - uint32_t ldsOffset = 0; - uint32_t stsOffset = 88 * 1024; - - // A, B and C register fragment - uint32_t AFrag[2][3][4]; - uint32_t BFrag[2][4][4]; - uint32_t CFrag[3][4][8]; -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { -#pragma unroll - for (int p = 0; p < 8; ++p) { - CFrag[i][j][p] = 0; - } - } - } - -// load 1'st fragment -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4(AFrag[0][i][0], AFrag[0][i][1], AFrag[0][i][2], AFrag[0][i][3], - ALdsAddr[0] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[0][i][0], BFrag[0][i][1], BFrag[0][i][2], BFrag[0][i][3], - BLdsAddr[i] + ldsOffset); - } - - if (tileIdX * 256 + 256 <= n) { - // matrixB CTA tile is full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 88 * 1024 ? ldsOffset + 22 * 1024 : 0; - stsOffset = stsOffset < 88 * 1024 ? stsOffset + 22 * 1024 : 0; - -// ldg pointer for the next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts64(AStsAddr + stsOffset + i * 32 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), true); - } - LdgStsGroupCommit(); - } - } - } - } else { - // matrixB CTA tile is not full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 88 * 1024 ? ldsOffset + 22 * 1024 : 0; - stsOffset = stsOffset < 88 * 1024 ? stsOffset + 22 * 1024 : 0; - -// ldg pointer for next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts64(AStsAddr + stsOffset + i * 32 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1U << i)) != 0); - } - LdgStsGroupCommit(); - } - } - } - } - - // k-tiles loop without prefetch - for (; kTiles > 0; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 88 * 1024 ? ldsOffset + 22 * 1024 : 0; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4( - AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], BFrag[kFrag % 2][j]); - } - } - - // dummy LdgStsGroupCommit to make LdgStsGroupWait work - if (kFrag == 0) { - LdgStsGroupCommit(); - } - } - } - - uint32_t CStsIdxX = warpId % 4 * 64 + laneId % 4; - uint32_t CStsIdxY = warpId / 4 * 48 + laneId / 4; - uint32_t* CStsPtr = - reinterpret_cast(smem) + CStsIdxY * 260 + CStsIdxX; - const float4* CLdsPtr = reinterpret_cast(smem) + - threadIdx.x / 64 * 65 + threadIdx.x % 64; - - uint32_t mIdx = tileIdY * 96 + threadIdx.x / 64; - uint32_t nIdx = tileIdX * 256 + threadIdx.x % 64 * 4; - - half* CStgPtr = C + mIdx * n + nIdx; - bool nGuard = nIdx < n; - - __syncthreads(); -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - CStsPtr[i * 16 * 260 + j * 16] = CFrag[i][j][0]; - CStsPtr[i * 16 * 260 + j * 16 + 4] = CFrag[i][j][1]; - CStsPtr[i * 16 * 260 + j * 16 + 8] = CFrag[i][j][2]; - CStsPtr[i * 16 * 260 + j * 16 + 12] = CFrag[i][j][3]; - - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16] = CFrag[i][j][4]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 4] = CFrag[i][j][5]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 8] = CFrag[i][j][6]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 12] = CFrag[i][j][7]; - } - } - __syncthreads(); - - float4 CLdsReg[24]; -#pragma unroll - for (int i = 0; i < 24; ++i) { - CLdsReg[i] = CLdsPtr[i * 4 * 65]; - } - - half2 CStgReg[24][2]; -#pragma unroll - for (int i = 0; i < 24; ++i) { - asm("{.reg .b16 h0, h1, h2, h3;\n" - " cvt.rn.f16.f32 h0, %2;\n" - " cvt.rn.f16.f32 h1, %3;\n" - " cvt.rn.f16.f32 h2, %4;\n" - " cvt.rn.f16.f32 h3, %5;\n" - " mov.b32 %0, {h0, h1};\n" - " mov.b32 %1, {h2, h3};}" - : "=r"(reinterpret_cast(CStgReg[i][0])), - "=r"(reinterpret_cast(CStgReg[i][1])) - : "f"(CLdsReg[i].x), "f"(CLdsReg[i].y), "f"(CLdsReg[i].z), - "f"(CLdsReg[i].w)); - } - -// C_tile stg -#pragma unroll - for (int i = 0; i < 24; ++i) { - Stg64(CStgReg[i][0], CStgReg[i][1], CStgPtr + i * 4 * n, - mIdx + i * 4 < m && nGuard); - } -} - -/** - * m_tile: 64 - * n_tile: 256 - * k_tile: 32x5 - * warp_tile: 32x64 - * CTA: 2x4 warps - * smem size: 100KB - */ -__device__ __forceinline__ void hgemm_f32_m64n256_k32x5_hmma161616_ldg8_loop( - const half* A, const half* B, const uint32_t* matARowIdx, half* C, - char* smem, const uint32_t& m, const uint32_t& n, const uint32_t& k, - const uint32_t& tileIdX, const uint32_t& tileIdY, - const uint32_t& BLdgStep) { - uint32_t warpId = threadIdx.x / 32; - uint32_t laneId = threadIdx.x % 32; - - uint32_t matARowId; - if (tileIdY * 64 + threadIdx.x / 4 < m) { - asm("ld.global.ca.b32 %0, [%1];" - : "=r"(matARowId) - : "l"(matARowIdx + tileIdY * 64 + threadIdx.x / 4)); - } else { - // map the out-of-bound threads to row0 of matrixA, - // to avoid predicated ld instructions - matARowId = 0; - } - - const char* ALdgPtr = - reinterpret_cast(A + matARowId * k + threadIdx.x % 4 * 8); - const char* BLdgPtr = reinterpret_cast( - B + (threadIdx.x / 8) * n + tileIdX * 256 + (threadIdx.x % 8) * 8); - - // LdgGuard to avoid LDG out of bound - uint32_t BLdgGuard = 0; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int nIdx = tileIdX * 256 + (threadIdx.x % 8) * 8 + i * 64; - if (nIdx < n) { - BLdgGuard |= (1U << i); - } - } - - uint32_t ASmemAddr = SmemU32Addr(smem); - uint32_t BSmemAddr = SmemU32Addr(smem + 64 * 32 * sizeof(half)); - - uint32_t AStsAddr = - ASmemAddr + - sizeof(half) * ((threadIdx.x % 4) * (64 * 8) + - ((threadIdx.x / 4) ^ (threadIdx.x % 4 * 2)) * 8); - uint32_t BStsAddr = - BSmemAddr + - sizeof(half) * ((threadIdx.x / 8) * 256 + - ((threadIdx.x % 8) ^ (threadIdx.x / 8 % 8)) * 8); - - // ATile lds addr - uint32_t ALdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = laneId / 8 % 2 + i * 2; - int row = (laneId / 16 * 8 + laneId % 8) ^ (col * 2); - ALdsAddr[i] = ASmemAddr + sizeof(half) * (col * 64 * 8 + - (warpId / 4) * 32 * 8 + row * 8); - } - - // BTile lds addr - uint32_t BLdsAddr[4]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int col = (laneId / 8 % 2 + i * 2) ^ (laneId % 8); - int row = laneId / 16 * 8 + laneId % 8; - BLdsAddr[i] = - BSmemAddr + sizeof(half) * (row * 256 + (warpId % 4) * 64 + col * 8); - } - - uint32_t kTiles = (k + 31) / 32; - - // load 1'st tile to shared memory - { - uint32_t firstKTile = k - (kTiles * 32 - 32); - uint32_t ASrcSize = threadIdx.x % 4 * 8 < firstKTile ? 16 : 0; - uint32_t BSrcSize = threadIdx.x / 8 < firstKTile ? 16 : 0; - - LdgSts128(AStsAddr, ALdgPtr, ASrcSize, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), BSrcSize, - (BLdgGuard & (1u << i)) != 0); - } - LdgStsGroupCommit(); - - // ldg pointer for the next tile - ALdgPtr += firstKTile * sizeof(half); - BLdgPtr += firstKTile * n * sizeof(half); - } - -// load 2'st to (N-stages - 1) tiles to shared memory -#pragma unroll - for (int prefetchIter = 1; prefetchIter < 4; ++prefetchIter) { - if (prefetchIter < kTiles) { - LdgSts128(AStsAddr + prefetchIter * 1024 * 20, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + prefetchIter * 1024 * 20 + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1u << i)) != 0); - } - - // ldg pointer for the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - } - LdgStsGroupCommit(); - } - - // wait for the 1'st tile - LdgStsGroupWait<3>(); - __syncthreads(); - - // smem double buffer offset - uint32_t ldsOffset = 0; - uint32_t stsOffset = 80 * 1024; - - // A, B and C register fragment - uint32_t AFrag[2][2][4]; - uint32_t BFrag[2][4][4]; - uint32_t CFrag[2][4][8]; -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { -#pragma unroll - for (int p = 0; p < 8; ++p) { - CFrag[i][j][p] = 0; - } - } - } - -// load 1'st fragment -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4(AFrag[0][i][0], AFrag[0][i][1], AFrag[0][i][2], AFrag[0][i][3], - ALdsAddr[0] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[0][i][0], BFrag[0][i][1], BFrag[0][i][2], BFrag[0][i][3], - BLdsAddr[i] + ldsOffset); - } - - if (tileIdX * 256 + 256 <= n) { - // matrixB CTA tile is full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 80 * 1024 ? ldsOffset + 20 * 1024 : 0; - stsOffset = stsOffset < 80 * 1024 ? stsOffset + 20 * 1024 : 0; - - // ldg pointer for the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { - LdgSts128(AStsAddr + stsOffset, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), true); - } - LdgStsGroupCommit(); - } - } - } - } else { - // matrixB CTA tile is not full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 80 * 1024 ? ldsOffset + 20 * 1024 : 0; - stsOffset = stsOffset < 80 * 1024 ? stsOffset + 20 * 1024 : 0; - - // ldg pointer for the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { - LdgSts128(AStsAddr + stsOffset, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1U << i)) != 0); - } - LdgStsGroupCommit(); - } - } - } - } - - // k-tiles loop without prefetch - for (; kTiles > 0; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 80 * 1024 ? ldsOffset + 20 * 1024 : 0; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4( - AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], BFrag[kFrag % 2][j]); - } - } - - // dummy LdgStsGroupCommit to make LdgStsGroupWait work - if (kFrag == 0) { - LdgStsGroupCommit(); - } - } - } - - uint32_t CStsIdxX = warpId % 4 * 64 + laneId % 4; - uint32_t CStsIdxY = warpId / 4 * 32 + laneId / 4; - uint32_t* CStsPtr = - reinterpret_cast(smem) + CStsIdxY * 260 + CStsIdxX; - const float4* CLdsPtr = reinterpret_cast(smem) + - threadIdx.x / 64 * 65 + threadIdx.x % 64; - - uint32_t mIdx = tileIdY * 64 + threadIdx.x / 64; - uint32_t nIdx = tileIdX * 256 + threadIdx.x % 64 * 4; - - half* CStgPtr = C + mIdx * n + nIdx; - bool nGuard = nIdx < n; - - __syncthreads(); -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 4; ++j) { - CStsPtr[i * 16 * 260 + j * 16] = CFrag[i][j][0]; - CStsPtr[i * 16 * 260 + j * 16 + 4] = CFrag[i][j][1]; - CStsPtr[i * 16 * 260 + j * 16 + 8] = CFrag[i][j][2]; - CStsPtr[i * 16 * 260 + j * 16 + 12] = CFrag[i][j][3]; - - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16] = CFrag[i][j][4]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 4] = CFrag[i][j][5]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 8] = CFrag[i][j][6]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 12] = CFrag[i][j][7]; - } - } - __syncthreads(); - - float4 CLdsReg[16]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - CLdsReg[i] = CLdsPtr[i * 4 * 65]; - } - - half2 CStgReg[16][2]; -#pragma unroll - for (int i = 0; i < 16; ++i) { - asm("{.reg .b16 h0, h1, h2, h3;\n" - " cvt.rn.f16.f32 h0, %2;\n" - " cvt.rn.f16.f32 h1, %3;\n" - " cvt.rn.f16.f32 h2, %4;\n" - " cvt.rn.f16.f32 h3, %5;\n" - " mov.b32 %0, {h0, h1};\n" - " mov.b32 %1, {h2, h3};}" - : "=r"(reinterpret_cast(CStgReg[i][0])), - "=r"(reinterpret_cast(CStgReg[i][1])) - : "f"(CLdsReg[i].x), "f"(CLdsReg[i].y), "f"(CLdsReg[i].z), - "f"(CLdsReg[i].w)); - } - -// C_tile stg -#pragma unroll - for (int i = 0; i < 16; ++i) { - Stg64(CStgReg[i][0], CStgReg[i][1], CStgPtr + i * 4 * n, - mIdx + i * 4 < m && nGuard); - } -} - -/** - * m_tile: 48 - * n_tile: 256 - * k_tile: 32x5 - * warp_tile: 48x32 - * CTA: 1x8 warps - * smem size: 95KB - */ -__device__ __forceinline__ void hgemm_f32_m48n256_k32x5_hmma161616_ldg2_loop( - const half* A, const half* B, const uint32_t* matARowIdx, half* C, - char* smem, const uint32_t& m, const uint32_t& n, const uint32_t& k, - const uint32_t& tileIdX, const uint32_t& tileIdY, - const uint32_t& BLdgStep) { - uint32_t warpId = threadIdx.x / 32; - uint32_t laneId = threadIdx.x % 32; - - uint32_t matARowId[3]; -#pragma unroll - for (int i = 0; i < 3; ++i) { - int mIdx = tileIdY * 48 + threadIdx.x / 16 + i * 16; - if (mIdx < m) { - asm("ld.global.ca.b32 %0, [%1];" - : "=r"(matARowId[i]) - : "l"(matARowIdx + mIdx)); - } else { - // map the out-of-bound threads to row0 of matrixA, - // to avoid predicated ld instructions - matARowId[i] = 0; - } - } - - const char* ALdgPtr[3]; -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] = reinterpret_cast(A + matARowId[i] * k + - threadIdx.x % 16 * 2); - } - const char* BLdgPtr = reinterpret_cast( - B + (threadIdx.x / 8) * n + tileIdX * 256 + (threadIdx.x % 8) * 8); - - // LdgGuard to avoid LDG out of bound - uint32_t BLdgGuard = 0; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int nIdx = tileIdX * 256 + (threadIdx.x % 8) * 8 + i * 64; - if (nIdx < n) { - BLdgGuard |= (1U << i); - } - } - - uint32_t ASmemAddr = SmemU32Addr(smem); - uint32_t BSmemAddr = SmemU32Addr(smem + 48 * 32 * sizeof(half)); - - uint32_t AStsAddr = - ASmemAddr + - sizeof(half) * ((threadIdx.x % 16 / 4) * (48 * 8) + - ((threadIdx.x / 16) ^ (threadIdx.x % 16 / 4 * 2)) * 8 + - threadIdx.x % 4 * 2); - uint32_t BStsAddr = - BSmemAddr + - sizeof(half) * ((threadIdx.x / 8) * 256 + - ((threadIdx.x % 8) ^ (threadIdx.x / 8 % 8)) * 8); - - // ATile lds addr - uint32_t ALdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = laneId / 8 % 2 + i * 2; - int row = (laneId / 16 * 8 + laneId % 8) ^ (col * 2); - ALdsAddr[i] = ASmemAddr + sizeof(half) * (col * 48 * 8 + row * 8); - } - - // BTile lds addr - uint32_t BLdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = (laneId / 8 % 2 + warpId % 2 * 4 + i * 2) ^ (laneId % 8); - int row = laneId / 16 * 8 + laneId % 8; - BLdsAddr[i] = - BSmemAddr + sizeof(half) * (row * 256 + (warpId / 2) * 64 + col * 8); - } - - uint32_t kTiles = (k + 31) / 32; - - // load 1'st tile to shared memory - { - uint32_t firstKTile = k - (kTiles * 32 - 32); - uint32_t ASrcSize = threadIdx.x % 16 * 2 < firstKTile ? 4 : 0; - uint32_t BSrcSize = threadIdx.x / 8 < firstKTile ? 16 : 0; - -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts32(AStsAddr + i * 16 * 8 * sizeof(half), ALdgPtr[i], ASrcSize, - true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), BSrcSize, - (BLdgGuard & (1u << i)) != 0); - } - LdgStsGroupCommit(); - -// ldg pointer for the the next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += firstKTile * sizeof(half); - } - BLdgPtr += firstKTile * n * sizeof(half); - } - -// load 2'st to (N-stages - 1) tiles to shared memory -#pragma unroll - for (int prefetchIter = 1; prefetchIter < 4; ++prefetchIter) { - if (prefetchIter < kTiles) { -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts32( - AStsAddr + prefetchIter * 1024 * 19 + i * 16 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + prefetchIter * 1024 * 19 + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1u << i)) != 0); - } - -// ldg pointer for the the next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - } - LdgStsGroupCommit(); - } - - // wait for the 1'st tile - LdgStsGroupWait<3>(); - __syncthreads(); - - // smem double buffer offset - uint32_t ldsOffset = 0; - uint32_t stsOffset = 76 * 1024; - - // A, B and C register fragment - uint32_t AFrag[2][3][4]; - uint32_t BFrag[2][2][4]; - uint32_t CFrag[3][2][8]; -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { -#pragma unroll - for (int p = 0; p < 8; ++p) { - CFrag[i][j][p] = 0; - } - } - } - -// load 1'st fragment -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4(AFrag[0][i][0], AFrag[0][i][1], AFrag[0][i][2], AFrag[0][i][3], - ALdsAddr[0] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[0][i][0], BFrag[0][i][1], BFrag[0][i][2], BFrag[0][i][3], - BLdsAddr[i] + ldsOffset); - } - - if (tileIdX * 256 + 256 <= n) { - // matrixB CTA tile is full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 76 * 1024 ? ldsOffset + 19 * 1024 : 0; - stsOffset = stsOffset < 76 * 1024 ? stsOffset + 19 * 1024 : 0; - -// ldg pointer for the next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts32(AStsAddr + stsOffset + i * 16 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), true); - } - LdgStsGroupCommit(); - } - } - } - } else { - // matrixB CTA tile is not full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 76 * 1024 ? ldsOffset + 19 * 1024 : 0; - stsOffset = stsOffset < 76 * 1024 ? stsOffset + 19 * 1024 : 0; - -// ldg pointer for next tile -#pragma unroll - for (int i = 0; i < 3; ++i) { - ALdgPtr[i] += 32 * sizeof(half); - } - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { -#pragma unroll - for (int i = 0; i < 3; ++i) { - LdgSts32(AStsAddr + stsOffset + i * 16 * 8 * sizeof(half), - ALdgPtr[i], true); - } -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1U << i)) != 0); - } - LdgStsGroupCommit(); - } - } - } - } - - // k-tiles loop without prefetch - for (; kTiles > 0; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 76 * 1024 ? ldsOffset + 19 * 1024 : 0; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 3; ++i) { - Ldsm4( - AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], BFrag[kFrag % 2][j]); - } - } - - // dummy LdgStsGroupCommit to make LdgStsGroupWait work - if (kFrag == 0) { - LdgStsGroupCommit(); - } - } - } - - uint32_t CStsIdxX = warpId * 32 + laneId % 4; - uint32_t CStsIdxY = laneId / 4; - uint32_t* CStsPtr = - reinterpret_cast(smem) + CStsIdxY * 260 + CStsIdxX; - const float4* CLdsPtr = reinterpret_cast(smem) + - threadIdx.x / 64 * 65 + threadIdx.x % 64; - - uint32_t mIdx = tileIdY * 48 + threadIdx.x / 64; - uint32_t nIdx = tileIdX * 256 + threadIdx.x % 64 * 4; - - half* CStgPtr = C + mIdx * n + nIdx; - bool nGuard = nIdx < n; - - __syncthreads(); -#pragma unroll - for (int i = 0; i < 3; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - CStsPtr[i * 16 * 260 + j * 16] = CFrag[i][j][0]; - CStsPtr[i * 16 * 260 + j * 16 + 4] = CFrag[i][j][1]; - CStsPtr[i * 16 * 260 + j * 16 + 8] = CFrag[i][j][2]; - CStsPtr[i * 16 * 260 + j * 16 + 12] = CFrag[i][j][3]; - - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16] = CFrag[i][j][4]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 4] = CFrag[i][j][5]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 8] = CFrag[i][j][6]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 12] = CFrag[i][j][7]; - } - } - __syncthreads(); - - float4 CLdsReg[12]; -#pragma unroll - for (int i = 0; i < 12; ++i) { - CLdsReg[i] = CLdsPtr[i * 4 * 65]; - } - - half2 CStgReg[12][2]; -#pragma unroll - for (int i = 0; i < 12; ++i) { - asm("{.reg .b16 h0, h1, h2, h3;\n" - " cvt.rn.f16.f32 h0, %2;\n" - " cvt.rn.f16.f32 h1, %3;\n" - " cvt.rn.f16.f32 h2, %4;\n" - " cvt.rn.f16.f32 h3, %5;\n" - " mov.b32 %0, {h0, h1};\n" - " mov.b32 %1, {h2, h3};}" - : "=r"(reinterpret_cast(CStgReg[i][0])), - "=r"(reinterpret_cast(CStgReg[i][1])) - : "f"(CLdsReg[i].x), "f"(CLdsReg[i].y), "f"(CLdsReg[i].z), - "f"(CLdsReg[i].w)); - } - -// C_tile stg -#pragma unroll - for (int i = 0; i < 12; ++i) { - Stg64(CStgReg[i][0], CStgReg[i][1], CStgPtr + i * 4 * n, - mIdx + i * 4 < m && nGuard); - } -} - -/** - * m_tile: 32 - * n_tile: 256 - * k_tile: 32x5 - * warp_tile: 32x32 - * CTA: 1x8 warps - * smem size: 90KB - */ -__device__ __forceinline__ void hgemm_f32_m32n256_k32x5_hmma161616_ldg4_loop( - const half* A, const half* B, const uint32_t* matARowIdx, half* C, - char* smem, const uint32_t& m, const uint32_t& n, const uint32_t& k, - const uint32_t& tileIdX, const uint32_t& tileIdY, - const uint32_t& BLdgStep) { - uint32_t warpId = threadIdx.x / 32; - uint32_t laneId = threadIdx.x % 32; - - uint32_t matARowId; - if (tileIdY * 32 + threadIdx.x / 8 < m) { - asm("ld.global.ca.b32 %0, [%1];" - : "=r"(matARowId) - : "l"(matARowIdx + tileIdY * 32 + threadIdx.x / 8)); - } else { - // map the out-of-bound threads to row0 of matrixA, - // to avoid predicated ld instructions - matARowId = 0; - } - - const char* ALdgPtr = - reinterpret_cast(A + matARowId * k + threadIdx.x % 8 * 4); - const char* BLdgPtr = reinterpret_cast( - B + (threadIdx.x / 8) * n + tileIdX * 256 + (threadIdx.x % 8) * 8); - - // LdgGuard to avoid LDG out of bound - uint32_t BLdgGuard = 0; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int nIdx = tileIdX * 256 + (threadIdx.x % 8) * 8 + i * 64; - if (nIdx < n) { - BLdgGuard |= (1U << i); - } - } - - uint32_t ASmemAddr = SmemU32Addr(smem); - uint32_t BSmemAddr = SmemU32Addr(smem + 32 * 32 * sizeof(half)); - - uint32_t AStsAddr = - ASmemAddr + - sizeof(half) * ((threadIdx.x % 8 / 2) * (32 * 8) + - ((threadIdx.x / 8) ^ (threadIdx.x % 8 / 2 * 2)) * 8 + - threadIdx.x % 2 * 4); - uint32_t BStsAddr = - BSmemAddr + - sizeof(half) * ((threadIdx.x / 8) * 256 + - ((threadIdx.x % 8) ^ (threadIdx.x / 8 % 8)) * 8); - - // ATile lds addr - uint32_t ALdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = laneId / 8 % 2 + i * 2; - int row = (laneId / 16 * 8 + laneId % 8) ^ (col * 2); - ALdsAddr[i] = ASmemAddr + sizeof(half) * (col * 32 * 8 + row * 8); - } - - // BTile lds addr - uint32_t BLdsAddr[2]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int col = (laneId / 8 % 2 + warpId % 2 * 4 + i * 2) ^ (laneId % 8); - int row = laneId / 16 * 8 + laneId % 8; - BLdsAddr[i] = - BSmemAddr + sizeof(half) * (row * 256 + (warpId / 2) * 64 + col * 8); - } - - uint32_t kTiles = (k + 31) / 32; - - // load 1'st tile to shared memory - { - uint32_t firstKTile = k - (kTiles * 32 - 32); - uint32_t ASrcSize = threadIdx.x % 8 * 4 < firstKTile ? 8 : 0; - uint32_t BSrcSize = threadIdx.x / 8 < firstKTile ? 16 : 0; - - LdgSts64(AStsAddr, ALdgPtr, ASrcSize, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), BSrcSize, - (BLdgGuard & (1u << i)) != 0); - } - LdgStsGroupCommit(); - - // ldg pointer for the the next tile - ALdgPtr += firstKTile * sizeof(half); - BLdgPtr += firstKTile * n * sizeof(half); - } - -// load 2'st to (N-stages - 1) tiles to shared memory -#pragma unroll - for (int prefetchIter = 1; prefetchIter < 4; ++prefetchIter) { - if (prefetchIter < kTiles) { - LdgSts64(AStsAddr + prefetchIter * 1024 * 18, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + prefetchIter * 1024 * 18 + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1u << i)) != 0); - } - - // ldg pointer for the the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - } - LdgStsGroupCommit(); - } - - // wait for the 1'st tile - LdgStsGroupWait<3>(); - __syncthreads(); - - // smem double buffer offset - uint32_t ldsOffset = 0; - uint32_t stsOffset = 72 * 1024; - - // A, B and C register fragment - uint32_t AFrag[2][2][4]; - uint32_t BFrag[2][2][4]; - uint32_t CFrag[2][2][8]; -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { -#pragma unroll - for (int p = 0; p < 8; ++p) { - CFrag[i][j][p] = 0; - } - } - } - -// load 1'st fragment -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4(AFrag[0][i][0], AFrag[0][i][1], AFrag[0][i][2], AFrag[0][i][3], - ALdsAddr[0] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[0][i][0], BFrag[0][i][1], BFrag[0][i][2], BFrag[0][i][3], - BLdsAddr[i] + ldsOffset); - } - - if (tileIdX * 256 + 256 <= n) { - // matrixB CTA tile is full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 72 * 1024 ? ldsOffset + 18 * 1024 : 0; - stsOffset = stsOffset < 72 * 1024 ? stsOffset + 18 * 1024 : 0; - - // ldg pointer for the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { - LdgSts64(AStsAddr + stsOffset, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), true); - } - LdgStsGroupCommit(); - } - } - } - } else { - // matrixB CTA tile is not full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 72 * 1024 ? ldsOffset + 18 * 1024 : 0; - stsOffset = stsOffset < 72 * 1024 ? stsOffset + 18 * 1024 : 0; - - // ldg pointer for next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4(AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + - i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], - BFrag[kFrag % 2][j]); - } - } - - // tile prefetch - if (kFrag == 0) { - LdgSts64(AStsAddr + stsOffset, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1U << i)) != 0); - } - LdgStsGroupCommit(); - } - } - } - } - - // k-tiles loop without prefetch - for (; kTiles > 0; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 72 * 1024 ? ldsOffset + 18 * 1024 : 0; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - -// load next A&B fragment from shared memory to register -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4( - AFrag[(kFrag + 1) % 2][i][0], AFrag[(kFrag + 1) % 2][i][1], - AFrag[(kFrag + 1) % 2][i][2], AFrag[(kFrag + 1) % 2][i][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset + i * 16 * 8 * sizeof(half)); - } -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - Hmma161616F32(CFrag[i][j], AFrag[kFrag % 2][i], BFrag[kFrag % 2][j]); - } - } - - // dummy LdgStsGroupCommit to make LdgStsGroupWait work - if (kFrag == 0) { - LdgStsGroupCommit(); - } - } - } - - uint32_t CStsIdxX = warpId * 32 + laneId % 4; - uint32_t CStsIdxY = laneId / 4; - uint32_t* CStsPtr = - reinterpret_cast(smem) + CStsIdxY * 260 + CStsIdxX; - const float4* CLdsPtr = reinterpret_cast(smem) + - threadIdx.x / 64 * 65 + threadIdx.x % 64; - - uint32_t mIdx = tileIdY * 32 + threadIdx.x / 64; - uint32_t nIdx = tileIdX * 256 + threadIdx.x % 64 * 4; - - half* CStgPtr = C + mIdx * n + nIdx; - bool nGuard = nIdx < n; - - __syncthreads(); -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int j = 0; j < 2; ++j) { - CStsPtr[i * 16 * 260 + j * 16] = CFrag[i][j][0]; - CStsPtr[i * 16 * 260 + j * 16 + 4] = CFrag[i][j][1]; - CStsPtr[i * 16 * 260 + j * 16 + 8] = CFrag[i][j][2]; - CStsPtr[i * 16 * 260 + j * 16 + 12] = CFrag[i][j][3]; - - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16] = CFrag[i][j][4]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 4] = CFrag[i][j][5]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 8] = CFrag[i][j][6]; - CStsPtr[i * 16 * 260 + 8 * 260 + j * 16 + 12] = CFrag[i][j][7]; - } - } - __syncthreads(); - - float4 CLdsReg[8]; -#pragma unroll - for (int i = 0; i < 8; ++i) { - CLdsReg[i] = CLdsPtr[i * 4 * 65]; - } - - half2 CStgReg[8][2]; -#pragma unroll - for (int i = 0; i < 8; ++i) { - asm("{.reg .b16 h0, h1, h2, h3;\n" - " cvt.rn.f16.f32 h0, %2;\n" - " cvt.rn.f16.f32 h1, %3;\n" - " cvt.rn.f16.f32 h2, %4;\n" - " cvt.rn.f16.f32 h3, %5;\n" - " mov.b32 %0, {h0, h1};\n" - " mov.b32 %1, {h2, h3};}" - : "=r"(reinterpret_cast(CStgReg[i][0])), - "=r"(reinterpret_cast(CStgReg[i][1])) - : "f"(CLdsReg[i].x), "f"(CLdsReg[i].y), "f"(CLdsReg[i].z), - "f"(CLdsReg[i].w)); - } - -// C_tile stg -#pragma unroll - for (int i = 0; i < 8; ++i) { - Stg64(CStgReg[i][0], CStgReg[i][1], CStgPtr + i * 4 * n, - mIdx + i * 4 < m && nGuard); - } -} - -/** - * m_tile: 16 - * n_tile: 256 - * k_tile: 32x5 - * warp_tile: 16x32 - * CTA: 1x8 warps - * smem size: 85KB - */ -__device__ __forceinline__ void hgemm_f32_m16n256_k32x5_hmma161616_ldg2_loop( - const half* A, const half* B, const uint32_t* matARowIdx, half* C, - char* smem, const uint32_t& m, const uint32_t& n, const uint32_t& k, - const uint32_t& tileIdX, const uint32_t& tileIdY, - const uint32_t& BLdgStep) { - uint32_t warpId = threadIdx.x / 32; - uint32_t laneId = threadIdx.x % 32; - - uint32_t matARowId; - if (tileIdY * 16 + threadIdx.x / 16 < m) { - asm("ld.global.ca.b32 %0, [%1];" - : "=r"(matARowId) - : "l"(matARowIdx + tileIdY * 16 + threadIdx.x / 16)); - } else { - // map the out-of-bound threads to row0 of matrixA, - // to avoid predicated ld instructions - matARowId = 0; - } - - const char* ALdgPtr = - reinterpret_cast(A + matARowId * k + threadIdx.x % 16 * 2); - const char* BLdgPtr = reinterpret_cast( - B + (threadIdx.x / 8) * n + tileIdX * 256 + (threadIdx.x % 8) * 8); - - // LdgGuard to avoid LDG out of bound - uint32_t BLdgGuard = 0; -#pragma unroll - for (int i = 0; i < 4; ++i) { - int nIdx = tileIdX * 256 + (threadIdx.x % 8) * 8 + i * 64; - if (nIdx < n) { - BLdgGuard |= (1U << i); - } - } - - uint32_t ASmemAddr = SmemU32Addr(smem); - uint32_t BSmemAddr = SmemU32Addr(smem + 16 * 32 * sizeof(half)); - - uint32_t AStsAddr = - ASmemAddr + - sizeof(half) * ((threadIdx.x % 16 / 4) * (16 * 8) + - ((threadIdx.x / 16) ^ (threadIdx.x % 16 / 4 * 2)) * 8 + - threadIdx.x % 4 * 2); - uint32_t BStsAddr = - BSmemAddr + - sizeof(half) * ((threadIdx.x / 8) * 256 + - ((threadIdx.x % 8) ^ (threadIdx.x / 8 % 8)) * 8); - - // ATile lds addr - uint32_t ALdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = laneId / 8 % 2 + i * 2; - int row = (laneId / 16 * 8 + laneId % 8) ^ (col * 2); - ALdsAddr[i] = ASmemAddr + sizeof(half) * (col * 16 * 8 + row * 8); - } - - // BTile lds addr - uint32_t BLdsAddr[2]; -#pragma unroll - for (int i = 0; i < 2; ++i) { - int col = (laneId / 8 % 2 + warpId % 2 * 4 + i * 2) ^ (laneId % 8); - int row = laneId / 16 * 8 + laneId % 8; - BLdsAddr[i] = - BSmemAddr + sizeof(half) * (row * 256 + (warpId / 2) * 64 + col * 8); - } - - uint32_t kTiles = (k + 31) / 32; - - // load 1'st tile to shared memory - { - uint32_t firstKTile = k - (kTiles * 32 - 32); - uint32_t ASrcSize = threadIdx.x % 16 * 2 < firstKTile ? 4 : 0; - uint32_t BSrcSize = threadIdx.x / 8 < firstKTile ? 16 : 0; - - LdgSts32(AStsAddr, ALdgPtr, ASrcSize, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), BSrcSize, - (BLdgGuard & (1u << i)) != 0); - } - LdgStsGroupCommit(); - - // ldg pointer for the the next tile - ALdgPtr += firstKTile * sizeof(half); - BLdgPtr += firstKTile * n * sizeof(half); - } - -// load 2'st to (N-stages - 1) tiles to shared memory -#pragma unroll - for (int prefetchIter = 1; prefetchIter < 4; ++prefetchIter) { - if (prefetchIter < kTiles) { - LdgSts32(AStsAddr + prefetchIter * 1024 * 17, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + prefetchIter * 1024 * 17 + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1u << i)) != 0); - } - - // ldg pointer for the the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - } - LdgStsGroupCommit(); - } - - // wait for the 1'st tile - LdgStsGroupWait<3>(); - __syncthreads(); - - // smem double buffer offset - uint32_t ldsOffset = 0; - uint32_t stsOffset = 68 * 1024; - - // A, B and C register fragment - uint32_t AFrag[2][4]; - uint32_t BFrag[2][2][4]; - uint32_t CFrag[2][8]; -#pragma unroll - for (int i = 0; i < 2; ++i) { -#pragma unroll - for (int p = 0; p < 8; ++p) { - CFrag[i][p] = 0; - } - } - - // load 1'st fragment - Ldsm4(AFrag[0][0], AFrag[0][1], AFrag[0][2], AFrag[0][3], - ALdsAddr[0] + ldsOffset); -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[0][i][0], BFrag[0][i][1], BFrag[0][i][2], BFrag[0][i][3], - BLdsAddr[i] + ldsOffset); - } - - if (tileIdX * 256 + 256 <= n) { - // matrixB CTA tile is full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 68 * 1024 ? ldsOffset + 17 * 1024 : 0; - stsOffset = stsOffset < 68 * 1024 ? stsOffset + 17 * 1024 : 0; - - // ldg pointer for the next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - - // load next A&B fragment from shared memory to register - Ldsm4(AFrag[(kFrag + 1) % 2][0], AFrag[(kFrag + 1) % 2][1], - AFrag[(kFrag + 1) % 2][2], AFrag[(kFrag + 1) % 2][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset); -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { - Hmma161616F32(CFrag[i], AFrag[kFrag % 2], BFrag[kFrag % 2][i]); - } - - // tile prefetch - if (kFrag == 0) { - LdgSts32(AStsAddr + stsOffset, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), true); - } - LdgStsGroupCommit(); - } - } - } - } else { - // matrixB CTA tile is not full - for (; kTiles > 4; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 68 * 1024 ? ldsOffset + 17 * 1024 : 0; - stsOffset = stsOffset < 68 * 1024 ? stsOffset + 17 * 1024 : 0; - - // ldg pointer for next tile - ALdgPtr += 32 * sizeof(half); - BLdgPtr += BLdgStep; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - - // load next A&B fragment from shared memory to register - Ldsm4(AFrag[(kFrag + 1) % 2][0], AFrag[(kFrag + 1) % 2][1], - AFrag[(kFrag + 1) % 2][2], AFrag[(kFrag + 1) % 2][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset); -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { - Hmma161616F32(CFrag[i], AFrag[kFrag % 2], BFrag[kFrag % 2][i]); - } - - // tile prefetch - if (kFrag == 0) { - LdgSts32(AStsAddr + stsOffset, ALdgPtr, true); -#pragma unroll - for (int i = 0; i < 4; ++i) { - LdgSts128(BStsAddr + stsOffset + i * 64 * sizeof(half), - BLdgPtr + i * 64 * sizeof(half), - (BLdgGuard & (1U << i)) != 0); - } - LdgStsGroupCommit(); - } - } - } - } - - // k-tiles loop without prefetch - for (; kTiles > 0; --kTiles) { -#pragma unroll - for (int kFrag = 0; kFrag < 2; ++kFrag) { - // store next A&B tile to shared memory - if (kFrag == 1) { - // switch double buffer - ldsOffset = ldsOffset < 68 * 1024 ? ldsOffset + 17 * 1024 : 0; - - LdgStsGroupWait<3>(); - __syncthreads(); - } - - // load next A&B fragment from shared memory to register - Ldsm4(AFrag[(kFrag + 1) % 2][0], AFrag[(kFrag + 1) % 2][1], - AFrag[(kFrag + 1) % 2][2], AFrag[(kFrag + 1) % 2][3], - ALdsAddr[(kFrag + 1) % 2] + ldsOffset); -#pragma unroll - for (int i = 0; i < 2; ++i) { - Ldsm4Trans(BFrag[(kFrag + 1) % 2][i][0], BFrag[(kFrag + 1) % 2][i][1], - BFrag[(kFrag + 1) % 2][i][2], BFrag[(kFrag + 1) % 2][i][3], - BLdsAddr[i] + ldsOffset + - ((kFrag + 1) % 2) * (16 * 256) * sizeof(half)); - } - -// HMMA loop -#pragma unroll - for (int i = 0; i < 2; ++i) { - Hmma161616F32(CFrag[i], AFrag[kFrag % 2], BFrag[kFrag % 2][i]); - } - - // dummy LdgStsGroupCommit to make LdgStsGroupWait work - if (kFrag == 0) { - LdgStsGroupCommit(); - } - } - } - - uint32_t CStsIdxX = warpId * 32 + laneId % 4; - uint32_t CStsIdxY = laneId / 4; - uint32_t* CStsPtr = - reinterpret_cast(smem) + CStsIdxY * 260 + CStsIdxX; - const float4* CLdsPtr = reinterpret_cast(smem) + - threadIdx.x / 64 * 65 + threadIdx.x % 64; - - uint32_t mIdx = tileIdY * 16 + threadIdx.x / 64; - uint32_t nIdx = tileIdX * 256 + threadIdx.x % 64 * 4; - - half* CStgPtr = C + mIdx * n + nIdx; - bool nGuard = nIdx < n; - - __syncthreads(); -#pragma unroll - for (int i = 0; i < 2; ++i) { - CStsPtr[i * 16] = CFrag[i][0]; - CStsPtr[i * 16 + 4] = CFrag[i][1]; - CStsPtr[i * 16 + 8] = CFrag[i][2]; - CStsPtr[i * 16 + 12] = CFrag[i][3]; - - CStsPtr[8 * 260 + i * 16] = CFrag[i][4]; - CStsPtr[8 * 260 + i * 16 + 4] = CFrag[i][5]; - CStsPtr[8 * 260 + i * 16 + 8] = CFrag[i][6]; - CStsPtr[8 * 260 + i * 16 + 12] = CFrag[i][7]; - } - __syncthreads(); - - float4 CLdsReg[4]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - CLdsReg[i] = CLdsPtr[i * 4 * 65]; - } - - half2 CStgReg[4][2]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - asm("{.reg .b16 h0, h1, h2, h3;\n" - " cvt.rn.f16.f32 h0, %2;\n" - " cvt.rn.f16.f32 h1, %3;\n" - " cvt.rn.f16.f32 h2, %4;\n" - " cvt.rn.f16.f32 h3, %5;\n" - " mov.b32 %0, {h0, h1};\n" - " mov.b32 %1, {h2, h3};}" - : "=r"(reinterpret_cast(CStgReg[i][0])), - "=r"(reinterpret_cast(CStgReg[i][1])) - : "f"(CLdsReg[i].x), "f"(CLdsReg[i].y), "f"(CLdsReg[i].z), - "f"(CLdsReg[i].w)); - } - -// C_tile stg -#pragma unroll - for (int i = 0; i < 4; ++i) { - Stg64(CStgReg[i][0], CStgReg[i][1], CStgPtr + i * 4 * n, - mIdx + i * 4 < m && nGuard); - } -} - -__global__ - __launch_bounds__(256) void hgemm_f32_n256_k32x5_hmma161616_ldg8_kernel( - const half* A, const half* B, half* C, const uint32_t* ctaIdYBarrier, - const BatchInfo* batchInfos, const uint32_t* matARowIdx, - uint32_t matARows, uint32_t n, uint32_t k, - uint32_t BLdgStep) { // 32 * n * sizeof(half) - /** - * CTA Tile Configuration: - * - * n_tile: 256, k_tile: 32x5 - * --------------------------- - * m_tile 128: 64x64 warp tile, 2x4 warps - * m_tile 96: 48x64 warp tile, 2x4 warps - * m_tile 64: 32x64 warp tile, 2x4 warps - * m_tile 48: 48x32 warp tile, 1x8 warps - * m_tile 32: 32x32 warp tile, 1x8 warps - * m_tile 16: 16x32 warp tile, 1x8 warps - */ - - // 24KB*5=120KB smem - extern __shared__ char smem[]; - - uint32_t ctaIdZ; - uint32_t laneId = threadIdx.x % 32; - asm( - // for nMatB <= 64 - "{.reg .b32 r0, r1;\n" - " .reg .pred p0, p1;\n" - " ld.global.ca.b32 r0, [%1];\n" - " ld.global.ca.b32 r1, [%1 + 128];\n" - " setp.ge.u32 p0, %%ctaid.y, r0;\n" - " setp.ge.u32 p1, %%ctaid.y, r1;\n" - " vote.sync.ballot.b32 r0, p0, 0xffffffff;\n" - " vote.sync.ballot.b32 r1, p1, 0xffffffff;\n" - " popc.b32 r0, r0;\n" - " popc.b32 r1, r1;\n" - " add.u32 %0, r0, r1;}\n" - : "=r"(ctaIdZ) - : "l"(ctaIdYBarrier + laneId)); - - // GEMM tile info - BatchInfo batchInfo; - asm("ld.global.ca.v4.b32 {%0, %1, %2, %3}, [%4];" - : "=r"(reinterpret_cast(batchInfo).x), - "=r"(reinterpret_cast(batchInfo).y), - "=r"(reinterpret_cast(batchInfo).z), - "=r"(reinterpret_cast(batchInfo).w), - : "l"(batchInfos + ctaIdZ)); - uint32_t batchId = batchInfo.batchId; - uint32_t m = batchInfo.m; - uint32_t COffset = batchInfo.COffset; - uint32_t ctaIdX = blockIdx.x; - uint32_t ctaIdY = blockIdx.y - batchInfo.ctaYOffset; - - if (m > 96) { - // m_tile 128, n_tile 256, k_tile 32x5, warp_tile 64x64, 2x4 warps - hgemm_f32_m128n256_k32x5_hmma161616_ldg8_loop( - A, B + batchId * k * n, matARowIdx + batchId * matARows, C + COffset, - smem, m, n, k, ctaIdX, ctaIdY, BLdgStep); - } else if (m > 64) { - // m_tile 96, n_tile 256, k_tile 32x5, warp_tile 48x64, 2x4 warps - hgemm_f32_m96n256_k32x5_hmma161616_ldg4_loop( - A, B + batchId * k * n, matARowIdx + batchId * matARows, C + COffset, - smem, m, n, k, ctaIdX, ctaIdY, BLdgStep); - } else if (m > 48) { - // m_tile 64, n_tile 256, k_tile 32x5, warp_tile 32x64, 2x4 warps - hgemm_f32_m64n256_k32x5_hmma161616_ldg8_loop( - A, B + batchId * k * n, matARowIdx + batchId * matARows, C + COffset, - smem, m, n, k, ctaIdX, ctaIdY, BLdgStep); - } else if (m > 32) { - // m_tile 48, n_tile 256, k_tile 32x5, warp_tile 48x32, 1x8 warps - hgemm_f32_m48n256_k32x5_hmma161616_ldg2_loop( - A, B + batchId * k * n, matARowIdx + batchId * matARows, C + COffset, - smem, m, n, k, ctaIdX, ctaIdY, BLdgStep); - } else if (m > 16) { - // m_tile 32, n_tile 256, k_tile 32x5, warp_tile 32x32, 1x8 warps - hgemm_f32_m32n256_k32x5_hmma161616_ldg4_loop( - A, B + batchId * k * n, matARowIdx + batchId * matARows, C + COffset, - smem, m, n, k, ctaIdX, ctaIdY, BLdgStep); - } else { - // m_tile 16, n_tile 256, k_tile 32x5, warp_tile 16x32, 1x8 warps - hgemm_f32_m16n256_k32x5_hmma161616_ldg2_loop( - A, B + batchId * k * n, matARowIdx + batchId * matARows, C + COffset, - smem, m, n, k, ctaIdX, ctaIdY, BLdgStep); - } -} - -template -__global__ void matA_row_idx_kernel( - const uint32_t* matBIndices, uint32_t* matARowIndices, - uint32_t* batchedGemmM, uint32_t* matCRowBatchOffset, - uint32_t size, // m * nMatBPerMatARow - uint32_t matARowIdxRShift, // log(nMatBPerMatARow) - uint32_t m, uint32_t nMatB) { - __shared__ uint32_t smem[CTA]; - smem[threadIdx.x] = 0; - __syncthreads(); - - uint32_t matARowIdx, matBIdx, stgOffset; - uint32_t idx = blockIdx.x * CTA + threadIdx.x; - if (idx < size) { - matARowIdx = idx >> matARowIdxRShift; - matBIdx = matBIndices[idx]; - stgOffset = atomicAdd(smem + matBIdx, 1); - } - __syncthreads(); - - if (threadIdx.x < nMatB) { - int ctaMatBCount = smem[threadIdx.x]; - if (ctaMatBCount != 0) { - smem[threadIdx.x] = atomicAdd(batchedGemmM + threadIdx.x, ctaMatBCount); - } - } - __syncthreads(); - - if (idx < size) { - stgOffset += smem[matBIdx]; - matCRowBatchOffset[idx] = stgOffset; - matARowIndices[matBIdx * m + stgOffset] = matARowIdx; - } -} - -void MatARowIndex(const uint32_t* matBIndices, // {m, nMatBPerMatARow} - uint32_t* matARowIndices, // {nMatB, m} - uint32_t* batchedGemmM, // {nMatB} - uint32_t* matCRowBatchOffset, // {m, nMatBPerMatARow} - uint32_t m, uint32_t nMatB, uint32_t nMatBPerMatARow, - cudaStream_t stream) { - const int CTA = 256; - if (nMatB > CTA || (nMatBPerMatARow & (nMatBPerMatARow - 1)) != 0) { - // inavlid nMatB or nMatBPerMatARow. - // nMatBPerMatARow must be power of 2 - return; - } -#ifdef __GNUC__ - uint32_t matARowIdxRShift = __builtin_ctz(nMatBPerMatARow); -#else - uint32_t matARowIdxRShift; - for (int i = 0; i < 32; ++i) { - if (nMatBPerMatARow >> i == 1) { - matARowIdxRShift = i; - } - } -#endif - uint32_t size = m * nMatBPerMatARow; - int grid = (size + CTA - 1) / CTA; - cudaMemsetAsync(batchedGemmM, 0, nMatB * sizeof(uint32_t), stream); - matA_row_idx_kernel<<>>( - matBIndices, matARowIndices, batchedGemmM, matCRowBatchOffset, size, - matARowIdxRShift, m, nMatB); -} - -template -__global__ void update_matCRowIndices_kernel( - const uint32_t* matBIndices, // {m, nMatBPerMatARow} - const uint32_t* gemmMPrefixSum, // {nMatB} - uint32_t* matCRowOffset, // {m, nMatBPerMatARow} - uint32_t size) { - int idx = blockIdx.x * CTA + threadIdx.x; - if (idx >= size) { - return; - } - asm volatile( - "{.reg .b32 r0, r1;\n" - " .reg .b64 r2;\n" - " ld.global.cg.b32 r0, [%0];\n" - " ld.global.cg.b32 r1, [%2];\n" - " cvta.to.global.u64 r2, %1;\n" - " mad.wide.u32 r2, r0, %3, r2;\n" - " ld.global.ca.b32 r0, [r2];\n" - " add.u32 r1, r1, r0;\n" - " st.global.b32 [%2], r1;}" - : - : "l"(matBIndices + idx), "l"(gemmMPrefixSum), "l"(matCRowOffset + idx), - "n"(sizeof(uint32_t))); -} - -// 128Byte aligned memory size -size_t AlignedMemSize(size_t requestedSize) { - return (requestedSize + 127) / 128 * 128; -} - -void GetWorkspaceSize(size_t* hostWsSize, size_t* deviceWsSize, uint32_t m, - uint32_t nMatB) { - size_t ctaIdYBarrierSize = AlignedMemSize(nMatB * sizeof(uint32_t)); - size_t batchInfoSize = AlignedMemSize(nMatB * sizeof(BatchInfo)); - size_t batchedGemmMSize = AlignedMemSize(nMatB * sizeof(uint32_t)); - size_t matARowIdxSize = AlignedMemSize(nMatB * m * sizeof(uint32_t)); - *hostWsSize = ctaIdYBarrierSize + batchInfoSize + batchedGemmMSize; - *deviceWsSize = - ctaIdYBarrierSize + batchInfoSize + batchedGemmMSize + matARowIdxSize; -} - -void MoeBatchedGemm(const half* A, const half* B, const uint32_t* matBIndices, - half* C, uint32_t* matCRowIndices, void* hostWs, - size_t hostWsSize, void* deviceWs, size_t deviceWsSize, - uint32_t matARows, uint32_t n, uint32_t k, uint32_t nMatB, - uint32_t nMatBPerMatARow, cudaStream_t stream) { - if (nMatB > 64) { - // invalid nMatB - return; - } - - size_t ctaIdYBarrierSize = AlignedMemSize(nMatB * sizeof(uint32_t)); - size_t batchInfoSize = AlignedMemSize(nMatB * sizeof(BatchInfo)); - size_t batchedGemmMSize = AlignedMemSize(nMatB * sizeof(uint32_t)); - size_t matARowIdxSize = AlignedMemSize(nMatB * matARows * sizeof(uint32_t)); - - if (hostWsSize < ctaIdYBarrierSize + batchInfoSize + batchedGemmMSize || - deviceWsSize < ctaIdYBarrierSize + batchInfoSize + batchedGemmMSize + - matARowIdxSize) { - // invalid workspace size - return; - } - // workspace: - // host: ctaIdYBarrier, batchInfos, batchedGemmM - // device: ctaIdYBarrier, batchInfos, batchedGemmM, matARowIdx - char* hWs = static_cast(hostWs); - char* dWs = static_cast(deviceWs); - uint32_t* hCtaIdYBarrier = reinterpret_cast(hWs); - uint32_t* dCtaIdYBarrier = reinterpret_cast(dWs); - BatchInfo* hBatchInfos = - reinterpret_cast(hWs + ctaIdYBarrierSize); - BatchInfo* dBatchInfos = - reinterpret_cast(dWs + ctaIdYBarrierSize); - uint32_t* hBatchedGemmM = - reinterpret_cast(hWs + ctaIdYBarrierSize + batchInfoSize); - uint32_t* dBatchedGemmM = - reinterpret_cast(dWs + ctaIdYBarrierSize + batchInfoSize); - uint32_t* dMatARowIdx = reinterpret_cast( - dWs + ctaIdYBarrierSize + batchInfoSize + batchedGemmMSize); - - // preprocess - // matCRowBatchOffset: reuse matCRowIndices - MatARowIndex(matBIndices, dMatARowIdx, dBatchedGemmM, matCRowIndices, - matARows, nMatB, nMatBPerMatARow, stream); - - cudaMemcpyAsync(hBatchedGemmM, dBatchedGemmM, batchedGemmMSize, - cudaMemcpyDefault, stream); - cudaStreamSynchronize(stream); - uint32_t gemmBatchIt = 0; - uint32_t mAcc = 0; - uint32_t gridYAcc = 0; - for (uint32_t matBIt = 0; matBIt < nMatB; ++matBIt) { - if (hBatchedGemmM[matBIt] != 0) { - hBatchInfos[gemmBatchIt].batchId = matBIt; - hBatchInfos[gemmBatchIt].m = hBatchedGemmM[matBIt]; - hBatchInfos[gemmBatchIt].ctaYOffset = gridYAcc; - hBatchInfos[gemmBatchIt].COffset = mAcc * n; - - uint32_t tileY = hBatchedGemmM[matBIt] > 96 ? 128 - : hBatchedGemmM[matBIt] > 64 ? 96 - : hBatchedGemmM[matBIt] > 48 ? 64 - : hBatchedGemmM[matBIt] > 32 ? 48 - : hBatchedGemmM[matBIt] > 16 ? 32 - : 16; - uint32_t gridY = (hBatchedGemmM[matBIt] + tileY - 1) / tileY; - gridYAcc += gridY; - mAcc += hBatchedGemmM[matBIt]; - hCtaIdYBarrier[gemmBatchIt] = gridYAcc; - ++gemmBatchIt; - } - } - for (uint32_t i = gemmBatchIt; i < 64; ++i) { - hCtaIdYBarrier[i] = gridYAcc; - } - - // m exclusive prefix sum for the postprocess, reuse batchedGemmM - for (uint32_t i = 0, mPrefix = 0; i < nMatB; ++i) { - uint32_t m = hBatchedGemmM[i]; - hBatchedGemmM[i] = mPrefix; - mPrefix += m; - } - - // H2D copy: ctaIdYBarrier, batchInfos, batchedGemmM (gemmMPrefixSum) - cudaMemcpyAsync(dWs, hWs, - ctaIdYBarrierSize + batchInfoSize + batchedGemmMSize, - cudaMemcpyDefault, stream); - - uint32_t smemSize = 120 * 1024; - dim3 grid((n + 255) / 256, gridYAcc); - hgemm_f32_n256_k32x5_hmma161616_ldg8_kernel<<>>( - A, B, C, dCtaIdYBarrier, dBatchInfos, dMatARowIdx, matARows, n, k, - 32 * n * sizeof(half)); - - // postprocess - update_matCRowIndices_kernel<256> - <<<(matARows * nMatBPerMatARow + 255) / 256, 256, 0, stream>>>( - matBIndices, dBatchedGemmM, matCRowIndices, - matARows * nMatBPerMatARow); -} -template <> -void MoeBatchedGemmLauncher( - const float* A, const float* B, const uint32_t* matBIndices, float* C, - uint32_t* matCRowIndices, void* hostWs, size_t hostWsSize, void* deviceWs, - size_t deviceWsSize, uint32_t matARows, uint32_t n, uint32_t k, - uint32_t nMatB, uint32_t nMatBPerMatARow, cudaStream_t stream) { - // TODO -} -#ifdef ENABLE_FP16 -template <> -void MoeBatchedGemmLauncher( - const half* A, const half* B, const uint32_t* matBIndices, half* C, - uint32_t* matCRowIndices, void* hostWs, size_t hostWsSize, void* deviceWs, - size_t deviceWsSize, uint32_t matARows, uint32_t n, uint32_t k, - uint32_t nMatB, uint32_t nMatBPerMatARow, cudaStream_t stream) { - MoeBatchedGemm(A, B, matBIndices, C, matCRowIndices, hostWs, hostWsSize, - deviceWs, deviceWsSize, matARows, n, k, nMatB, nMatBPerMatARow, - stream); -} -#endif -#ifdef ENABLE_BF16 -template <> -void MoeBatchedGemmLauncher( - const hie::bfloat16* A, const hie::bfloat16* B, const uint32_t* matBIndices, - hie::bfloat16* C, uint32_t* matCRowIndices, void* hostWs, size_t hostWsSize, - void* deviceWs, size_t deviceWsSize, uint32_t matARows, uint32_t n, - uint32_t k, uint32_t nMatB, uint32_t nMatBPerMatARow, cudaStream_t stream) { - // TODO -} -#endif -} // namespace cuda -} // namespace allspark \ No newline at end of file diff --git a/csrc/core/kernel/cuda/moe_ppu/moe_ppu_kernel.h b/csrc/core/kernel/cuda/moe_ppu/moe_ppu_kernel.h deleted file mode 100644 index d35a3510..00000000 --- a/csrc/core/kernel/cuda/moe_ppu/moe_ppu_kernel.h +++ /dev/null @@ -1,26 +0,0 @@ -/*! - * Copyright (c) Alibaba, Inc. and its affiliates. - * @file moe_ppu_kernel.h - */ - -#pragma once -#include -#include - -#include "../cuda_common.h" -// #include "../hie/cuda_activation.hpp" - -namespace allspark { -namespace cuda { -void GetWorkspaceSize(size_t* hostWsSize, size_t* deviceWsSize, uint32_t m, - uint32_t nMatB); - -template -void MoeBatchedGemmLauncher(const T* A, const T* B, const uint32_t* matBIndices, - T* C, uint32_t* matCRowIndices, void* hostWs, - size_t hostWsSize, void* deviceWs, - size_t deviceWsSize, uint32_t matARows, uint32_t n, - uint32_t k, uint32_t nMatB, - uint32_t nMatBPerMatARow, cudaStream_t stream); -} // namespace cuda -} // namespace allspark \ No newline at end of file diff --git a/csrc/service/CMakeLists.txt b/csrc/service/CMakeLists.txt index 3c64975e..72efc073 100644 --- a/csrc/service/CMakeLists.txt +++ b/csrc/service/CMakeLists.txt @@ -32,9 +32,10 @@ target_include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR}/../common ${CMAKE_CURRENT_SOURCE_DIR}/../interface) +set_target_properties(allspark_daemon PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/../${CMAKE_INSTALL_LIBDIR}") add_library(allspark_client STATIC ${PROTO_SVC_SRCS} ${PROTO_SVC_GRPC_SRC} allspark_client.cpp allspark_client_impl.cpp allspark_service_parallel.cpp) -target_link_libraries(allspark_client CONAN_PKG::grpc CONAN_PKG::protobuf CONAN_PKG::glog ${THREAD_LIB}) +target_link_libraries(allspark_client allspark_framework CONAN_PKG::grpc CONAN_PKG::protobuf CONAN_PKG::glog ${THREAD_LIB}) if (MEM_CHECK) target_link_options(allspark_client PUBLIC "-fsanitize=address") diff --git a/docs/sphinx/devel/source_code_build_en.rst b/docs/sphinx/devel/source_code_build_en.rst index 9689f844..b5d2a7d4 100644 --- a/docs/sphinx/devel/source_code_build_en.rst +++ b/docs/sphinx/devel/source_code_build_en.rst @@ -30,10 +30,10 @@ CUDA - CUDA sdk version >= 11.4 - cuBLAS: CUDA sdk provided -conan +Conan ,,,,, - + **conan**: C++ package management tools, can be installed by : `pip install conan==1.60.0`, only 1.60.0 is supported. + + **conan**: C++ package management tools, can be installed by : ``pip install conan==1.60.0``, only 1.60.0 is supported. .. note:: if there is any package-not-found issue, please make sure your conan center is available. Reset it with this command: `conan remote add conancenter https://center.conan.io` @@ -51,7 +51,7 @@ Leak check tool CPU ,,, -For multi-NUMA inference, `numactl`, `openmpi` are required: +For multi-NUMA inference, ``numactl``, ``openmpi`` are required: - for Ubuntu: @@ -77,76 +77,96 @@ We have build some Docker image for easier development setup. .. code-block:: shell docker run -d --name="dashinfer-dev-cu124-${USER}" \ - --shm-size=8g \ + --shm-size=8g --gpus all \ --network=host \ - --gpus all \ - -v $(pwd):/root/workspace/HIE-AllSpark \ + -v $(pwd):/root/workspace/DashInfer \ -w /root/workspace \ -it registry-1.docker.io/dashinfer/dev-centos7-cu124 docker exec -it "dashinfer-dev-cu124-${USER}" /bin/bash -- YiTian 710 Develoment +- CPU-only (Linux x86 server) .. code-block:: shell docker run -d --name="dashinfer-dev-${USER}" \ --network=host \ - -v $(pwd):/root/workspace/HIE-AllSpark \ + -v $(pwd):/root/workspace/DashInfer \ + -w /root/workspace \ + -it registry-1.docker.io/dashinfer/dev-centos7-x86 + docker exec -it "dashinfer-dev-${USER}" /bin/bash + +- CPU-only (Linux ARM server) + +.. code-block:: shell + + docker run -d --name="dashinfer-dev-${USER}" \ + --network=host \ + -v $(pwd):/root/workspace/DashInfer \ -w /root/workspace \ -it registry-1.docker.io/dashinfer/dev-centos8-arm docker exec -it "dashinfer-dev-${USER}" /bin/bash +.. note:: When creating a container for multi-NUMA inference, ``--cap-add SYS_NICE --cap-add SYS_PTRACE --ipc=host`` arguments are required, because components such as numactl and openmpi need the appropriate permissions to run. If you only need to use the single NUMA API, you may not grant this permission. + Build from Source Code ====================== -.. tip:: Here we use CUDA 12.4 as the default CUDA version. If you want to change to a different version, you can use enviroment variable to control CUDA dependency version. +Build Python Package +,,,,,,,,,,,,,,,,,,,, +1. Build python package for CUDA: -Python package build -,,,,,,,,,,,,,,,,,,,, +.. code-block:: bash -CUDA normal build: + cd python + AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" AS_PLATFORM="cuda" \ + python3 setup.py bdist_wheel + +2. Build python package for x86: .. code-block:: bash cd python - AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" AS_PLATFORM="cuda" python3 setup.py bdist_wheel + AS_PLATFORM="x86" python3 setup.py bdist_wheel -.. note:: The Python build only performs the `conan install` operation at the first time; subsequent builds will not conduct `conan install`. If you encounter issues, consider using `rm -rf ./python/build/temp.*` to re-run the entire process. +3. Build python package for arm: -.. note:: Change `AS_RELEASE_VERSION` enviroment var to change package version. +.. code-block:: bash -.. note:: To build an x86 or arm CPU only Python package, it's similar to CUDA build, but change the `AS_PLATFORM` environment variable to `x86` or `arm`. + cd python + AS_PLATFORM="armclang" python3 setup.py bdist_wheel +.. note:: + - We use CUDA 12.4 as the default CUDA version. If you want to change to a different version, set ``AS_CUDA_VERSION`` to the target CUDA version. + - Set ``AS_RELEASE_VERSION`` enviroment variable to change package version. + - Set ``ENABLE_MULTINUMA=ON`` enviroment variable to enable multi-NUMA inference in CPU-only version. -C++ package build +Build C++ Libraries ,,,,,,,,,,,,,,,,,,, -1. C++ lib build for CUDA +1. Build C++ libraries for CUDA .. code-block:: bash - mkdir build; - AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" ./build.sh + AS_CUDA_VERSION="12.4" AS_NCCL_VERSION="2.23.4" AS_CUDA_SM="'80;86;89;90a'" AS_PLATFORM="cuda" AS_BUILD_PACKAGE="ON" ./build.sh -2. C++ lib build for x86 +2. Build C++ libraries for x86 .. code-block:: bash - AS_PLATFORM="x86" ./build.sh - -3. C++ lib build for armclang + AS_PLATFORM="x86" AS_BUILD_PACKAGE="ON" ./build.sh -ARM Compile require armcc to archive best performance, setup the compiler in enviroment var. +3. Build C++ libraries for arm .. code-block:: bash export ARM_COMPILER_ROOT=/opt/arm/arm-linux-compiler-24.04_RHEL-8/ # change this path to your own export PATH=$PATH:$ARM_COMPILER_ROOT/bin - AS_PLATFORM="armclang" ./build.sh + + AS_PLATFORM="armclang" AS_BUILD_PACKAGE="ON" ./build.sh Profiling --------- @@ -156,9 +176,9 @@ Operator Profiling This section describes how to enable and utilize the operator profiling functionality. -1. Enable OP profile data collection +1. Enable OP profiling data collection -To enable OP profiling, set the environment variable as follows: +To enable OP profiling, set the environment variable ``AS_PROFILE=ON`` before running DashInfer. .. code-block:: bash @@ -166,9 +186,9 @@ To enable OP profiling, set the environment variable as follows: # Then, run any Python program utilizing the DashInfer Engine. -2. Print OP profile data +2. Print OP pro - To view the profiling information, insert the following function call before deinitializing the engine, replacing model_name with your actual model's name: +To view the profiling information, call the following function before deinitializing the engine: .. code-block:: bash @@ -177,15 +197,14 @@ To enable OP profiling, set the environment variable as follows: .. tip:: Replace *model_name* with the name of your model. -3. Analyze OP profile data +3. Analyze OP profiling data - An OP profile data report begins with a section header marked by *****
***** followed by a detailed table. The report consists of three main sections: + An OP profiling data report begins with a section header marked by \*\*\*
\*\*\* followed by a detailed table. The report consists of three main sections: - reshape: Statistics on the cost of reshaping inputs for operators. - alloc: Measures the cost of memory allocation for paged KV cache. - forward: Focuses on the execution time of operators' forward passes; developers should closely examine this section. - Below is an illustration of the table structure and the meaning of each column: 1. **opname**: The name of the operator. @@ -193,7 +212,6 @@ To enable OP profiling, set the environment variable as follows: 3. **(min/max/ave)**: Minimum, maximum, and average execution times in milliseconds. 4. **total_ms**: The cumulative time spent on this operator. 5. **percentage**: The operator's total time as a percentage of the overall profiling duration. - 6. **rank**: This column is deprecated. An example snippet of the profiling output is shown below: @@ -243,10 +261,10 @@ This section describes how to use controlled Nsys profiling to obtain decoder an **Steps:** 0. **Disable Warm-up:** Set the environment variable `ALLSPARK_DISABLE_WARMUP=1` to disable the warm-up phase. -1. **Enable Nsys Profiling Call:** In the file `cuda_context.cpp`, uncomment line 14 to enable the Nsys profiling call. +1. **Enable Nsys Profiling Call:** Set ``#define ENABLE_NSYS_PROFILE 1`` in file `cuda_context.cpp`. 2. **Model.cpp Configuration:** - - **Context Phase Profiling:** To profile the context phase, set the variable `PROFILE_CONTEXT_TIME_GPU` to `1`. This will initiate Nsys profiling on the 10th request and terminate the process after one context loop completes. - - **Generation Phase Profiling:** To profile the generation phase, set the variable `PROFILE_GENERATION_TIME_GPU` to `1`. Profiling will commence after reaching a concurrency (or batch size) specified by `PROFILE_GENERATION_TIME_BS` (adjust this value according to your needs). This allows you to profile the system under a fixed concurrency level. + - **Context Phase Profiling:** To profile the context phase, set ``#define PROFILE_CONTEXT_TIME_GPU 1`` in file `model.cpp`. This will initiate Nsys profiling on the 10th request and terminate the process after one context loop completes. + - **Generation Phase Profiling:** To profile the generation phase, set ``#define PROFILE_GENERATION_TIME_GPU 1`` in file `model.cpp`. Profiling will commence after reaching a concurrency (or batch size) specified by `PROFILE_GENERATION_TIME_BS` (adjust this value according to your needs). This allows you to profile the system under a fixed concurrency level. 3. **ReCompile:** Recompile your package and install 4. **Start Profiling:** Execute your benchmark or server using the following command: diff --git a/docs/sphinx/get_started/env_var_options_en.rst b/docs/sphinx/get_started/env_var_options_en.rst index 49fe8336..9b3c8ced 100644 --- a/docs/sphinx/get_started/env_var_options_en.rst +++ b/docs/sphinx/get_started/env_var_options_en.rst @@ -56,7 +56,7 @@ Memory Mangament store kv cache. - float - ``0.0`` - - float value between (0.0,1.0] + - float value between [0.0, 1.0] Logging ======= diff --git a/docs/sphinx/get_started/install_en.md b/docs/sphinx/get_started/install_en.md index 346d39d4..a1ec8dd2 100644 --- a/docs/sphinx/get_started/install_en.md +++ b/docs/sphinx/get_started/install_en.md @@ -33,14 +33,6 @@ Install python package by following command: - Install local package: `pip install dashinfer-allspark--xxx.whl` - Uninstall: `pip uninstall dashinfer-allspark -y` -## Install C++ Pacakge - -for Ubuntu: - -- Install: `dpkg -i DashInfer--ubuntu.deb` -- Uninstall: `dpkg -r DashInfer` - -for CentOS: - -- Install: `rpm -i DashInfer--centos.rpm` +## C++ Library +Download the *.tar.gz package, unzip it, and add it to the compile search path. diff --git a/docs/sphinx/llm/prefix_caching.rst b/docs/sphinx/llm/prefix_caching.rst index e346fe0b..842c6fec 100644 --- a/docs/sphinx/llm/prefix_caching.rst +++ b/docs/sphinx/llm/prefix_caching.rst @@ -2,4 +2,50 @@ Prefix Caching ===================== -TODO +What is Prefix Caching +********************** + +Prefix caching stores kv-caches in GPU or CPU memory for extended periods to reduce redundant calculations. When a new prompt shares the same prefix as a previous one, it can directly use the cached kv-caches, avoiding unnecessary computation and improving performance. + +Enable Prefix Caching +********************* + +Runtime Configuration +--------------------- + +- ``prefill_cache(enable=True)``: Enables or disables the prefix cache, default value is True. +- ``prefix_cache_ttl(ttl: int)``: Prefix cache time to live, default value is 300s. + +Environment Variable +-------------------- + +- ``CPU_CACHE_RATIO`` + - Description: DashInfer will set CPU_CACHE_RATIO * 100% of the current remaining CPU memory for kv-cache storage, and when CPU_CACHE_RATIO=0, no CPU memory is used to store kv cache. + - Data type: float + - Default value: ``0.0`` + - Range: float value between [0.0, 1.0] + + +Performance +*********** + +Run `benchmark_throughput.py` in `examples/benchmark` by following command: + +.. code-block:: shell + + model=qwen/Qwen2-7B-Instruct && \ + python3 benchmark_throughput.py --model_path=${model} --modelscope \ + --engine_max_batch=1 --engine_max_length=4003 --device_ids=0 \ + --test_qps=250 --test_random_input --test_sample_size=20 --test_max_output=3 \ + --engine_enable_prefix_cache --prefix_cache_rate_list 0.99,0.9,0.6,0.3 + +On Nvidia-A100 GPU we get following result: + +.. csv-table:: + + Batch_size,Request_num,In_tokens,Out_tokens,Avg_context_time(s),Avg_generate_time(s),Prefix_Cache(hit rate) + 1,20,4000,3,0.030,0.040,96.0% + 1,20,4000,3,0.044,0.040,89.6% + 1,20,4000,3,0.121,0.040,57.6% + 1,20,4000,3,0.185,0.040,28.8% + 1,20,4000,3,0.254,0.040,0.0% diff --git a/docs/sphinx/llm/runtime_config.rst b/docs/sphinx/llm/runtime_config.rst index a8540871..59f1fd12 100644 --- a/docs/sphinx/llm/runtime_config.rst +++ b/docs/sphinx/llm/runtime_config.rst @@ -78,11 +78,10 @@ Sequence Length and Batch Size - ``max_prefill_length(length: int)``: Sets the maximum prefill length that will be processed in one context inference; if input length is greater than this length, it will be process in multiple context inference steps. -Prefix Cache Configuration +Prefix Caching Configuration -------------------------- -- ``prefill_cache(enable=True)``: Enables or disables the prefix cache. -- ``prefix_cache_ttl(ttl: int)``: Prefix cache time to live, default value is 300s. +See :doc:`Prefix Caching <../llm/prefix_caching>`. KV Cache Quantization Configuration ----------------------------------- diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md deleted file mode 100644 index da31882e..00000000 --- a/examples/benchmark/README.md +++ /dev/null @@ -1,35 +0,0 @@ - -# 准备工作 - -## Dataset - -从这个git里面下载: -`git clone https://code.alibaba-inc.com/HCI/dashscope-data` - -## 模型下载 -可以使用本地下载好的HF格式的模型,或者使用 modelscope的ID去下载模型,如果可以访问hf,可以使用hf的的路径。 - -把这个地址填到 --model_path 这个参数即可。 - -# Benchmark工具 - -## 使用数据集测试 - -使用 `--test_data_path` 这个参数进行选择数据集,`--test_dataset_id` 来选择数据集类型, 默认sample 100条,可以通过--test_sample_size 进行改变。 - -使用方法: - - -例如, 一个7b模型 1qps, 单卡 压测: -` python3 ./examples/benchmark/benchmark_throughput.py --model_path=qwen/Qwen2-7B-Instruct --modelscope True --test_qps=1 --test_dataset_path=/home/jiejing.zjj/workspace/llm_evaluation/dataset/type0_online_data.json --test_dataset_id=0 --engine_max_batch=10 --test_sample_size=10` - -例如, 一个7b 模型, 双卡 最大batch size压测。 - -`python3 ./examples/benchmark/benchmark_throughput.py --model_path=qwen/Qwen2-7B-Instruct --modelscope True --test_qps=250 --test_dataset_path=/home/jiejing.zjj/workspace/llm_evaluation/dataset/type0_online_data.json --test_dataset_id=0 --engine_max_batch=380 --engine_max_length=819 --device_ids=0,1 --test_sample_size=310 --test_max_output=20` - - -## 使用随机数进行测试 -使用随机数会进行生成固定长度的测试。 -注意: 随机数据会使用输入长度 = engine_max_length - test_max_output, 输出长度: test_max_output -`python3 ./examples/benchmark/benchmark_throughput.py --model_path=qwen/Qwen2-7B-Instruct --modelscope True --test_qps=250 --test_random_input --engine_max_batch=100 --engine_max_length=800 --device_ids=0,1 --test_sample_size=310 --test_max_output=400 ` - diff --git a/examples/cpp/0_basic/example_qwen.cpp b/examples/cpp/0_basic/example_qwen.cpp index 3813b93f..a0b81c17 100644 --- a/examples/cpp/0_basic/example_qwen.cpp +++ b/examples/cpp/0_basic/example_qwen.cpp @@ -76,8 +76,8 @@ int main(int argc, char** argv) { auto all_exists = check_model_file_exists(model_path, tiktoken_file); if (!all_exists) return 1; - std::string dimodel_file = model_path + ".dimodel"; - std::string ditensors_file = model_path + ".ditensors"; + std::string dimodel_file = model_path + ".asgraph"; + std::string ditensors_file = model_path + ".asparam"; // create an inference engine instance. std::unique_ptr as_engine = std::make_unique(); diff --git a/examples/cpp/1_apiserver/apiserver.cpp b/examples/cpp/1_apiserver/apiserver.cpp index 27e6e43e..1adeeb4f 100644 --- a/examples/cpp/1_apiserver/apiserver.cpp +++ b/examples/cpp/1_apiserver/apiserver.cpp @@ -394,8 +394,8 @@ int main(int argc, const char** argv) { auto all_exists = check_model_file_exists(model_path, tiktoken_file); if (!all_exists) return 1; - std::string dimodel_file = model_path + ".dimodel"; - std::string ditensors_file = model_path + ".ditensors"; + std::string dimodel_file = model_path + ".asgraph"; + std::string ditensors_file = model_path + ".asparam"; // create an inference engine instance. setup_tiktoken_tokenizer(tiktoken_file, tokenizer); diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index dde0f332..e639f944 100755 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -12,6 +12,14 @@ set(CMAKE_CXX_EXTENSIONS OFF) # std::string crash. # add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) +if (DEFINED ENV{DASHINFER_INCLUDE_PATH}) + include_directories($ENV{DASHINFER_INCLUDE_PATH}) +endif() + +if (DEFINED ENV{DASHINFER_LIBRARY_PATH}) + link_directories($ENV{DASHINFER_LIBRARY_PATH}) +endif() + ########################################### # Example 1: Single NUMA or GPU qwen v1 example. ########################################### @@ -19,8 +27,7 @@ add_executable( example_qwen 0_basic/example_qwen.cpp tokenizer/tokenizer.cpp tokenizer/base64.cpp) -target_link_libraries(example_qwen PRIVATE allspark_framework - ) +target_link_libraries(example_qwen PRIVATE allspark_framework) target_include_directories(example_qwen PRIVATE tokenizer utils) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index f3b749e9..ec5f016d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -33,7 +33,7 @@ target_link_libraries(_allspark PRIVATE allspark_framework_static CONAN_PKG::protobuf CONAN_PKG::zlib) -set_target_properties(_allspark PROPERTIES INSTALL_RPATH "$ORIGIN") +set_target_properties(_allspark PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/${CMAKE_INSTALL_LIBDIR}") set_target_properties(_allspark PROPERTIES CXX_STANDARD 17) if(UNIX AND NOT APPLE) set(ALLSPARK_LINK_MAP ${PROJECT_SOURCE_DIR}/link_python.map) @@ -58,7 +58,7 @@ if (ENABLE_MULTINUMA) -Wl,--no-whole-archive CONAN_PKG::protobuf) # target_link_libraries(_allspark_client PRIVATE allspark_client) - set_target_properties(_allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN") + set_target_properties(_allspark_client PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/${CMAKE_INSTALL_LIBDIR}") set_target_properties(_allspark_client PROPERTIES CXX_STANDARD 17) if(UNIX AND NOT APPLE) set(ALLSPARK_LINK_MAP ${PROJECT_SOURCE_DIR}/link_python.map) diff --git a/scripts/docker/README.md b/scripts/docker/README.md deleted file mode 100644 index e21f4ca6..00000000 --- a/scripts/docker/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Build a Developer docker - -notes: it will run a conan install to make most of conan depends cached in docker image. - -``` bash -cp ../../conan/conanfile.txt . -cp ../../python/requirements_dev.txt . -docker_ver=2.9 -docker build -f dev_cuda_114.Dockerfile . -t hie-allspark-dev:${docker_ver} - -# build with proxy. -docker build \ - --build-arg http_proxy=http://11.169.82.19:80 \ - --build-arg https_proxy=http://11.169.82.19:80 \ - --build-arg no_proxy="localhost,127.0.0.1,.aliyun.com,.alibaba-inc.com" \ - -f dev_cuda_114.Dockerfile . -t hie-allspark-dev:${docker_ver} -``` - -# Updaete - -``` bash -docker login --username= reg.docker.alibaba-inc.com - -docker tag hie-allspark-dev:${docker_ver} reg.docker.alibaba-inc.com/hci/hie-allspark-dev:${docker_ver} -docker push reg.docker.alibaba-inc.com/hci/hie-allspark-dev:${docker_ver} -``` - diff --git a/scripts/docker/build_fschat_cu121.sh b/scripts/docker/build_fschat_cu121.sh deleted file mode 100644 index 5744816c..00000000 --- a/scripts/docker/build_fschat_cu121.sh +++ /dev/null @@ -1,13 +0,0 @@ -allspark_version=3.0.2 - -cp ../../examples/api_server/fschat/allspark_worker.py ./ - -docker build \ - -f fschat-hie-allspark-cuda.Dockerfile \ - -t fschat-hie-allspark-cuda:${allspark_version} \ - . -docker tag fschat-hie-allspark-cuda:${allspark_version} reg.docker.alibaba-inc.com/hci/fschat-hie-allspark-cuda:${allspark_version} -docker login --username= reg.docker.alibaba-inc.com -docker push reg.docker.alibaba-inc.com/hci/fschat-hie-allspark-cuda:${allspark_version} - -rm ./allspark_worker.py \ No newline at end of file diff --git a/scripts/docker/dev_arm_centos8.Dockerfile b/scripts/docker/dev_arm_centos8.Dockerfile index 17af8d08..04a0f749 100644 --- a/scripts/docker/dev_arm_centos8.Dockerfile +++ b/scripts/docker/dev_arm_centos8.Dockerfile @@ -66,6 +66,15 @@ RUN wget "ftp://ftp.gnu.org/gnu/automake/automake-1.15.1.tar.gz" && \ cd automake-1.15.1 && ./configure --prefix=/usr/ && make -j && make install && \ cd .. && rm -rf automake-1.15.1.tar.gz automake-1.15.1 +RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz && \ + tar -xzf 0.14.5.tar.gz && \ + cd patchelf-0.14.5 && \ + ./bootstrap.sh && \ + ./configure && \ + make install && \ + cd .. && rm -rf patchelf-0.14.5 0.14.5.tar.gz +RUN pip3 install auditwheel==6.1.0 + RUN wget "https://xxxxxx/conan_allspark_source_arm_20241119.tar" && \ tar -xvf conan_allspark_source_arm_20241119.tar && \ mv conan_allspark_source_arm_20241119 /root/.conan && \ diff --git a/scripts/docker/dev_cuda_124.Dockerfile b/scripts/docker/dev_cuda_124.Dockerfile index 25b5047f..e624fda9 100644 --- a/scripts/docker/dev_cuda_124.Dockerfile +++ b/scripts/docker/dev_cuda_124.Dockerfile @@ -72,11 +72,8 @@ RUN conda config --set ssl_verify false RUN curl -LO https://github.com/Kitware/CMake/releases/download/v3.27.9/cmake-3.27.9-linux-x86_64.sh \ && bash ./cmake-3.27.9-linux-x86_64.sh --skip-license --prefix=/usr RUN pip3 install pytest -RUN curl https://gosspublic.alicdn.com/ossutil/install.sh | bash RUN conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -RUN ossutil config - RUN yum install -y epel-release && yum install -y dnf RUN dnf makecache && dnf -y install ccache RUN pip3 install jsonlines GitPython editdistance sacrebleu nltk rouge-score @@ -91,26 +88,33 @@ RUN yum install -y bash-completion tig RUN yum install -y build-essential autoconf automake libtool ca-certificates -RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz -RUN tar -xzf 0.14.5.tar.gz && \ - cd patchelf-0.14.5 && \ - ./bootstrap.sh && \ - ./configure && \ - source /opt/rh/devtoolset-10/enable && make install && \ - rm -rf patchelf-0.14.5 0.14.5.tar.gz && rm -rf patchelf-0.14.5 - -RUN pip3 install auditwheel==6.1.0 - RUN yum install -y libtool flex - RUN wget "ftp://ftp.gnu.org/gnu/automake/automake-1.15.1.tar.gz" && \ tar -xvf automake-1.15.1.tar.gz && \ cd automake-1.15.1 && ./configure --prefix=/usr/ && make -j && make install && \ cd .. && rm -rf automake-1.15.1.tar.gz automake-1.15.1 -RUN wget "https://xxxxxx/conan_allspark_source_cuda124_20241121.tar" && \ - tar -xvf conan_allspark_source_cuda124_20241121.tar && \ - mv conan_allspark_source_cuda124_20241121 /root/.conan && \ - rm -rf conan_allspark_source_cuda124_20241121.tar +# git version required by github actions +RUN yum install -y gettext +RUN source /root/.bashrc && \ + wget "https://github.com/git/git/archive/refs/tags/v2.47.0.tar.gz" && \ + tar -xvf v2.47.0.tar.gz && cd git-2.47.0 && \ + make configure && ./configure --prefix=/usr && \ + make -j && make install &&\ + cd .. && rm -rf v2.47.0.tar.gz git-2.47.0 + +RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz && \ + tar -xzf 0.14.5.tar.gz && \ + cd patchelf-0.14.5 && \ + ./bootstrap.sh && \ + ./configure && \ + source /opt/rh/devtoolset-10/enable && make install && \ + cd .. && rm -rf patchelf-0.14.5 0.14.5.tar.gz +RUN pip3 install auditwheel==6.1.0 + +RUN wget "https://xxxxxx/conan_allspark_source_cuda124_20241203_verbose.tar" && \ + tar -xvf conan_allspark_source_cuda124_20241203_verbose.tar && \ + mv conan_allspark_source_cuda124_20241203_verbose /root/.conan && \ + rm -rf conan_allspark_source_cuda124_20241203_verbose.tar WORKDIR /root/ diff --git a/scripts/docker/dev_x86_centos7.Dockerfile b/scripts/docker/dev_x86_centos7.Dockerfile index f785f7ee..74bccdd1 100644 --- a/scripts/docker/dev_x86_centos7.Dockerfile +++ b/scripts/docker/dev_x86_centos7.Dockerfile @@ -21,9 +21,6 @@ RUN yum install devtoolset-7 -y --nogpgcheck RUN echo "source /opt/rh/devtoolset-7/enable" >> /root/.bashrc && source /root/.bashrc ARG PY_VER=3.8 -RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | bash \ - && yum install git-lfs -y - RUN curl -LO https://github.com/Kitware/CMake/releases/download/v3.27.9/cmake-3.27.9-linux-x86_64.sh \ && bash ./cmake-3.27.9-linux-x86_64.sh --skip-license --prefix=/usr @@ -57,6 +54,7 @@ custom_channels:\n\ RUN conda clean -i -y && conda config --show channels && conda create -y --name ds_py python==${PY_VER} && conda update -n base conda # RUN conda run python --version && pip3 install --upgrade pip pyOpenSSL==22.0.0 && conda env list RUN conda run python --version && pip3 install --upgrade pip pyOpenSSL==22.0.0 -i https://mirrors.aliyun.com/pypi/simple && conda env list + SHELL ["conda", "run", "-n", "ds_py", "/bin/bash", "-c"] RUN echo "source activate ds_py" >> /root/.bashrc && source /root/.bashrc @@ -71,15 +69,37 @@ RUN echo -e "[global]\ntrusted-host=mirrors.aliyun.com\nindex-url = http://mirro # engine requirements RUN conda install -y pytorch-cpu -c pytorch -RUN pip3 install modelscope transformers==4.41.0 protobuf==3.18.3 conan==1.60.0 pytest tokenizers scons wheel pandas tabulate +RUN pip3 install modelscope transformers protobuf==3.18.3 conan==1.60.0 pytest scons wheel pandas tabulate -RUN yum install -y libtool flex +SHELL ["/bin/bash", "-c"] +RUN yum install -y libtool flex RUN wget "ftp://ftp.gnu.org/gnu/automake/automake-1.15.1.tar.gz" && \ tar -xvf automake-1.15.1.tar.gz && \ cd automake-1.15.1 && ./configure --prefix=/usr/ && make -j && make install && \ cd .. && rm -rf automake-1.15.1.tar.gz automake-1.15.1 +# git version required by github actions +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | bash \ + && yum install git-lfs -y + +RUN yum install -y gettext +RUN source /root/.bashrc && \ + wget "https://github.com/git/git/archive/refs/tags/v2.47.0.tar.gz" && \ + tar -xvf v2.47.0.tar.gz && cd git-2.47.0 && \ + make configure && ./configure --prefix=/usr && \ + make -j && make install &&\ + cd .. && rm -rf v2.47.0.tar.gz git-2.47.0 + +RUN curl -LO https://github.com/NixOS/patchelf/archive/refs/tags/0.14.5.tar.gz && \ + tar -xzf 0.14.5.tar.gz && \ + cd patchelf-0.14.5 && \ + ./bootstrap.sh && \ + ./configure && \ + source /opt/rh/devtoolset-7/enable && make install && \ + cd .. && rm -rf patchelf-0.14.5 0.14.5.tar.gz +RUN pip3 install auditwheel==6.1.0 + RUN wget "https://xxxxxx/conan_allspark_source_x86_20241119.tar" && \ tar -xvf conan_allspark_source_x86_20241119.tar && \ mv conan_allspark_source_x86_20241119 /root/.conan && \ diff --git a/scripts/docker/fschat-hie-allspark-cuda.Dockerfile b/scripts/docker/fschat-hie-allspark-cuda.Dockerfile deleted file mode 100644 index ac2a0ec1..00000000 --- a/scripts/docker/fschat-hie-allspark-cuda.Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM reg.docker.alibaba-inc.com/hci/base-hie-allspark-cuda:3.0.2 - -WORKDIR /root/workspace - - -COPY ./fschat_entrypoint.sh ./ -COPY ./allspark_worker.py ./ - -RUN chmod +x ./fschat_entrypoint.sh - -SHELL [ "conda", "run", "--no-capture-output", "-n", "py38", "/bin/bash", "-c" ] -RUN pip3 install -i https://mirrors.aliyun.com/pypi/simple \ - addict \ - modelscope \ - psutil \ - accelerate \ - "fschat==0.2.36" - -# fastchat has a bug in pydantic v2: https://github.com/lm-sys/FastChat/pull/3356 -# downgrade to v1.10.13 -RUN pip3 uninstall pydantic -y \ - && pip3 install -i https://mirrors.aliyun.com/pypi/simple pydantic==1.10.13 - -ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "py38", "./fschat_entrypoint.sh"] \ No newline at end of file diff --git a/scripts/docker/test_cuda_ubuntu.Dockerfile b/scripts/docker/test_cuda_ubuntu.Dockerfile new file mode 100644 index 00000000..ef6bc177 --- /dev/null +++ b/scripts/docker/test_cuda_ubuntu.Dockerfile @@ -0,0 +1,45 @@ +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 + +RUN apt-get update && \ + apt-get install curl -y + +ARG PY_VER=3.10 + +RUN curl -LO https://repo.anaconda.com/miniconda/Miniconda3-py38_23.11.0-2-Linux-x86_64.sh \ + && bash Miniconda3-py38_23.11.0-2-Linux-x86_64.sh -p /miniconda -b \ + && rm -f Miniconda3-py38_23.11.0-2-Linux-x86_64.sh +ENV PATH=/miniconda/bin:${PATH} + +########################################################################## +# uncomment if want to use anaconda mirror +########################################################################## +RUN echo -e "\ +channels:\n\ + - defaults\n\ +show_channel_urls: true\n\ +default_channels:\n\ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main\n\ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r\n\ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2\n\ +custom_channels:\n\ + conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + menpo: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + pytorch: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + pytorch-lts: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + simpleitk: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ + deepmodeling: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud\n\ +" > /root/.condarc + +RUN conda clean -i && conda config --show channels && conda create -y --name test_py python==${PY_VER} && conda update -n base conda +SHELL ["conda", "run", "-n", "test_py", "/bin/bash", "-c"] +RUN echo "source activate test_py" >> /root/.bashrc && source /root/.bashrc + +########################################################################## +# uncomment if want to use pip mirror +########################################################################## +RUN mkdir -p /root/.pip/ +RUN echo -e "[global]\ntrusted-host=mirrors.aliyun.com\nindex-url = http://mirrors.aliyun.com/pypi/simple\n\n[install]\nuse-wheel=yes" > /root/.pip/pip.conf + +WORKDIR /root/ diff --git a/scripts/release/python_manylinux_build.sh b/scripts/release/python_manylinux_build.sh index 03347c7f..6c6a2ff4 100755 --- a/scripts/release/python_manylinux_build.sh +++ b/scripts/release/python_manylinux_build.sh @@ -1,7 +1,8 @@ #!/bin/bash set -e -x -ALL_VERSION="3.8 3.9 3.10 3.11" +# ALL_VERSION="3.8 3.9 3.10 3.11" +ALL_VERSION="3.8" BUILD_VERSION=${@:-$ALL_VERSION} echo " going to build python wheels with version: ${BUILD_VERSION}" @@ -15,13 +16,17 @@ pushd $SCRIPT_DIR # 捕获arch命令的输出 architecture=$(arch) +export AS_PYTHON_PKG_NAME="dashinfer-cpu" + # 使用if-else结构进行条件判断 if [ "${architecture}" == "aarch64" ]; then export PLAT=manylinux_2_28_aarch64 export AS_PLATFORM=armclang + # export ENABLE_MULTINUMA="ON" else export PLAT=manylinux2014_x86_64 export AS_PLATFORM=x86 + # export ENABLE_MULTINUMA="ON" fi if [ -z "$PLAT" ] || [ -z "$AS_PLATFORM" ]; @@ -30,8 +35,6 @@ then exit 1 fi -export AS_PYTHON_MANYLINUX=ON - function repair_wheel { wheel="$1" if ! auditwheel show "$wheel"; then @@ -57,8 +60,9 @@ build_wheel_for_python() { conda install pybind11 -y pip install -r ${REPO_ROOT}/python/requirements_dev_cpu.txt -i https://mirrors.aliyun.com/pypi/simple/ - python ${REPO_ROOT}/python/setup.py bdist_wheel - pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --log wheel_log.txt + ln -sf ${REPO_ROOT}/python/dashinfer . + # python ${REPO_ROOT}/python/setup.py bdist_wheel + pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --verbose conda deactivate # conda remove --name "$env_name" --all -y @@ -69,7 +73,7 @@ build_wheel_for_python() { mkdir -p ${REPO_ROOT}/python/wheelhouse/ for python_version in $BUILD_VERSION; do - build_wheel_for_python ${python_version} 2>&1 | tee whl_build_log_py${python_version//.}.txt + build_wheel_for_python ${python_version} 2>&1 | tee wheel_build_log_py${python_version//.}.txt done diff --git a/scripts/release/python_manylinux_build_cuda.sh b/scripts/release/python_manylinux_build_cuda.sh index e5b9d69a..cd01b52d 100755 --- a/scripts/release/python_manylinux_build_cuda.sh +++ b/scripts/release/python_manylinux_build_cuda.sh @@ -2,8 +2,9 @@ set -e -x # ALL_VERSION="3.8 3.9 3.10 3.11" -ALL_VERSION="3.10" +ALL_VERSION="3.8" BUILD_VERSION=${@:-$ALL_VERSION} +CUDA_VERSION=12.4 echo " going to build python wheels with version: ${BUILD_VERSION}" @@ -19,14 +20,17 @@ architecture=$(arch) export PLAT=manylinux2014_x86_64 export AS_PLATFORM=cuda +mkdir -p local_cuda_libs +ln -sf /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libnvidia-ml.so local_cuda_libs/libnvidia-ml.so.1 +ln -sf /usr/local/cuda-${CUDA_VERSION}/compat/libcuda.so.1 local_cuda_libs/libcuda.so.1 +export LD_LIBRARY_PATH=${PWD}/local_cuda_libs:${LD_LIBRARY_PATH} + if [ -z "$PLAT" ] || [ -z "$AS_PLATFORM" ]; then echo " please set PLAT and AS_PLATFORM env, PLAT can be manylinux_2_28_aarch64 or manylinux2014_x86_64" exit 1 fi -export AS_PYTHON_MANYLINUX=ON - function repair_wheel { wheel="$1" if ! auditwheel show "$wheel"; then @@ -52,8 +56,9 @@ build_wheel_for_python() { conda install pybind11 -y pip install -r ${REPO_ROOT}/python/requirements_dev_cuda.txt -i https://mirrors.aliyun.com/pypi/simple/ - python ${REPO_ROOT}/python/setup.py bdist_wheel - pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --log wheel_log.txt + ln -sf ${REPO_ROOT}/python/dashinfer . + # python ${REPO_ROOT}/python/setup.py bdist_wheel + pip wheel ${REPO_ROOT}/python --no-deps -w ${REPO_ROOT}/python/wheelhouse/ --verbose conda deactivate # conda remove --name "$env_name" --all -y @@ -64,7 +69,7 @@ build_wheel_for_python() { mkdir -p ${REPO_ROOT}/python/wheelhouse/ for python_version in $BUILD_VERSION; do - build_wheel_for_python ${python_version} 2>&1 | tee whl_build_log_py${python_version//.}.txt + build_wheel_for_python ${python_version} 2>&1 | tee wheel_build_log_py${python_version//.}.txt done diff --git a/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl new file mode 100644 index 00000000..526db83e --- /dev/null +++ b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_common.inl @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/layout/tensor.h" +#include "cutlass/arch/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Maps a rank-1 cute::Shape<> representing the cluster shape on to the IM2COL TMA atom that should be used with it +template +constexpr auto +sm90_cluster_shape_to_im2col_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { + static_assert(cute::rank(unimodal_cluster_shape) == 1, + "Use this function to figure out TMA for each mode individually."); + + if constexpr (cute::size(unimodal_cluster_shape) == 1) { + return cute::SM90_TMA_LOAD_IM2COL{}; + } + else { + return cute::SM90_TMA_LOAD_IM2COL_MULTICAST{}; + } +} + +// Collective tile traits struct that serves as a type list containing a tensor's mem layouts and atoms for the +template< + class GmemTiledCopy_, + class SmemLayout_, + class SmemCopyAtom_ = void +> +struct Sm90ImplicitGemmTileTraits { + using GmemTiledCopy = GmemTiledCopy_; + using SmemLayout = SmemLayout_; + using SmemCopyAtom = SmemCopyAtom_; +}; + +// Accepts a cutlass::layout::Tensor tag and computes the corresponding spatial dimension count +template +constexpr int +gmem_layout_tags_to_spatial_dims() { + static_assert(cute::is_same_v); + if constexpr (cute::is_same_v) { + return 1; + } + else if constexpr (cute::is_same_v) { + return 2; + } + else if constexpr (cute::is_same_v) { + return 3; + } + else { + static_assert(cutlass::detail::dependent_false); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl new file mode 100644 index 00000000..a08209ef --- /dev/null +++ b/span-attention/thirdparty/cutlass/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS_FPROP +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t || + cute::is_same_v || + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(cutlass::gemm::collective::detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + // For fprop, majorA = K, major B = K; + // For wgrad, majorA = MN, major B = MN; + // For dgrad, majorA = K, major B = MN; + static constexpr cute::GMMA::Major GmmaMajorA = + (ConvOp == conv::Operator::kWgrad) ? cute::GMMA::Major::MN : cute::GMMA::Major::K; + static constexpr cute::GMMA::Major GmmaMajorB = + (ConvOp == conv::Operator::kFprop) ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + // For wgrad kernel, tensor A uses tma tiled mode and tensor B uses tma im2col mode. + using GmemTiledCopyA = cute::conditional_t(ClusterShape_MNK{}))), + decltype(cutlass::conv::collective::detail::sm90_cluster_shape_to_im2col_tma_atom(cute::shape<1>(ClusterShape_MNK{})))>; + using GmemTiledCopyB = cute::conditional_t(ClusterShape_MNK{}))), + decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(cute::shape<0>(ClusterShape_MNK{})))>; + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}), + Step<_2,_1,_3>{})); + + constexpr static int NumSpatialDimensions = cutlass::conv::collective::detail::gmem_layout_tags_to_spatial_dims(); + + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = CollectiveConv< + DispatchPolicy, + TileShape_MNK, + ElementA, + ElementB, + TiledMma, + detail::Sm90ImplicitGemmTileTraits, + detail::Sm90ImplicitGemmTileTraits + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA auto kernel schedule +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + +/* +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) + // Cooperative schedule performs best for CUDA Toolkits with version >= 12.1 + + // For TileShape_M == 64, choosing KernelTmaWarpSpecialized as the KernelSchedule + // Since KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 + using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelImplicitTmaWarpSpecializedSm90PingPong, KernelImplicitTmaWarpSpecializedSm90Cooperative>; +#else + using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90; +#endif +*/ + using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelWarpSpecializedSchedule + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl new file mode 100644 index 00000000..2ca62c97 --- /dev/null +++ b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -0,0 +1,797 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_traits_sm90.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90.hpp" + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the parameterized dispatch policy for the TMA epilogue +template +constexpr auto +sm90_get_tma_dispatch_policy() { + using namespace cute; + + constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128); + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); + // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs + constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_tma_ptr_array_v; + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + return cute::conditional_t, + Sm90PtrArrayTmaWarpSpecialized, + Sm90TmaWarpSpecialized>{}; +} + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +sm90_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // ColMajor C/D (M-major) + if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. +template +constexpr auto +sm90_compute_tile_shape_or_override() { + if constexpr (cute::is_same_v) { + auto epi_tile = [&] () { + if constexpr (detail::sm90_is_cooperative_v) { + auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else if constexpr (detail::sm90_is_warp_specialized_v) { + constexpr int N_perf = sizeof_bits_v == 8 ? 64 : 32; + auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::min(Int{}, size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported schedule."); + } + }(); + + return cute::transform(epi_tile, seq<0,1>{}, + [] (auto epi_tiler, auto I) { + auto cta_tiler = make_layout(get(TileShape_MNK{})); + // This is a multimodal CTA tiler, transform before returning + if constexpr (depth(cta_tiler) > 0) { + // This is an implicit multimodal tiler, match profile and return + if constexpr (tuple_size_v == 1) { + return make_tile(epi_tiler); + } + // This is an explicit multimodal tiler, compose out epi tiler + else { + return composition(cta_tiler, epi_tiler); + } + } + // This is a flat CTA tiler, no need for transformation + else { + return epi_tiler; + } + }); + } + else if constexpr (cute::is_tuple::value) { + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(M == 64 && detail::sm90_is_warp_specialized_v || + M == 128 && detail::sm90_is_cooperative_v, "Unsupported tile shape"); + static_assert(N % 16 == 0, "Unsupported tile shape"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} + +// callbacks builder with TMA aux out +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && sizeof_bits_v == 1> +> { + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem + >; +}; + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the fusion operation performed and the dispatch policy to use. +template < + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class FusionOpOrCallbacks, + class DispatchPolicy +> +struct Sm90TmaBuilderImpl { + // Passing void D disables destination store + smem allocation + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + + // Passing void C disables source load + smem allocation + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using CopyOpS2G = cute::conditional_t, + SM90_TMA_STORE_IM2COL, + SM90_TMA_STORE + >; + using CopyOpG2S = cute::conditional_t, + SM90_TMA_LOAD_IM2COL, + SM90_TMA_LOAD + >; + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + + using FusionDispatchPolicy = Sm90TmaWarpSpecialized; + + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using FusionCallbacks = + typename CallbacksBuilder< + FusionDispatchPolicy, + FusionOpOrCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; + + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD_, + GmemStrideTypeD, + FusionCallbacks, + CopyOpG2S, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + CopyOpS2G, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Descriptor classes for defining EVT nodes +// Some of the epilogue visitor nodes require non-intuitive template arguments +// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the +// builder classes. Here we provide a set of descriptor classes that resolve +// these template arguments from more intuitive types such as Stride, Layout + +// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD +template< + typename TileShape_MNK, + typename EpilogueTileType, + typename ElementC, + typename ElementD, + typename Schedule +> +struct EpilogueDescriptor { + using TileShape = TileShape_MNK; + using EpilogueTile = + decltype( + detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK + >() + ); + using DispatchPolicy = + decltype( + detail::sm90_get_tma_dispatch_policy< + TileShape_MNK, EpilogueTile, + ElementC, ElementD, Schedule + >() + ); + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpS2R = + decltype(detail::sm90_get_smem_load_op_for_source()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpR2S = + decltype(detail::sm90_get_smem_store_op_for_accumulator()); +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +// No-smem builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + FloatRoundStyle RoundStyle +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + fusion::LinearCombination, + cute::enable_if_t || + cute::is_same_v >> { + + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC>; + + using CollectiveOp = cute::conditional_t< + cute::is_same_v, + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault>>, + // Epilogue for Ptr-Array and Grouped Gemm + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogueArray< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + Schedule>> + >; +}; + +// Tma warp-specialized builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOperation +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + Schedule, + FusionOperation, + cute::enable_if_t || + cute::is_same_v || + cute::is_same_v >> { +private: + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + using EpilogueTile_MN = + decltype(detail::sm90_compute_tile_shape_or_override()); + using DispatchPolicy = + decltype(detail::sm90_get_tma_dispatch_policy()); + +public: + using CollectiveOp = + typename detail::Sm90TmaBuilderImpl< + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + FusionOperation, + DispatchPolicy + >::CollectiveOp; +}; + +// Auto builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class FusionOperation +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleAuto, + FusionOperation, + void> { +private: + static_assert(cute::is_same_v>, + "Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead."); + + // Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance) + // since TMA epilogues are not compatible with non-TMA non-WS mainloops + using EpilogueSchedule = NoSmemWarpSpecialized; + using _CollectiveBuilder = CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueSchedule, + FusionOperation + >; + +public: + using CollectiveOp = typename _CollectiveBuilder::CollectiveOp; +}; + +// DEPRECATED Tma warp-specialized builder for elementwise fusion +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class UnusedFusionOp +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + UnusedFusionOp, + cute::enable_if_t || + cute::is_base_of_v >> { +private: + using FusionOp = + fusion::LinCombEltAct; + using ImplSchedule = + cute::conditional_t, + TmaWarpSpecialized, TmaWarpSpecializedCooperative>; + +public: + using CollectiveOp = + typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + ImplSchedule, + FusionOp + >::CollectiveOp; +}; + +// DEPRECATED Tma warp-specialized builder for bias + elementwise fusion +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class UnusedFusionOp +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]] +CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + UnusedFusionOp, + cute::enable_if_t || + cute::is_base_of_v >> { +private: + using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK>()); + // MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's + // defined as decltype of a detail::sm90_get_tma_dispatch_policy call. + // Instead, we paste in the contents of that function. A natural refactoring + // would be to create a type alias in the detail namespace. + using DispatchPolicy = Sm90TmaWarpSpecialized< + /* StagesC = */ size(shape_div(take<0, 2>(TileShape_MNK{}), EpilogueTile_MN{})), + /* StagesD = */ 2, + /* FragmentSize = */ size(EpilogueTile_MN{}) / (detail::sm90_is_cooperative_v ? 256 : 128), + /* ReuseSmemC = */ sizeof_bits_v == sizeof_bits_v, + false + >; + + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename Schedule::ElementT, EpilogueTile_MN>()); + using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename Schedule::ElementT>()); + using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute + >; + using FusionCallbacksAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux + >; + + using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< + Schedule::template ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementBias, ElementC_, ElementCompute + >; + using FusionCallbacksNoAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN + >; + + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = gemm::TagToStrideC_t; + using GmemStrideTypeD = gemm::TagToStrideC_t; + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + +public: + using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< + DispatchPolicy::StagesC, + DispatchPolicy::StagesD, + DispatchPolicy::FragmentSize, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + cute::conditional_t, + SM90_TMA_LOAD, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + SM90_TMA_STORE, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC + >; +}; + +// CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels +// since swapping NNN kernels input matrix and transposing its output at the same time then +// we can get TTN kernel. +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + FloatRoundStyle RoundStyle +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + cutlass::gemm::EpilogueTransposed, + fusion::LinearCombination, + void> { + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC>; + + using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueTransposed> + >; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl new file mode 100644 index 00000000..cd2639c5 --- /dev/null +++ b/span-attention/thirdparty/cutlass/include/cutlass/epilogue/collective/builders/sm90_common.inl @@ -0,0 +1,80 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Selects the largest vectorized smem store atom available +template +constexpr auto +sm90_get_smem_store_op_for_accumulator() { + using namespace cute; + + if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) { + return SM90_U16x8_STSM_T{}; + } + else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) { + return SM90_U32x4_STSM_N{}; + } + else { + // auto-vectorizing store + return AutoVectorizingCopyWithAssumedAlignment{}; + } +} + +// Selects the largest vectorized smem load atom available +template +constexpr auto +sm90_get_smem_load_op_for_source() { + using namespace cute; + + // Reuse the logic from smem store selector + using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator()); + + if constexpr (cute::is_same_v) { + return SM75_U16x8_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U32x4_LDSM_N{}; + } + else { + // auto-vectorizing load + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } +} + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective::detail diff --git a/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl b/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl new file mode 100644 index 00000000..298793e8 --- /dev/null +++ b/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -0,0 +1,364 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/dependent_false.hpp" + +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// Some named constants +// +constexpr int tma_alignment_bytes = 16; +constexpr int cp_async_min_alignment_bytes = 4; +constexpr int sm90_smem_capacity_bytes = 232448; + +// Maps 2.x A matrix layout tag to respective GMMA major mode enum +template +constexpr cute::GMMA::Major +gmma_ss_tag_to_major_A() { + // MN major mode is only valid for non-TF32, non-int and non-fp8 MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_A() && + not cute::is_same_v && + sizeof(ElementA) != 1) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +// Maps 2.x B matrix layout tag to respective GMMA major mode enum +template +constexpr cute::GMMA::Major +gmma_ss_tag_to_major_B() { + // MN major mode is only valid for non-TF32, non-int and non-fp8 MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_B() && + not cute::is_same_v && + sizeof(ElementB) != 1) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +template +constexpr cute::GMMA::Major +gmma_rs_tag_to_major_A() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_A()) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +template +constexpr cute::GMMA::Major +gmma_rs_tag_to_major_B() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_B()) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} +// Maps a rank-1 cute::Shape<> representing the cluster shape on to the TMA atom that should be used with it +template +constexpr auto +sm90_cluster_shape_to_tma_atom(UnimodalClusterShape) { + static_assert(cute::rank(UnimodalClusterShape{}) == 1, + "Use this function to figure out TMA for each mode individually."); + + if constexpr (cute::size(UnimodalClusterShape{}) == 1) { + return cute::SM90_TMA_LOAD{}; + } + else { + return cute::SM90_TMA_LOAD_MULTICAST{}; + } +} + +// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters. +template +constexpr auto +make_cp_async_gmem_tiled_copy() { + using namespace cute; + + using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>; + constexpr int TileSizeMN = cute::size(TileMN{}); + constexpr int TileSizeK = cute::size(TileK{}); + + // Maximize the number of threads along the gmem major mode to promote coalesced reads + // While making sure our thread layout tiles the threadblock tile evenly + + if constexpr (cutlass::gemm::detail::is_k_major()) { + // K major thread layout for K major gmem + constexpr int threads_major = (ThreadCount >= TileSizeK / Alignment) ? (TileSizeK / Alignment) : ThreadCount; + constexpr int threads_minor = ThreadCount / threads_major; + static_assert(threads_major > 0); + static_assert(ThreadCount % threads_major == 0); + static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0)); + return make_tiled_copy( + Copy_Atom, Element>{}, + Layout,Int>, + Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) { + // MN major thread layout for MN major gmem + constexpr int threads_major = (ThreadCount >= TileSizeMN / Alignment) ? (TileSizeMN / Alignment) : ThreadCount; + constexpr int threads_minor = ThreadCount / threads_major; + static_assert(threads_major > 0); + static_assert(ThreadCount % threads_major == 0); + static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); + return make_tiled_copy( + Copy_Atom, Element>{}, + Layout,Int>, + Stride< _1,Int>>{}, + Layout,_1>>{}); + } + else { + static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); + } +} + +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the optimal GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +constexpr auto +rs_smem_selector() { + using namespace cute; + + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + if constexpr (major == GMMA::Major::MN) { + if constexpr (sizeof(ElementType) == 4){ + if constexpr (is_ws_transposed_B) { + // only optimized transpositionB(SW32 and SW128 for tf32) can be used, but prefer SW32 due to free bank conflict + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_SW32_Atom{})"); + } + } + else { + // Fall into SW32 due to free bank conflict + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + } + // Used for int8, fp8, fp16 and bf16 I/O kernels + else if constexpr (sizeof(ElementType) == 1 || sizeof(ElementType) == 2) { + if constexpr (sizeof(ElementType) == 1 && is_ws_transposed_B) { + // Only optimized transpositionB (SW32 for int8 and fp8) can be used + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_128_Atom{})"); + } + } + else { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Smem selector does not support this element type"); + } + } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +ss_smem_selector() +{ + using namespace cute; + + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + if constexpr (major == GMMA::Major::MN) { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +template +constexpr bool +is_input_size_two_bytes() { + return (sizeof(ElementA) == 2 && sizeof(ElementB) == 2); +} + +template +constexpr bool +is_input_fp8() { + return ((cute::is_same_v || cute::is_same_v) && + (cute::is_same_v || cute::is_same_v)); +} + +// We need to handle the tuples in this function since it is used in SFINAE dispatch in the CollectiveBuilder. +// At that point, it is not guaranteed that the tuples have been split out into the required parts. +template +constexpr bool +is_use_rmem_A() { + + using ElementA = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementA>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementB>; + + constexpr bool IsABDifferentWidth = cute::sizeof_bits_v != cute::sizeof_bits_v; + constexpr bool HasScales = cute::is_tuple::value ^ cute::is_tuple::value; + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A() && + cutlass::gemm::detail::is_k_major_B(); + constexpr bool IsUseRmemA = (!IsInputSizeTwoBytes && !IsLayoutAkBk) || IsABDifferentWidth || HasScales; + return IsUseRmemA; +} + +template +constexpr bool +is_aligned() { + return ((sizeof(ElementA) * AlignmentA) % RequiredAlignment == 0) && + ((sizeof(ElementB) * AlignmentB) % RequiredAlignment == 0); +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl new file mode 100644 index 00000000..25b1f848 --- /dev/null +++ b/span-attention/thirdparty/cutlass/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -0,0 +1,1003 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_single_affine_transformed_input(StageCount stage_count) { + return stages; +} + +template +constexpr int get_bits_for_possibly_void_element() { + if constexpr (cute::is_same_v) { + return 0; + } + else { + return sizeof_bits::value; + } +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout stage_count) { + + // 32 bytes to account for barriers etc. + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr int scale_zero_k_tile = 1; + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto s_bits = get_bits_for_possibly_void_element(); + constexpr auto z_bits = get_bits_for_possibly_void_element(); + + constexpr auto scale_bytes = cutlass::bits_to_bytes(s_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile); + constexpr auto zero_bytes = cutlass::bits_to_bytes(z_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile); + static_assert(scale_bytes % 128 == 0, "Scale bytes must be a multiple of 128"); + static_assert(zero_bytes % 128 == 0, "Zero bytes must be a multiple of 128"); + + // When scales are void, s_bits will be 0 so no smem will be allocated for scales. + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(scale_bytes + zero_bytes + mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +template +constexpr bool +is_swapAB(){ + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAkBmn = cutlass::gemm::detail::is_k_major_A() && + cutlass::gemm::detail::is_mn_major_B(); + constexpr bool SwapAB = !IsInputSizeTwoBytes && IsLayoutAkBmn; + return SwapAB; +} + +template +constexpr bool +is_warpspecialized_transpose_B(){ + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAmnBmn = cutlass::gemm::detail::is_mn_major_A() && + cutlass::gemm::detail::is_mn_major_B(); + constexpr bool IsWarpSpecialized = cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v; + constexpr bool IsWarpSpecializedTransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn && IsWarpSpecialized; + return IsWarpSpecializedTransposeB; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_RS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool SwapAB = detail::is_swapAB(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B< + ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>(); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_RS Mixed Scaled GEMM +template < + class ElementPairA_, + class GmemLayoutATag_, + int AlignmentA, + class ElementPairB_, + class GmemLayoutBTag_, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementPairA_, + GmemLayoutATag_, + AlignmentA, + ElementPairB_, + GmemLayoutBTag_, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v)> +> { + +private: + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>; + static constexpr bool NeitherIsTuple = !cute::is_tuple::value && !cute::is_tuple::value; + +public: + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; + static_assert(cute::is_tuple::value ^ cute::is_tuple::value || + (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + "Either A OR B must be a tuple or the widths of A and B must be different."); + + static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; + + using GmemLayoutATag = GmemLayoutATag_; + using GmemLayoutBTag = GmemLayoutBTag_; + + using ElementPairA = cute::conditional_t, ElementPairA_>; + using ElementPairB = cute::conditional_t, ElementPairB_>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B< + ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>(); + static_assert(!IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B."); + + // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands. + static constexpr bool SwapAB = !IsATransformed; + + // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. + static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; + + using ElementMma = cute::conditional_t; + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementMma, ElementMma, ElementAccumulator, TileShape_MNK, GMMA::Major::K, TiledMmaGmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + using RealElementA = cute::conditional_t; + using RealElementB = cute::conditional_t; + static constexpr int PipelineStages = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + + // We pack the scale data with the operand that will be optionally scaled and converted before MMA. + using StrideA = TagToStrideA_t; + using StrideB = TagToStrideB_t; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementPairA, + StrideA, + ElementPairB, + StrideB, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); + using AtomLayoutMNK = cute::conditional_t || + IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_SS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t && + not detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>())); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmma; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_CpAsync +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct [[deprecated("Use one of KernelCpAsyncWarpSpecialized schedules instead")]] +CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + // Map to warp-specialized kernels for better performance + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelCpAsyncWarpSpecialized + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_CpAsync_WS_SS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A() + > +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static_assert(detail::is_aligned(), + "Minimum alignment required for cp.async is 4B."); + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t, + Layout(TileShape_MNK{}) < 128) ? 1 : 2>,_1,_1>>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + static constexpr int NumLoadWarpGroups = cute::is_same_v ? 2 : 1; + + using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override< + detail::sm90_smem_capacity_bytes, ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{}); + + using DispatchPolicy = MainloopSm90CpAsyncGmmaWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_CpAsync_WS_RS +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + detail::is_use_rmem_A() + > +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + static_assert(detail::is_aligned(), + "Minimum alignment required for cp.async is 4B."); + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool SwapAB = detail::is_swapAB(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B< + ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>(); + + using AtomLayoutMNK = cute::conditional_t, + Layout(TileShape_MNK{}) < 128) ? 1 : 2>,_1,_1>>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + + static constexpr int NumLoadWarpGroups = 1; + + using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override< + detail::sm90_smem_capacity_bytes, ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{}); + + using DispatchPolicy = MainloopSm90CpAsyncGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA auto kernel schedule +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + +using ExtractedElementA = detail::deduce_mixed_width_dtype_t<0, ElementA>; +using ExtractedElementB = detail::deduce_mixed_width_dtype_t<0, ElementB>; + +static constexpr bool IsTmaCompatible = detail::is_aligned< + ExtractedElementA, AlignmentA, ExtractedElementB, AlignmentB, detail::tma_alignment_bytes>(); + +// Users opt into scales via the builder by passing a tuple of Elements for the input that will be scaled. We detect +// scale support if ONLY one of the inputs have tuples to describe them. +static constexpr bool OnlyOneIsTuple = cute::is_tuple::value ^ cute::is_tuple::value; +static constexpr bool IsDifferentWidth = sizeof_bits::value != sizeof_bits::value; +static constexpr bool IsMixedWidthInput = IsDifferentWidth || (IsDifferentWidth && OnlyOneIsTuple); + +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) + // Persistent schedules perform best for CUDA Toolkits with version >= 12.1 + // KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 + using KernelTmaWarpSpecializedScheduleSameInput = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelTmaWarpSpecializedPingpong, KernelTmaWarpSpecializedCooperative>; + + using KernelTmaWarpSpecializedScheduleMixedInput = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelTmaWarpSpecializedPingpongMixedInput, KernelTmaWarpSpecializedCooperativeMixedInput>; + + using KernelTmaWarpSpecializedSchedule = cute::conditional_t; +#else + using KernelTmaWarpSpecializedSchedule = cute::conditional_t; +#endif + + // Non-persistent schedule is a safer choice for CpAsync kernels due to register pressure + using KernelCpAsyncWarpSpecializedSchedule = KernelCpAsyncWarpSpecialized; + using KernelSchedule = cute::conditional_t; + static_assert((cute::is_same_v && IsMixedWidthInput) || !IsMixedWidthInput, "Only TMA warp specialized kernels are supported for mixed width input."); + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelSchedule + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/span-attention/thirdparty/cutlass/media/docs/build/building_in_windows_with_visual_studio.md b/span-attention/thirdparty/cutlass/media/docs/build/building_in_windows_with_visual_studio.md new file mode 100644 index 00000000..2c69e1ac --- /dev/null +++ b/span-attention/thirdparty/cutlass/media/docs/build/building_in_windows_with_visual_studio.md @@ -0,0 +1,90 @@ +[README](../../README.md#documentation) > **CUTLASS 3.0: Building on Windows with Visual Studio** + +# Building on Windows with Visual Studio + +CUTLASS 3.2 reintroduces support for the Microsoft Visual Studio compiler on Windows. +Users and developers may build either +in Visual Studio's graphical integrated development environment, +or on the command line with `cmake --build`. + +# Software prerequisites + +1. Windows 10 or 11 + +2. Visual Studio 2019 version 16.11.27, or Visual Studio 2022 + +3. CUDA Toolkit (at least 12.2; earlier 12.x versions may work) + +4. CMake (at least 3.18) + +5. git + +6. Python (at least 3.6) + +Visual Studio must be installed *before* the CUDA Toolkit. +Otherwise, Visual Studio's build system won't know about CUDA. + +# Operating system settings + +By default, Windows restricts the maximum file path length (`MAX_PATH`) to 260 characters. +CUTLASS has many files and directory paths that challenge this requirement. +As a result, CUTLASS is unlikely to build with this default setting. +The choice of source and build directories affect path lengths, +so the kinds of errors and whether they occur may depend on this. +Symptoms may vary, from errors when running `cmake` +(e.g., during the "generating library instances" step) to build failures. + +CUTLASS recommends changing the maximum file path length setting +and rebooting the computer before attempting to clone or build CUTLASS. +Windows 10 (as of version 1607) and 11 permit changing this setting +by making sure that the following registry key exists, +and that its value is set to 1. + +``` +Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\FileSystem\LongPathsEnabled +``` + +After changing the registry key's value, reboot the computer first +before attempting to clone or build CUTLASS. + +[This Microsoft help article](https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation?tabs=registry) +explains different ways to change the registry setting. + +# Set up build environment + +1. Run "git bash" to get a familiar command-line interface + +2. Edit `~/.profile` and set the environment variables as needed to access the CUTLASS repository + +3. Clone the CUTLASS repository + +4. Create the `build` subdirectory in the CUTLASS clone directory, and run CMake in it, + specifying whatever CMake options are desired, e.g., + `cmake .. -DCUTLASS_NVCC_ARCHS=90a` + +Alternate approaches may rely on the CMake GUI and/or Windows' native command line. + +# Building + +A successful CMake run will create a `CUTLASS.sln` Visual Studio "solution" file in the build directory. +One can open this in Visual Studio and build the entire solution or any subset of projects as desired. +It may be necessary to limit maximum build parallelism by setting the appropriate Visual Studio option. + +Alternately, one can run `cmake --build . --config Release -j 4` in the build directory. +Replace 4 with the desired maximum build parallelism. +It's important to put the `--build` option before the period that signifies the build directory. +The `--config` option specifies the kind of build; +`--config Release` builds a Release build, while `--config Debug` builds a Debug build. +Unlike with CMake's Makefile or Ninja generators, +`CMAKE_BUILD_TYPE` has no effect on the Visual Studio generator, +because the Visual Studio generator creates all build configurations. + +# Tips + +With Windows builds, one may find that CMake reruns unnecessarily. +For example, cancelling a build and starting it again may rerun CMake. +This will in turn touch build files that result in unnecessary rebuilds. +One work-around is to set the CMake option `CMAKE_SUPPRESS_REGENERATION=ON`. +However, this turns off CMake's ability to detect on its own when it needs to rerun. +As a result, one will need to know when to rerun CMake by hand. + diff --git a/span-attention/thirdparty/cutlass/media/docs/build/building_with_clang_as_host_compiler.md b/span-attention/thirdparty/cutlass/media/docs/build/building_with_clang_as_host_compiler.md new file mode 100644 index 00000000..c5350060 --- /dev/null +++ b/span-attention/thirdparty/cutlass/media/docs/build/building_with_clang_as_host_compiler.md @@ -0,0 +1,59 @@ +[README](../../README.md#documentation) > **CUTLASS 3: Building with Clang as host compiler** + +# Building with Clang as host compiler + +CUTLASS 3.2(.1) reintroduces support for building with +Clang as host compiler, and NVCC as device compiler. +This is NOT the same as building with +Clang as both host and device compiler ("CUDA Clang"). + +# Software prerequisites + +1. Clang (regularly tested with Clang 17; + occasionally tested with Clang 10 and greater) + +2. CUDA Toolkit (tested with 12.2; other versions likely work) + +3. CMake (at least 3.18) + +4. git + +5. Python (at least 3.6) + +Experience with Ubuntu 22.04 LTS is that +clang requires the following packages to be installed. + +```bash +$ sudo apt-get install clang cmake ninja-build pkg-config libgtk-3-dev liblzma-dev libstdc++-12-dev +``` + +A symptom of not installing all needed dependencies +is the following error when attempting to use clang: +`"/usr/bin/ld: cannot find -lstdc++: No such file or directory"`. + +# Running CMake + +## Required CMake options + +The Clang build requires specifying the following CMake options. +Replace `` with the path to your `clang++` executable. +You may use `clang++` directly if it is in your `PATH`. + +* `CMAKE_CXX_COMPILER=` +* `CMAKE_CUDA_HOST_COMPILER=` + +One must set both! It's not enough just to set the `CXX` environment +variable, for example. Symptoms of only setting `CMAKE_CXX_COMPILER` +(or only setting the `CXX` environment variable) include `cc1plus` +(GCC's compiler executable) reporting build errors due to it not +understanding Clang's command-line options. + +Users can also specify a particular CUDA Toolkit version +by setting the CMake option `CMAKE_CUDA_COMPILER` +to the path to the `nvcc` executable +that lives in the CUDA Toolkit's directory. For example, +if `${PATH_TO_CUDA_TOOLKIT}` is the CUDA Toolkit directory, +then one can set `CMAKE_CUDA_COMPILER` as follows. + +* `CMAKE_CUDA_COMPILER=${PATH_TO_CUDA_TOOLKIT}/bin/nvcc` + diff --git a/third_party/patch/ppu_flash-attn.patch b/third_party/patch/ppu_flash-attn.patch deleted file mode 100644 index 0ad5797a..00000000 --- a/third_party/patch/ppu_flash-attn.patch +++ /dev/null @@ -1,406 +0,0 @@ -diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt -new file mode 100644 -index 0000000..d5114d5 ---- /dev/null -+++ b/csrc/CMakeLists.txt -@@ -0,0 +1,124 @@ -+cmake_minimum_required(VERSION 3.18) -+ -+project(FLASHATTN LANGUAGES CXX CUDA) -+option(CMAKE_EXPORT_COMPILE_COMMANDS ON) -+ -+set(FLASHATTN_CUDA_VERSION -+ "11.8" -+ CACHE STRING "cuda version") -+set(FLASHATTN_GPU_ARCHS -+ "80;86" -+ CACHE STRING "gpu archs") -+set(FLASHATTN_USE_EXTERNAL_CUTLASS -+ OFF -+ CACHE BOOL "use external cutlass target") -+set(FLASHATTN_USE_CUDA_STATIC -+ OFF -+ CACHE BOOL "use static CUDA") -+# Generate SASS for each architecture -+foreach(arch ${FLASHATTN_GPU_ARCHS}) -+ list(APPEND GENCODES "${arch}-real") -+endforeach() -+# Generate PTX for the last architecture -+list(GET FLASHATTN_GPU_ARCHS -1 LATEST_GPU_ARCH) -+list(APPEND GENCODES "${LATEST_GPU_ARCH}-virtual") -+set(CMAKE_CUDA_ARCHITECTURES ${GENCODES}) -+ -+find_package(CUDAToolkit ${FLASHATTN_CUDA_VERSION} EXACT REQUIRED) -+ -+if(FLASHATTN_USE_CUDA_STATIC) -+ set(FLASHATTN_CUDA_CUDART CUDA::cudart_static) -+else() -+ set(FLASHATTN_CUDA_CUDART CUDA::cudart) -+endif() -+ -+if(FLASHATTN_USE_EXTERNAL_CUTLASS) -+ message("flash attn use external cutlass") -+ find_package(NvidiaCutlass PATHS ${CUTLASS_INSTALL_PATH}) -+ set(CUTLASS_INCLUDE_DIR ${CUTLASS_INSTALL_PATH}/include) -+ set(CUTLASS_LIBRARY NvidiaCutlass) -+else() -+ message("flash attn use internal cutlass") -+ message("========== CUTLASS ==========") -+ set(CUTLASS_ENABLE_TESTS -+ OFF -+ CACHE BOOL "Enable CUTLASS Tests") -+ set(CUTLASS_ENABLE_TOOLS -+ OFF -+ CACHE BOOL "Enable CUTLASS Tools") -+ set(CUTLASS_ENABLE_EXAMPLES -+ OFF -+ CACHE BOOL "Enable CUTLASS Examples") -+ set(CUTLASS_NVCC_ARCHS -+ ${FLASHATTN_GPU_ARCHS} -+ CACHE STRING "The SM architectures requested.") -+ add_subdirectory(${PROJECT_SOURCE_DIR}/cutlass EXCLUDE_FROM_ALL) -+ set(CUTLASS_LIBRARY nvidia::cutlass::cutlass) -+ unset(CUTLASS_ENABLE_TESTS) -+ unset(CUTLASS_ENABLE_TOOLS) -+ unset(CUTLASS_ENABLE_EXAMPLES) -+ unset(CUTLASS_NVCC_ARCHS) -+ message("===========================") -+endif() -+ -+set(FLASHATTN_ROOT ${PROJECT_SOURCE_DIR}/flash_attn/src) -+set(FLASHATTN_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/flash_attn/src -+ ${PROJECT_SOURCE_DIR}) -+file(GLOB_RECURSE FLASHATTN_SRCS ${FLASHATTN_ROOT}/*.cu) -+# no bwd -+file(GLOB_RECURSE FLASHATTN_BWD_SRCS ${FLASHATTN_ROOT}/*_bwd_*.cu) -+foreach(file ${FLASHATTN_BWD_SRCS}) -+ list(REMOVE_ITEM FLASHATTN_SRCS "${file}") -+endforeach() -+ -+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF_OPERATORS__") -+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF_CONVERSIONS__") -+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_HALF2_OPERATORS__") -+list(APPEND FLASHATTN_CUDA_FLAGS "-U__CUDA_NO_BFLOAT16_CONVERSIONS__") -+list(APPEND FLASHATTN_CUDA_FLAGS "-mllvm") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-max-vreg-count=255") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-sink-matrix-addr=true") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-max-alloca-byte-size=320") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-sink-async-addr=true") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-sink-load-addr=true") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-sink-store-addr=true") -+list(APPEND FLASHATTN_CUDA_FLAGS "-alippu-alloca-half-ldst-simplify=true") -+list(APPEND FLASHATTN_CUDA_FLAGS "--expt-relaxed-constexpr") -+list(APPEND FLASHATTN_CUDA_FLAGS "--expt-extended-lambda") -+list(APPEND FLASHATTN_CUDA_FLAGS "--use_fast_math") -+# list(APPEND FLASHATTN_CUDA_FLAGS "-mllvm") -+# list(APPEND FLASHATTN_CUDA_FLAGS "--ptxas-options=-v") -+# list(APPEND FLASHATTN_CUDA_FLAGS "--ptxas-options=-O2") -+# list(APPEND FLASHATTN_CUDA_FLAGS "-lineinfo") -+# list(APPEND FLASHATTN_CUDA_FLAGS "--save-temps") -+list(APPEND FLASHATTN_CUDA_FLAGS "-DUSE_PPU") -+list(APPEND FLASHATTN_CUDA_FLAGS "-DUSE_AIU=1") -+list(APPEND FLASHATTN_CUDA_FLAGS "-DACOMPUTE_VERSION=10000") -+ -+# Create an object library with the source files -+add_library(flash-attn-obj OBJECT ${FLASHATTN_SRCS}) -+set_target_properties(flash-attn-obj PROPERTIES CXX_STANDARD 17 CUDA_STANDARD 17) -+set_target_properties(flash-attn-obj PROPERTIES POSITION_INDEPENDENT_CODE ON) -+target_compile_options(flash-attn-obj PRIVATE $<$:${FLASHATTN_CUDA_FLAGS}>) -+target_include_directories(flash-attn-obj PUBLIC ${FLASHATTN_INCLUDE_DIR} ${CUTLASS_INCLUDE_DIR}) -+ -+# Create STATIC library from the object files -+add_library(flash-attn_static STATIC $) -+set_target_properties(flash-attn_static PROPERTIES OUTPUT_NAME "flash-attn") -+target_link_libraries(flash-attn_static PRIVATE ${FLASHATTN_CUDA_CUDART}) -+ -+# Create SHARED library from the object files -+add_library(flash-attn SHARED $) -+target_link_libraries(flash-attn PRIVATE ${FLASHATTN_CUDA_CUDART}) -+ -+# Create alias for static library -+add_library(flash-attention::flash-attn_static ALIAS flash-attn_static) -+ -+# Create alias for shared library -+add_library(flash-attention::flash-attn ALIAS flash-attn) -+ -+# Install both static and shared libraries -+install(TARGETS flash-attn_static flash-attn -+ EXPORT flash-attn -+ # PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX} -+) -diff --git a/csrc/flash_attn/src/flash.cu b/csrc/flash_attn/src/flash.cu -new file mode 100644 -index 0000000..5cf1214 ---- /dev/null -+++ b/csrc/flash_attn/src/flash.cu -@@ -0,0 +1,16 @@ -+#include "cuda.h" -+#include "flash.h" -+#include "static_switch.h" -+#include -+ -+void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel) { -+ FP16_SWITCH(!params.is_bf16, [&] { -+ FWD_HEADDIM_SWITCH(params.d, [&] { -+ if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 -+ run_mha_fwd_(params, stream); -+ } else { -+ run_mha_fwd_splitkv_dispatch(params, stream); -+ } -+ }); -+ }); -+} -diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h -index 4a33f3d..1af5dfa 100644 ---- a/csrc/flash_attn/src/flash.h -+++ b/csrc/flash_attn/src/flash.h -@@ -5,20 +5,10 @@ - #pragma once - - #include -+#include -+#include - #include - --#ifdef OLD_GENERATOR_PATH --#include --#else --#include --#endif -- --#include // For at::cuda::philox::unpack -- --constexpr int TOTAL_DIM = 0; --constexpr int H_DIM = 1; --constexpr int D_DIM = 2; -- - //////////////////////////////////////////////////////////////////////////////////////////////////// - - struct Qkv_params { -@@ -49,6 +39,7 @@ struct Qkv_params { - //////////////////////////////////////////////////////////////////////////////////////////////////// - - struct Flash_fwd_params : public Qkv_params { -+ void SetCudaConfig(const cudaDeviceProp* dprop_) { dprop = dprop_; } - - // The O matrix (output). - void * __restrict__ o_ptr; -@@ -115,7 +106,7 @@ struct Flash_fwd_params : public Qkv_params { - int window_size_left, window_size_right; - - // Random state. -- at::PhiloxCudaState philox_args; -+ // at::PhiloxCudaState philox_args; - - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t * rng_state; -@@ -133,6 +124,9 @@ struct Flash_fwd_params : public Qkv_params { - - void * __restrict__ alibi_slopes_ptr; - index_t alibi_slopes_batch_stride; -+ -+ // Cuda Device Properties -+ const cudaDeviceProp* dprop; - }; - - //////////////////////////////////////////////////////////////////////////////////////////////////// -@@ -179,6 +173,8 @@ struct Flash_bwd_params : public Flash_fwd_params { - - //////////////////////////////////////////////////////////////////////////////////////////////////// - -+void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false); -+ - template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); - template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h -index ee9b80a..c745ade 100644 ---- a/csrc/flash_attn/src/flash_fwd_kernel.h -+++ b/csrc/flash_attn/src/flash_fwd_kernel.h -@@ -110,11 +110,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { - // Save seed and offset for backward. If we don't have this here, the 0-th thread block might - // exit early and no one saves the rng state. -- if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { -- auto seeds = at::cuda::philox::unpack(params.philox_args); -- params.rng_state[0] = std::get<0>(seeds); -- params.rng_state[1] = std::get<1>(seeds); -- } -+ // if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { -+ // auto seeds = at::cuda::philox::unpack(params.philox_args); -+ // params.rng_state[0] = std::get<0>(seeds); -+ // params.rng_state[1] = std::get<1>(seeds); -+ // } - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; -@@ -332,6 +332,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - } - -+#if 0 - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; -@@ -341,6 +342,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - params.rng_state[0] = seed; - params.rng_state[1] = std::get<1>(seeds); - } -+#endif - - clear(acc_o); - -@@ -447,6 +449,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - - //PPU: move data convert after dropout, dropout need use ppu C-layout as random result. - #ifdef USE_PPU -+#if 0 - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); - if (Return_softmax) { -@@ -469,12 +472,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - flash::apply_dropout(acc_s_drop, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); - } -+#endif - - //PPU: convert output C layout to input A and data type convert. - Tensor rP = flash::convert_acc(scores); - Tensor tOrP = make_tensor(rP.data(), make_layout(get<0>(tSrQ.layout()), get<1>(acc_s.layout()), get<2>(acc_s.layout()))); - // if (cute::thread0()) { print(tOrP); } - #else -+#error - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) -@@ -566,6 +571,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - - //PPU: move data convert after dropout, dropout need use ppu C-layout as random result. - #ifdef USE_PPU -+#if 0 - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); - if (Return_softmax) { -@@ -588,6 +594,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi - flash::apply_dropout(acc_s_drop, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); - } -+#endif - - //PPU: convert output C layout to input A and data type convert. - Tensor rP = flash::convert_acc(scores); -diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h -index 4437cdb..584890c 100644 ---- a/csrc/flash_attn/src/flash_fwd_launch_template.h -+++ b/csrc/flash_attn/src/flash_fwd_launch_template.h -@@ -4,12 +4,23 @@ - - #pragma once - --#include -+#include "cuda.h" - - #include "static_switch.h" - #include "flash.h" - #include "flash_fwd_kernel.h" - -+#define CUDA_CHECK(status) \ -+ { \ -+ cudaError_t error = status; \ -+ if (error != cudaSuccess) { \ -+ std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ -+ << " at line: " << __LINE__ << std::endl; \ -+ exit(EXIT_FAILURE); \ -+ } \ -+ } -+#define CUDA_KERNEL_LAUNCH_CHECK() CUDA_CHECK(cudaGetLastError()) -+ - template - __global__ void flash_fwd_kernel(Flash_fwd_params params) { - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false -@@ -56,7 +67,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { -- C10_CUDA_CHECK(cudaFuncSetAttribute( -+ CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; -@@ -64,7 +75,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); -- C10_CUDA_KERNEL_LAUNCH_CHECK(); -+ CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); -@@ -95,11 +106,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { -- C10_CUDA_CHECK(cudaFuncSetAttribute( -+ CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); -- C10_CUDA_KERNEL_LAUNCH_CHECK(); -+ CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); -@@ -129,7 +140,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } -- C10_CUDA_KERNEL_LAUNCH_CHECK(); -+ CUDA_KERNEL_LAUNCH_CHECK(); - }); - } - } -@@ -179,7 +190,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { - template - void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 96; -- auto dprops = at::cuda::getCurrentDeviceProperties(); -+ auto dprops = params.dprop; - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { -@@ -205,7 +216,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { - template - void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 128; -- auto dprops = at::cuda::getCurrentDeviceProperties(); -+ auto dprops = params.dprop; - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { -@@ -242,7 +253,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { - template - void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; -- auto dprops = at::cuda::getCurrentDeviceProperties(); -+ auto dprops = params.dprop; - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { -@@ -297,7 +308,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { -- C10_CUDA_CHECK(status_); -+ CUDA_CHECK(status_); - } - // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { -@@ -328,7 +339,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { - status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { -- C10_CUDA_CHECK(status_); -+ CUDA_CHECK(status_); - } - // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {