Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shintaro-iwasaki committed Nov 6, 2023
1 parent 7a737e1 commit c5b7d8b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
branches: [ "main" ]

jobs:
test-plugin:
call-workflow:
uses: ./.github/workflows/test-plugin.yml
with:
triton-ref: '05dc28be0e72dd496300a31b99a21a5a5118f8e9' # known good commit "[CI] refactor workflows (#2504)"
Expand Down
2 changes: 1 addition & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ if(TRITON_BUILD_PYTHON_MODULE)
${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/CRunnerUtils.h
${CMAKE_CURRENT_SOURCE_DIR}/ExecutionEngine/CRunnerUtils.cpp
DESTINATION ${PYTHON_THIRD_PARTY_PATH}/cpu/)
# TODO: perhaps we want to install binary files used in __init__.py
# TODO: perhaps we want to install binary files used by __init__.py
endif()
69 changes: 30 additions & 39 deletions python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
def _get_triton_shared_opt_path() -> str:
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
if path == "":
assert Exception("TRITON_SHARED_OPT_PATH is not set.")
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
return path


def _get_llvm_bin_path(bin_name: str) -> str:
path = os.getenv("LLVM_BINARY_DIR", "")
if path == "":
raise Exception("LLVM_BINARY_DIR is not set.")
return f"{path}/{bin_name}"
return os.path.join(path, bin_name)


def _ttir_to_ttsharedir(mod):
Expand All @@ -33,8 +33,7 @@ def _ttir_to_ttsharedir(mod):
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
ret = subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path])
assert ret == 0
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path])
return Path(dst_path).read_text()


Expand All @@ -51,7 +50,7 @@ def _ttsharedir_to_llir(ttsharedir: str):
Path(ttshared_path).write_text(ttsharedir)
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
# TritonShared-MLIR to LLVM-MLIR
ret = subprocess.check_call([mlir_opt_path, ttshared_path,
subprocess.check_call([mlir_opt_path, ttshared_path,
"--convert-linalg-to-affine-loops",
"--eliminate-empty-tensors",
"--empty-tensor-to-alloc-tensor",
Expand All @@ -72,15 +71,13 @@ def _ttsharedir_to_llir(ttsharedir: str):
"--reconcile-unrealized-casts",
"-o",
llmlir_path])
assert ret == 0

# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
ret = subprocess.check_call([mlir_translate_path, llmlir_path,
subprocess.check_call([mlir_translate_path, llmlir_path,
"--mlir-to-llvmir",
"-o",
llir_path])
assert ret == 0
return Path(llir_path).read_text()


Expand All @@ -93,11 +90,9 @@ def _llir_to_bin(llir: str):
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "kernel.ll")
dst_path = os.path.join(tmpdir, "kernel.o")
with open(src_path, "w") as f:
f.write(llir)
Path(src_path).write_text(llir)
llc_path = _get_llvm_bin_path("llc")
ret = subprocess.check_call([llc_path, src_path, "-o", dst_path])
assert ret == 0
subprocess.check_call([llc_path, src_path, "-o", dst_path])
# Actually it's text-format assembly. Use read_text().
return Path(dst_path).read_text()

Expand Down Expand Up @@ -346,25 +341,27 @@ def make_launcher_stub(self, name, signature, constants, ids):
so_name = f"{name}.py"
# retrieve stub from cache if it exists
cache_path = so_cache_manager.get_file(so_name)
if cache_path is None:
kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER"
with tempfile.TemporaryDirectory() as tmpdir:
# Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name
# in the following launch function.
launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name)
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]

dst_path = os.path.join(tmpdir, so_name)
py_src = f"""
if cache_path is not None:
return cache_path

kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER"
with tempfile.TemporaryDirectory() as tmpdir:
# Later KERNEL_NAME_PLACEHOLDER will be used to assign the kernel name
# in the following launch function.
launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name)
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
# 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
# path changes to include 'local'. This change is required to use triton with system-wide python.
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]

dst_path = os.path.join(tmpdir, so_name)
py_src = f"""
import os, subprocess, tempfile
import importlib.util
from pathlib import Path
Expand All @@ -385,21 +382,15 @@ def launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDim0, clusterDim1, c
Path(asm_src_path).write_text(asm_src)
Path(launcher_src_path).write_text(launcher_src)
# Compile it together.
ret = subprocess.check_call(["g++", launcher_src_path, asm_src_path, f"-I{py_include_dir}", f"-I{Path(__file__).resolve().parent}", "-shared", "-fPIC", "-o", so_path])
if ret != 0:
raise AssertionError("Kernel compilation failed.")
subprocess.check_call(["g++", launcher_src_path, asm_src_path, f"-I{py_include_dir}", f"-I{Path(__file__).resolve().parent}", "-shared", "-fPIC", "-o", so_path])
# Load and launch the compiled kernel.
spec = importlib.util.spec_from_file_location("__triton_shared_ref_cpu_kernel_launcher", so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod.launch(gridX, gridY, gridZ, launch_enter_hook, launch_exit_hook, compiled_kernel, *args)
"""
Path(dst_path).write_text(py_src)
with open(dst_path, "rb") as f:
return so_cache_manager.put(f.read(), so_name, binary=True)
else:
return cache_path
return so_cache_manager.put(py_src, so_name, binary=False)


register_backend("cpu", TritonSharedRefCPUBackend)

0 comments on commit c5b7d8b

Please sign in to comment.