diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ce78b028..60f77f16 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -7,8 +7,14 @@ on: branches: [ "main" ] jobs: - call-workflow: + test-plugin: uses: ./.github/workflows/test-plugin.yml with: triton-ref: '05dc28be0e72dd496300a31b99a21a5a5118f8e9' # known good commit "[CI] refactor workflows (#2504)" triton-shared-ref: ${{ github.ref }} + + test-cpuref: + uses: ./.github/workflows/test-cpuref.yml + with: + triton-ref: '05dc28be0e72dd496300a31b99a21a5a5118f8e9' # known good commit "[CI] refactor workflows (#2504)" + triton-shared-ref: ${{ github.ref }} diff --git a/.github/workflows/test-cpuref.yml b/.github/workflows/test-cpuref.yml new file mode 100644 index 00000000..4819e559 --- /dev/null +++ b/.github/workflows/test-cpuref.yml @@ -0,0 +1,75 @@ +name: Triton-Shared Plugin Testing + +on: + workflow_call: + inputs: + triton-ref: + required: true + type: string + triton-shared-ref: + required: true + type: string + workflow_dispatch: + inputs: + triton-ref: + required: true + type: string + triton-shared-ref: + required: true + type: string + +jobs: + build_and_test_triton_shared: + runs-on: ubuntu-latest + + steps: + + - name: Checkout Triton + uses: actions/checkout@v4 + with: + repository: 'openai/triton' + ref: ${{ inputs.triton-ref }} + path: triton + submodules: 'recursive' + + - name: Checkout Triton-Shared + uses: actions/checkout@v4 + with: + ref: ${{ inputs.triton-shared-ref }} + path: triton/third_party/triton_shared + + - name: Clear Triton Cache + run: | + rm -rf ~/.triton + + - name: Update PATH + run: | + echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + + - name: Check pre-commit + run: | + cd triton + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + + - name: Build/Install Triton + run: | + export TRITON_CODEGEN_TRITON_SHARED=1 + cd triton/python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + python3 -m pip install ninja + python3 -m pip uninstall -y triton + python3 setup.py build + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Install PyTorch + run: | + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + + - name: Run an example + run: | + cd triton/python + export TRITON_SHARED_OPT_PATH="$(pwd)/build/$(ls $(pwd)/build | grep -i cmake)/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt" + export LLVM_BINARY_DIR="${HOME}/.triton/llvm/$(ls ${HOME}/.triton/llvm/ | grep -i llvm)/bin" + python3 ../third_party/triton_shared/python/examples/reduce.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 8185e3b7..69a304b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,9 @@ set(TRITON_SHARED_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) # Tablegen'd files +set(TRITON_BUILD_PYTHON_MODULE ON) add_subdirectory(include) add_subdirectory(lib) add_subdirectory(test) add_subdirectory(tools) +add_subdirectory(python) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt new file mode 100644 index 00000000..db2ad468 --- /dev/null +++ b/python/CMakeLists.txt @@ -0,0 +1,12 @@ + +# Python module +if(TRITON_BUILD_PYTHON_MODULE) + message(STATUS "Adding Triton-Shared Reference CPU Backend") + file(INSTALL + ${CMAKE_CURRENT_SOURCE_DIR}/__init__.py + ${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/Msan.h + ${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/CRunnerUtils.h + ${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/CRunnerUtils.cpp + DESTINATION ${PYTHON_THIRD_PARTY_PATH}/cpu/) + # TODO: perhaps we want to install binary files used in __init__.py +endif() diff --git a/python/ExecutionEngine/CRunnerUtils.cpp b/python/ExecutionEngine/CRunnerUtils.cpp new file mode 100644 index 00000000..48e2afbf --- /dev/null +++ b/python/ExecutionEngine/CRunnerUtils.cpp @@ -0,0 +1,192 @@ +//===- CRunnerUtils.cpp - Utils for MLIR execution ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements basic functions to manipulate structured MLIR types at +// runtime. Entities in this file are meant to be retargetable, including on +// targets without a C++ runtime, and must be kept C compatible. +// +//===----------------------------------------------------------------------===// + +#include "CRunnerUtils.h" +#include "Msan.h" + +#ifndef _WIN32 +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + defined(__DragonFly__) +#include +#else +#include +#endif +#include +#else +#include "malloc.h" +#endif // _WIN32 + +#include +#include +#include +#include +#include +#include + +#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS + +namespace { +template +void stdSort(uint64_t n, V *p) { + std::sort(p, p + n); +} + +} // namespace + +// Small runtime support "lib" for vector.print lowering. +// By providing elementary printing methods only, this +// library can remain fully unaware of low-level implementation +// details of our vectors. Also useful for direct LLVM IR output. +extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); } +extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); } +extern "C" void printF32(float f) { fprintf(stdout, "%g", f); } +extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); } +extern "C" void printString(char const *s) { fputs(s, stdout); } +extern "C" void printOpen() { fputs("( ", stdout); } +extern "C" void printClose() { fputs(" )", stdout); } +extern "C" void printComma() { fputs(", ", stdout); } +extern "C" void printNewline() { fputc('\n', stdout); } + +extern "C" void memrefCopy(int64_t elemSize, UnrankedMemRefType *srcArg, + UnrankedMemRefType *dstArg) { + DynamicMemRefType src(*srcArg); + DynamicMemRefType dst(*dstArg); + + int64_t rank = src.rank; + MLIR_MSAN_MEMORY_IS_INITIALIZED(src.sizes, rank * sizeof(int64_t)); + + // Handle empty shapes -> nothing to copy. + for (int rankp = 0; rankp < rank; ++rankp) + if (src.sizes[rankp] == 0) + return; + + char *srcPtr = src.data + src.offset * elemSize; + char *dstPtr = dst.data + dst.offset * elemSize; + + if (rank == 0) { + memcpy(dstPtr, srcPtr, elemSize); + return; + } + + int64_t *indices = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *srcStrides = static_cast(alloca(sizeof(int64_t) * rank)); + int64_t *dstStrides = static_cast(alloca(sizeof(int64_t) * rank)); + + // Initialize index and scale strides. + for (int rankp = 0; rankp < rank; ++rankp) { + indices[rankp] = 0; + srcStrides[rankp] = src.strides[rankp] * elemSize; + dstStrides[rankp] = dst.strides[rankp] * elemSize; + } + + int64_t readIndex = 0, writeIndex = 0; + for (;;) { + // Copy over the element, byte by byte. + memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize); + // Advance index and read position. + for (int64_t axis = rank - 1; axis >= 0; --axis) { + // Advance at current axis. + auto newIndex = ++indices[axis]; + readIndex += srcStrides[axis]; + writeIndex += dstStrides[axis]; + // If this is a valid index, we have our next index, so continue copying. + if (src.sizes[axis] != newIndex) + break; + // We reached the end of this axis. If this is axis 0, we are done. + if (axis == 0) + return; + // Else, reset to 0 and undo the advancement of the linear index that + // this axis had. Then continue with the axis one outer. + indices[axis] = 0; + readIndex -= src.sizes[axis] * srcStrides[axis]; + writeIndex -= dst.sizes[axis] * dstStrides[axis]; + } + } +} + +/// Prints GFLOPS rating. +extern "C" void printFlops(double flops) { + fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9); +} + +/// Returns the number of seconds since Epoch 1970-01-01 00:00:00 +0000 (UTC). +extern "C" double rtclock() { +#ifndef _WIN32 + struct timeval tp; + int stat = gettimeofday(&tp, nullptr); + if (stat != 0) + fprintf(stderr, "Error returning time from gettimeofday: %d\n", stat); + return (tp.tv_sec + tp.tv_usec * 1.0e-6); +#else + fprintf(stderr, "Timing utility not implemented on Windows\n"); + return 0.0; +#endif // _WIN32 +} + +extern "C" void *mlirAlloc(uint64_t size) { return malloc(size); } + +extern "C" void *mlirAlignedAlloc(uint64_t alignment, uint64_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#elif defined(__APPLE__) + // aligned_alloc was added in MacOS 10.15. Fall back to posix_memalign to also + // support older versions. + void *result = nullptr; + (void)::posix_memalign(&result, alignment, size); + return result; +#else + return aligned_alloc(alignment, size); +#endif +} + +extern "C" void mlirFree(void *ptr) { free(ptr); } + +extern "C" void mlirAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +extern "C" void *rtsrand(uint64_t s) { + // Standard mersenne_twister_engine seeded with s. + return new std::mt19937(s); +} + +extern "C" uint64_t rtrand(void *g, uint64_t m) { + std::mt19937 *generator = static_cast(g); + std::uniform_int_distribution distrib(0, m); + return distrib(*generator); +} + +extern "C" void rtdrand(void *g) { + std::mt19937 *generator = static_cast(g); + delete generator; +} + +#define IMPL_STDSORT(VNAME, V) \ + extern "C" void _mlir_ciface_stdSort##VNAME(uint64_t n, \ + StridedMemRefType *vref) { \ + assert(vref); \ + assert(vref->strides[0] == 1); \ + V *values = vref->data + vref->offset; \ + stdSort(n, values); \ + } +IMPL_STDSORT(I64, int64_t) +IMPL_STDSORT(F64, double) +IMPL_STDSORT(F32, float) +#undef IMPL_STDSORT + +#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS diff --git a/python/ExecutionEngine/CRunnerUtils.h b/python/ExecutionEngine/CRunnerUtils.h new file mode 100644 index 00000000..76b04145 --- /dev/null +++ b/python/ExecutionEngine/CRunnerUtils.h @@ -0,0 +1,499 @@ +//===- CRunnerUtils.h - Utils for debugging MLIR execution ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares basic classes and functions to manipulate structured MLIR +// types at runtime. Entities in this file must be compliant with C++11 and be +// retargetable, including on targets without a C++ runtime. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_CRUNNERUTILS_H +#define MLIR_EXECUTIONENGINE_CRUNNERUTILS_H + +#ifdef _WIN32 +#ifndef MLIR_CRUNNERUTILS_EXPORT +#ifdef mlir_c_runner_utils_EXPORTS +// We are building this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllexport) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#else +// We are using this library +#define MLIR_CRUNNERUTILS_EXPORT __declspec(dllimport) +#endif // mlir_c_runner_utils_EXPORTS +#endif // MLIR_CRUNNERUTILS_EXPORT +#else // _WIN32 +// Non-windows: use visibility attributes. +#define MLIR_CRUNNERUTILS_EXPORT __attribute__((visibility("default"))) +#define MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS +#endif // _WIN32 + +#include +#include +#include +#include +#include + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for Vector type. +//===----------------------------------------------------------------------===// +namespace mlir { +namespace detail { + +constexpr bool isPowerOf2(int n) { return (!(n & (n - 1))); } + +constexpr unsigned nextPowerOf2(int n) { + return (n <= 1) ? 1 : (isPowerOf2(n) ? n : (2 * nextPowerOf2((n + 1) / 2))); +} + +template +struct Vector1D; + +template +struct Vector1D { + Vector1D() { + static_assert(detail::nextPowerOf2(sizeof(T[Dim])) == sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; +}; + +// 1-D vector, padded to the next power of 2 allocation. +// Specialization occurs to avoid zero size arrays (which fail in -Werror). +template +struct Vector1D { + Vector1D() { + static_assert(nextPowerOf2(sizeof(T[Dim])) > sizeof(T[Dim]), "size error"); + static_assert(nextPowerOf2(sizeof(T[Dim])) < 2 * sizeof(T[Dim]), + "size error"); + } + inline T &operator[](unsigned i) { return vector[i]; } + inline const T &operator[](unsigned i) const { return vector[i]; } + +private: + T vector[Dim]; + char padding[nextPowerOf2(sizeof(T[Dim])) - sizeof(T[Dim])]; +}; +} // namespace detail +} // namespace mlir + +// N-D vectors recurse down to 1-D. +template +struct Vector { + inline Vector &operator[](unsigned i) { return vector[i]; } + inline const Vector &operator[](unsigned i) const { + return vector[i]; + } + +private: + Vector vector[Dim]; +}; + +// 1-D vectors in LLVM are automatically padded to the next power of 2. +// We insert explicit padding in to account for this. +template +struct Vector + : public mlir::detail::Vector1D { +}; + +template +using Vector1D = Vector; +template +using Vector2D = Vector; +template +using Vector3D = Vector; +template +using Vector4D = Vector; + +template +void dropFront(int64_t arr[N], int64_t *res) { + for (unsigned i = 1; i < N; ++i) + *(res + i - 1) = arr[i]; +} + +//===----------------------------------------------------------------------===// +// Codegen-compatible structures for StridedMemRef type. +//===----------------------------------------------------------------------===// +template +class StridedMemrefIterator; + +/// StridedMemRef descriptor type with static rank. +template +struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[N]; + int64_t strides[N]; + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == N && + "indices should match rank in memref subscript"); + int64_t curOffset = offset; + for (int dim = N - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + StridedMemRefType operator[](int64_t idx) { + StridedMemRefType res; + res.basePtr = basePtr; + res.data = data; + res.offset = offset + idx * strides[0]; + dropFront(sizes, res.sizes); + dropFront(strides, res.strides); + return res; + } +}; + +/// StridedMemRef descriptor type specialized for rank 1. +template +struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + int64_t sizes[1]; + int64_t strides[1]; + + template ().begin())> + T &operator[](Range indices) { + assert(indices.size() == 1 && + "indices should match rank in memref subscript"); + return (*this)[*indices.begin()]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, -1}; } + + T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } +}; + +/// StridedMemRef descriptor type specialized for rank 0. +template +struct StridedMemRefType { + T *basePtr; + T *data; + int64_t offset; + + template ().begin())> + T &operator[](Range indices) { + assert((indices.size() == 0) && + "Expect empty indices for 0-rank memref subscript"); + return data[offset]; + } + + StridedMemrefIterator begin() { return {*this, offset}; } + StridedMemrefIterator end() { return {*this, offset + 1}; } +}; + +/// Iterate over all elements in a strided memref. +template +class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, + int64_t offset = 0) + : offset(offset), descriptor(&descriptor) {} + StridedMemrefIterator &operator++() { + int dim = Rank - 1; + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + if (dim < 0) { + offset = -1; + return *this; + } + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::array &getIndices() { return indices; } + + bool operator==(const StridedMemrefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::array indices = {}; + + /// Descriptor for the strided memref. + StridedMemRefType *descriptor; +}; + +/// Iterate over all elements in a 0-ranked strided memref. +template +class StridedMemrefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + StridedMemrefIterator(StridedMemRefType &descriptor, int64_t offset = 0) + : elt(descriptor.data + offset) {} + + StridedMemrefIterator &operator++() { + ++elt; + return *this; + } + + reference operator*() { return *elt; } + pointer operator->() { return elt; } + + // There are no indices for a 0-ranked memref, but this API is provided for + // consistency with the general case. + const std::array &getIndices() { + // Since this is a 0-array of indices we can keep a single global const + // copy. + static const std::array indices = {}; + return indices; + } + + bool operator==(const StridedMemrefIterator &other) const { + return other.elt == elt; + } + + bool operator!=(const StridedMemrefIterator &other) const { + return !(*this == other); + } + +private: + /// Pointer to the single element in the zero-ranked memref. + T *elt; +}; + +//===----------------------------------------------------------------------===// +// Codegen-compatible structure for UnrankedMemRef type. +//===----------------------------------------------------------------------===// +// Unranked MemRef +template +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + +//===----------------------------------------------------------------------===// +// DynamicMemRefType type. +//===----------------------------------------------------------------------===// +template +class DynamicMemRefIterator; + +// A reference to one of the StridedMemRef types. +template +class DynamicMemRefType { +public: + int64_t rank; + T *basePtr; + T *data; + int64_t offset; + const int64_t *sizes; + const int64_t *strides; + + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(0), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(nullptr), strides(nullptr) {} + template + explicit DynamicMemRefType(const StridedMemRefType &memRef) + : rank(N), basePtr(memRef.basePtr), data(memRef.data), + offset(memRef.offset), sizes(memRef.sizes), strides(memRef.strides) {} + explicit DynamicMemRefType(const ::UnrankedMemRefType &memRef) + : rank(memRef.rank) { + auto *desc = static_cast *>(memRef.descriptor); + basePtr = desc->basePtr; + data = desc->data; + offset = desc->offset; + sizes = rank == 0 ? nullptr : desc->sizes; + strides = sizes + rank; + } + + template ().begin())> + T &operator[](Range &&indices) { + assert(indices.size() == rank && + "indices should match rank in memref subscript"); + if (rank == 0) + return data[offset]; + + int64_t curOffset = offset; + for (int dim = rank - 1; dim >= 0; --dim) { + int64_t currentIndex = *(indices.begin() + dim); + assert(currentIndex < sizes[dim] && "Index overflow"); + curOffset += currentIndex * strides[dim]; + } + return data[curOffset]; + } + + DynamicMemRefIterator begin() { return {*this, offset}; } + DynamicMemRefIterator end() { return {*this, -1}; } + + // This operator[] is extremely slow and only for sugaring purposes. + DynamicMemRefType operator[](int64_t idx) { + assert(rank > 0 && "can't make a subscript of a zero ranked array"); + + DynamicMemRefType res(*this); + --res.rank; + res.offset += idx * res.strides[0]; + ++res.sizes; + ++res.strides; + return res; + } + + // This operator* can be used in conjunction with the previous operator[] in + // order to access the underlying value in case of zero-ranked memref. + T &operator*() { + assert(rank == 0 && "not a zero-ranked memRef"); + return data[offset]; + } +}; + +/// Iterate over all elements in a dynamic memref. +template +class DynamicMemRefIterator { +public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T *; + using reference = T &; + + DynamicMemRefIterator(DynamicMemRefType &descriptor, int64_t offset = 0) + : offset(offset), descriptor(&descriptor) { + indices.resize(descriptor.rank, 0); + } + + DynamicMemRefIterator &operator++() { + if (descriptor->rank == 0) { + offset = -1; + return *this; + } + + int dim = descriptor->rank - 1; + + while (dim >= 0 && indices[dim] == (descriptor->sizes[dim] - 1)) { + offset -= indices[dim] * descriptor->strides[dim]; + indices[dim] = 0; + --dim; + } + + if (dim < 0) { + offset = -1; + return *this; + } + + ++indices[dim]; + offset += descriptor->strides[dim]; + return *this; + } + + reference operator*() { return descriptor->data[offset]; } + pointer operator->() { return &descriptor->data[offset]; } + + const std::vector &getIndices() { return indices; } + + bool operator==(const DynamicMemRefIterator &other) const { + return other.offset == offset && other.descriptor == descriptor; + } + + bool operator!=(const DynamicMemRefIterator &other) const { + return !(*this == other); + } + +private: + /// Offset in the buffer. This can be derived from the indices and the + /// descriptor. + int64_t offset = 0; + + /// Array of indices in the multi-dimensional memref. + std::vector indices = {}; + + /// Descriptor for the dynamic memref. + DynamicMemRefType *descriptor; +}; + +//===----------------------------------------------------------------------===// +// Small runtime support library for memref.copy lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +memrefCopy(int64_t elemSize, ::UnrankedMemRefType *src, + ::UnrankedMemRefType *dst); + +//===----------------------------------------------------------------------===// +// Small runtime support library for vector.print lowering during codegen. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); + +//===----------------------------------------------------------------------===// +// Small runtime support library for timing execution and printing GFLOPS +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops); +extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock(); + +//===----------------------------------------------------------------------===// +// Runtime support library for random number generation. +//===----------------------------------------------------------------------===// +// Uses a seed to initialize a random generator and returns the generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void *rtsrand(uint64_t s); +// Returns a random number in the range of [0, m). +extern "C" MLIR_CRUNNERUTILS_EXPORT uint64_t rtrand(void *, uint64_t m); +// Deletes the random number generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void rtdrand(void *); + +//===----------------------------------------------------------------------===// +// Runtime support library to allow the use of std::sort in MLIR program. +//===----------------------------------------------------------------------===// +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortI64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF64(uint64_t n, StridedMemRefType *vref); +extern "C" MLIR_CRUNNERUTILS_EXPORT void +_mlir_ciface_stdSortF32(uint64_t n, StridedMemRefType *vref); +#endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/python/ExecutionEngine/Msan.h b/python/ExecutionEngine/Msan.h new file mode 100644 index 00000000..ee94660a --- /dev/null +++ b/python/ExecutionEngine/Msan.h @@ -0,0 +1,35 @@ +//===- Msan.h - Utils related to the memory sanitizer ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares and defines macros related to msan. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_MSAN_H +#define MLIR_EXECUTIONENGINE_MSAN_H + +// Memory sanitizer currently can't be enabled for the jit-compiled code, and +// to suppress msan warnings we need to unpoison pointers and pointed-to +// datastructures before they can be accessed. + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_feature(memory_sanitizer) && !defined(MLIR_MEMORY_SANITIZER) +#define MLIR_MEMORY_SANITIZER +#endif + +#if defined(MLIR_MEMORY_SANITIZER) +#include +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) __msan_unpoison((p), (s)) +#else // Memory sanitizer: OFF +#define MLIR_MSAN_MEMORY_IS_INITIALIZED(p, s) +#endif // MLIR_MEMORY_SANITIZER + +#endif // MLIR_EXECUTIONENGINE_MSAN_H diff --git a/python/ExecutionEngine/version.txt b/python/ExecutionEngine/version.txt new file mode 100644 index 00000000..c3f15e55 --- /dev/null +++ b/python/ExecutionEngine/version.txt @@ -0,0 +1 @@ +https://github.com/llvm/llvm-project/commit/3be3883e6d67bf908fd12b51219075293ebb3dff diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 00000000..9c38b4b3 --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,405 @@ +import functools +import os +import sysconfig +import subprocess +import tempfile +from pathlib import Path + +from triton.common.backend import BaseBackend, register_backend +from triton.compiler.make_launcher import make_so_cache_key +from triton.runtime.cache import get_cache_manager +from triton.runtime.jit import version_key + + +def _get_triton_shared_opt_path() -> str: + path = os.getenv("TRITON_SHARED_OPT_PATH", "") + if path == "": + assert Exception("TRITON_SHARED_OPT_PATH is not set.") + return path + + +def _get_llvm_bin_path(bin_name: str) -> str: + path = os.getenv("LLVM_BINARY_DIR", "") + if path == "": + raise Exception("LLVM_BINARY_DIR is not set.") + return f"{path}/{bin_name}" + + +def _ttir_to_ttsharedir(mod): + # Get Triton-MLIR as string + ttir_code = str(mod) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "tt.mlir") + dst_path = os.path.join(tmpdir, "ttshared.mlir") + Path(src_path).write_text(ttir_code) + triton_shared_opt_path = _get_triton_shared_opt_path() + ret = subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path]) + assert ret == 0 + return Path(dst_path).read_text() + + +def _optimize_ttsharedir(ttsharedir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return ttsharedir + + +def _ttsharedir_to_llir(ttsharedir: str): + with tempfile.TemporaryDirectory() as tmpdir: + ttshared_path = os.path.join(tmpdir, "ttshared.mlir") + llmlir_path = os.path.join(tmpdir, "ll.mlir") + llir_path = os.path.join(tmpdir, "ll.ir") + Path(ttshared_path).write_text(ttsharedir) + mlir_opt_path = _get_llvm_bin_path("mlir-opt") + # TritonShared-MLIR to LLVM-MLIR + ret = subprocess.check_call([mlir_opt_path, ttshared_path, + "--convert-linalg-to-affine-loops", + "--eliminate-empty-tensors", + "--empty-tensor-to-alloc-tensor", + "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--lower-affine", + "--convert-linalg-to-loops", + "--convert-scf-to-cf", + "--convert-cf-to-llvm", + "--convert-arith-to-llvm", + "--convert-math-to-llvm", + "--convert-complex-to-llvm", + "--convert-vector-to-llvm", + "--convert-index-to-llvm", + "--memref-expand", + "--expand-strided-metadata", + "--finalize-memref-to-llvm", + "--convert-func-to-llvm", + "--reconcile-unrealized-casts", + "-o", + llmlir_path]) + assert ret == 0 + + # LLVM-MLIR to LLVM-IR + mlir_translate_path = _get_llvm_bin_path("mlir-translate") + ret = subprocess.check_call([mlir_translate_path, llmlir_path, + "--mlir-to-llvmir", + "-o", + llir_path]) + assert ret == 0 + return Path(llir_path).read_text() + + +def _optimize_llir(llir: str): + # We don't apply any optimizations now, but we can add passes if needed. + return llir + + +def _llir_to_bin(llir: str): + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.ll") + dst_path = os.path.join(tmpdir, "kernel.o") + with open(src_path, "w") as f: + f.write(llir) + llc_path = _get_llvm_bin_path("llc") + ret = subprocess.check_call([llc_path, src_path, "-o", dst_path]) + assert ret == 0 + # Actually it's text-format assembly. Use read_text(). + return Path(dst_path).read_text() + + +def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + +def _extracted_ty(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + +def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + +def _generate_launcher(constants, signature, kernel_name): + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + format = "iiiOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) + return f""" +#include +#include +#include +#include "CRunnerUtils.h" +#include "CRunnerUtils.cpp" + +extern "C" {{ + // Pointer type (=Memref) becomes int64_t + MemRef struct + // FIXME: understand what this int64_t is used for. + void {kernel_name}({', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if i not in constants)}, + int, int, int, int, int, int); +}} + +static void _launch(int gridX, int gridY, int gridZ, {arg_decls}) {{ + if (gridX*gridY*gridZ > 0) {{ + // Cast "function" to the real function type. + for(int x = 0; x < gridX; x++) {{ + for(int y = 0; y < gridY; y++) {{ + for(int z = 0; z < gridZ; z++) {{ + // Use some random type "char" here. + {' '.join(f'StridedMemRefType ptr_arg{i} = {{static_cast(arg{i}), static_cast(arg{i}), 0}};' for i, ty in signature.items() if i not in constants and ty[0] == "*")} + {kernel_name}({', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if i not in constants)}, + gridX, gridY, gridZ, x, y, z); + }} + }} + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + {' '.join([f"{_extracted_ty(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &launch_enter_hook, &launch_exit_hook, &compiled_kernel + {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''})) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + if (PyErr_Occurred()) {{ + return NULL; + }} + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_shared_ref_cpu_kernel_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_shared_ref_cpu_kernel_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + + +class TritonSharedRefCPUBackend(BaseBackend): + stub_so_path = "" + + def __init__(self, device_type: str) -> None: + super(TritonSharedRefCPUBackend, self).__init__(device_type) + + def add_stages(self, arch, extern_libs, stages): + filter_in_stages = ["ast", "ttir"] + filter_out_stages = [] + for key, _ in stages.items(): + if key not in filter_in_stages: + filter_out_stages.append(key) + for filter_out_key in filter_out_stages: + stages.pop(filter_out_key) + + stages["ttsharedir"] = (lambda path: Path(path).read_text(), + lambda src: _optimize_ttsharedir(_ttir_to_ttsharedir(src))) + stages["llir"] = (lambda path: Path(path).read_text(), + lambda src: _optimize_llir(_ttsharedir_to_llir(src))) + stages["cpuasm"] = (lambda path: Path(path).read_text(), + lambda src: _llir_to_bin(src)) + + def add_meta_info(self, ir, module, next_module, metadata, asm): + metadata["shared"] = 1 + if ir == "llir": + # We can get a function name (C naming) from + # LLVM-IR by getting the first "define void @". + metadata["name"] = asm["llir"].split("define void @")[1].split("(")[0].strip() + + def get_driver(self): + return None + + def get_stream(self, idx=None) -> int: + # Returns int to make Triton happy. + return 0 + + @functools.lru_cache(None) + def get_device_properties(self, device): + # CPU has no property. Return some values to make the Triton runtime happy. + return {"max_shared_mem": 2 ** 20} + + def get_current_device(self): + # CPU doesn't have a device to return. Return something. + return "cpu" + + def set_current_device(self, device): + # CPU doesn't have a device to set + assert device == "cpu" + return + + def get_load_binary_fn(self): + def _load_binary_fn(kernel_name, binary, shared_size, device): + # Returns mod, func, n_regs, n_spills, but this implementation does not use it. + # Note: func is a function pointer. + return None, 0, 0, 0 + return _load_binary_fn + + def get_kernel_bin(self): + return "cpuasm" + + def get_architecture_descriptor(self, **kwargs): + # CPU does not have the following parameters, but we need to pass some values to + # make the Triton runtime happy. + return {"num_warps": 1, "num_stages": 1} + + def make_launcher_stub(self, name, signature, constants, ids): + # name of files that are cached + so_cache_key = make_so_cache_key(version_key(), signature, constants, ids) + so_cache_manager = get_cache_manager(so_cache_key) + so_name = f"{name}.py" + # retrieve stub from cache if it exists + cache_path = so_cache_manager.get_file(so_name) + if cache_path is None: + kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER" + with tempfile.TemporaryDirectory() as tmpdir: + # Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name + # in the following launch function. + launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name) + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + dst_path = os.path.join(tmpdir, so_name) + py_src = f""" +import os, subprocess, tempfile +import importlib.util +from pathlib import Path + +def launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDim0, clusterDim1, clusterDim2, + shared, stream, cu_function, launch_enter_hook, launch_exit_hook, compiled_kernel, + *args): + # Unlike CUDA/HIP, we cannot easily pass function pointer across different pybind libraries. + # Let's compile a kernel every time. + asm_src = compiled_kernel.asm["{self.get_kernel_bin()}"] + launcher_src = ''' +{launcher_src} +'''.replace("{kernel_placeholder_name}", compiled_kernel.metadata["name"]) + with tempfile.TemporaryDirectory() as tmpdir: + asm_src_path = os.path.join(tmpdir, "kernel.s") + launcher_src_path = os.path.join(tmpdir, "main.cxx") + so_path = os.path.join(tmpdir, "kernel.so") + Path(asm_src_path).write_text(asm_src) + Path(launcher_src_path).write_text(launcher_src) + # Compile it together. + ret = subprocess.check_call(["g++", launcher_src_path, asm_src_path, f"-I{py_include_dir}", f"-I{Path(__file__).resolve().parent}", "-shared", "-fPIC", "-o", so_path]) + if ret != 0: + raise AssertionError("Kernel compilation failed.") + + # Load and launch the compiled kernel. + spec = importlib.util.spec_from_file_location("__triton_shared_ref_cpu_kernel_launcher", so_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod.launch(gridX, gridY, gridZ, launch_enter_hook, launch_exit_hook, compiled_kernel, *args) +""" + Path(dst_path).write_text(py_src) + with open(dst_path, "rb") as f: + return so_cache_manager.put(f.read(), so_name, binary=True) + else: + return cache_path + + +register_backend("cpu", TritonSharedRefCPUBackend) diff --git a/python/examples/reduce.py b/python/examples/reduce.py new file mode 100644 index 00000000..795f0b51 --- /dev/null +++ b/python/examples/reduce.py @@ -0,0 +1,38 @@ + +import torch + +import triton +import triton.language as tl + +@triton.jit +def reduce_kernel_2d( + x_ptr, + output_ptr, + stride, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + x = tl.load(tl.make_block_ptr(base=x_ptr, shape=[n_elements * tl.num_programs(0)], + strides=[1], offsets=[stride * pid0], + block_shape=[BLOCK_SIZE], order=[0]), boundary_check=[0]) + output = triton.language.sum(x, axis=0).to(dtype=x.dtype) + tl.store(output_ptr + pid0, output) + +n_rows = 16 +n_cols = 32 +x = torch.rand([n_cols, n_rows], device="cpu", dtype=torch.float32) +output = torch.empty([n_cols], device=x.device, dtype=x.dtype) +BLOCK_SIZE = n_rows +grid = lambda meta: (n_cols, ) + +reduce_kernel_2d[grid](x, output, x.stride(0), n_rows, BLOCK_SIZE=BLOCK_SIZE) +ans = torch.sum(x, dim=1) +torch.testing.assert_close(output, ans, rtol=0.001, atol=1e-5) +print("Pass!") + +ret = triton.compile(reduce_kernel_2d, signature="*fp32,*fp32,i32,i32", constants={"BLOCK_SIZE": 32}, device_type="cpu") +print(ret.asm["ttir"]) +print(ret.asm["ttsharedir"]) +print(ret.asm["llir"]) +print(ret.asm["cpuasm"])