Skip to content

Commit

Permalink
Add flags to filter gemm configs (#3)
Browse files Browse the repository at this point in the history
We can filter based on 3 criteria:
1. Data types: f16/bf16
2. Matmul variants: NN, NT, TN, and TT. (NT -- transpose B)
3. Tag regex, e.g., `"unet"` or `".*skinny"`.
  • Loading branch information
kuhar authored Oct 2, 2024
1 parent ad06e6b commit 4b1345b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 30 deletions.
37 changes: 26 additions & 11 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
from utils import *
from gemm_utils import *
from problems import get_gemm_configs, get_tk_gemm_configs
from problems import get_gemm_configs, get_tk_gemm_configs, get_matching_configs


def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk):
Expand All @@ -40,33 +40,48 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
default=[],
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
)
parser.add_argument(
"--dtypes", action='append', help="List of data types to benchmark. Defaults to all supported types."
)
parser.add_argument(
"--variants",
action='append',
help="List of matmul variants to benchmark. Default to all variants: NN, NT, TN, and TT."
)
parser.add_argument(
"--tag-regex",
help="Regular expression for allowed benchmark tags. Defaults to all tags allowed.",
default=".*"
)
parser.add_argument("--roofline", help="Comma separated csv file list to generate roofline plot with", default=None)
parser.add_argument("--plot", help="location to save plot", default=None)
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
parser.add_argument("--model", help="roofline on certain model", default=None)
parser.add_argument(
"--tk",
action="store_true",
default=False,
help="Option to run gemm kernels using Turbine Kernels",
help="Run gemm kernels using Turbine Kernels",
)

args = parser.parse_args()
# Handle default values here, since 'append' is not compatible with defaulted lists.
requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes)
requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants)

logging.basicConfig(level=args.log_level)

if args.roofline:
roofline(args.roofline, args.plot, args.batch, args.dtype, args.model)
for dtype in requested_dtypes:
roofline(args.roofline, f"{args.plot}_{dtype}", args.batch, dtype, args.model)
sys.exit()

tk = args.tk
if tk:
configs = get_tk_gemm_configs()
else:
configs = get_gemm_configs()
configs = get_tk_gemm_configs() if tk else get_gemm_configs()
configs = get_matching_configs(configs, requested_dtypes, requested_variants, args.tag_regex)
print(f"Generated {len(configs)} gemm configs.")

num_cpus = max(1, cpu_count() - 20)
num_cpus = max(1, max(cpu_count() // 2, 1))
print(f"Using {num_cpus} CPUs for parallel processing.")

manager = Manager()
Expand Down Expand Up @@ -125,7 +140,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
]

if tk:
exec_args += ["--function=isolated_benchmark"]
exec_args += ["--function=isolated_benchmark"]
else:
exec_args += ["--function=main"]

Expand Down
67 changes: 48 additions & 19 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from gemm_utils import GemmConfig

import re


def is_compute_bound(M, N, K, bpe):
"""Is this GEMM compute (or memory) bound?"""
magic_ratio = 64
Expand Down Expand Up @@ -860,34 +863,46 @@ def unet(dtype: str) -> list[GemmConfig]:
return configs

def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
configs: list[tuple[str, GemmConfig]] = []
llama13bmatvec_configs = llama13bmatvec("f16")
llama13bmatvec_configs: list[GemmConfig] = []
llama13bmatvec_configs += llama13bmatvec("f16")
llama13bmatvec_configs += llama13bmatvecbf16("bf16")
llama70bmatvec_configs = llama70bmatvec("f16")

llama70bmatvec_configs: list[GemmConfig] = []
llama70bmatvec_configs += llama70bmatvec("f16")
llama70bmatvec_configs += llama70bmatvecbf16("bf16")
llama13bskinny_configs = llama13bskinny("f16")

llama13bskinny_configs: list[GemmConfig] = []
llama13bskinny_configs += llama13bskinny("f16")
llama13bskinny_configs += llama13bskinnybf16("bf16")
llama70bskinny_configs = llama70bskinny("f16")

llama70bskinny_configs: list[GemmConfig] = []
llama70bskinny_configs += llama70bskinny("f16")
llama70bskinny_configs += llama70bskinnybf16("bf16")

gpt4compute_configs = gpt4compute("f16")
llama70bmemory_configs = llama70bmemory("bf16")
tk_default_configs = tk_default("f16")
compute_configs = compute("f16")

compute_configs: list[GemmConfig] = []
compute_configs += compute("f16")
compute_configs += compute("bf16")
unet_configs = unet("f16")

unet_configs: list[GemmConfig] = []
unet_configs += unet("f16")
unet_configs += unet("bf16")

configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
configs += [("llama13bskinny", x) for x in llama13bskinny_configs]
configs += [("llama70bskinny", x) for x in llama70bskinny_configs]
configs += [("gpt4compute", x) for x in gpt4compute_configs]
configs += [("llama70bmemory", x) for x in llama70bmemory_configs]
configs += [("compute", x) for x in compute_configs]
configs += [("unet", x) for x in unet_configs]
configs += [("tk", x) for x in tk_default_configs]

return configs
all_configs: list[tuple[str, GemmConfig]] = []
all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
all_configs += [("llama13bskinny", x) for x in llama13bskinny_configs]
all_configs += [("llama70bskinny", x) for x in llama70bskinny_configs]
all_configs += [("gpt4compute", x) for x in gpt4compute_configs]
all_configs += [("llama70bmemory", x) for x in llama70bmemory_configs]
all_configs += [("compute", x) for x in compute_configs]
all_configs += [("unet", x) for x in unet_configs]
all_configs += [("tk", x) for x in tk_default_configs]

return all_configs

def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]:
configs: list[tuple[str, GemmConfig]] = []
Expand All @@ -896,5 +911,19 @@ def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]:

configs += [("tk", x) for x in tk_default_configs]
configs += [("unet", x) for x in tk_unet_configs]

return configs

def get_matching_configs(tagged_configs: list[tuple[str, GemmConfig]],
dtypes: list[str], variants: list[str], tag_regex: str) -> list[tuple[str, GemmConfig]]:
tag_re = re.compile(tag_regex)
matching_configs: list[tuple[str, GemmConfig]] = []
for tag, config in tagged_configs:
if config.dtype not in dtypes:
continue
if f"{config.tA}{config.tB}" not in variants:
continue
if not tag_re.match(tag):
continue
matching_configs.append((tag, config))

return matching_configs

0 comments on commit 4b1345b

Please sign in to comment.