Skip to content

Commit

Permalink
Add support for i8 dtype, add --raw_accumulators flag, add `--tar…
Browse files Browse the repository at this point in the history
…get=host_cpu` for easy local testing. (#22)

A few unrelated things mixed in this PR, but they are separate commits
if you'd prefer me to slice it into 3 PRs.
1. Add a `--raw_accumulators` flag that drops the truncation of the
results (default False). This leads to lower arithmetic intensity
(because the result values are larger) and either higher or lower
performance. This is less representative of real workloads, but is
sometimes easier to reason about as a microbenchmark.
2. Add support for `i8` dtype accumulating into `i32`. For now only
added to the `square` problem set. Also added `bf16` to that set.
3. Add a special value for the existing `--target` flag: `"host_cpu"`
for testing on CPU configured for the host. This was mostly for my own
use to be able to develop these changes locally without a GPU.

---------

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Oct 10, 2024
1 parent 4fc4145 commit 91f1260
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 211 deletions.
20 changes: 14 additions & 6 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
help="Set the logging level",
)

parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942")
parser.add_argument("--target", help="The IREE hip target to compile for. The special value host_cpu results in a llvm-cpu benchmark instead of HIP, compiled for the host CPU.", type=str, default="gfx942")
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
parser.add_argument(
"--Xiree_compile",
Expand Down Expand Up @@ -76,10 +76,15 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
default=None,
help="Directory to which executable files will be dumped."
)
parser.add_argument(
"--raw_accumulators",
action='store_true',
help="If true, benchmark matmuls returning the raw accumulator type with no truncation. If false (default), the results are truncated and cast to the input element type."
)

args = parser.parse_args()
# Handle default values here, since list args are not compatible with defaulted lists.
requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes)
requested_dtypes = ["f16", "bf16", "i8"] if not args.dtypes else list(args.dtypes)
requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants)

logging.basicConfig(level=args.log_level)
Expand All @@ -91,7 +96,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,

tk = args.tk
configs = get_tk_gemm_configs() if tk else get_gemm_configs()
configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex)
configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex, args.raw_accumulators)
print(f"Generated {len(configs)} gemm configs.")

num_cpus = max(1, max(cpu_count() // 2, 1))
Expand All @@ -108,7 +113,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
target = args.target
extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
dump_dir = args.dump_dir
device = args.device
device = "local-task" if args.target == "host_cpu" else args.device

compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs
Expand All @@ -130,9 +135,12 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,

results = []
index = 0
output_csv = "results/iree_gemm.csv"
output_csv_base = "iree_gemm"
if args.raw_accumulators:
output_csv_base += "_raw_accumulators"
if tk:
output_csv = "results/iree_gemm_tk.csv"
output_csv_base += "_tk"
output_csv = f"results/{output_csv_base}.csv"
csv_dir = os.path.dirname(output_csv)
if not os.path.exists(csv_dir):
os.makedirs(csv_dir)
Expand Down
76 changes: 48 additions & 28 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from iree.turbine.kernel.lang.global_symbols import *
import torch


@dataclass
class GemmConfig:
M: int
Expand Down Expand Up @@ -48,73 +49,79 @@ def get_byte_count(self) -> int:
}
operand_bytes_per_element = dtype_to_bytes[self.operand_element_type]
result_bytes_per_element = dtype_to_bytes[self.result_element_type]
byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element
return byte_count
byte_count_input = (self.M + self.N) * self.K * operand_bytes_per_element
byte_count_output = (self.M * self.N) * result_bytes_per_element
return byte_count_input + byte_count_output

def get_flops(self) -> int:
flops = 2 * self.M * self.N * self.K
return flops


def generate_mlir(config: GemmConfig):
K = config.K
M = config.M
N = config.N
operand_element_type = config.operand_element_type
acc_element_type = config.accumulator_element_type
result_element_type = config.result_element_type
assert not operand_element_type.startswith('i'), "Integer types not supported yet"
is_integer = operand_element_type.startswith('i')
literal_zero = "0" if is_integer else "0.0"
trunc_op = "arith.trunci" if is_integer else "arith.truncf"

tA = config.tA
tB = config.tB
mlir_template_A = f"""
mlir_template_matmul_transpose_a = f"""
module {{
func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%cst = arith.constant {literal_zero} : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template_B = f"""
mlir_template_matmul_transpose_b = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%cst = arith.constant {literal_zero} : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template = f"""
mlir_template_matmul_normal = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%cst = arith.constant {literal_zero} : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
"""
mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal

mlir_template_return_truncated = f"""
%3 = {trunc_op} %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""
if tA == "T":
return mlir_template_A
if tB == "T":
return mlir_template_B
return mlir_template

mlir_template_return_untruncated = f"""
return %2 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated

return mlir_template_matmul + mlir_template_return


@dataclass
class TkTunedConfig:
Expand All @@ -131,6 +138,7 @@ class TkTunedConfig:
DELAY_SHARED: int
DELAY_GLOBAL: int


def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
if config.M == 2048 and config.N == 10240 and config.K == 1280:
return TkTunedConfig(128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2)
Expand All @@ -145,6 +153,7 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
# Default config
return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2)


def generate_tk_mlir(config: GemmConfig):
# TODO: Enable waves_per_eu
# TODO: Use scheduling barriers with LLVM patch
Expand All @@ -166,14 +175,16 @@ def generate_tk_mlir(config: GemmConfig):
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / tc.RATIO_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / tc.RATIO_N)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1))
tkw.HardwareConstraint(threads_per_wave=64,
waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1))
]

# Wave-level micro-kernel.
Expand Down Expand Up @@ -266,13 +277,22 @@ def compile_gemm_config(
exec_args = [
"iree-compile",
f"{mlir_file}",
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={target}",
"--iree-llvmgpu-enable-prefetch=true",
"-o",
f"{vmfb_file}",
] + extra_compiler_args

if target == "host_cpu":
exec_args += [
"--iree-hal-target-backends=llvm-cpu",
"--iree-llvmcpu-target-cpu=host"
]
else:
exec_args += [
"--iree-hal-target-backends=rocm",
f"--iree-hip-target={target}",
"--iree-llvmgpu-enable-prefetch=true",
]

print(" ".join(exec_args))

ret_value, stdout, stderr = run_iree_command(exec_args)
Expand Down
Loading

0 comments on commit 91f1260

Please sign in to comment.