From 91f1260078d6b0244703e0427d072100ade2374d Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 10 Oct 2024 17:15:13 -0400 Subject: [PATCH] Add support for `i8` dtype, add `--raw_accumulators` flag, add `--target=host_cpu` for easy local testing. (#22) A few unrelated things mixed in this PR, but they are separate commits if you'd prefer me to slice it into 3 PRs. 1. Add a `--raw_accumulators` flag that drops the truncation of the results (default False). This leads to lower arithmetic intensity (because the result values are larger) and either higher or lower performance. This is less representative of real workloads, but is sometimes easier to reason about as a microbenchmark. 2. Add support for `i8` dtype accumulating into `i32`. For now only added to the `square` problem set. Also added `bf16` to that set. 3. Add a special value for the existing `--target` flag: `"host_cpu"` for testing on CPU configured for the host. This was mostly for my own use to be able to develop these changes locally without a GPU. --------- Signed-off-by: Benoit Jacob --- gemmbench/gemm_bench.py | 20 +- gemmbench/gemm_utils.py | 76 +++++--- gemmbench/problems.py | 392 ++++++++++++++++++++++------------------ 3 files changed, 277 insertions(+), 211 deletions(-) diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index 8778c7c..6ce3066 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -38,7 +38,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, help="Set the logging level", ) - parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942") + parser.add_argument("--target", help="The IREE hip target to compile for. The special value host_cpu results in a llvm-cpu benchmark instead of HIP, compiled for the host CPU.", type=str, default="gfx942") parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") parser.add_argument( "--Xiree_compile", @@ -76,10 +76,15 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, default=None, help="Directory to which executable files will be dumped." ) + parser.add_argument( + "--raw_accumulators", + action='store_true', + help="If true, benchmark matmuls returning the raw accumulator type with no truncation. If false (default), the results are truncated and cast to the input element type." + ) args = parser.parse_args() # Handle default values here, since list args are not compatible with defaulted lists. - requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes) + requested_dtypes = ["f16", "bf16", "i8"] if not args.dtypes else list(args.dtypes) requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants) logging.basicConfig(level=args.log_level) @@ -91,7 +96,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk = args.tk configs = get_tk_gemm_configs() if tk else get_gemm_configs() - configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex) + configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex, args.raw_accumulators) print(f"Generated {len(configs)} gemm configs.") num_cpus = max(1, max(cpu_count() // 2, 1)) @@ -108,7 +113,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, target = args.target extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)] dump_dir = args.dump_dir - device = args.device + device = "local-task" if args.target == "host_cpu" else args.device compile_args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs @@ -130,9 +135,12 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, results = [] index = 0 - output_csv = "results/iree_gemm.csv" + output_csv_base = "iree_gemm" + if args.raw_accumulators: + output_csv_base += "_raw_accumulators" if tk: - output_csv = "results/iree_gemm_tk.csv" + output_csv_base += "_tk" + output_csv = f"results/{output_csv_base}.csv" csv_dir = os.path.dirname(output_csv) if not os.path.exists(csv_dir): os.makedirs(csv_dir) diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 86d8aec..a196574 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.lang.global_symbols import * import torch + @dataclass class GemmConfig: M: int @@ -48,13 +49,15 @@ def get_byte_count(self) -> int: } operand_bytes_per_element = dtype_to_bytes[self.operand_element_type] result_bytes_per_element = dtype_to_bytes[self.result_element_type] - byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element - return byte_count + byte_count_input = (self.M + self.N) * self.K * operand_bytes_per_element + byte_count_output = (self.M * self.N) * result_bytes_per_element + return byte_count_input + byte_count_output def get_flops(self) -> int: flops = 2 * self.M * self.N * self.K return flops + def generate_mlir(config: GemmConfig): K = config.K M = config.M @@ -62,59 +65,63 @@ def generate_mlir(config: GemmConfig): operand_element_type = config.operand_element_type acc_element_type = config.accumulator_element_type result_element_type = config.result_element_type - assert not operand_element_type.startswith('i'), "Integer types not supported yet" + is_integer = operand_element_type.startswith('i') + literal_zero = "0" if is_integer else "0.0" + trunc_op = "arith.trunci" if is_integer else "arith.truncf" tA = config.tA tB = config.tB - mlir_template_A = f""" + mlir_template_matmul_transpose_a = f""" module {{ func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ - %cst = arith.constant 0.000000e+00 : {acc_element_type} + %cst = arith.constant {literal_zero} : {acc_element_type} %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> - return %3 : tensor<{M}x{N}x{result_element_type}> - }} -}} """ - mlir_template_B = f""" + mlir_template_matmul_transpose_b = f""" module {{ func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ - %cst = arith.constant 0.000000e+00 : {acc_element_type} + %cst = arith.constant {literal_zero} : {acc_element_type} %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> - return %3 : tensor<{M}x{N}x{result_element_type}> - }} -}} """ - mlir_template = f""" + mlir_template_matmul_normal = f""" module {{ func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ - %cst = arith.constant 0.000000e+00 : {acc_element_type} + %cst = arith.constant {literal_zero} : {acc_element_type} %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) outs(%1 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> - %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> +""" + mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal + + mlir_template_return_truncated = f""" + %3 = {trunc_op} %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ - if tA == "T": - return mlir_template_A - if tB == "T": - return mlir_template_B - return mlir_template + + mlir_template_return_untruncated = f""" + return %2 : tensor<{M}x{N}x{result_element_type}> + }} +}} +""" + + mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated + + return mlir_template_matmul + mlir_template_return + @dataclass class TkTunedConfig: @@ -131,6 +138,7 @@ class TkTunedConfig: DELAY_SHARED: int DELAY_GLOBAL: int + def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig: if config.M == 2048 and config.N == 10240 and config.K == 1280: return TkTunedConfig(128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2) @@ -145,6 +153,7 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig: # Default config return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2) + def generate_tk_mlir(config: GemmConfig): # TODO: Enable waves_per_eu # TODO: Use scheduling barriers with LLVM patch @@ -166,14 +175,16 @@ def generate_tk_mlir(config: GemmConfig): STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] constraints += [tkw.WaveConstraint(M, BLOCK_M / tc.RATIO_M)] constraints += [tkw.WaveConstraint(N, BLOCK_N / tc.RATIO_N)] constraints += [ - tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1)) + tkw.HardwareConstraint(threads_per_wave=64, + waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1)) ] # Wave-level micro-kernel. @@ -266,13 +277,22 @@ def compile_gemm_config( exec_args = [ "iree-compile", f"{mlir_file}", - "--iree-hal-target-backends=rocm", - f"--iree-hip-target={target}", - "--iree-llvmgpu-enable-prefetch=true", "-o", f"{vmfb_file}", ] + extra_compiler_args + if target == "host_cpu": + exec_args += [ + "--iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host" + ] + else: + exec_args += [ + "--iree-hal-target-backends=rocm", + f"--iree-hip-target={target}", + "--iree-llvmgpu-enable-prefetch=true", + ] + print(" ".join(exec_args)) ret_value, stdout, stderr = run_iree_command(exec_args) diff --git a/gemmbench/problems.py b/gemmbench/problems.py index ab2dc7d..6b23c6c 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -19,16 +19,17 @@ def get_default_accumulator_element_type(operand_element_type: str) -> str: ] -def get_default_result_element_type(operand_element_type: str) -> str: - return operand_element_type +def get_default_result_element_type(operand_element_type: str, raw_accumulators: bool) -> str: + return get_default_accumulator_element_type(operand_element_type) if raw_accumulators else operand_element_type -def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: +def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool) -> bool: """Is this GEMM compute (or memory) bound?""" magic_ratio = 64 flops = 2 * M * N * K elem_type_bytes = num_bytes(dtype) - result_bytes = num_bytes(get_default_result_element_type(dtype)) + result_bytes = num_bytes( + get_default_result_element_type(dtype, raw_accumulators)) bytes = elem_type_bytes * (M * K + K * N) + result_bytes * (M * N) return flops > magic_ratio * bytes @@ -680,23 +681,26 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: (8192, 8192, 8192), ] + def llama13bmatvec(dtype: str) -> list[GemmConfig]: configs = [] """LLAMA 13b, single batch, FP16.""" for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": - configs.append( - GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + ) ) - ) return configs @@ -705,16 +709,17 @@ def llama13bmatvecbf16(dtype: str) -> list[GemmConfig]: """LLAMA 13b, single batch, BF16.""" for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + )) return configs @@ -723,16 +728,17 @@ def llama70bmatvec(dtype: str) -> list[GemmConfig]: configs = [] for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + )) return configs @@ -741,16 +747,17 @@ def llama70bmatvecbf16(dtype: str) -> list[GemmConfig]: configs = [] for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + )) return configs @@ -760,16 +767,18 @@ def llama13bskinny(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -779,16 +788,18 @@ def llama13bskinnybf16(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -798,16 +809,18 @@ def llama70bskinny(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -817,16 +830,18 @@ def llama70bskinnybf16(dtype: str) -> list[GemmConfig]: for m, n, k, model, gcount in LLAMA: if n == 1 and model == "70b": for batch in [2, 4, 8, 16, 32]: - configs.append(GemmConfig( - m, - batch, - k, - "T", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - )) + for raw_accumulators in [False, True]: + configs.append(GemmConfig( + m, + batch, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + )) return configs @@ -834,18 +849,19 @@ def gpt4memory(dtype: str) -> list[GemmConfig]: """GPT4 memory bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig( - m, - n, - k, - "N", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - ) - if not is_compute_bound(m, n, k, dtype): - configs.append(hgemm) + for raw_accumulators in [False, True]: + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) + if not is_compute_bound(m, n, k, dtype, raw_accumulators): + configs.append(hgemm) return configs @@ -853,36 +869,43 @@ def gpt4compute(dtype: str) -> list[GemmConfig]: """GPT4 compute bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig( - m, - n, - k, - "N", - "N", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - ) - if is_compute_bound(m, n, k, dtype): - configs.append(hgemm) + for raw_accumulators in [False, True]: + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) + if is_compute_bound(m, n, k, dtype, raw_accumulators): + configs.append(hgemm) return configs def tk_default(dtype: str) -> list[GemmConfig]: """TK Shapes.""" - acc_type = get_default_accumulator_element_type(dtype) - res_type = get_default_result_element_type(dtype) - configs = [] - M, N, K = 2048, 10240, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 2048, 1280, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 2048, 1280, 5120 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 128, 1280, 2048 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 8192, 5120, 640 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) + for raw_accumulators in [False, True]: + acc_type = get_default_accumulator_element_type(dtype) + res_type = get_default_result_element_type(dtype, raw_accumulators) + configs = [] + M, N, K = 2048, 10240, 1280 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 2048, 1280, 1280 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 2048, 1280, 5120 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 128, 1280, 2048 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) + M, N, K = 8192, 5120, 640 + configs.append(GemmConfig(M, N, K, "N", "T", + dtype, acc_type, res_type)) return configs @@ -890,18 +913,19 @@ def tk_unet(dtype: str) -> list[GemmConfig]: """UNET Shapes for TK.""" configs = [] for m, n, k in UNET: - configs.append( - GemmConfig( - m, - n, - k, - "N", - "T", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + m, + n, + k, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) ) - ) return configs @@ -909,18 +933,19 @@ def llama70bmemory(dtype: str) -> list[GemmConfig]: """LLAMA 70b memory bound GEMMs; NT; BF16.""" configs = [] for n in [1280, 3584, 7168]: - configs.append( - GemmConfig( - 2, - n, - 8192, - "N", - "T", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + 2, + n, + 8192, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) ) - ) return configs @@ -928,18 +953,19 @@ def compute(dtype: str) -> list[GemmConfig]: """Compute bound GEMMs.""" configs = [] for tA, tB in [("N", "N"), ("N", "T"), ("T", "N")]: - configs.append( - GemmConfig( - 4096, - 4096, - 8192, - tA, - tB, - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + 4096, + 4096, + 8192, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), + ) ) - ) return configs @@ -947,39 +973,42 @@ def unet(dtype: str) -> list[GemmConfig]: configs = [] for tA, tB in [("N", "N"), ("N", "T")]: for m, n, k in UNET: + for raw_accumulators in [False, True]: + configs.append( + GemmConfig( + m, + n, + k, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type( + dtype, raw_accumulators), + ) + ) + return configs + + +def square(dtype: str) -> list[GemmConfig]: + configs = [] + for m, n, k in SQUARE: + for raw_accumulators in [False, True]: configs.append( GemmConfig( m, n, k, - tA, - tB, + "N", + "T", dtype, get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), + get_default_result_element_type(dtype, raw_accumulators), ) ) return configs -def square(dtype: str) -> list[GemmConfig]: - configs = [] - for m, n, k in SQUARE: - configs.append( - GemmConfig( - m, - n, - k, - "N", - "T", - dtype, - get_default_accumulator_element_type(dtype), - get_default_result_element_type(dtype), - ) - ) - return configs - - def get_gemm_configs() -> list[tuple[str, GemmConfig]]: llama13bmatvec_configs: list[GemmConfig] = [] llama13bmatvec_configs += llama13bmatvec("f16") @@ -1009,7 +1038,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: unet_configs += unet("f16") unet_configs += unet("bf16") - square_configs: list[GemmConfig] = square("f16") + square_configs: list[GemmConfig] = square("f16") + square("bf16") + square("i8") all_configs: list[tuple[str, GemmConfig]] = [] all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs] @@ -1041,6 +1070,7 @@ def get_matching_configs( dtypes: list[str], variants: list[str], tag_regex: str, + raw_accumulators: bool ) -> list[tuple[str, GemmConfig]]: tag_re = re.compile(tag_regex) matching_configs: list[tuple[str, GemmConfig]] = [] @@ -1051,6 +1081,14 @@ def get_matching_configs( continue if not tag_re.match(tag): continue + # The raw_accumulators arg means "test configs where the result element + # type is different from what it would be in the default mode". + # We can't just test for (result_element_type == accumulator_element_type), + # as that would cause e.g. f32 matmuls to be omitted in the default mode. + default_result_element_type = get_default_result_element_type(config.operand_element_type, False) + is_raw_accumulators_config = (config.result_element_type != default_result_element_type) + if raw_accumulators != is_raw_accumulators_config: + continue matching_configs.append((tag, config)) return matching_configs