Skip to content

Commit

Permalink
Allow for setting device to benchmark on (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar authored Oct 10, 2024
1 parent c5ad991 commit 4fc4145
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
8 changes: 5 additions & 3 deletions attentionbench/attention_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
help="Comma seperated csv file list to generate roofline plot with",
default=None,
)
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
parser.add_argument("--plot", help="location to save plot", default=None)
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
Expand All @@ -55,14 +56,15 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
repo_root = Path(__file__).parent.parent
kernel_dir = repo_root / "attention" / "mlir"
vmfb_dir = repo_root / "attention" / "vmfb"
device = args.device
kernel_dir.mkdir(parents=True, exist_ok=True)
vmfb_dir.mkdir(parents=True, exist_ok=True)

args = itertools.starmap(
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
)
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_attention, list(args))))
compilation_results = list(tqdm(pool.starmap(compile_attention, list(compile_args))))

error_count = 0
for tag, config, mlir_file, vmfb_file in compilation_results:
Expand Down Expand Up @@ -93,7 +95,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):

exec_args = [
"iree-benchmark-module",
f"--device=hip",
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--function=main",
Expand Down
8 changes: 5 additions & 3 deletions convbench/shark_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
type=str.upper,
help="Set the logging level",
)
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
parser.add_argument(
"--roofline",
help="Comma seperated csv file list to generate roofline plot with",
Expand Down Expand Up @@ -57,12 +58,13 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
vmfb_dir = repo_root / "conv" / "vmfb"
kernel_dir.mkdir(parents=True, exist_ok=True)
vmfb_dir.mkdir(parents=True, exist_ok=True)
device = args.device

args = itertools.starmap(
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
)
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_conv, list(args))))
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))

error_count = 0
for tag, config, mlir_file, vmfb_file in compilation_results:
Expand Down Expand Up @@ -92,7 +94,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):

exec_args = [
"iree-benchmark-module",
f"--device=hip",
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--function=main",
Expand Down
10 changes: 6 additions & 4 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
)

parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942")
parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip")
parser.add_argument(
"--Xiree_compile",
nargs='+',
Expand Down Expand Up @@ -77,7 +78,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
)

args = parser.parse_args()
# Handle default values here, since 'append' is not compatible with defaulted lists.
# Handle default values here, since list args are not compatible with defaulted lists.
requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes)
requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants)

Expand Down Expand Up @@ -107,12 +108,13 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
target = args.target
extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
dump_dir = args.dump_dir
device = args.device

args = itertools.starmap(
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs
)
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_gemm, list(args))))
compilation_results = list(tqdm(pool.starmap(compile_gemm, list(compile_args))))

error_count = 0
for tag, config, mlir_file, vmfb_file in compilation_results:
Expand Down Expand Up @@ -145,7 +147,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,

exec_args = [
"iree-benchmark-module",
f"--device=hip",
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
f"--input={inp1}",
Expand Down

0 comments on commit 4fc4145

Please sign in to comment.