From 8a5fdd8f8d6e7733ec76446382c1b181134fb1d5 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Tue, 8 Oct 2024 00:45:42 -0700 Subject: [PATCH] Add more shapes to Tk benchmark Also add tuned config which specifies what tile sizes and scheduling params to use. --- gemmbench/gemm_utils.py | 51 +++++++++++++++++++++++++++++++++++------ gemmbench/problems.py | 10 +++++--- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 82d47dd..30db3ae 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -116,8 +116,37 @@ def generate_mlir(config: GemmConfig): return mlir_template_B return mlir_template +@dataclass +class TkTunedConfig: + BLOCK_M: int + BLOCK_N: int + BLOCK_K: int + RATIO_M: int + RATIO_N: int + WAVES_PER_EU: int + MMA_UNITS: int + SHARED_UNITS: int + GLOBAL_UNITS: int + DELAY_MMA: int + DELAY_SHARED: int + DELAY_GLOBAL: int + +def get_tk_tuned_config(config: GemmConfig) -> TunedConfig: + if config.M == 2048 and config.N == 10240 and config.K == 1280: + return TunedConfig(128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2) + if config.M == 2048 and config.N == 1280 and config.K == 1280: + return TunedConfig(64, 64, 64, 2, 2, 1, 2, 1, 1, 1, 1, 2) + if config.M == 2048 and config.N == 1280 and config.K == 5120: + return TunedConfig(128, 80, 128, 4, 1, 1, 4, 2, 2, 1, 1, 2) + if config.M == 128 and config.N == 1280 and config.K == 2048: + return TunedConfig(64, 64, 128, 2, 2, 1, 8, 2, 2, 1, 1, 2) + if config.M == 8192 and config.N == 5120 and config.K == 640: + return TunedConfig(128, 128, 32, 2, 2, 1, 4, 2, 2, 1, 1, 2) def generate_tk_mlir(config: GemmConfig): + # TODO: Enable waves_per_eu + # TODO: Use scheduling barriers with LLVM patch + tc = get_tk_tuned_config(config) assert config.operand_element_type == 'f16', "Unsupported problem" assert config.accumulator_element_type == 'f32', "Unsupported problem" # Input sizes @@ -138,11 +167,11 @@ def generate_tk_mlir(config: GemmConfig): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / tc.RATIO_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / tc.RATIO_N)] constraints += [ - tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1)) ] # Wave-level micro-kernel. @@ -184,16 +213,24 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: ADDRESS_SPACE: SHARED_ADDRESS_SPACE, LOAD_ELEMS_PER_THREAD: 4, STORE_ELEMS_PER_THREAD: 4, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K: 32, + BLOCK_M: tc.BLOCK_M, + BLOCK_N: tc.BLOCK_N, + BLOCK_K: tc.BLOCK_K, M: shape[0], N: shape[1], K: shape[2], + READ_SHARED_DELAY: tc.DELAY_SHARED, + WRITE_SHARED_DELAY: tc.DELAY_SHARED, + READ_GLOBAL_DELAY: tc.DELAY_GLOBAL, + WRITE_GLOBAL_DELAY: tc.DELAY_GLOBAL, + MMA_DELAY: tc.DELAY_MMA, + SHARED_MEMORY_UNITS: tc.SHARED_UNITS, + GLOBAL_MEMORY_UNITS: tc.GLOBAL_UNITS, + MMA_UNITS: tc.MMA_UNITS, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} with tk.gen.TestLaunchContext( - hyperparams, canonicalize=True, run=True, run_config=config + hyperparams, canonicalize=True, run=True, run_config=config, schedule=True, ): a = torch.randn(shape[0], shape[2], dtype=operand_element_type) b = torch.randn(shape[1], shape[2], dtype=operand_element_type) diff --git a/gemmbench/problems.py b/gemmbench/problems.py index 69dbe47..385a343 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -864,11 +864,15 @@ def tk_default(dtype: str) -> list[GemmConfig]: acc_type = get_default_accumulator_element_type(dtype) res_type = get_default_result_element_type(dtype) configs = [] - M, N, K = 1024, 5120, 640 - configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) M, N, K = 2048, 10240, 1280 configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) - M, N, K = 4096, 20480, 2560 + M, N, K = 2048, 1280, 1280 + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) + M, N, K = 2048, 1280, 5120 + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) + M, N, K = 128, 1280, 2048 + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) + M, N, K = 8192, 5120, 640 configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) return configs