Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing Arm SME/SVE2 Optimization pass #109

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from triton.backends.compiler import BaseBackend
from triton._C.libtriton import ir, passes
from triton._C.libtriton import ir, passes, llvm, triton_shared
from dataclasses import dataclass
from typing import Any
import hashlib
Expand All @@ -10,20 +10,34 @@
import functools
from pathlib import Path

def printc(obj, color="cyan"): #makes things easier to see, will remove later
color_code = {
"black": "30", "red": "31", "green": "32", "yellow": "33",
"blue": "34", "magenta": "35", "cyan": "36", "white": "37"
}
colored_text = f"\033[{color_code[color]}m{obj}\033[0m" if color in color_code else obj
print(colored_text)

def _get_triton_shared_opt_path() -> str:
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
if path == "":
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
return path


def _get_triton_SME_path() -> str:
return os.getenv("TRITON_SME_PATH", "")


def _get_llvm_bin_path(bin_name: str) -> str:
path = os.getenv("LLVM_BINARY_DIR", "")
if path == "":
raise Exception("LLVM_BINARY_DIR is not set.")
return os.path.join(path, bin_name)




def _ttir_to_ttsharedir(mod):
# Get Triton-MLIR as string
ttir_code = str(mod)
Expand All @@ -32,13 +46,54 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-structured", "--canonicalize", "--triton-arith-to-linalg", "--cse", "--structured-to-memref", "-o", dst_path])
subprocess.check_call([triton_shared_opt_path, src_path,
"--triton-to-structured",
"--canonicalize",
"--triton-arith-to-linalg",
"--cse",
"--structured-to-memref",
"-o",
dst_path])
return Path(dst_path).read_text()



def _optimize_ttsharedir(ttsharedir: str):
# We don't apply any optimizations now, but we can add passes if needed.
return ttsharedir

if _get_triton_SME_path() == "":
return ttsharedir
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "ttshared.mlir")
sme_first_pass = os.path.join(tmpdir, "sme_first_pass.mlir")
sme_second_pass = os.path.join(tmpdir, "sme_second_pass.mlir")
Path(src_path).write_text(ttsharedir)
triton_sme_opt_path = _get_triton_SME_path()
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
subprocess.check_call([triton_sme_opt_path, src_path, "-sme-conversion" , "-o", sme_first_pass])


subprocess.check_call([mlir_opt_path, sme_first_pass,
"--canonicalize",
"--one-shot-bufferize=allow-return-allocs-from-loops=true",
"--eliminate-empty-tensors",
"--convert-linalg-to-loops",
"--lower-affine",
"--empty-tensor-to-alloc-tensor",
"--expand-strided-metadata",
"--arm-sme-vector-legalization",
"--convert-vector-to-arm-sme",
"--arm-sme-outer-product-fusion",
"--arm-sve-legalize-vector-storage",
"--convert-arith-to-arm-sme",
"--allocate-arm-sme-tiles",
"--convert-arm-sme-to-scf",
"--convert-vector-to-scf=full-unroll",
"--convert-arith-to-arm-sme",
"--convert-arith-to-llvm",
"-o",
sme_second_pass])

return Path(sme_second_pass).read_text()


