diff --git a/README.md b/README.md index 0fc9d8dd..7c99d120 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,36 @@ pytest /python/examples ``` In addition to testing on the tutorial kernels, there are many lit tests covering various scenarios. +## Intermediate Representation (IR) Dumps + +To facilitate debugging and analysis, the triton-shared project now supports emitting all intermediate representations (IRs) generated during the compilation process. This functionality is controlled via the environment variable `TRITON_SHARED_DUMP_PATH`. + +### How It Works + +By setting the `TRITON_SHARED_DUMP_PATH` environment variable, you specify a directory where all intermediate representations will be saved. The Triton compiler will emit IR dumps at various stages of compilation into the specified folder, allowing developers to inspect and analyze the transformations applied to the code. + +### How to Use + +Create a directory where the IR dumps will be stored (e.g., /path/to/dump_dir). +Set the `TRITON_SHARED_DUMP_PATH` environment variable to the directory path: +`export TRITON_SHARED_DUMP_PATH=/path/to/dump_dir` +Run your Triton compilation as usual. The compiler will emit IR dumps into the specified directory. + +### Example + +Suppose your dump directory is `/tmp/ir_dumps`. Before running your code, set the environment variable: + +```sh +export TRITON_SHARED_DUMP_PATH=/tmp/ir_dumps +``` + +After the compilation process completes, you can explore the `/tmp/ir_dumps` directory to find all the intermediate representation files. + +```sh +$ ls /tmp/ir_dumps +ll.ir ll.mlir tt.mlir ttshared.mlir +``` + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/backend/compiler.py b/backend/compiler.py index 0dc38fc3..3c51068b 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -7,6 +7,7 @@ import tempfile import os import re +import shutil import subprocess import functools from pathlib import Path @@ -25,6 +26,14 @@ def _get_llvm_bin_path(bin_name: str) -> str: return os.path.join(path, bin_name) +def _dump_ir_if_needed(files): + path = os.getenv("TRITON_SHARED_DUMP_PATH", "") + if not path: + return + for f in files: + shutil.copy(f, os.path.join(path, os.path.basename(f))) + + def _ttir_to_ttsharedir(mod): # Get Triton-MLIR as string ttir_code = str(mod) @@ -33,8 +42,8 @@ def _ttir_to_ttsharedir(mod): dst_path = os.path.join(tmpdir, "ttshared.mlir") Path(src_path).write_text(ttir_code) triton_shared_opt_path = _get_triton_shared_opt_path() - subprocess.check_call([triton_shared_opt_path, src_path, - "--triton-to-linalg-experimental", "--mlir-print-debuginfo", "-o", dst_path]) + subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg-experimental", "--mlir-print-debuginfo", "-o", dst_path]) + _dump_ir_if_needed([src_path]) return Path(dst_path).read_text() @@ -91,6 +100,7 @@ def _ttsharedir_to_llir(ttsharedir: str): "--mlir-to-llvmir", "-o", llir_path]) + _dump_ir_if_needed([ttshared_path, llmlir_path, llir_path]) return Path(llir_path).read_text()