Skip to content

Commit

Permalink
[CONV] Pass the device and target info
Browse files Browse the repository at this point in the history
The device and target info was hardcoded and now can be passed
viacommand line to the compile command. Also, black formatted the files.

Signed-off-by: Prashant Kumar <[email protected]>
  • Loading branch information
Prashant Kumar authored and pashu123 committed Oct 17, 2024
1 parent 558d9e5 commit 3181239
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
6 changes: 3 additions & 3 deletions convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def generate_mlir(config: ConvConfig):


def compile_conv_config(
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path
config: ConvConfig, kernel_dir: Path, vmfb_dir: Path, device: str, target: str
) -> tuple[Path, Optional[Path]]:
mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
Expand All @@ -298,9 +298,9 @@ def compile_conv_config(
"-o",
f"{vmfb_file}",
# Target Device: hip
"--iree-hal-target-device=hip",
f"--iree-hal-target-device={device}",
# Device: MI300x
"--iree-hip-target=gfx942",
f"--iree-hip-target={target}",
]

print(" ".join(exec_args))
Expand Down
18 changes: 14 additions & 4 deletions convbench/shark_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from problems import get_conv_configs


def compile_conv(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_conv_config(config, kernel_dir, vmfb_dir)
def compile_conv(tag, config, kernel_dir, vmfb_dir, device, target):
mlir_file, vmfb_file = compile_conv_config(
config, kernel_dir, vmfb_dir, device, target
)
return (tag, config, mlir_file, vmfb_file)


Expand All @@ -32,6 +34,12 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
type=str,
default="hip",
)
parser.add_argument(
"--target",
help="The device's target to execute benchmarks on",
type=str,
default="gfx942",
)
parser.add_argument(
"--roofline",
help="Comma seperated csv file list to generate roofline plot with",
Expand Down Expand Up @@ -68,7 +76,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
device = args.device

compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, device, args.target),
configs,
)
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_conv, list(compile_args))))
Expand Down Expand Up @@ -137,7 +146,8 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
config.S,
config.input_dtype,
config.output_dtype,
round(benchmark_gemm_mean_time_us, 4),
round(benchmark_gemm_mean_time_us, 4),

round(arithmetic_intensity, 4),
round(tflops_per_second, 4),
ok,
Expand Down

0 comments on commit 3181239

Please sign in to comment.