def _ttsharedir_to_llir(ttsharedir: str):
Expand All @@ -49,7 +104,8 @@ def _ttsharedir_to_llir(ttsharedir: str):
Path(ttshared_path).write_text(ttsharedir)
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
# TritonShared-MLIR to LLVM-MLIR
subprocess.check_call([mlir_opt_path, ttshared_path,
if _get_triton_SME_path() == "":
subprocess.check_call([mlir_opt_path, ttshared_path,
"--convert-linalg-to-affine-loops",
"--eliminate-empty-tensors",
"--empty-tensor-to-alloc-tensor",
Expand All @@ -71,18 +127,49 @@ def _ttsharedir_to_llir(ttsharedir: str):
# Lowering these affine ops again creates further arith ops,
# so we have to run these two passes again here.
"--lower-affine",
"--convert-arith-to-llvm",
# "--convert-arith-to-llvm",
# Remove all unrealized casts created
"--reconcile-unrealized-casts",
# "--debug",
"-o",
llmlir_path])

# "--convert-vector-to-llvm=enable-arm-sve",
# "--convert-arm-sme-to-llvm",
else:
subprocess.check_call([mlir_opt_path, ttshared_path,
"--canonicalize",
"--eliminate-empty-tensors",
"--convert-linalg-to-loops",
"--lower-affine",
"--convert-scf-to-cf",
"--convert-cf-to-llvm",

"--convert-vector-to-llvm=enable-arm-sve",
"--convert-arm-sme-to-llvm",
"--convert-math-to-llvm",
"--convert-complex-to-llvm",
"--convert-index-to-llvm",
"--memref-expand",

"--expand-strided-metadata",
"--finalize-memref-to-llvm",
"--convert-func-to-llvm",
"--lower-affine",
"--convert-arith-to-llvm",
# Remove all unrealized casts created
"--reconcile-unrealized-casts",
# "--debug",
"-o",
llmlir_path])

# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
subprocess.check_call([mlir_translate_path, llmlir_path,
"--mlir-to-llvmir",
"--mlir-to-llvmir",
"-o",
llir_path])

return Path(llir_path).read_text()


Expand All @@ -101,8 +188,11 @@ def _llir_to_bin(llir: str, metadata):
dst_path = os.path.join(tmpdir, "kernel.o")
Path(src_path).write_text(llir)
llc_path = _get_llvm_bin_path("llc")
subprocess.check_call([llc_path, src_path, "-o", dst_path])
# Actually it's text-format assembly. Use read_text().
if _get_triton_SME_path() == "":
subprocess.check_call([llc_path, src_path, "-o", dst_path])
else:
subprocess.check_call(["/usr/bin/qemu-aarch64-static", llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

return Path(dst_path).read_text()


Expand Down Expand Up @@ -151,6 +241,7 @@ def load_dialects(self, ctx):

@staticmethod
def make_ttir(mod, metadata, opt):
# assert False
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
Expand Down
43 changes: 36 additions & 7 deletions backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile
import sysconfig

import os, subprocess, tempfile
import os, subprocess, tempfile, platform
import importlib.util
import sysconfig

Expand All @@ -12,10 +12,31 @@
from triton.backends.driver import DriverBase


# -------------------- Launcher ----------------------------

def has_hf():
resp = os.popen("lscpu").read()
if platform.machine() == "x86_64" and "f16c" in resp:
return True
elif platform.machine() == "aarch64" and "fphp" in resp:
return True
return False

def has_bf():
resp = os.popen("lscpu").read()
if platform.machine() == "x86_64" and "avx512_bf16" in resp:
return True
elif platform.machine() == "aarch64" and "bf16" in resp:
return True
return False



# -------------------- Launcher ----------------------------
def _ty_to_cpp(ty):
if ty[0] == '*':
return "void*"

return {
"i1": "int32_t",
"i8": "int8_t",
Expand All @@ -24,12 +45,14 @@ def _ty_to_cpp(ty):
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp16": "float16_t" if has_hf() else "float",
"bf16": "bfloat16_t" if has_bf() else "float16_t" if has_hf() else "float",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]



def _extracted_ty(ty):
if ty[0] == '*':
Expand All @@ -40,15 +63,15 @@ def _extracted_ty(ty):
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
"fp16": "float16_t" if has_hf() else "float",
"bf16": "bfloat16_t" if has_bf() else "float16_t" if has_hf() else "float",
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
}[ty]

def _format_of(ty):
return {
format_dict = {
"PyObject*": "O",
"float": "f",
"double": "d",
Expand All @@ -57,7 +80,12 @@ def _format_of(ty):
"int32_t": "i",
"uint64_t": "K",
"int64_t": "L",
}[ty]
}
if has_bf():
format_dict["bfloat16"] = "e"
if has_hf():
format_dict["float16"] = "h"
return format_dict[ty]

def _generate_launcher(constants, signature, kernel_name):
arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
Expand All @@ -68,6 +96,7 @@ def _generate_launcher(constants, signature, kernel_name):
#include <Python.h>
#include "ExecutionEngine/CRunnerUtils.h"
#include "ExecutionEngine/CRunnerUtils.cpp"
#include <stdfloat>

extern "C" {{
// Pointer type (=Memref) becomes int64_t + MemRef struct
Expand Down
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(triton-shared-opt)
add_subdirectory(triton-sme-opt)
19 changes: 19 additions & 0 deletions tools/triton-sme-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)

add_llvm_executable(triton-sme-opt triton-sme-opt.cpp PARTIAL_SOURCES_INTENDED)

llvm_update_compile_flags(triton-sme-opt)
target_link_libraries(triton-sme-opt PRIVATE

${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIROptLib
MLIRPass
MLIRTransforms
)

mlir_check_all_link_libraries(triton-sme-opt)
Loading