From 19e23e9eee0f4b4171ddb5e07123a8ab974c6b14 Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 5 Mar 2024 18:06:07 -0500 Subject: [PATCH 1/7] Inital SME opt push --- backend/compiler.py | 32 +++++- backend/driver.py | 6 +- tools/CMakeLists.txt | 1 + tools/triton-sme-opt/CMakeLists.txt | 19 ++++ tools/triton-sme-opt/triton-sme-opt.cpp | 142 ++++++++++++++++++++++++ 5 files changed, 192 insertions(+), 8 deletions(-) create mode 100644 tools/triton-sme-opt/CMakeLists.txt create mode 100644 tools/triton-sme-opt/triton-sme-opt.cpp diff --git a/backend/compiler.py b/backend/compiler.py index 6341a6a1..5395982a 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -1,5 +1,5 @@ from triton.backends.compiler import BaseBackend -from triton._C.libtriton import ir, passes +from triton._C.libtriton import ir, passes, llvm, triton_shared from dataclasses import dataclass from typing import Any import hashlib @@ -10,6 +10,14 @@ import functools from pathlib import Path +def printc(obj, color="cyan"): #makes things easier to see, will remove later + color_code = { + "black": "30", "red": "31", "green": "32", "yellow": "33", + "blue": "34", "magenta": "35", "cyan": "36", "white": "37" + } + colored_text = f"\033[{color_code[color]}m{obj}\033[0m" if color in color_code else obj + print(colored_text) + def _get_triton_shared_opt_path() -> str: path = os.getenv("TRITON_SHARED_OPT_PATH", "") if path == "": @@ -17,6 +25,10 @@ def _get_triton_shared_opt_path() -> str: return path +def _get_triton_SME_path() -> str: + return os.getenv("TRITON_SME_PATH", "") + + def _get_llvm_bin_path(bin_name: str) -> str: path = os.getenv("LLVM_BINARY_DIR", "") if path == "": @@ -32,16 +44,25 @@ def _ttir_to_ttsharedir(mod): dst_path = os.path.join(tmpdir, "ttshared.mlir") Path(src_path).write_text(ttir_code) triton_shared_opt_path = _get_triton_shared_opt_path() - subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-structured", "--canonicalize", "--triton-arith-to-linalg", "--cse", "--structured-to-memref", "-o", dst_path]) + subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path]) return Path(dst_path).read_text() def _optimize_ttsharedir(ttsharedir: str): - # We don't apply any optimizations now, but we can add passes if needed. - return ttsharedir + if _get_triton_SME_path() == "": + return ttsharedir + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "ttshared.mlir") + dst_path = os.path.join(tmpdir, "ttsme.mlir") + Path(src_path).write_text(ttsharedir) + triton_shared_opt_path = _get_triton_SME_path() + subprocess.check_call([triton_shared_opt_path, src_path, "-sme-converison", "-o", dst_path]) + output= Path(dst_path).read_text() + printc(output) + return output -def _ttsharedir_to_llir(ttsharedir: str): +def _ttsharedir_to_llir(ttsharedir: str): #going to need to add some flags to this, recent changes to SME feature flags with tempfile.TemporaryDirectory() as tmpdir: ttshared_path = os.path.join(tmpdir, "ttshared.mlir") llmlir_path = os.path.join(tmpdir, "ll.mlir") @@ -151,6 +172,7 @@ def load_dialects(self, ctx): @staticmethod def make_ttir(mod, metadata, opt): + # assert False pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) diff --git a/backend/driver.py b/backend/driver.py index 920da7c7..b7dffde8 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -40,8 +40,8 @@ def _extracted_ty(ty): 'i64': 'int64_t', 'u32': 'uint32_t', 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', + 'fp16': 'float16_t', + 'bf16': 'bfloat16_t', 'fp32': 'float', 'f32': 'float', 'fp64': 'double', @@ -226,7 +226,7 @@ def launch( # Compile it together. subprocess.check_call([ "g++", launcher_src_path, asm_src_path, - f"-I{py_include_dir}", f"-I{include_dir}", + f"-I{py_include_dir}", f"-I{include_dir}" "-mfp16-format=ieee", "-shared", "-fPIC", "-o", so_path ]) diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 3cdf7432..d68f2854 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(triton-shared-opt) +add_subdirectory(triton-sme-opt) diff --git a/tools/triton-sme-opt/CMakeLists.txt b/tools/triton-sme-opt/CMakeLists.txt new file mode 100644 index 00000000..3ac771f3 --- /dev/null +++ b/tools/triton-sme-opt/CMakeLists.txt @@ -0,0 +1,19 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + +add_llvm_executable(triton-sme-opt triton-sme-opt.cpp PARTIAL_SOURCES_INTENDED) + +llvm_update_compile_flags(triton-sme-opt) +target_link_libraries(triton-sme-opt PRIVATE + + ${dialect_libs} + ${conversion_libs} + # tests + TritonTestAnalysis + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-sme-opt) diff --git a/tools/triton-sme-opt/triton-sme-opt.cpp b/tools/triton-sme-opt/triton-sme-opt.cpp new file mode 100644 index 00000000..ec984ef2 --- /dev/null +++ b/tools/triton-sme-opt/triton-sme-opt.cpp @@ -0,0 +1,142 @@ +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" +#include +#include "mlir/IR/Operation.h" +#include + + +using namespace mlir; + + +namespace { +struct OuterProductVectorizationPass : public PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + + // Step 4: Lower vector.multi_reduction to vector.contract (+ some helpful patterns) + vector::VectorTransformsOptions vectorTransformsOptions; + vectorTransformsOptions.setVectorTransformsOptions(vector::VectorContractLowering::OuterProduct); + vector::populateVectorTransferDropUnitDimsPatterns(patterns); + vector::populateVectorReductionToContractPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + // Step 5: Lower vector.contract to vector.outerproduct. Also drop unit dims. + patterns.clear(); + vectorTransformsOptions.setVectorTransformsOptions(vector::VectorContractLowering::OuterProduct); + vector::populateVectorTransferDropUnitDimsPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + + + struct MatmulTileConversion: public OpRewritePattern { + using OpRewritePattern ::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter & rewriter) const override { + + SmallVector tileSizes = {4, 4,1}; // Tile sizes for [M, N, K] dimensions tofo + + + linalg::LinalgTilingOptions tilingOptions = linalg::LinalgTilingOptions().setTileSizes(tileSizes); + auto tiledOpResult = tileLinalgOp(rewriter, op, tilingOptions); + if (failed(tiledOpResult)) { + std::cout << "TILING FAILED" << std::endl; + return failure(); + } + + if (failed(linalg::vectorize(rewriter, cast (tiledOpResult->op.getOperation())))) { + return failure(); + } + MLIRContext *context = getContext(); + + + rewriter.replaceOp(op, tiledOpResult->tensorResults); + return success(); + + } + }; + + class MatmulTileConversionPass + : public PassWrapper > { + public: void getDependentDialects(DialectRegistry & registry) const override { + registry.insert (); + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = &getContext(); + + + RewritePatternSet patterns(context); + patterns.add(context); + + ConversionTarget target( * context); + target.addLegalDialect (); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + signalPassFailure(); + } + } + }; + + std::unique_ptr createOuterProductVectorizationPass() { + return std::make_unique (); + } + std::unique_ptr createMatmulTileConversionPass() { + return std::make_unique (); + } +} + +int main(int argc, char ** argv) { + mlir::DialectRegistry registry; + + registry.insert(); + + + registerAllDialects(registry); + registerAllPasses(); + + MLIRContext context; + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + + PassPipelineRegistration<> pipeline( + "sme-converison", + "Converts linalg.matmul to a more optimized form", + [](OpPassManager & pm) { + pm.addPass(createMatmulTileConversionPass()); + pm.addPass(createOuterProductVectorizationPass()); + } + ); + + + return asMainReturnCode( + MlirOptMain(argc, argv, "Optimizer Driver\n", registry) + ); +} \ No newline at end of file From ccae7dfaeb19b058c98286e76d2ea24b7802aa45 Mon Sep 17 00:00:00 2001 From: Daniyal Khan Date: Tue, 5 Mar 2024 18:45:06 -0500 Subject: [PATCH 2/7] reverting half float changes --- backend/driver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/driver.py b/backend/driver.py index b7dffde8..920da7c7 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -40,8 +40,8 @@ def _extracted_ty(ty): 'i64': 'int64_t', 'u32': 'uint32_t', 'u64': 'uint64_t', - 'fp16': 'float16_t', - 'bf16': 'bfloat16_t', + 'fp16': 'float', + 'bf16': 'float', 'fp32': 'float', 'f32': 'float', 'fp64': 'double', @@ -226,7 +226,7 @@ def launch( # Compile it together. subprocess.check_call([ "g++", launcher_src_path, asm_src_path, - f"-I{py_include_dir}", f"-I{include_dir}" "-mfp16-format=ieee", + f"-I{py_include_dir}", f"-I{include_dir}", "-shared", "-fPIC", "-o", so_path ]) From bf95999f5e8a69f2d9b38806f73211db34e651f8 Mon Sep 17 00:00:00 2001 From: Daniyal Khan Date: Thu, 7 Mar 2024 20:14:28 -0500 Subject: [PATCH 3/7] updated Scalable vectorization --- backend/driver.py | 44 +++++++++++++++++++++---- tools/triton-sme-opt/triton-sme-opt.cpp | 27 +++++++++++++-- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/backend/driver.py b/backend/driver.py index 920da7c7..f58ca3b5 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -2,7 +2,7 @@ import tempfile import sysconfig -import os, subprocess, tempfile +import os, subprocess, tempfile, platform import importlib.util import sysconfig @@ -12,10 +12,32 @@ from triton.backends.driver import DriverBase +# -------------------- Launcher ---------------------------- + +def has_hf(): + resp = os.popen("lscpu").read() + if platform.machine() == "x86_64" and "f16c" in resp: + + return True + elif platform.machine() == "aarch64" and "fphp" in resp: + return True + return False + +def has_bf(): + resp = os.popen("lscpu").read() + if platform.machine() == "x86_64" and "avx512_bf16" in resp: + print("axv512_bf16") + return True + #TODO add aarch64 support for bf16 + return False + + + # -------------------- Launcher ---------------------------- def _ty_to_cpp(ty): if ty[0] == '*': return "void*" + return { "i1": "int32_t", "i8": "int8_t", @@ -24,12 +46,14 @@ def _ty_to_cpp(ty): "i64": "int64_t", "u32": "uint32_t", "u64": "uint64_t", - "fp16": "float", - "bf16": "float", + "fp16": "float16_t" if has_hf() else "float", + "bf16": "bfloat16_t" if has_bf() else "float16_t" if has_hf() else "float", "fp32": "float", "f32": "float", "fp64": "double", }[ty] + + def _extracted_ty(ty): if ty[0] == '*': @@ -40,15 +64,15 @@ def _extracted_ty(ty): 'i64': 'int64_t', 'u32': 'uint32_t', 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', + "fp16": "float16_t" if has_hf() else "float", + "bf16": "bfloat16_t" if has_bf() else "float16_t" if has_hf() else "float", 'fp32': 'float', 'f32': 'float', 'fp64': 'double', }[ty] def _format_of(ty): - return { + format_dict = { "PyObject*": "O", "float": "f", "double": "d", @@ -57,7 +81,12 @@ def _format_of(ty): "int32_t": "i", "uint64_t": "K", "int64_t": "L", - }[ty] + } + if has_bf(): + format_dict["bfloat16"] = "e" + if has_hf(): + format_dict["float16"] = "h" + return format_dict[ty] def _generate_launcher(constants, signature, kernel_name): arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) @@ -68,6 +97,7 @@ def _generate_launcher(constants, signature, kernel_name): #include #include "ExecutionEngine/CRunnerUtils.h" #include "ExecutionEngine/CRunnerUtils.cpp" +#include extern "C" {{ // Pointer type (=Memref) becomes int64_t + MemRef struct diff --git a/tools/triton-sme-opt/triton-sme-opt.cpp b/tools/triton-sme-opt/triton-sme-opt.cpp index ec984ef2..60ab10e6 100644 --- a/tools/triton-sme-opt/triton-sme-opt.cpp +++ b/tools/triton-sme-opt/triton-sme-opt.cpp @@ -51,16 +51,36 @@ struct OuterProductVectorizationPass : public PassWrapper { using OpRewritePattern ::OpRewritePattern; LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter & rewriter) const override { + + + linalg::LinalgTilingOptions tilingOptions; + + tilingOptions.setTileSizeComputationFunction([&](OpBuilder& b, Operation*) { + SmallVector sizes; + sizes.reserve(3); + Location loc = op.getLoc(); + + Value vscale = b.create(loc, b.getIndexType()); + + Value tileM = b.create(loc, 4); + Value tileMScaled = b.create(loc, tileM, vscale); + sizes.push_back(tileMScaled); + + Value tileN = b.create(loc, 4); + Value tileNScaled = b.create(loc, tileN, vscale); + sizes.push_back(tileNScaled); + + Value tileK = b.create(loc, 1); + sizes.push_back(tileK); - SmallVector tileSizes = {4, 4,1}; // Tile sizes for [M, N, K] dimensions tofo + return sizes; + }); - linalg::LinalgTilingOptions tilingOptions = linalg::LinalgTilingOptions().setTileSizes(tileSizes); auto tiledOpResult = tileLinalgOp(rewriter, op, tilingOptions); if (failed(tiledOpResult)) { std::cout << "TILING FAILED" << std::endl; @@ -68,6 +88,7 @@ struct OuterProductVectorizationPass : public PassWrapper (tiledOpResult->op.getOperation())))) { + std::cout << "Vectorization FAILED" << std::endl; return failure(); } MLIRContext *context = getContext(); From 08ac19a5761a90fa3955adeded37834bbcfcff47 Mon Sep 17 00:00:00 2001 From: green Date: Fri, 8 Mar 2024 12:16:18 -0500 Subject: [PATCH 4/7] fixed Vectorized Tiling --- backend/driver.py | 5 +- tools/triton-sme-opt/triton-sme-opt.cpp | 89 +++++++++++++------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/backend/driver.py b/backend/driver.py index f58ca3b5..e872193d 100644 --- a/backend/driver.py +++ b/backend/driver.py @@ -17,7 +17,6 @@ def has_hf(): resp = os.popen("lscpu").read() if platform.machine() == "x86_64" and "f16c" in resp: - return True elif platform.machine() == "aarch64" and "fphp" in resp: return True @@ -26,9 +25,9 @@ def has_hf(): def has_bf(): resp = os.popen("lscpu").read() if platform.machine() == "x86_64" and "avx512_bf16" in resp: - print("axv512_bf16") return True - #TODO add aarch64 support for bf16 + elif platform.machine() == "aarch64" and "bf16" in resp: + return True return False diff --git a/tools/triton-sme-opt/triton-sme-opt.cpp b/tools/triton-sme-opt/triton-sme-opt.cpp index 60ab10e6..380e8174 100644 --- a/tools/triton-sme-opt/triton-sme-opt.cpp +++ b/tools/triton-sme-opt/triton-sme-opt.cpp @@ -51,55 +51,56 @@ struct OuterProductVectorizationPass : public PassWrapper { - using OpRewritePattern ::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::MatmulOp op, PatternRewriter & rewriter) const override { - - - linalg::LinalgTilingOptions tilingOptions; - - tilingOptions.setTileSizeComputationFunction([&](OpBuilder& b, Operation*) { - SmallVector sizes; - sizes.reserve(3); - Location loc = op.getLoc(); - - Value vscale = b.create(loc, b.getIndexType()); - - Value tileM = b.create(loc, 4); - Value tileMScaled = b.create(loc, tileM, vscale); - sizes.push_back(tileMScaled); - - Value tileN = b.create(loc, 4); - Value tileNScaled = b.create(loc, tileN, vscale); - sizes.push_back(tileNScaled); - - Value tileK = b.create(loc, 1); - sizes.push_back(tileK); + struct MatmulTileConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::MatmulOp op, + PatternRewriter &rewriter) const override { + linalg::LinalgTilingOptions tilingOptions; + + tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, + Operation *) { + SmallVector sizes; + sizes.reserve(3); + + Location loc = op.getLoc(); + Value vscale = b.create(loc, b.getIndexType()); + Value tileM = b.create(loc, 4); + Value tileMScaled = b.create(loc, tileM, vscale); + sizes.push_back(tileMScaled); + Value tileN = b.create(loc, 4); + Value tileNScaled = b.create(loc, tileN, vscale); + sizes.push_back(tileNScaled); + Value tileK = b.create(loc, 1); + sizes.push_back(tileK); + + return sizes; + }); + + auto tiledOpResult = tileLinalgOp(rewriter, op, tilingOptions); + if (failed(tiledOpResult)) { + std::cout << "TILING FAILED" << std::endl; + return failure(); + } - return sizes; - }); + // Specify vector sizes and scalable dimensions for each dimension + SmallVector inputVectorSizes = {4, 4, 1}; + SmallVector inputScalableVecDims = {true, true, false}; + if (failed(linalg::vectorize(rewriter, + cast( + tiledOpResult->op.getOperation()), + inputVectorSizes, inputScalableVecDims))) { + std::cout << "Vectorization FAILED" << std::endl; + return failure(); + } - auto tiledOpResult = tileLinalgOp(rewriter, op, tilingOptions); - if (failed(tiledOpResult)) { - std::cout << "TILING FAILED" << std::endl; - return failure(); - } + MLIRContext *context = getContext(); + rewriter.replaceOp(op, tiledOpResult->tensorResults); - if (failed(linalg::vectorize(rewriter, cast (tiledOpResult->op.getOperation())))) { - std::cout << "Vectorization FAILED" << std::endl; - return failure(); + return success(); } - MLIRContext *context = getContext(); - - - rewriter.replaceOp(op, tiledOpResult->tensorResults); - return success(); - - } - }; - + }; class MatmulTileConversionPass : public PassWrapper > { public: void getDependentDialects(DialectRegistry & registry) const override { From 83932b6655684ffd731cd53921b5fdc93f37a44b Mon Sep 17 00:00:00 2001 From: green Date: Sun, 14 Apr 2024 12:10:52 -0400 Subject: [PATCH 5/7] SME Emulator needs to work --- backend/compiler.py | 95 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 82 insertions(+), 13 deletions(-) diff --git a/backend/compiler.py b/backend/compiler.py index 5395982a..da9d6be8 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -36,6 +36,8 @@ def _get_llvm_bin_path(bin_name: str) -> str: return os.path.join(path, bin_name) + + def _ttir_to_ttsharedir(mod): # Get Triton-MLIR as string ttir_code = str(mod) @@ -44,25 +46,57 @@ def _ttir_to_ttsharedir(mod): dst_path = os.path.join(tmpdir, "ttshared.mlir") Path(src_path).write_text(ttir_code) triton_shared_opt_path = _get_triton_shared_opt_path() - subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path]) + subprocess.check_call([triton_shared_opt_path, src_path, + "--triton-to-structured", + "--canonicalize", + "--triton-arith-to-linalg", + "--cse", + "--structured-to-memref", + "-o", + dst_path]) return Path(dst_path).read_text() + def _optimize_ttsharedir(ttsharedir: str): + if _get_triton_SME_path() == "": return ttsharedir with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "ttshared.mlir") - dst_path = os.path.join(tmpdir, "ttsme.mlir") + sme_first_pass = os.path.join(tmpdir, "sme_first_pass.mlir") + sme_second_pass = os.path.join(tmpdir, "sme_second_pass.mlir") Path(src_path).write_text(ttsharedir) - triton_shared_opt_path = _get_triton_SME_path() - subprocess.check_call([triton_shared_opt_path, src_path, "-sme-converison", "-o", dst_path]) - output= Path(dst_path).read_text() - printc(output) - return output + triton_sme_opt_path = _get_triton_SME_path() + mlir_opt_path = _get_llvm_bin_path("mlir-opt") + subprocess.check_call([triton_sme_opt_path, src_path, "-sme-conversion" , "-o", sme_first_pass]) + + + subprocess.check_call([mlir_opt_path, sme_first_pass, + "--canonicalize", + "--one-shot-bufferize=allow-return-allocs-from-loops=true", + "--eliminate-empty-tensors", + "--convert-linalg-to-loops", + "--lower-affine", + "--empty-tensor-to-alloc-tensor", + "--expand-strided-metadata", + "--arm-sme-vector-legalization", + "--convert-vector-to-arm-sme", + "--arm-sme-outer-product-fusion", + "--arm-sve-legalize-vector-storage", + "--convert-arith-to-arm-sme", + "--allocate-arm-sme-tiles", + "--convert-arm-sme-to-scf", + "--convert-vector-to-scf=full-unroll", + "--convert-arith-to-arm-sme", + "--convert-arith-to-llvm", + "-o", + sme_second_pass]) + + return Path(sme_second_pass).read_text() -def _ttsharedir_to_llir(ttsharedir: str): #going to need to add some flags to this, recent changes to SME feature flags +def _ttsharedir_to_llir(ttsharedir: str): with tempfile.TemporaryDirectory() as tmpdir: ttshared_path = os.path.join(tmpdir, "ttshared.mlir") llmlir_path = os.path.join(tmpdir, "ll.mlir") @@ -70,7 +104,8 @@ def _ttsharedir_to_llir(ttsharedir: str): #going to need to add some flags to th Path(ttshared_path).write_text(ttsharedir) mlir_opt_path = _get_llvm_bin_path("mlir-opt") # TritonShared-MLIR to LLVM-MLIR - subprocess.check_call([mlir_opt_path, ttshared_path, + if _get_triton_SME_path() == "": + subprocess.check_call([mlir_opt_path, ttshared_path, "--convert-linalg-to-affine-loops", "--eliminate-empty-tensors", "--empty-tensor-to-alloc-tensor", @@ -92,18 +127,49 @@ def _ttsharedir_to_llir(ttsharedir: str): #going to need to add some flags to th # Lowering these affine ops again creates further arith ops, # so we have to run these two passes again here. "--lower-affine", - "--convert-arith-to-llvm", + # "--convert-arith-to-llvm", # Remove all unrealized casts created "--reconcile-unrealized-casts", + # "--debug", "-o", llmlir_path]) + # "--convert-vector-to-llvm=enable-arm-sve", + # "--convert-arm-sme-to-llvm", + else: + subprocess.check_call([mlir_opt_path, ttshared_path, + "--canonicalize", + "--eliminate-empty-tensors", + "--convert-linalg-to-loops", + "--lower-affine", + "--convert-scf-to-cf", + "--convert-cf-to-llvm", + + "--convert-vector-to-llvm=enable-arm-sve", + "--convert-arm-sme-to-llvm", + "--convert-math-to-llvm", + "--convert-complex-to-llvm", + "--convert-index-to-llvm", + "--memref-expand", + + "--expand-strided-metadata", + "--finalize-memref-to-llvm", + "--convert-func-to-llvm", + "--lower-affine", + "--convert-arith-to-llvm", + # Remove all unrealized casts created + "--reconcile-unrealized-casts", + # "--debug", + "-o", + llmlir_path]) + # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") subprocess.check_call([mlir_translate_path, llmlir_path, - "--mlir-to-llvmir", + "--mlir-to-llvmir", "-o", llir_path]) + return Path(llir_path).read_text() @@ -122,8 +188,11 @@ def _llir_to_bin(llir: str, metadata): dst_path = os.path.join(tmpdir, "kernel.o") Path(src_path).write_text(llir) llc_path = _get_llvm_bin_path("llc") - subprocess.check_call([llc_path, src_path, "-o", dst_path]) - # Actually it's text-format assembly. Use read_text(). + if _get_triton_SME_path() == "": + subprocess.check_call([llc_path, src_path, "-o", dst_path]) + else: + subprocess.check_call(["/usr/bin/qemu-aarch64-static", llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path]) + return Path(dst_path).read_text() From 85c98450f9dc57c813245465aa05fe2efc9e010e Mon Sep 17 00:00:00 2001 From: Daniyal khan <48022006+danikhan632@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:13:24 -0400 Subject: [PATCH 6/7] Added oneshot bufferization fix --- tools/triton-sme-opt/triton-sme-opt.cpp | 358 +++++++++++++++++------- 1 file changed, 252 insertions(+), 106 deletions(-) diff --git a/tools/triton-sme-opt/triton-sme-opt.cpp b/tools/triton-sme-opt/triton-sme-opt.cpp index 380e8174..6da21f08 100644 --- a/tools/triton-sme-opt/triton-sme-opt.cpp +++ b/tools/triton-sme-opt/triton-sme-opt.cpp @@ -1,145 +1,272 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/DialectConversion.h" +#include #include -#include "mlir/IR/Operation.h" #include - +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Pass/Pass.h" using namespace mlir; +namespace matmul_conversion { -namespace { -struct OuterProductVectorizationPass : public PassWrapper> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *context = funcOp.getContext(); - RewritePatternSet patterns(context); - // Step 4: Lower vector.multi_reduction to vector.contract (+ some helpful patterns) - vector::VectorTransformsOptions vectorTransformsOptions; - vectorTransformsOptions.setVectorTransformsOptions(vector::VectorContractLowering::OuterProduct); - vector::populateVectorTransferDropUnitDimsPatterns(patterns); - vector::populateVectorReductionToContractPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - // Step 5: Lower vector.contract to vector.outerproduct. Also drop unit dims. - patterns.clear(); - vectorTransformsOptions.setVectorTransformsOptions(vector::VectorContractLowering::OuterProduct); - vector::populateVectorTransferDropUnitDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + +struct RestrictToTensorOpsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RestrictToTensorOpsPass) + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = &getContext(); + + funcOp.walk([&](bufferization::ToTensorOp op) { + OpBuilder builder(op); + Location loc = op.getLoc(); + Value alloc = op.getMemref(); + Type tensorType = op.getType(); + + Value tensor = builder.create( + loc, tensorType, alloc, true /* restrict */, true /* writable */); + op.replaceAllUsesWith(tensor); + op.erase(); + }); + } +}; +struct OneShotBufferizationPass : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OneShotBufferizationPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + MLIRContext *context = &getContext(); + + // Set up OneShotBufferizationOptions. + bufferization::OneShotBufferizationOptions options; + // auto options = mlir::bufferization::OneShotBufferizationOptions(); + options.allowReturnAllocsFromLoops = true; + options.allowUnknownOps = true; + options.bufferizeFunctionBoundaries = true; + options.unknownTypeConverterFn = + [](mlir::Value value, mlir::Attribute memorySpace, + const mlir::bufferization::BufferizationOptions &options) { + return mlir::bufferization::getMemRefTypeWithStaticIdentityLayout( + value.getType().cast(), memorySpace); + }; + options.setFunctionBoundaryTypeConversion( + mlir::bufferization::LayoutMapOption::IdentityLayoutMap); + // options.getMemorySpaceFn = [](mlir::TensorType t) { + // if (auto rt = t.dyn_cast()) + // return rt.getEncoding(); + // return mlir::Attribute(); + // }; + + // Run One-Shot Bufferize. + if (failed(bufferization::runOneShotBufferize(moduleOp, options))) { return signalPassFailure(); } } }; - struct MatmulTileConversion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::MatmulOp op, - PatternRewriter &rewriter) const override { - linalg::LinalgTilingOptions tilingOptions; - tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, - Operation *) { - SmallVector sizes; +struct MatmulTileConversion : public OpRewritePattern { + explicit MatmulTileConversion(MLIRContext *context, bool enableSME) + : OpRewritePattern(context), enableSME(enableSME) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp op, + PatternRewriter &rewriter) const override { + linalg::LinalgTilingOptions tilingOptions; + + tilingOptions.setTileSizeComputationFunction( + [&](OpBuilder &b, Operation *) { + SmallVector sizes; sizes.reserve(3); Location loc = op.getLoc(); Value vscale = b.create(loc, b.getIndexType()); - Value tileM = b.create(loc, 4); - Value tileMScaled = b.create(loc, tileM, vscale); - sizes.push_back(tileMScaled); + + if (enableSME) { + Value tileM = b.create(loc, 4); + Value tileMScaled = b.create(loc, tileM, vscale); + sizes.push_back(tileMScaled); + } else { + Value tileM = b.create(loc, 2); + sizes.push_back(tileM); + } Value tileN = b.create(loc, 4); Value tileNScaled = b.create(loc, tileN, vscale); sizes.push_back(tileNScaled); - Value tileK = b.create(loc, 1); + Value tileK = b.create(loc, 2); sizes.push_back(tileK); return sizes; }); + std::cout << enableSME << std::endl; + + auto tiledOpResult = tileLinalgOp(rewriter, op, tilingOptions); + if (failed(tiledOpResult)) { + std::cout << "TILING FAILED" << std::endl; + return failure(); + } + + SmallVector inputVectorSizes = {enableSME ? 4 : 2, 4, 2}; + SmallVector inputScalableVecDims = {enableSME, true, false}; + + if (failed(linalg::vectorize( + rewriter, cast(tiledOpResult->op.getOperation()), + inputVectorSizes, inputScalableVecDims))) { + std::cout << "Vectorization FAILED" << std::endl; + return failure(); + } + + MLIRContext *context = getContext(); + rewriter.replaceOp(op, tiledOpResult->tensorResults); + + return success(); + } + + +private: + bool enableSME; +}; + +struct MatmulTileConversionPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatmulTileConversionPass) - auto tiledOpResult = tileLinalgOp(rewriter, op, tilingOptions); - if (failed(tiledOpResult)) { - std::cout << "TILING FAILED" << std::endl; - return failure(); - } - - // Specify vector sizes and scalable dimensions for each dimension - SmallVector inputVectorSizes = {4, 4, 1}; - SmallVector inputScalableVecDims = {true, true, false}; - - if (failed(linalg::vectorize(rewriter, - cast( - tiledOpResult->op.getOperation()), - inputVectorSizes, inputScalableVecDims))) { - std::cout << "Vectorization FAILED" << std::endl; - return failure(); - } - - MLIRContext *context = getContext(); - rewriter.replaceOp(op, tiledOpResult->tensorResults); - - return success(); - } - }; - class MatmulTileConversionPass - : public PassWrapper > { - public: void getDependentDialects(DialectRegistry & registry) const override { - registry.insert (); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *context = &getContext(); - - - RewritePatternSet patterns(context); - patterns.add(context); - - ConversionTarget target( * context); - target.addLegalDialect (); - - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - signalPassFailure(); - } - } - }; - - std::unique_ptr createOuterProductVectorizationPass() { - return std::make_unique (); + explicit MatmulTileConversionPass(bool enableSME) : enableSME(enableSME) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); } - std::unique_ptr createMatmulTileConversionPass() { - return std::make_unique (); + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = &getContext(); + + RewritePatternSet patterns(context); + patterns.add(context, enableSME); + + ConversionTarget target(*context); + target.addLegalDialect(); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + signalPassFailure(); + } } -} -int main(int argc, char ** argv) { - mlir::DialectRegistry registry; +private: + bool enableSME; +}; - registry.insert(); +struct OuterProductVectorizationPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OuterProductVectorizationPass) + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + // Apply patterns for lowering masked transfers + transform::ApplyLowerMaskedTransfersPatternsOp lowerMaskedTransfersPatterns; + lowerMaskedTransfersPatterns.populatePatterns(patterns); + + // Apply patterns for transfer permutation + transform::ApplyTransferPermutationPatternsOp transferPermutationPatterns; + transferPermutationPatterns.populatePatterns(patterns); + + // Apply patterns for reduction to contract + transform::ApplyVectorReductionToContractPatternsOp reductionToContractPatterns; + reductionToContractPatterns.populatePatterns(patterns); + transform::ApplyLowerMasksPatternsOp lowerMasksPatterns; + lowerMasksPatterns.populatePatterns(patterns); + + // Apply patterns for rank-reducing subview + transform::ApplyRankReducingSubviewPatternsOp rankReducingSubviewPatterns; + rankReducingSubviewPatterns.populatePatterns(patterns); + + vector::populateVectorContractLoweringPatterns( + patterns, vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct)); + target.addIllegalOp(); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + + + std::unique_ptr createOuterProductVectorizationPass() { + return std::make_unique(); + } + std::unique_ptr createPrefetchPass() { + return std::make_unique(); + } + std::unique_ptr createMatmulTileConversionPass(bool enableSME) { + return std::make_unique(enableSME); + } + + std::unique_ptr createRestrictToTensorOpsPass() { + return std::make_unique(); + } + + std::unique_ptr createOneShotBufferizationPass() { + return std::make_unique(); + } +} // namespace matmul_conversion + +int main(int argc, char **argv) { + DialectRegistry registry; + + registry.insert(); registerAllDialects(registry); registerAllPasses(); @@ -148,17 +275,36 @@ int main(int argc, char ** argv) { context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); - PassPipelineRegistration<> pipeline( - "sme-converison", - "Converts linalg.matmul to a more optimized form", - [](OpPassManager & pm) { - pm.addPass(createMatmulTileConversionPass()); - pm.addPass(createOuterProductVectorizationPass()); - } - ); + + PassPipelineRegistration<> smeConversionPipeline( + "sme-conversion", + "Converts linalg.matmul to a more optimized form using SME", + [](OpPassManager &pm) { + pm.addPass(matmul_conversion::createMatmulTileConversionPass(true)); + pm.addPass(matmul_conversion::createRestrictToTensorOpsPass()); + pm.addPass(matmul_conversion::createOneShotBufferizationPass()); + pm.addPass(matmul_conversion::createOuterProductVectorizationPass()); + }); + + PassPipelineRegistration<> sveConversionPipeline( + "sve-conversion", + "Converts linalg.matmul to a more optimized form using SME", + [](OpPassManager &pm) { + pm.addPass(matmul_conversion::createMatmulTileConversionPass(false)); + pm.addPass(matmul_conversion::createRestrictToTensorOpsPass()); + pm.addPass(matmul_conversion::createOneShotBufferizationPass()); + pm.addPass(matmul_conversion::createOuterProductVectorizationPass()); + }); + + PassPipelineRegistration<> prefConversionPipeline( + "prefetch", + "Converts linalg.matmul to a more optimized form using SME", + [](OpPassManager &pm) { + pm.addPass(matmul_conversion::createPrefetchPass()); + }); + return asMainReturnCode( - MlirOptMain(argc, argv, "Optimizer Driver\n", registry) - ); -} \ No newline at end of file + MlirOptMain(argc, argv, "Optimizer Driver\n", registry)); +} From c431d59f9ef2030193cab2f8a31e090f096a401a Mon Sep 17 00:00:00 2001 From: Daniyal khan <48022006+danikhan632@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:13:52 -0400 Subject: [PATCH 7/7] Update triton-sme-opt.cpp --- tools/triton-sme-opt/triton-sme-opt.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tools/triton-sme-opt/triton-sme-opt.cpp b/tools/triton-sme-opt/triton-sme-opt.cpp index 6da21f08..a99d1345 100644 --- a/tools/triton-sme-opt/triton-sme-opt.cpp +++ b/tools/triton-sme-opt/triton-sme-opt.cpp @@ -243,9 +243,7 @@ struct OuterProductVectorizationPass std::unique_ptr createOuterProductVectorizationPass() { return std::make_unique(); } - std::unique_ptr createPrefetchPass() { - return std::make_unique(); - } + std::unique_ptr createMatmulTileConversionPass(bool enableSME) { return std::make_unique(enableSME); }