Skip to content

Commit

Permalink
fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Jan 21, 2025
1 parent 6c54992 commit 83fbc55
Showing 1 changed file with 21 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

vmfb_dir = os.getenv("TEST_OUTPUT_ARTIFACTS", default=Path.cwd())
rocm_chip = os.getenv("ROCM_CHIP", default="gfx942")
sku = os.getenv("SKU", default="mi300")
iree_test_path_extension = os.getenv("IREE_TEST_PATH_EXTENSION", default=Path.cwd())

###############################################################################
Expand Down Expand Up @@ -64,36 +65,37 @@
# FP16 Model for 960x1024 image size

sdxl_unet_fp16_960_1024_inference_input_0 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg0_latent_model_input.npy"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg0_latent_model_input.npy",
group="sdxl_unet_fp16_960_1024",
)

sdxl_unet_fp16_960_1024_inference_input_1 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg1_guidanc_scale.npy"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg1_guidanc_scale.npy",
group="sdxl_unet_fp16_960_1024",
)

sdxl_unet_fp16_960_1024_inference_input_2 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg2_prompt_embeds.npy"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg2_prompt_embeds.npy",
group="sdxl_unet_fp16_960_1024",
)

sdxl_unet_fp16_960_1024_inference_input_3 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg3_add_text_embeds.npy"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg3_add_text_embeds.npy",
group="sdxl_unet_fp16_960_1024",
)

sdxl_unet_fp16_960_1024_inference_input_4 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg4_add_time_ids.npy"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg4_add_time_ids.npy",
group="sdxl_unet_fp16_960_1024",
)

sdxl_unet_fp16_960_1024_inference_input_5 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg5_t.npy"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/arg5_t.npy",
group="sdxl_unet_fp16_960_1024",
)

sdxl_unet_fp16_960_1024_inference_output_0 = fetch_source_fixture(
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/inference_output.0.bin"
"https://sharkpublic.blob.core.windows.net/sharkpublic/ian/unet_npys/inference_output.0.bin",
group="sdxl_unet_fp16_960_1024",
)

Expand Down Expand Up @@ -197,13 +199,13 @@ def SDXL_UNET_FP16_COMMON_RUN_FLAGS(

@pytest.fixture
def SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS(
sdxl_unet_fp16_inference_input_0,
sdxl_unet_fp16_inference_input_1,
sdxl_unet_fp16_inference_input_2,
sdxl_unet_fp16_inference_input_3,
sdxl_unet_fp16_inference_input_4,
sdxl_unet_fp16_inference_input_5,
sdxl_unet_fp16_inference_output_0,
sdxl_unet_fp16_960_1024_inference_input_0,
sdxl_unet_fp16_960_1024_inference_input_1,
sdxl_unet_fp16_960_1024_inference_input_2,
sdxl_unet_fp16_960_1024_inference_input_3,
sdxl_unet_fp16_960_1024_inference_input_4,
sdxl_unet_fp16_960_1024_inference_input_5,
sdxl_unet_fp16_960_1024_inference_output_0,
):
return [
f"--input=@{sdxl_unet_fp16_960_1024_inference_input_0.path}",
Expand Down Expand Up @@ -349,15 +351,18 @@ def test_run_unet_fp16_cpu(

@pytest.mark.depends(
on=["test_compile_unet_fp16_cpu"]
)
def test_run_unet_fp16_960_1024_cpu(
SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS, sdxl_unet_fp16_real_weights
):
return iree_run_module(
VmfbManager.sdxl_unet_fp16_960_1024_vfmb,
device="local-task",
function="run_forward",
args=[
f"--parameters=model={sdxl_unet_fp16_real_weights.path}",
f"--module={VmfbManager.sdxl_unet_fp16_960_1024_vfmb},
--expected_f16_threshold=0.8f",
f"--module={VmfbManager.sdxl_unet_fp16_960_1024_vfmb}",
"--expected_f16_threshold=0.8f",
]
+ SDXL_UNET_FP16_960_1024_COMMON_RUN_FLAGS,
)
Expand Down

0 comments on commit 83fbc55

Please sign in to comment.