diff --git a/convbench/conv_utils.py b/convbench/conv_utils.py index c8c8ad4..d1493d5 100644 --- a/convbench/conv_utils.py +++ b/convbench/conv_utils.py @@ -162,7 +162,8 @@ def generate_mlir(config: ConvConfig): def compile_conv_config( - config: ConvConfig, kernel_dir: Path, vmfb_dir: Path + config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, + device: str, target: str ) -> tuple[Path, Optional[Path]]: mlir_file = kernel_dir / (config.get_name() + ".mlir") vmfb_file = vmfb_dir / (config.get_name() + ".vmfb") @@ -185,9 +186,9 @@ def compile_conv_config( "-o", f"{vmfb_file}", # Target Device: hip - "--iree-hal-target-device=hip", + f"--iree-hal-target-device={device}", # Device: MI300x - "--iree-hip-target=gfx942", + f"--iree-hip-target={target}", ] print(" ".join(exec_args)) diff --git a/convbench/shark_conv.py b/convbench/shark_conv.py index 8f797c1..5ad1064 100644 --- a/convbench/shark_conv.py +++ b/convbench/shark_conv.py @@ -12,8 +12,8 @@ from problems import get_conv_configs -def compile_conv(tag, config, kernel_dir, vmfb_dir): - mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir) +def compile_conv(tag, config, kernel_dir, vmfb_dir, device, target): + mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir, device, target) return (tag, config, mlir_file, vmfb_file) @@ -27,6 +27,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): help="Set the logging level", ) parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") + parser.add_argument("--target", help="The device's target to execute benchmarks on", type=str, default="gfx942") parser.add_argument( "--roofline", help="Comma seperated csv file list to generate roofline plot with", @@ -61,7 +62,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): device = args.device compile_args = itertools.starmap( - lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs + lambda tag, config: (tag, config, kernel_dir, vmfb_dir, device, args.target), configs ) with Pool(num_cpus) as pool: compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))