diff --git a/attentionbench/attention_bench.py b/attentionbench/attention_bench.py index 5e84db2..0b69d29 100644 --- a/attentionbench/attention_bench.py +++ b/attentionbench/attention_bench.py @@ -31,6 +31,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir): help="Comma seperated csv file list to generate roofline plot with", default=None, ) + parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") parser.add_argument("--plot", help="location to save plot", default=None) parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None) parser.add_argument("--dtype", help="roofline on certain dtype", default=None) @@ -55,14 +56,15 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir): repo_root = Path(__file__).parent.parent kernel_dir = repo_root / "attention" / "mlir" vmfb_dir = repo_root / "attention" / "vmfb" + device = args.device kernel_dir.mkdir(parents=True, exist_ok=True) vmfb_dir.mkdir(parents=True, exist_ok=True) - args = itertools.starmap( + compile_args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs ) with Pool(num_cpus) as pool: - compilation_results = list(tqdm(pool.starmap(compile_attention, list(args)))) + compilation_results = list(tqdm(pool.starmap(compile_attention, list(compile_args)))) error_count = 0 for tag, config, mlir_file, vmfb_file in compilation_results: @@ -93,7 +95,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir): exec_args = [ "iree-benchmark-module", - f"--device=hip", + f"--device={device}", "--device_allocator=caching", f"--module={vmfb_filename}", "--function=main", diff --git a/convbench/shark_conv.py b/convbench/shark_conv.py index 215c7f5..8f797c1 100644 --- a/convbench/shark_conv.py +++ b/convbench/shark_conv.py @@ -26,6 +26,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): type=str.upper, help="Set the logging level", ) + parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") parser.add_argument( "--roofline", help="Comma seperated csv file list to generate roofline plot with", @@ -57,12 +58,13 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): vmfb_dir = repo_root / "conv" / "vmfb" kernel_dir.mkdir(parents=True, exist_ok=True) vmfb_dir.mkdir(parents=True, exist_ok=True) + device = args.device - args = itertools.starmap( + compile_args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs ) with Pool(num_cpus) as pool: - compilation_results = list(tqdm(pool.starmap(compile_conv, list(args)))) + compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args)))) error_count = 0 for tag, config, mlir_file, vmfb_file in compilation_results: @@ -92,7 +94,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): exec_args = [ "iree-benchmark-module", - f"--device=hip", + f"--device={device}", "--device_allocator=caching", f"--module={vmfb_filename}", "--function=main", diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index 9bdf1d9..8778c7c 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -39,6 +39,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, ) parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942") + parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") parser.add_argument( "--Xiree_compile", nargs='+', @@ -77,7 +78,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, ) args = parser.parse_args() - # Handle default values here, since 'append' is not compatible with defaulted lists. + # Handle default values here, since list args are not compatible with defaulted lists. requested_dtypes = ["f16", "bf16"] if not args.dtypes else list(args.dtypes) requested_variants = ["NN", "NT", "TN", "TT"] if not args.variants else list(args.variants) @@ -107,12 +108,13 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, target = args.target extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)] dump_dir = args.dump_dir + device = args.device - args = itertools.starmap( + compile_args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs ) with Pool(num_cpus) as pool: - compilation_results = list(tqdm(pool.starmap(compile_gemm, list(args)))) + compilation_results = list(tqdm(pool.starmap(compile_gemm, list(compile_args)))) error_count = 0 for tag, config, mlir_file, vmfb_file in compilation_results: @@ -145,7 +147,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, exec_args = [ "iree-benchmark-module", - f"--device=hip", + f"--device={device}", "--device_allocator=caching", f"--module={vmfb_filename}", f"--input={inp1}",