diff --git a/backend/compiler.py b/backend/compiler.py index 62dd4d59..8a85d46c 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -10,6 +10,7 @@ import functools from pathlib import Path + def _get_triton_shared_opt_path() -> str: path = os.getenv("TRITON_SHARED_OPT_PATH", "") if path == "": @@ -17,6 +18,13 @@ def _get_triton_shared_opt_path() -> str: return path +# because of the way triton loads backends, this function is duplicated +# in compiler and driver +def _get_triton_shared_use_openblas() -> bool: + use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "") + return use_blas != "" + + def _get_llvm_bin_path(bin_name: str) -> str: path = os.getenv("LLVM_BINARY_DIR", "") if path == "": @@ -32,7 +40,10 @@ def _ttir_to_ttsharedir(mod): dst_path = os.path.join(tmpdir, "ttshared.mlir") Path(src_path).write_text(ttir_code) triton_shared_opt_path = _get_triton_shared_opt_path() - subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental", "-o", dst_path]) + extra_pass = ["--triton-to-linear-algebra-subprograms"] if _get_triton_shared_use_openblas() else [] + subprocess.check_call([triton_shared_opt_path, src_path] + extra_pass + \ + ["--triton-to-linalg-experimental", + "-o", dst_path]) return Path(dst_path).read_text() diff --git a/backend/driver.py b/backend/driver.py index 81a0b905..853543d5 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -12,6 +12,14 @@ from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget + +# because of the way triton loads backends, this function is duplicated +# in compiler and driver +def _get_triton_shared_use_openblas() -> bool: + use_blas = os.getenv("TRITON_SHARED_USE_OPENBLAS", "") + return use_blas != "" + + # -------------------- Launcher ---------------------------- def _ty_to_cpp(ty): if ty[0] == '*': @@ -250,11 +258,12 @@ def launch( so_path = os.path.join(tmpdir, "kernel.so") Path(asm_src_path).write_bytes(asm_src) Path(launcher_src_path).write_text(src) + extra_lib = ["-lopenblas"] if _get_triton_shared_use_openblas() else [] # Compile it together. subprocess.check_call([ "g++", launcher_src_path, asm_src_path, f"-I{py_include_dir}", f"-I{include_dir}", - "-shared", "-fPIC", "-o", so_path + "-shared", "-fPIC"] + extra_lib + ["-o", so_path ]) with open(so_path, "rb") as f: diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index 45da8aca..95f8f173 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(TritonToLinearAlgebraSubprograms) diff --git a/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt new file mode 100644 index 00000000..dd67aba4 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinearAlgebraSubprograms) +add_public_tablegen_target(TritonToLinearAlgebraSubprogramsConversionPassIncGen) diff --git a/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h new file mode 100644 index 00000000..cceef048 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H +#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.td b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.td new file mode 100644 index 00000000..724ccd3c --- /dev/null +++ b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.td @@ -0,0 +1,10 @@ +#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES +#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToLinearAlgebraSubprograms : Pass<"triton-to-linear-algebra-subprograms", "mlir::ModuleOp"> { + let summary = "Convert Linalg operations to library calls"; +} + +#endif diff --git a/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h new file mode 100644 index 00000000..c2377b29 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h @@ -0,0 +1,25 @@ +#ifndef TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H +#define TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc" + +void populateTritonToLinearAlgebraSubprogramsConversionPatterns(bool pidsToFuncArgs, + bool addptrToLinalg, + bool assertToCf, + RewritePatternSet &patterns); + +std::unique_ptr> createTritonToLinearAlgebraSubprogramsPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_TO_LINEAR_ALGEBRA_SUBPROGRAMS_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 45da8aca..95f8f173 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(TritonToLinearAlgebraSubprograms) diff --git a/lib/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt b/lib/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt new file mode 100644 index 00000000..6f917ef0 --- /dev/null +++ b/lib/Conversion/TritonToLinearAlgebraSubprograms/CMakeLists.txt @@ -0,0 +1,21 @@ +add_triton_library(TritonToLinearAlgebraSubprograms + TritonToLinearAlgebraSubprogramsPass.cpp + + DEPENDS + TritonToLinearAlgebraSubprogramsConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRLinalgTransforms + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonIR + TritonTransforms + TritonTilingExtIR + TritonStructuredIR +) diff --git a/lib/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprogramsPass.cpp b/lib/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprogramsPass.cpp new file mode 100644 index 00000000..f4676389 --- /dev/null +++ b/lib/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprogramsPass.cpp @@ -0,0 +1,173 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/TritonToLinearAlgebraSubprograms.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "linalg-to-las" + +using namespace mlir; +using namespace triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_TRITONTOLINEARALGEBRASUBPROGRAMS +#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +struct MatmulConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value A = op.getA(); + Value B = op.getB(); + Value C = op.getC(); + + auto tensorA = cast(A.getType()); + auto tensorB = cast(B.getType()); + auto tensorC = cast(C.getType()); + + if (tensorA.getElementType() != tensorB.getElementType() || + tensorC.getElementType() != tensorB.getElementType()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot replace, different element types\n"); + return failure(); + } + + if (!tensorA.getElementType().isF32() && !tensorA.getElementType().isF64()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot replace, unsupported type\n"); + return failure(); + } + + auto floatType = tensorA.getElementType(); + + // since tensors are immutable, we need to allocate a buffer for the result + Value memrefConst = rewriter.create(loc, MemRefType::get(tensorC.getShape(), tensorC.getElementType()), C); + auto memrefType = MemRefType::get(tensorC.getShape(), floatType); + Value memrefC = rewriter.create(loc, memrefType); + auto copyOp = rewriter.create(loc, ValueRange{memrefConst}, ValueRange{memrefC}); + + ModuleOp module = op->getParentOfType(); + + auto intType = rewriter.getI32Type(); + auto int64Type = rewriter.getI64Type(); + auto ptrType = LLVM::LLVMPointerType::get(op.getContext(), 0); // default address space + + auto funcType = FunctionType::get(op.getContext(), + {intType, intType, intType, intType, intType, intType, floatType, + ptrType, intType, ptrType, intType, floatType, + ptrType, intType}, {}); + + bool usingF64 = floatType.isF64(); + const char *funcName = usingF64 ? "cblas_dgemm" : "cblas_sgemm"; + auto func = module.lookupSymbol(funcName); + if (!func) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + func = rewriter.create(loc, funcName, funcType); + func.setVisibility(SymbolTable::Visibility::Private); + } + + auto memrefToPointer = [&rewriter, &loc, &int64Type, &ptrType](Value &memref) { + auto indexPtr = rewriter.create(loc, memref); + auto castOp = rewriter.create(loc, int64Type, indexPtr); + return rewriter.create(loc, ptrType, castOp); + }; + + auto tensorToPointer = [&rewriter, &loc, &memrefToPointer](Value &V, RankedTensorType &T) { + Value memref = rewriter.create(loc, MemRefType::get(T.getShape(), T.getElementType()), V); + return memrefToPointer(memref); + }; + + Value ptrA = tensorToPointer(A, tensorA); + Value ptrB = tensorToPointer(B, tensorB); + Value ptrC = memrefToPointer(memrefC); + + int32_t M = tensorA.getShape()[0]; + int32_t K = tensorA.getShape()[1]; + int32_t N = tensorB.getShape()[1]; + + Value alpha = rewriter.create(loc, floatType, usingF64 ? rewriter.getF64FloatAttr(1.0) : rewriter.getF32FloatAttr(1.0)); + Value beta = alpha; + + auto constOp = [&rewriter, &loc, &intType](int32_t V) { + return rewriter.create(loc, intType, rewriter.getI32IntegerAttr(V)); + }; + Value CblasRowMajor = constOp(101), CblasNoTrans = constOp(111); + Value MVal = constOp(M), NVal = constOp(N), KVal = constOp(K); + Value LDA = KVal, LDB = NVal, LDC = NVal; + + auto funcOp = rewriter.create(loc, func, ValueRange{ + CblasRowMajor, CblasNoTrans, CblasNoTrans, + MVal, NVal, KVal, + alpha, ptrA, LDA, + ptrB, LDB, beta, + ptrC, LDC + }); + + auto toTensorOp = rewriter.create(loc, + tensorC, memrefC, true /* restrict */, true /* writable */); + rewriter.replaceOp(op, toTensorOp); + return success(); + } +}; + +class TritonToLinearAlgebraSubprogramsPass + : public triton::impl::TritonToLinearAlgebraSubprogramsBase { + using TritonToLinearAlgebraSubprogramsBase< + TritonToLinearAlgebraSubprogramsPass>::TritonToLinearAlgebraSubprogramsBase; + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + patterns.add(patterns.getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, memref::MemRefDialect, LLVM::LLVMDialect>(); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +triton::createTritonToLinearAlgebraSubprogramsPass() { + return std::make_unique(); +} diff --git a/python/examples/bare_matmul.py b/python/examples/bare_matmul.py new file mode 100644 index 00000000..c3902fc6 --- /dev/null +++ b/python/examples/bare_matmul.py @@ -0,0 +1,39 @@ +import torch + +import triton +import triton.language as tl +import benchmark + + +@triton.jit +def bare_matmul(X, Y, Z, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + offs_x = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_y = tl.arange(0, BLOCK_SIZE) + + x = tl.load(X + offs_x[:, None]) + y = tl.load(Y + offs_y[None, :]) + + z = tl.dot(x, y) + tl.store(Z + offs_x[:, None] + offs_y[None, :], z) + + +@benchmark.measure() +def bench_matmul(M, N, K, provider): + device = 'cpu' + dtype = torch.float32 + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((K, N), device=device, dtype=dtype) + c = torch.empty((K, N), device=device, dtype=dtype) + if provider == 'torch': + torch.matmul(a, b) + if provider == 'triton': + bare_matmul[(1,)](a, b, c, N) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [2**i for i in range(7, 11, 1)]: + for provider in ['torch', 'triton']: + bench_matmul(X, X, X, provider) diff --git a/python/examples/benchmark.py b/python/examples/benchmark.py new file mode 100644 index 00000000..7c900906 --- /dev/null +++ b/python/examples/benchmark.py @@ -0,0 +1,50 @@ +import time +import numpy as np +from functools import wraps +import triton +from triton.backends.triton_shared.driver import CPUDriver + + +def select_cpu_backend(): + triton.runtime.driver.set_active(CPUDriver()) + + +# Unfortunately, we can't use triton.testing.perf_report and triton.testing.do_bench for CPU backend because +# they are very specific to cuda + +def measure(repeats=20, percentiles=(20, 50, 90)): + """ + Decorator to benchmark a function. + + Parameters: + - repeats: int, the number of times the function should be executed for each set of parameters. + - percentiles: tuple, the percentiles to compute on the execution times. + + Returns: + - A decorated function that prints the average execution time and the requested percentiles. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + print(f"{func.__name__}{args} {kwargs}, {repeats} times, all results in seconds") + times = [] + for _ in range(repeats): + start_time = time.perf_counter() + result = func(*args, **kwargs) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + times.append(elapsed_time) + + average_time = np.mean(times) + min_time = np.min(times) + max_time = np.max(times) + computed_percentiles = np.percentile(times, percentiles) + + print(f"Avg={average_time:.6f}, min={min_time:.6f},", end=" ") + for p, value in zip(percentiles, computed_percentiles): + print(f"{p}pp={value:.6f},", end=" ") + print(f"max={max_time:.6f}") + + return result + return wrapper + return decorator \ No newline at end of file diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py index 2281f7d4..470006a5 100644 --- a/python/examples/test_matmul.py +++ b/python/examples/test_matmul.py @@ -2,7 +2,7 @@ import triton import triton.language as tl - +import benchmark # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # - A list of `triton.Config` objects that define different configurations of @@ -155,3 +155,20 @@ def test_matmul(device): triton_output = matmul(a, b) torch_output = torch.matmul(a, b) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) + + +@benchmark.measure() +def bench_matmul(M, N, K, provider): + a = torch.randn((M, K), device='cpu', dtype=torch.float32) + b = torch.randn((K, N), device='cpu', dtype=torch.float32) + if provider == 'torch': + torch.matmul(a, b) + if provider == 'triton': + matmul(a, b) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [128 * i for i in range(2, 7)]: + for provider in ['torch', 'triton']: + bench_matmul(X, X, X, provider) diff --git a/python/examples/test_vec_add.py b/python/examples/test_vec_add.py index 4dc43139..db2fa098 100644 --- a/python/examples/test_vec_add.py +++ b/python/examples/test_vec_add.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +import benchmark @triton.jit @@ -66,3 +67,19 @@ def test(device): f"The maximum difference between torch and triton is " f"{torch.max(torch.abs(output_torch - output_triton))}" ) + +@benchmark.measure() +def bench_vecadd(size, provider): + a = torch.rand(size, device='cpu', dtype=torch.float32) + b = torch.rand(size, device='cpu', dtype=torch.float32) + if provider == 'torch': + a + b + if provider == 'triton': + add(a, b) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [2**i for i in range(22, 25, 1)]: + for provider in ['torch', 'triton']: + bench_vecadd(X, provider) \ No newline at end of file diff --git a/tools/RegisterTritonSharedDialects.h b/tools/RegisterTritonSharedDialects.h index 82ba4f39..653b32b4 100644 --- a/tools/RegisterTritonSharedDialects.h +++ b/tools/RegisterTritonSharedDialects.h @@ -18,6 +18,7 @@ #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" +#include "triton-shared/Conversion/TritonToLinearAlgebraSubprograms/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" @@ -44,6 +45,7 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestMembarPass(); mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPass(); + mlir::triton::registerTritonToLinearAlgebraSubprogramsPass(); mlir::triton::registerTritonToStructuredPass(); mlir::triton::registerTritonArithToLinalgPasses(); mlir::triton::registerConvertTritonToTritonGPUPass();