diff --git a/convbench/conv_utils.py b/convbench/conv_utils.py index c8c8ad4..267ecf4 100644 --- a/convbench/conv_utils.py +++ b/convbench/conv_utils.py @@ -38,25 +38,71 @@ class ConvConfig: output_dtype: str def get_name(self) -> str: - return self.OP + "_" + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + "_" + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + "_stride" + str(self.S) - + return ( + self.OP + + "_" + + f"{self.N}x{self.H}x{self.W}x{self.C}x{self.P}x{self.Q}x{self.F}" + + "_" + + f"{self.input_dtype}x{self.input_dtype}x{self.output_dtype}" + + "_stride" + + str(self.S) + ) + def get_img_shape(self) -> str: if "nhwc" in self.OP: in_h = self.H * self.S + self.P - 1 in_w = self.W * self.S + self.Q - 1 - return str(self.N) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(self.C) + "x" + self.input_dtype + return ( + str(self.N) + + "x" + + str(in_h) + + "x" + + str(in_w) + + "x" + + str(self.C) + + "x" + + self.input_dtype + ) if "nchw" in self.OP: in_h = self.H * self.S + self.P - 1 in_w = self.W * self.S + self.Q - 1 - return str(self.N) + "x" + str(self.C) + "x" + str(in_h) + "x" + str(in_w) + "x" + self.input_dtype - - + return ( + str(self.N) + + "x" + + str(self.C) + + "x" + + str(in_h) + + "x" + + str(in_w) + + "x" + + self.input_dtype + ) + def get_kernel_shape(self) -> str: if "nhwc" in self.OP: - return str(self.P) + "x" + str(self.Q) + "x" + str(self.C) + "x" + str(self.F) + "x" + self.input_dtype + return ( + str(self.P) + + "x" + + str(self.Q) + + "x" + + str(self.C) + + "x" + + str(self.F) + + "x" + + self.input_dtype + ) if "nchw" in self.OP: - return str(self.F) + "x" + str(self.C) + "x" + str(self.P) + "x" + str(self.Q) + "x" + self.input_dtype - + return ( + str(self.F) + + "x" + + str(self.C) + + "x" + + str(self.P) + + "x" + + str(self.Q) + + "x" + + self.input_dtype + ) def get_byte_count(self) -> int: dtype_bits_map = { @@ -80,7 +126,13 @@ def get_byte_count(self) -> int: k_height = self.P byte_count = ( (batch * input_channels * in_w * in_h * bytes_per_input) - + (batch * output_channels * output_width * output_height * bytes_per_output) + + ( + batch + * output_channels + * output_width + * output_height + * bytes_per_output + ) + (k_width * k_height * input_channels * output_channels * bytes_per_input) ) return byte_count @@ -100,6 +152,7 @@ def get_flops(self) -> int: flops = operation_per_pixel * output_pixels_per_batch * batch return flops + def generate_mlir(config: ConvConfig): n = config.N h = config.H @@ -116,17 +169,77 @@ def generate_mlir(config: ConvConfig): in_w = str(int(w) * int(stride) + int(q) - 1) if "nhwc" in operation: conv_type = "nhwc_hwcf" - lhs = str(n) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(c) + "x" + str(elem_types[0]) - rhs = str(p) + "x" + str(q) + "x" + str(c) + "x" + str(f) + "x" + str(elem_types[1]) - out = str(n) + "x" + str(h) + "x" + str(w) + "x" + str(f) + "x" + str(elem_types[2]) + lhs = ( + str(n) + + "x" + + str(in_h) + + "x" + + str(in_w) + + "x" + + str(c) + + "x" + + str(elem_types[0]) + ) + rhs = ( + str(p) + + "x" + + str(q) + + "x" + + str(c) + + "x" + + str(f) + + "x" + + str(elem_types[1]) + ) + out = ( + str(n) + + "x" + + str(h) + + "x" + + str(w) + + "x" + + str(f) + + "x" + + str(elem_types[2]) + ) if "nchw" in operation: conv_type = "nchw_fchw" - lhs = str(n) + "x" + str(c) + "x" + str(in_h) + "x" + str(in_w) + "x" + str(elem_types[0]) - rhs = str(f) + "x" + str(c) + "x" + str(p) + "x" + str(q) + "x" + str(elem_types[1]) - out = str(n) + "x" + str(f) + "x" + str(h) + "x" + str(w) + "x" + str(elem_types[2]) + lhs = ( + str(n) + + "x" + + str(c) + + "x" + + str(in_h) + + "x" + + str(in_w) + + "x" + + str(elem_types[0]) + ) + rhs = ( + str(f) + + "x" + + str(c) + + "x" + + str(p) + + "x" + + str(q) + + "x" + + str(elem_types[1]) + ) + out = ( + str(n) + + "x" + + str(f) + + "x" + + str(h) + + "x" + + str(w) + + "x" + + str(elem_types[2]) + ) one = "1" zero = "0" - if (elem_types[0][0] == "f"): + if elem_types[0][0] == "f": one = "1.0" zero = "0.0" conv_template = CONV @@ -162,7 +275,7 @@ def generate_mlir(config: ConvConfig): def compile_conv_config( - config: ConvConfig, kernel_dir: Path, vmfb_dir: Path + config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, device: str, target: str ) -> tuple[Path, Optional[Path]]: mlir_file = kernel_dir / (config.get_name() + ".mlir") vmfb_file = vmfb_dir / (config.get_name() + ".vmfb") @@ -185,9 +298,9 @@ def compile_conv_config( "-o", f"{vmfb_file}", # Target Device: hip - "--iree-hal-target-device=hip", + f"--iree-hal-target-device={device}", # Device: MI300x - "--iree-hip-target=gfx942", + f"--iree-hip-target={target}", ] print(" ".join(exec_args)) diff --git a/convbench/shark_conv.py b/convbench/shark_conv.py index 8f797c1..9522af4 100644 --- a/convbench/shark_conv.py +++ b/convbench/shark_conv.py @@ -12,8 +12,10 @@ from problems import get_conv_configs -def compile_conv(tag, config, kernel_dir, vmfb_dir): - mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir) +def compile_conv(tag, config, kernel_dir, vmfb_dir, device, target): + mlir_file, vmfb_file = compile_conv_config( + config, kernel_dir, vmfb_dir, device, target + ) return (tag, config, mlir_file, vmfb_file) @@ -26,14 +28,27 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): type=str.upper, help="Set the logging level", ) - parser.add_argument("--device", help="The IREE device to execute benchmarks on", type=str, default="hip") + parser.add_argument( + "--device", + help="The IREE device to execute benchmarks on", + type=str, + default="hip", + ) + parser.add_argument( + "--target", + help="The device's target to execute benchmarks on", + type=str, + default="gfx942", + ) parser.add_argument( "--roofline", help="Comma seperated csv file list to generate roofline plot with", default=None, ) parser.add_argument("--plot", help="location to save plot", default=None) - parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None) + parser.add_argument( + "--batch", help="roofline on certain batch", type=int, default=None + ) parser.add_argument("--dtype", help="roofline on certain dtype", default=None) parser.add_argument("--model", help="roofline on certain model", default=None) @@ -61,7 +76,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): device = args.device compile_args = itertools.starmap( - lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs + lambda tag, config: (tag, config, kernel_dir, vmfb_dir, device, args.target), + configs, ) with Pool(num_cpus) as pool: compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args)))) @@ -130,7 +146,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir): config.S, config.input_dtype, config.output_dtype, - round(benchmark_gemm_mean_time_us, 4), + round(benchmark_gemm_mean_time_us, 4), + round(arithmetic_intensity, 4), round(tflops_per_second, 4), ok,