Skip to content

Commit

Permalink
Update scripts to allow extra flags. (#10)
Browse files Browse the repository at this point in the history
Couple of changes in this PR
1) `--Xiree_compile` can take now a list of flags to pass to
`iree-compile` without having to specify `--Xiree-compile` multiple
times. Note that the flags passed to `iree-compile` should not contain
`--`
2) Capture the `stderr` for `run_iree_command` and pass back to caller.
This allows generating IR dumps.

For example
`--Xiree_compile mlir-disable-threading mlir-print-ir-before=<pass-list>
mlir-print-ir-after=<pass-list>` dumps IR before and after passes to
`*.stderr.mlir` .

---------

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Oct 8, 2024
1 parent 5010a9a commit 11fd8c4
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion attentionbench/attention_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
]

# iree benchmark kernels
ret_value, cmd_out = run_iree_command(exec_args)
ret_value, cmd_out, cmd_err = run_iree_command(exec_args)
ok = ret_value == 0
benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000
Expand Down
6 changes: 5 additions & 1 deletion attentionbench/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def compile_attention_config(
) -> tuple[Path, Optional[Path]]:
mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
dump_file = kernel_dir / (config.get_name() + ".stderr.mlir")

# TODO: Use different tuning specs for different configs. This is just a
# general tuning config that worked well for sdxl shapes.
Expand Down Expand Up @@ -181,9 +182,12 @@ def compile_attention_config(

print(" ".join(exec_args))

ret_value, stderr = run_iree_command(exec_args)
ret_value, stdout, stderr = run_iree_command(exec_args)
if ret_value == 0:
print(f"Successfully compiled {mlir_file} to {vmfb_file}")
if stderr:
with open(dump_file, "w") as f:
f.write(stderr.decode("utf-8"))
else:
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {mlir_file}. Error dumped in {error_file}")
Expand Down
4 changes: 2 additions & 2 deletions common_tools/utils/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def run_iree_command(args: Sequence[str] = ()):
)
return_code = proc.returncode
if return_code == 0:
return 0, proc.stdout
return 0, proc.stdout, proc.stderr
logging.getLogger().error(
f"Command failed!\n"
f"Stderr diagnostics:\n{proc.stderr}\n"
f"Stdout diagnostics:\n{proc.stdout}\n"
)
return 1, proc.stderr
return 1, proc.stdout, proc.stderr

def decode_output(bench_lines):
benchmark_results = []
Expand Down
6 changes: 5 additions & 1 deletion convbench/conv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def compile_conv_config(
) -> tuple[Path, Optional[Path]]:
mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
dump_file = kernel_dir / (config.get_name() + ".stderr.mlir")

# Generate mlir content
mlir_content = generate_mlir(config)
Expand All @@ -191,9 +192,12 @@ def compile_conv_config(

print(" ".join(exec_args))

ret_value, stderr = run_iree_command(exec_args)
ret_value, stdout, stderr = run_iree_command(exec_args)
if ret_value == 0:
print(f"Successfully compiled {mlir_file} to {vmfb_file}")
if stderr:
with open(dump_file, "w") as f:
f.write(stderr.decode("utf-8"))
else:
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {mlir_file}. Error dumped in {error_file}")
Expand Down
2 changes: 1 addition & 1 deletion convbench/shark_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def compile_conv(tag, config, kernel_dir, vmfb_dir):
]

# iree benchmark kernels
ret_value, cmd_out = run_iree_command(exec_args)
ret_value, cmd_out, cmd_stderr = run_iree_command(exec_args)
ok = ret_value == 0
benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000
Expand Down
10 changes: 5 additions & 5 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
parser.add_argument("--target", help="The IREE hip target to compile for", type=str, default="gfx942")
parser.add_argument(
"--Xiree_compile",
action='append',
nargs='+',
default=[],
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
help="Extra command line arguments passed to the IREE compiler. The flags need to be specified without the `--` or `-`"
)
parser.add_argument(
"--dtypes", action='append', help="List of data types to benchmark. Defaults to all supported types."
Expand Down Expand Up @@ -85,7 +85,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,

if args.roofline:
for dtype in requested_dtypes:
roofline(args.roofline, f"{args.plot}_{dtype}", args.batch, dtype, args.model)
roofline(args.roofline, f"{args.plot.split('.')[0]}_{dtype}.png", args.batch, dtype, args.model)
sys.exit()

tk = args.tk
Expand All @@ -105,7 +105,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
kernel_dir.mkdir(parents=True, exist_ok=True)
vmfb_dir.mkdir(parents=True, exist_ok=True)
target = args.target
extra_compiler_args = list(args.Xiree_compile)
extra_compiler_args = ['--' + x for x in list(args.Xiree_compile)]
dump_dir = args.dump_dir

args = itertools.starmap(
Expand Down Expand Up @@ -159,7 +159,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
exec_args += ["--function=main"]

# iree benchmark kernels
ret_value, cmd_out = run_iree_command(exec_args)
ret_value, cmd_out, cmd_err = run_iree_command(exec_args)
ok = ret_value == 0
benchmark_gemm_mean_time_ms = bench_summary_process(ret_value, cmd_out)
benchmark_gemm_mean_time_us = benchmark_gemm_mean_time_ms * 1000
Expand Down
6 changes: 5 additions & 1 deletion gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def compile_gemm_config(
) -> tuple[Path, Optional[Path]]:
mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
dump_file = kernel_dir / (config.get_name() + ".stderr.mlir")

if not os.path.exists(vmfb_dir):
os.makedirs(vmfb_dir)
Expand Down Expand Up @@ -235,9 +236,12 @@ def compile_gemm_config(

print(" ".join(exec_args))

ret_value, stderr = run_iree_command(exec_args)
ret_value, stdout, stderr = run_iree_command(exec_args)
if ret_value == 0:
print(f"Successfully compiled {mlir_file} to {vmfb_file}")
if stderr:
with open(dump_file, "w") as f:
f.write(stderr.decode("utf-8"))
else:
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {mlir_file}. Error dumped in {error_file}")
Expand Down

0 comments on commit 11fd8c4

Please sign in to comment.