From 661318013db0f90a5fd4cca21a6faffa147b6f88 Mon Sep 17 00:00:00 2001 From: Ian Date: Mon, 20 Jan 2025 19:06:14 -0600 Subject: [PATCH 1/9] Adds new unet regression test - Adds new unet for sdxl image shape of 960x1024 - Adds new tuning spec for new unet shapes - Adds mi308 to tests Signed-off-by: Ian Signed-off-by: ianNod --- .github/workflows/pkgci_regression_test.yml | 8 + ...ttention_and_matmul_spec_punet_mi300.mlir} | 0 ...ntion_and_matmul_spec_unet_fp16_mi308.mlir | 255 ++++++++++++++++++ .../benchmarks/sdxl/benchmark_sdxl_rocm.py | 21 ++ .../shark-test-suite-models/sdxl/test_unet.py | 130 ++++++++- 5 files changed, 413 insertions(+), 1 deletion(-) rename build_tools/pkgci/external_test_suite/{attention_and_matmul_spec_punet.mlir => attention_and_matmul_spec_punet_mi300.mlir} (100%) create mode 100644 build_tools/pkgci/external_test_suite/attention_and_matmul_spec_unet_fp16_mi308.mlir diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 448c5565d1ff..fd2b21f7f679 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -42,11 +42,18 @@ jobs: - name: amdgpu_rocm_mi250_gfx90a rocm-chip: gfx90a backend: rocm + sku: mi250 runs-on: nodai-amdgpu-mi250-x86-64 - name: amdgpu_rocm_mi300_gfx942 rocm-chip: gfx942 backend: rocm + sku: mi300 runs-on: nodai-amdgpu-mi300-x86-64 + - name: amdgpu_rocm_mi308_gfx942 + rocm-chip: gfx942 + backend: rocm + sku: mi308 + runs-on: nodai-amdgpu-mi308-x86-64 env: PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages IREE_TEST_PATH_EXTENSION: ${{ github.workspace }}/build_tools/pkgci/external_test_suite @@ -95,6 +102,7 @@ jobs: --durations=0 env: ROCM_CHIP: ${{ matrix.rocm-chip }} + SKU: ${{ matrix.sku }} - name: "Running SD3 special model tests" if: "!cancelled()" diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet_mi300.mlir similarity index 100% rename from build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir rename to build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet_mi300.mlir diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_unet_fp16_mi308.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_unet_fp16_mi308.mlir new file mode 100644 index 000000000000..291778b9c3e1 --- /dev/null +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec_unet_fp16_mi308.mlir @@ -0,0 +1,255 @@ +module attributes { transform.with_named_sequence } { +//===----------------------------------------------------------------------===// +// Tuning infra +//===----------------------------------------------------------------------===// + +transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, + %config: !transform.any_param {transform.readonly}) { + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + // transform.print %op {name = "Applied"} : !transform.any_op + transform.yield +} + +transform.named_sequence @apply_attn_op_config(%attention: !transform.any_op {transform.readonly}, + %config: !transform.any_param {transform.readonly}, + %decomposition_config: !transform.any_param {transform.readonly}) { + transform.annotate %attention "compilation_info" = %config : !transform.any_op, !transform.any_param + transform.annotate %attention "decomposition_config" = %decomposition_config : !transform.any_op, !transform.any_param + // transform.print %attention {name = "Applied attention config"} : !transform.any_op + transform.yield +} + +transform.named_sequence @match_attention_f16(%attention: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param, !transform.any_param) { + transform.match.operation_name %attention ["iree_linalg_ext.attention"] : !transform.any_op + %in0 = transform.get_operand %attention[0] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value + + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>, + translation_info = #iree_codegen.translation_info> + -> !transform.any_param + + %decomposition_config = transform.param.constant { + qk_attrs = {attention_qk_matmul, + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.virtual_mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1] }>}, + pv_attrs = {attention_pv_matmul, + lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 1, promote_operands = [1] }>} + } -> !transform.any_param + + transform.yield %attention, %config, %decomposition_config : !transform.any_op, !transform.any_param, !transform.any_param +} + +transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %18 = arith.extf %in : f16 to f32 + %19 = arith.extf %in_0 : f16 to f32 + %20 = arith.mulf %18, %19 : f32 + %21 = arith.addf %acc, %20 : f32 + linalg.yield %21 : f32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op +} + +// TUNING_SPEC_BEGIN DO NOT REMOVE + +//===----------------------------------------------------------------------===// +// Matmul tuning +//===----------------------------------------------------------------------===// + +transform.named_sequence @match_mmt_1920x10240x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<1920x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 2, + reduction = [0, 0, 32], + workgroup = [128, 128, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_1920x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<1920x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 2, + reduction = [0, 0, 32], + workgroup = [128, 128, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_1920x1280x5120(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<1920x5120xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 2, + reduction = [0, 0, 32], + workgroup = [128, 128, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_7680x5120x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<7680x640xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 4, + reduction = [0, 0, 32], + workgroup = [128, 256, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_128x1280x2048(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<1280x2048xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x2048xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 2, subgroup_n_count = 1, + reduction = [0, 0, 128], + workgroup = [64, 16, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_7680x640x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<7680x640xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x640xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 1, subgroup_n_count = 4, + reduction = [0, 0, 32], + workgroup = [256, 128, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +transform.named_sequence @match_mmt_7680x640x2560(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<7680x2560xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], + mma_kind = #iree_gpu.mma_layout, + subgroup_m_count = 4, subgroup_n_count = 2, + reduction = [0, 0, 32], + workgroup = [256, 128, 0]}>, + translation_info = #iree_codegen.translation_info, + llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"} + }>> -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param +} + +//===----------------------------------------------------------------------===// +// Convolution tuning +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Batch matmul tuning +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Broadcast rhs mmt tuning +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Contraction tuning +//===----------------------------------------------------------------------===// + +// TUNING_SPEC_END DO NOT REMOVE + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + + transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.consumed}) { + transform.foreach_match in %variant_op + @match_attention_f16 -> @apply_attn_op_config + + // TUNING_MATCH_BEGIN DO NOT REMOVE + + // MMT. + , @match_mmt_1920x10240x1280 -> @apply_op_config + , @match_mmt_1920x1280x1280 -> @apply_op_config + , @match_mmt_1920x1280x5120 -> @apply_op_config + , @match_mmt_7680x5120x640 -> @apply_op_config + , @match_mmt_128x1280x2048 -> @apply_op_config + , @match_mmt_7680x640x640 -> @apply_op_config + , @match_mmt_7680x640x2560 -> @apply_op_config + + // TUNING_MATCH_END DO NOT REMOVE + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} //// module diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index f8d7a1f01a01..aaf82e0f8c2f 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -25,6 +25,7 @@ vae_decode_dir = f"{artifacts_dir}/sdxl_vae" prompt_encoder_dir_compile = f"{vmfb_dir}/sdxl_clip_vmfbs" scheduled_unet_dir_compile = f"{vmfb_dir}/sdxl_unet_fp16_vmfbs" +unet_fp16_960_1024 = f"{vmfb_dir}/sdxl_unet_fp16_960_1024_vmfbs" punet_int8_fp16_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp16_vmfbs" punet_int8_fp8_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp8_vmfbs" vae_decode_dir_compile = f"{vmfb_dir}/sdxl_vae_vmfbs" @@ -161,6 +162,26 @@ def run_sdxl_punet_int8_fp8_rocm_benchmark(rocm_chip): return run_iree_command(exec_args) +def run_sdxl_unet_fp16_960_1024_rocm_benchmark(rocm_chip): + exec_args = [ + "iree-benchmark-module", + f"--device=hip", + "--device_allocator=caching", + f"--module={unet_fp16_960_1024_dir_compile}/unet_fp16_960_1024.rocm_{rocm_chip}.vmfb", + f"--parameters=model={scheduled_unet_dir}/real_weights.irpa", + "--function=run_forward", + f"--input=1x4x120x128xf16", + f"--input=1xf16", + f"--input=2x64x2048xf16", + f"--input=2x1280xf16", + f"--input=2x6xf16", + f"--input=1xf16", + "--benchmark_repetitions=10", + "--benchmark_min_warmup_time=3.0", + ] + # iree benchmark command for full sdxl pipeline + return run_iree_command(exec_args) + def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 381db9e69e2a..b26c667c932c 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -61,6 +61,47 @@ group="sdxl_unet_fp16", ) +# FP16 Model for 960x1024 image size + +sdxl_unet_fp16_960_1024_inference_input_0 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg0_latent_model_input.npy" + group="sdxl_unet_fp16_960_1024", +) + +sdxl_unet_fp16_960_1024_inference_input_1 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg1_guidanc_scale.npy" + group="sdxl_unet_fp16_960_1024", +) + +sdxl_unet_fp16_960_1024_inference_input_2 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg2_prompt_embeds.npy" + group="sdxl_unet_fp16_960_1024", +) + +sdxl_unet_fp16_960_1024_inference_input_3 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg3_add_text_embeds.npy" + group="sdxl_unet_fp16_960_1024", +) + +sdxl_unet_fp16_960_1024_inference_input_4 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg4_add_time_ids.npy" + group="sdxl_unet_fp16_960_1024", +) + +sdxl_unet_fp16_960_1024_inference_input_5 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg5_t.npy" + group="sdxl_unet_fp16_960_1024", + +sdxl_unet_fp16_960_1024_inference_output_0 = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/inference_output.0.bin" + group="sdxl_unet_fp16_960_1024", +) + +sdxl_unet_fp16_960_1024_mlir = fetch_source_fixture( + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/sdxl_960x1024/stable_diffusion_xl_base_1_0_bs1_64_960x1024_fp16_unet.mlir", + group="sdxl_unet_fp16_960_1024", +) + # INT8 Punet + FP16 Attention sdxl_punet_int8_inference_input_0 = fetch_source_fixture( @@ -154,6 +195,27 @@ def SDXL_UNET_FP16_COMMON_RUN_FLAGS( ] +@pytest.fixture +def SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS( + sdxl_unet_fp16_inference_input_0, + sdxl_unet_fp16_inference_input_1, + sdxl_unet_fp16_inference_input_2, + sdxl_unet_fp16_inference_input_3, + sdxl_unet_fp16_inference_input_4, + sdxl_unet_fp16_inference_input_5, + sdxl_unet_fp16_inference_output_0, +): + return [ + f"--input=@{sdxl_unet_fp16_960_1024_inference_input_0.path}", + f"--input=@{sdxl_unet_fp16_960_1024_inference_input_1.path}", + f"--input=@{sdxl_unet_fp16_960_1024_inference_input_2.path}", + f"--input=@{sdxl_unet_fp16_960_1024_inference_input_3.path}", + f"--input=@{sdxl_unet_fp16_960_1024_inference_input_4.path}", + f"--input=@{sdxl_unet_fp16_960_1024_inference_input_5.path}", + f"--expected_output=@{sdxl_unet_fp16_960_1024_inference_output_0.path}", + ] + + @pytest.fixture def SDXL_PUNET_INT8_COMMON_RUN_FLAGS( sdxl_punet_int8_inference_input_0, @@ -216,11 +278,15 @@ def SDXL_PUNET_INT8_FP8_OUT( FP16_UNET_FLAGS = [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", ] +if os.path.isfile(f"{iree_test_path_extension}/attention_and_matmul_spec_unet_fp16_{sku}.mlir"): + FP16_UNET_FLAGS.append(f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_unet_fp16_{sku}.mlir") INT8_PUNET_FLAGS = [ - f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet.mlir", "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", ] +if os.path.isfile(f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir"): + INT8_PUNET_FLAGS.append(f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir") + ROCM_UNET_PIPELINE_FP16_COMPILE_FLAGS = [ "--iree-hal-target-backends=rocm", @@ -253,6 +319,14 @@ def test_compile_unet_fp16_cpu(sdxl_unet_fp16_mlir): / Path(sdxl_unet_fp16_mlir.path.name).with_suffix(f".cpu.vmfb"), ) +def test_compile_unet_fp16_960_1024_cpu(sdxl_unet_fp16_960_1024_mlir): + VmfbManager.sdxl_unet_fp16_960_1024_vfmb = iree_compile( + sdxl_unet_fp16_960_1024_mlir, + CPU_COMIPLE_FLAGS, + Path(vmfb_dir) + / Path("sdxl_unet_fp16_960_1024_vmfbs") + / Path(sdxl_unet_fp16_960_1024_mlir.path.name).with_suffix(f".cpu.vmfb"), + ) @pytest.mark.depends( on=["test_compile_unet_fp16_pipeline_cpu", "test_compile_unet_fp16_cpu"] @@ -273,11 +347,38 @@ def test_run_unet_fp16_cpu( ) +@pytest.mark.depends( + on=["test_compile_unet_fp16_cpu"] +): + return iree_run_module( + VmfbManager.sdxl_unet_fp16_960_1024_vfmb, + device="local-task", + function="run_forward", + args=[ + f"--parameters=model={sdxl_unet_fp16_real_weights.path}", + f"--module={VmfbManager.sdxl_unet_fp16_960_1024_vfmb}, + --expected_f16_threshold=0.8f", + ] + + SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, + ) + + ############################################################################### # ROCM ############################################################################### +def test_compile_unet_fp16_pipeline_rocm(sdxl_unet_fp16_pipeline_mlir): + VmfbManager.sdxl_unet_fp16_rocm_pipeline_vmfb = iree_compile( + sdxl_unet_fp16_pipeline_mlir, + ROCM_UNET_PIPELINE_FP16_COMPILE_FLAGS, + Path(vmfb_dir) + / Path("sdxl_unet_fp16_vmfbs") + / Path(sdxl_unet_fp16_pipeline_mlir.path.name).with_suffix( + f".rocm_{rocm_chip}.vmfb" + ), + ) + def test_compile_unet_fp16_pipeline_rocm(sdxl_unet_fp16_pipeline_mlir): VmfbManager.sdxl_unet_fp16_rocm_pipeline_vmfb = iree_compile( sdxl_unet_fp16_pipeline_mlir, @@ -300,6 +401,16 @@ def test_compile_unet_fp16_rocm(sdxl_unet_fp16_mlir): ) +def test_compile_unet_fp16_960_1024_rocm(sdxl_unet_fp16_960_1024_mlir): + VmfbManager.sdxl_unet_fp16_960_1024_rocm_vmfb = iree_compile( + sdxl_unet_fp16_960_1024_mlir, + ROCM_COMPILE_FLAGS + FP16_UNET_FLAGS, + Path(vmfb_dir) + / Path("sdxl_unet_fp16_960_1024_vmfbs") + / Path(sdxl_unet_fp16_960_1024_mlir.path.name).with_suffix(f".rocm_{rocm_chip}.vmfb"), + ) + + @pytest.mark.depends( on=["test_compile_unet_fp16_pipeline_rocm", "test_compile_unet_fp16_rocm"] ) @@ -318,6 +429,23 @@ def test_run_unet_fp16_rocm( + SDXL_UNET_FP16_COMMON_RUN_FLAGS, ) +@pytest.mark.depends( + on=["test_compile_unet_fp16_960_1024_rocm"] +) +def test_run_unet_fp16_rocm( + SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, sdxl_unet_fp16_real_weights +): + return iree_run_module( + VmfbManager.sdxl_unet_fp16_960_1024_rocm_vmfb, + device="hip", + function="run_forward", + args=[ + f"--parameters=model={sdxl_unet_fp16_real_weights.path}", + f"--module={VmfbManager.sdxl_unet_fp16_960_1024_rocm_vmfb}", + "--expected_f16_threshold=0.705f", + ] + + SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, + ) def test_compile_punet_int8_fp16_rocm(request, sdxl_punet_int8_fp16_mlir): if rocm_chip == "gfx90a": From 590832d947a2c1ba3430e1942861a7df6844ed8b Mon Sep 17 00:00:00 2001 From: Ian Date: Tue, 21 Jan 2025 10:22:42 -0600 Subject: [PATCH 2/9] fix typos Signed-off-by: Ian Signed-off-by: ianNod --- .../benchmarks/sdxl/benchmark_sdxl_rocm.py | 2 +- .../shark-test-suite-models/conftest.py | 2 + .../shark-test-suite-models/sdxl/test_unet.py | 43 +++++++++++-------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index aaf82e0f8c2f..04af1d4aeda2 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -25,7 +25,7 @@ vae_decode_dir = f"{artifacts_dir}/sdxl_vae" prompt_encoder_dir_compile = f"{vmfb_dir}/sdxl_clip_vmfbs" scheduled_unet_dir_compile = f"{vmfb_dir}/sdxl_unet_fp16_vmfbs" -unet_fp16_960_1024 = f"{vmfb_dir}/sdxl_unet_fp16_960_1024_vmfbs" +unet_fp16_960_1024_dir_compile = f"{vmfb_dir}/sdxl_unet_fp16_960_1024_vmfbs" punet_int8_fp16_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp16_vmfbs" punet_int8_fp8_dir_compile = f"{vmfb_dir}/sdxl_punet_int8_fp8_vmfbs" vae_decode_dir_compile = f"{vmfb_dir}/sdxl_vae_vmfbs" diff --git a/experimental/regression_suite/shark-test-suite-models/conftest.py b/experimental/regression_suite/shark-test-suite-models/conftest.py index 8e62bcb97274..b1292fb6ad97 100644 --- a/experimental/regression_suite/shark-test-suite-models/conftest.py +++ b/experimental/regression_suite/shark-test-suite-models/conftest.py @@ -10,6 +10,7 @@ class VmfbManager: sdxl_vae_cpu_vmfb = None sdxl_unet_fp16_cpu_vmfb = None sdxl_unet_fp16_cpu_pipeline_vmfb = None + sdxl_unet_fp16_960_1024_cpu_vfmb = None sdxl_scheduler_cpu_vmfb = None sdxl_clip_rocm_vmfb = None sdxl_vae_rocm_vmfb = None @@ -17,6 +18,7 @@ class VmfbManager: sdxl_punet_int8_fp16_rocm_vmfb = None sdxl_punet_int8_fp8_rocm_vmfb = None sdxl_unet_fp16_rocm_pipeline_vmfb = None + sdxl_unet_fp16_960_1024_rocm_vmfb = None sdxl_scheduler_rocm_vmfb = None sd3_clip_cpu_vmfb = None sd3_vae_cpu_vmfb = None diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index b26c667c932c..8f91e52bd14f 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -13,6 +13,7 @@ vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd()) rocm_chip = os.getenv("ROCM_CHIP", default="gfx942") +sku = os.getenv("SKU", default="mi300") iree_test_path_extension = os.getenv("IREE_TEST_PATH_EXTENSION", default=Path.cwd()) ############################################################################### @@ -64,36 +65,37 @@ # FP16 Model for 960x1024 image size sdxl_unet_fp16_960_1024_inference_input_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg0_latent_model_input.npy" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg0_latent_model_input.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_1 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg1_guidanc_scale.npy" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg1_guidanc_scale.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_2 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg2_prompt_embeds.npy" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg2_prompt_embeds.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_3 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg3_add_text_embeds.npy" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg3_add_text_embeds.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_4 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg4_add_time_ids.npy" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg4_add_time_ids.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_5 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg5_t.npy" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg5_t.npy", group="sdxl_unet_fp16_960_1024", +) sdxl_unet_fp16_960_1024_inference_output_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/inference_output.0.bin" + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/inference_output_0.npy", group="sdxl_unet_fp16_960_1024", ) @@ -197,13 +199,13 @@ def SDXL_UNET_FP16_COMMON_RUN_FLAGS( @pytest.fixture def SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS( - sdxl_unet_fp16_inference_input_0, - sdxl_unet_fp16_inference_input_1, - sdxl_unet_fp16_inference_input_2, - sdxl_unet_fp16_inference_input_3, - sdxl_unet_fp16_inference_input_4, - sdxl_unet_fp16_inference_input_5, - sdxl_unet_fp16_inference_output_0, + sdxl_unet_fp16_960_1024_inference_input_0, + sdxl_unet_fp16_960_1024_inference_input_1, + sdxl_unet_fp16_960_1024_inference_input_2, + sdxl_unet_fp16_960_1024_inference_input_3, + sdxl_unet_fp16_960_1024_inference_input_4, + sdxl_unet_fp16_960_1024_inference_input_5, + sdxl_unet_fp16_960_1024_inference_output_0, ): return [ f"--input=@{sdxl_unet_fp16_960_1024_inference_input_0.path}", @@ -320,9 +322,9 @@ def test_compile_unet_fp16_cpu(sdxl_unet_fp16_mlir): ) def test_compile_unet_fp16_960_1024_cpu(sdxl_unet_fp16_960_1024_mlir): - VmfbManager.sdxl_unet_fp16_960_1024_vfmb = iree_compile( + VmfbManager.sdxl_unet_fp16_960_1024_cpu_vfmb = iree_compile( sdxl_unet_fp16_960_1024_mlir, - CPU_COMIPLE_FLAGS, + CPU_COMPILE_FLAGS, Path(vmfb_dir) / Path("sdxl_unet_fp16_960_1024_vmfbs") / Path(sdxl_unet_fp16_960_1024_mlir.path.name).with_suffix(f".cpu.vmfb"), @@ -349,15 +351,18 @@ def test_run_unet_fp16_cpu( @pytest.mark.depends( on=["test_compile_unet_fp16_cpu"] +) +def test_run_unet_fp16_960_1024_cpu( + SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, sdxl_unet_fp16_real_weights ): return iree_run_module( - VmfbManager.sdxl_unet_fp16_960_1024_vfmb, + VmfbManager.sdxl_unet_fp16_960_1024_cpu_vfmb, device="local-task", function="run_forward", args=[ f"--parameters=model={sdxl_unet_fp16_real_weights.path}", - f"--module={VmfbManager.sdxl_unet_fp16_960_1024_vfmb}, - --expected_f16_threshold=0.8f", + f"--module={VmfbManager.sdxl_unet_fp16_960_1024_cpu_vfmb}", + "--expected_f16_threshold=0.8f", ] + SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, ) From ddff236cf00e4334b008e7dc0b16c266480314bf Mon Sep 17 00:00:00 2001 From: Ian Date: Tue, 21 Jan 2025 19:00:05 -0600 Subject: [PATCH 3/9] Linting Signed-off-by: Ian Signed-off-by: ianNod --- .../benchmarks/sdxl/benchmark_sdxl_rocm.py | 1 + .../shark-test-suite-models/sdxl/test_unet.py | 35 ++++++++++++------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py index 04af1d4aeda2..9331ef82c334 100644 --- a/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py +++ b/experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py @@ -182,6 +182,7 @@ def run_sdxl_unet_fp16_960_1024_rocm_benchmark(rocm_chip): # iree benchmark command for full sdxl pipeline return run_iree_command(exec_args) + def run_sdxl_prompt_encoder_rocm_benchmark(rocm_chip): exec_args = [ "iree-benchmark-module", diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 8f91e52bd14f..65268f2eff8d 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -280,15 +280,23 @@ def SDXL_PUNET_INT8_FP8_OUT( FP16_UNET_FLAGS = [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", ] -if os.path.isfile(f"{iree_test_path_extension}/attention_and_matmul_spec_unet_fp16_{sku}.mlir"): - FP16_UNET_FLAGS.append(f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_unet_fp16_{sku}.mlir") +if os.path.isfile( + f"{iree_test_path_extension}/attention_and_matmul_spec_unet_fp16_{sku}.mlir" +): + FP16_UNET_FLAGS.append( + f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_unet_fp16_{sku}.mlir" + ) INT8_PUNET_FLAGS = [ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", ] -if os.path.isfile(f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir"): - INT8_PUNET_FLAGS.append(f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir") - +if os.path.isfile( + f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" +): + INT8_PUNET_FLAGS.append( + f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" + ) + ROCM_UNET_PIPELINE_FP16_COMPILE_FLAGS = [ "--iree-hal-target-backends=rocm", @@ -321,6 +329,7 @@ def test_compile_unet_fp16_cpu(sdxl_unet_fp16_mlir): / Path(sdxl_unet_fp16_mlir.path.name).with_suffix(f".cpu.vmfb"), ) + def test_compile_unet_fp16_960_1024_cpu(sdxl_unet_fp16_960_1024_mlir): VmfbManager.sdxl_unet_fp16_960_1024_cpu_vfmb = iree_compile( sdxl_unet_fp16_960_1024_mlir, @@ -330,6 +339,7 @@ def test_compile_unet_fp16_960_1024_cpu(sdxl_unet_fp16_960_1024_mlir): / Path(sdxl_unet_fp16_960_1024_mlir.path.name).with_suffix(f".cpu.vmfb"), ) + @pytest.mark.depends( on=["test_compile_unet_fp16_pipeline_cpu", "test_compile_unet_fp16_cpu"] ) @@ -349,9 +359,7 @@ def test_run_unet_fp16_cpu( ) -@pytest.mark.depends( - on=["test_compile_unet_fp16_cpu"] -) +@pytest.mark.depends(on=["test_compile_unet_fp16_cpu"]) def test_run_unet_fp16_960_1024_cpu( SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, sdxl_unet_fp16_real_weights ): @@ -384,6 +392,7 @@ def test_compile_unet_fp16_pipeline_rocm(sdxl_unet_fp16_pipeline_mlir): ), ) + def test_compile_unet_fp16_pipeline_rocm(sdxl_unet_fp16_pipeline_mlir): VmfbManager.sdxl_unet_fp16_rocm_pipeline_vmfb = iree_compile( sdxl_unet_fp16_pipeline_mlir, @@ -412,7 +421,9 @@ def test_compile_unet_fp16_960_1024_rocm(sdxl_unet_fp16_960_1024_mlir): ROCM_COMPILE_FLAGS + FP16_UNET_FLAGS, Path(vmfb_dir) / Path("sdxl_unet_fp16_960_1024_vmfbs") - / Path(sdxl_unet_fp16_960_1024_mlir.path.name).with_suffix(f".rocm_{rocm_chip}.vmfb"), + / Path(sdxl_unet_fp16_960_1024_mlir.path.name).with_suffix( + f".rocm_{rocm_chip}.vmfb" + ), ) @@ -434,9 +445,8 @@ def test_run_unet_fp16_rocm( + SDXL_UNET_FP16_COMMON_RUN_FLAGS, ) -@pytest.mark.depends( - on=["test_compile_unet_fp16_960_1024_rocm"] -) + +@pytest.mark.depends(on=["test_compile_unet_fp16_960_1024_rocm"]) def test_run_unet_fp16_rocm( SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, sdxl_unet_fp16_real_weights ): @@ -452,6 +462,7 @@ def test_run_unet_fp16_rocm( + SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, ) + def test_compile_punet_int8_fp16_rocm(request, sdxl_punet_int8_fp16_mlir): if rocm_chip == "gfx90a": request.node.add_marker( From 70a152c1e0680bb76b74826337f5de9851aced5e Mon Sep 17 00:00:00 2001 From: Ian Date: Wed, 22 Jan 2025 20:12:36 -0600 Subject: [PATCH 4/9] Update inputs and outputs for new unet model Signed-off-by: Ian Signed-off-by: ianNod --- .../shark-test-suite-models/sdxl/test_unet.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 65268f2eff8d..9d3f96a0c8c8 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -65,37 +65,37 @@ # FP16 Model for 960x1024 image size sdxl_unet_fp16_960_1024_inference_input_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg0_latent_model_input.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/input1.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_1 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg1_guidanc_scale.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/input2.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_2 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg2_prompt_embeds.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/input3.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_3 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg3_add_text_embeds.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/input4.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_4 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg4_add_time_ids.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/input5.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_input_5 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg5_t.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/input6.npy", group="sdxl_unet_fp16_960_1024", ) sdxl_unet_fp16_960_1024_inference_output_0 = fetch_source_fixture( - "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/inference_output_0.npy", + "https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/golden_out.npy", group="sdxl_unet_fp16_960_1024", ) @@ -290,8 +290,9 @@ def SDXL_PUNET_INT8_FP8_OUT( INT8_PUNET_FLAGS = [ "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", ] + if os.path.isfile( - f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" + f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" ): INT8_PUNET_FLAGS.append( f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" From cb5dcf0a47aed6f22a31cf0056ec4b7e0886e4c8 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 23 Jan 2025 10:30:48 -0600 Subject: [PATCH 5/9] If no spec file for sku uses mi300 for punet due to numeric issues without Signed-off-by: Ian Signed-off-by: ianNod --- .../shark-test-suite-models/sdxl/test_unet.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index 9d3f96a0c8c8..a2426b93da19 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -297,6 +297,11 @@ def SDXL_PUNET_INT8_FP8_OUT( INT8_PUNET_FLAGS.append( f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" ) +else: + # TODO: Investigate numerics failure without using the MI300 punet attention spec + INT8_PUNET_FLAGS.append( + f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_mi300.mlir" + ) ROCM_UNET_PIPELINE_FP16_COMPILE_FLAGS = [ From bcb87ffb58eca6bbce25a5b04876b3fd2738a9b7 Mon Sep 17 00:00:00 2001 From: Ian Date: Thu, 23 Jan 2025 14:51:12 -0600 Subject: [PATCH 6/9] Linting Signed-off-by: Ian Signed-off-by: ianNod --- .../shark-test-suite-models/sdxl/test_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py index a2426b93da19..e60836207114 100644 --- a/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py +++ b/experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py @@ -292,13 +292,13 @@ def SDXL_PUNET_INT8_FP8_OUT( ] if os.path.isfile( - f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" + f"{iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" ): INT8_PUNET_FLAGS.append( f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_{sku}.mlir" ) else: - # TODO: Investigate numerics failure without using the MI300 punet attention spec + # TODO: Investigate numerics failure without using the MI300 punet attention spec INT8_PUNET_FLAGS.append( f"--iree-codegen-transform-dialect-library={iree_test_path_extension}/attention_and_matmul_spec_punet_mi300.mlir" ) From 250470396ad3d09f525a69df8091694b2ed27793 Mon Sep 17 00:00:00 2001 From: ianNod Date: Thu, 23 Jan 2025 15:37:03 -0600 Subject: [PATCH 7/9] Add MI308 to benchmark regression tests Signed-off-by: ianNod --- .github/workflows/pkgci_regression_test.yml | 29 +++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index fd2b21f7f679..50a6cd64a73f 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -173,3 +173,32 @@ jobs: --timeout=600 \ --retries 7 echo "$(> $GITHUB_STEP_SUMMARY + # Note: allowing 10% deviation from observed averages here to account for + # different runner conditions. + - name: "Running SDXL rocm pipeline benchmark (mi300)" + if: contains(matrix.name, 'rocm_mi308_gfx942') + run: | + source ${VENV_DIR}/bin/activate + pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \ + --goldentime-tolerance-multiplier 1.1 \ + --goldentime-rocm-e2e-ms 797.0 \ + --goldentime-rocm-unet-ms 195.0 \ + --goldentime-rocm-clip-ms 15.0 \ + --goldentime-rocm-vae-ms 190.0 \ + --goldendispatch-rocm-unet 1602 \ + --goldendispatch-rocm-clip 1139 \ + --goldendispatch-rocm-vae 246 \ + --goldensize-rocm-unet-bytes 2270000 \ + --goldensize-rocm-clip-bytes 860000 \ + --goldensize-rocm-vae-bytes 840000 \ + --goldentime-rocm-punet-int8-fp16-ms 138.0 \ + --goldentime-rocm-punet-int8-fp8-ms 147 \ + --goldendispatch-rocm-punet-int8-fp16 1424 \ + --goldendispatch-rocm-punet-int8-fp8 1704 \ + --goldensize-rocm-punet-int8-fp8-bytes 2800000 \ + --goldensize-rocm-punet-int8-fp16-bytes 2560000 \ + --rocm-chip gfx942 \ + --log-cli-level=info \ + --timeout=600 \ + --retries 7 + echo "$(> $GITHUB_STEP_SUMMARY From dab8fb7456dee45b6ee4a340d726c60a3a6c25ca Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:00:44 -0800 Subject: [PATCH 8/9] Clarify runner name --- .github/workflows/pkgci_regression_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 50a6cd64a73f..81e537ed2733 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -175,7 +175,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY # Note: allowing 10% deviation from observed averages here to account for # different runner conditions. - - name: "Running SDXL rocm pipeline benchmark (mi300)" + - name: "Running SDXL rocm pipeline benchmark (mi308)" if: contains(matrix.name, 'rocm_mi308_gfx942') run: | source ${VENV_DIR}/bin/activate From 112d2e8630c9cfc89882dcefa22f4d5c022a29f7 Mon Sep 17 00:00:00 2001 From: ianNod Date: Fri, 24 Jan 2025 17:37:11 -0600 Subject: [PATCH 9/9] Loosen benchmark precision times for runner tolerances Signed-off-by: ianNod --- .github/workflows/pkgci_regression_test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 81e537ed2733..d5828367f4bf 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -181,7 +181,7 @@ jobs: source ${VENV_DIR}/bin/activate pytest ./experimental/benchmarks/sdxl/benchmark_sdxl_rocm.py \ --goldentime-tolerance-multiplier 1.1 \ - --goldentime-rocm-e2e-ms 797.0 \ + --goldentime-rocm-e2e-ms 800.0 \ --goldentime-rocm-unet-ms 195.0 \ --goldentime-rocm-clip-ms 15.0 \ --goldentime-rocm-vae-ms 190.0 \ @@ -191,8 +191,8 @@ jobs: --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ - --goldentime-rocm-punet-int8-fp16-ms 138.0 \ - --goldentime-rocm-punet-int8-fp8-ms 147 \ + --goldentime-rocm-punet-int8-fp16-ms 140.0 \ + --goldentime-rocm-punet-int8-fp8-ms 150 \ --goldendispatch-rocm-punet-int8-fp16 1424 \ --goldendispatch-rocm-punet-int8-fp8 1704 \ --goldensize-rocm-punet-int8-fp8-bytes 2800000 \