From 0cbff68363f2b1cbd0c2d5512b55eb48b305d9f2 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 2 Oct 2024 16:57:18 -0400 Subject: [PATCH] Add missing unet gemm shapes --- gemmbench/problems.py | 40 ++++++++-------------------------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/gemmbench/problems.py b/gemmbench/problems.py index ccdc8b2..25f6843 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -639,6 +639,8 @@ def is_compute_bound(M, N, K, bpe): (2048, 1280, 1280), (2048, 1280, 5120), (128, 1280, 2048), + (8192, 640, 640), + (8192, 640, 2560), (8192, 5120, 640), ] @@ -780,7 +782,7 @@ def gpt4memory(dtype: str) -> list[GemmConfig]: for m, n, k in GPT4: hgemm = GemmConfig(m, n, k, "N", "N", dtype) if not is_compute_bound(m, n, k, 2): - yield configs.append(hgemm) + configs.append(hgemm) return configs @@ -794,26 +796,6 @@ def gpt4compute(dtype: str) -> list[GemmConfig]: return configs -def gpt4clocktest(dtype: str) -> list[GemmConfig]: - """GPT4 compute bound GEMMs; FP16.""" - configs = [] - macM, macN = 128, 128 - M, N, K = 2048, 2048, 8192 - for mult in range(1, M//macM + 1): - configs.append(GemmConfig(mult * macM, mult * macN, K, "N", "N", dtype)) - return configs - - -def test(dtype: str) -> list[GemmConfig]: - """GPT4 compute bound GEMMs; FP16.""" - #M, N, K = 2048, 2048, 8192 - configs = [] - M, N, K = 128, 128, 8192 - configs.append(GemmConfig(M, N, K, "N", "N", dtype)) - M, N, K = 2048, 2048, 8192 - configs.append(GemmConfig(M, N, K, "N", "N", dtype)) - return configs - def tk_default(dtype: str) -> list[GemmConfig]: """TK Shapes.""" configs = [] @@ -845,21 +827,15 @@ def compute(dtype: str) -> list[GemmConfig]: """Compute bound GEMMs.""" #for dtype in ["fp16", "bf16", "fp8"]: configs = [] - for dtype in [dtype]: - for tA in ["N", "T"]: - for tB in ["N", "T"]: - if tA == "N" or tB == "N": - configs.append(GemmConfig(4096, 4096, 8192, tA, tB, dtype)) + for tA, tB in [("N", "N"), ("N", "T"), ("T", "N")]: + configs.append(GemmConfig(4096, 4096, 8192, tA, tB, dtype)) return configs def unet(dtype: str) -> list[GemmConfig]: configs = [] - for dtype in [dtype]: - for tA in ["N", "T"]: - for tB in ["N", "T"]: - for m, n, k in UNET: - if tA == "N" or tB == "N": - configs.append(GemmConfig(m, n, k, tA, tB, dtype)) + for tA, tB in [("N", "N"), ("N", "T")]: + for m, n, k in UNET: + configs.append(GemmConfig(m, n, k, tA, tB, dtype)) return configs def get_gemm_configs() -> list[tuple[str, GemmConfig]]: