Skip to content

Commit

Permalink
Add missing unet gemm shapes (#7)
Browse files Browse the repository at this point in the history
Simplify the generator code and remove dead functions.
Also do not generate matmul_transpose_a as this doesn't show up in the
real model.
  • Loading branch information
kuhar authored Oct 3, 2024
1 parent 17f1f8b commit 6f5ce46
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,8 @@ def is_compute_bound(M, N, K, bpe):
(2048, 1280, 1280),
(2048, 1280, 5120),
(128, 1280, 2048),
(8192, 640, 640),
(8192, 640, 2560),
(8192, 5120, 640),
]

Expand Down Expand Up @@ -780,7 +782,7 @@ def gpt4memory(dtype: str) -> list[GemmConfig]:
for m, n, k in GPT4:
hgemm = GemmConfig(m, n, k, "N", "N", dtype)
if not is_compute_bound(m, n, k, 2):
yield configs.append(hgemm)
configs.append(hgemm)
return configs


Expand All @@ -794,26 +796,6 @@ def gpt4compute(dtype: str) -> list[GemmConfig]:
return configs


def gpt4clocktest(dtype: str) -> list[GemmConfig]:
"""GPT4 compute bound GEMMs; FP16."""
configs = []
macM, macN = 128, 128
M, N, K = 2048, 2048, 8192
for mult in range(1, M//macM + 1):
configs.append(GemmConfig(mult * macM, mult * macN, K, "N", "N", dtype))
return configs


def test(dtype: str) -> list[GemmConfig]:
"""GPT4 compute bound GEMMs; FP16."""
#M, N, K = 2048, 2048, 8192
configs = []
M, N, K = 128, 128, 8192
configs.append(GemmConfig(M, N, K, "N", "N", dtype))
M, N, K = 2048, 2048, 8192
configs.append(GemmConfig(M, N, K, "N", "N", dtype))
return configs

def tk_default(dtype: str) -> list[GemmConfig]:
"""TK Shapes."""
configs = []
Expand Down Expand Up @@ -845,21 +827,15 @@ def compute(dtype: str) -> list[GemmConfig]:
"""Compute bound GEMMs."""
#for dtype in ["fp16", "bf16", "fp8"]:
configs = []
for dtype in [dtype]:
for tA in ["N", "T"]:
for tB in ["N", "T"]:
if tA == "N" or tB == "N":
configs.append(GemmConfig(4096, 4096, 8192, tA, tB, dtype))
for tA, tB in [("N", "N"), ("N", "T"), ("T", "N")]:
configs.append(GemmConfig(4096, 4096, 8192, tA, tB, dtype))
return configs

def unet(dtype: str) -> list[GemmConfig]:
configs = []
for dtype in [dtype]:
for tA in ["N", "T"]:
for tB in ["N", "T"]:
for m, n, k in UNET:
if tA == "N" or tB == "N":
configs.append(GemmConfig(m, n, k, tA, tB, dtype))
for tA, tB in [("N", "N"), ("N", "T")]:
for m, n, k in UNET:
configs.append(GemmConfig(m, n, k, tA, tB, dtype))
return configs

def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
Expand Down

0 comments on commit 6f5ce46

Please sign in to comment.